# -*- coding: utf-8 -*-
"""Fit a Rotated ICRS reference frame."""
__all__ = [
"RotatedFrameFitter",
"cartesian_model",
"residual",
]
##############################################################################
# IMPORTS
# STDLIB
import copy
import functools
import typing as T
from types import MappingProxyType
# THIRD PARTY
import astropy.coordinates as coord
import astropy.units as u
import numpy as np
import scipy.optimize as opt
from astropy.utils.decorators import lazyproperty
# LOCAL
from trackstream.config import conf
from trackstream.setup_package import HAS_LMFIT
from trackstream.utils import cartesian_to_spherical, reference_to_skyoffset_matrix
if HAS_LMFIT:
# THIRD PARTY
import lmfit as lf
##############################################################################
# PARAMETERS
FT = T.TypeVar("FT")
##############################################################################
# CODE
##############################################################################
@T.overload
def scipy_residual_to_lmfit(function: None, *, param_order: T.List[str]) -> functools.partial:
...
@T.overload
def scipy_residual_to_lmfit(function: FT, *, param_order: T.List[str]) -> FT: # noqa: F811
...
def scipy_residual_to_lmfit(function=None, *, param_order): # noqa: F811
"""Decorator to make scipy residual functions compatible with lmfit.
Parameters
----------
function : callable
The residual function.
param_order : list of strs
The variable order used by lmfit.
Strings are the names of the lmfit parameters.
Must be in the same order as the scipy residual function.
Returns
-------
function : callable
The same as ``function``.
"""
# allow for @-syntax
if function is None:
return functools.partial(scipy_residual_to_lmfit, param_order=param_order)
def lmfit(params: T.Mapping[str, T.Any], *args: T.Any, **kwargs: T.Any) -> T.Sequence:
""":mod:`lmfit` version of function.
Parameters
----------
params : `~lmfit.Parameters`
*args, **kwargs : Any
"""
variables: T.List[T.Any] = [params[n].value for n in param_order]
return function(variables, *args, **kwargs)
# /def
# attach lmfit version to original function
function.lmfit = lmfit
return function
# -------------------------------------------------------------------
def cartesian_model(
data: coord.CartesianRepresentation,
*,
lon: T.Union[u.Quantity, float],
lat: T.Union[u.Quantity, float],
rotation: T.Union[u.Quantity, float],
deg: bool = True,
) -> T.Tuple:
"""Model from Cartesian Coordinates.
Parameters
----------
data : |CartesianRep|
Cartesian representation of the data.
lon, lat : float or |Angle| or |Quantity| instance
The longitude and latitude origin for the reference frame.
If float, assumed degrees.
rotation : float or |Angle| or |Quantity| instance
The final rotation of the frame about the ``origin``. The sign of
the rotation is the left-hand rule. That is, an object at a
particular position angle in the un-rotated system will be sent to
the positive latitude (z) direction in the final frame.
If float, assumed degrees.
Returns
-------
r, lat, lon : array_like
Same shape as `x`, `y`, `z`.
Other Parameters
----------------
deg : bool
whether to return `lat` and `lon` as degrees
(default True) or radians.
"""
rot_matrix = reference_to_skyoffset_matrix(lon, lat, rotation)
rot_xyz = np.dot(rot_matrix, data.xyz.value).reshape(-1, len(data))
lon, lat, r = cartesian_to_spherical(*rot_xyz, deg=deg)
return r, lon, lat
# -------------------------------------------------------------------
@scipy_residual_to_lmfit(param_order=["rotation", "lon", "lat"])
def residual(
variables: T.Sequence,
data: coord.CartesianRepresentation,
scalar: bool = False,
) -> T.Union[float, T.Sequence]:
r"""How close phi2, the rotated latitude (dec), is to flat.
Parameters
----------
variables : Sequence[float]
(rotation, lon, lat)
- rotation angle : float
The final rotation of the frame about the ``origin``. The sign of
the rotation is the left-hand rule. That is, an object at a
particular position angle in the un-rotated system will be sent to
the positive latitude (z) direction in the final frame.
In degrees.
- lon, lat : float
In degrees. If ICRS, equivalent to ra & dec.
data : |CartesianRep|
eg. ``ICRS.cartesian``
Returns
-------
res : float or Sequence
:math:`\rm{lat} - 0`.
If `scalar` is True, then sum array_like to return float.
Other Parameters
----------------
scalar : bool (optional, keyword-only)
Whether to sum `res` into a float.
Note that if `res` is also a float, it is unaffected.
"""
rotation = variables[0]
lon = variables[1]
lat = variables[2]
r, lon, lat = cartesian_model(
data,
lon=lon,
lat=lat,
rotation=rotation,
deg=True,
)
res = np.abs(lat - 0.0) # phi2 - 0
if scalar:
return np.sum(res)
return res
#####################################################################
[docs]class RotatedFrameFitter(object):
"""Class to Fit Rotated Frames.
.. todo::
include errors.
Parameters
----------
data : :class:`~astropy.coordinates.BaseCoordinateFrame`
In ICRS coordinates.
origin : :class:`~astropy.coordinates.ICRS`
location of point on sky about which to rotate.
Other Parameters
----------------
rot_lower, rot_upper : |Quantity|, (optional, keyword-only)
The lower and upper bounds in degrees.
Default is (-180, 180] degree.
origin_lim : |Quantity|, (optional, keyword-only)
The symmetric lower and upper bounds on origin in degrees.
Default is 0.005 degree.
fix_origin : bool (optional, keyword-only)
Whether to fix the origin point. Default is False.
use_lmfit : bool or None, (optional, keyword-only)
Whether to use ``lmfit`` package.
None (default) falls back to config file.
leastsquares : bool (optional, keyword-only)
If `use_lmfit` is False, whether to to use
:func:`~scipy.optimize.least_square` or
:func:`~scipy.optimize.minimize`
Default is False
align_v : bool
Whether to align by the velocity.
"""
def __init__(self, data: coord.BaseCoordinateFrame, origin: coord.ICRS, **kwargs):
super().__init__()
self.data = data
self.origin = origin
# -------------
# create bounds
bounds_args = {
k: kwargs.pop(k) for k in ("rot_lower", "rot_upper", "origin_lim") if k in kwargs
}
self.set_bounds(**bounds_args)
# -------------
# process options
self._default_options = dict(
fix_origin=kwargs.pop("fix_origin", False),
use_lmfit=kwargs.pop("use_lmfit", None),
leastsquares=kwargs.pop("leastsquares", False),
)
# determine whether velocity exists to break +/- 180 degree
# degeneracy If it does, call the `align_v` option in `fit_frame`
align_v = kwargs.pop("align_v", None)
if align_v and "s" not in self.data.data.differentials:
raise ValueError
if align_v is None and "s" in self.data.data.differentials:
align_v = True
self._default_options["align_v"] = align_v
# Minimizer kwargs are the leftovers
self.fitter_kwargs = kwargs
# /def
@property
def default_fit_options(self):
return MappingProxyType(dict(**self._default_options, **self.fitter_kwargs))
#######################################################
# @u.quantity_input(rot_lower=u.deg, rot_upper=u.deg, origin_lim=u.deg)
[docs] def set_bounds(
self,
rot_lower: u.Quantity = -180.0 * u.deg,
rot_upper: u.Quantity = 180.0 * u.deg,
origin_lim: u.Quantity = 0.005 * u.deg,
) -> T.Tuple[float, float]:
"""Make bounds on Rotation parameter.
Parameters
----------
rot_lower, rot_upper : |Quantity|, optional
The lower and upper bounds in degrees.
origin_lim : |Quantity|, optional
The symmetric lower and upper bounds on origin in degrees.
"""
origin = self.origin.data.represent_as(coord.UnitSphericalRepresentation)
rotation_bounds = (rot_lower.to_value(u.deg), rot_upper.to_value(u.deg))
# longitude bounds (ra in ICRS).
lon_bounds = (origin.lon + (-1, 1) * origin_lim).to_value(u.deg)
# latitude bounds (dec in ICRS).
lat_bounds = (origin.lat + (-1, 1) * origin_lim).to_value(u.deg)
# stack bounds so rows are bounds.
bounds = np.c_[rotation_bounds, lon_bounds, lat_bounds].T
self.bounds = bounds
# /def
[docs] def align_v_positive_lon(
self,
fit_values: T.Dict[str, T.Any],
subsel: T.Union[type(Ellipsis), T.Sequence, slice] = Ellipsis,
):
"""Align the velocity along the positive Longitudinal direction.
Parameters
----------
fit_values : dict
The rotation and origin. Output of `~minimize`
subsel : slice
sub-select a portion of the `pm_lon_coslat` for determining
the average velocity.
Returns
-------
values : dict
`fit_values` with "rotation" adjusted.
"""
values = copy.deepcopy(fit_values) # copy for safety
rotation = values["rotation"]
# make frame
frame = coord.SkyOffsetFrame(**values) # make frame
frame.differential_type = coord.SphericalCosLatDifferential
rot_data = self.data.transform_to(frame)
# rot_datarot_data.represent_as(coord.SphericalRepresentation)
# # all this to get the rotated velocity
# # TODO faster!
# rot_matrix = reference_to_skyoffset_matrix(
# lon=origin.lon, lat=origin.lat, rotation=rotation
# )
# rot_data = data.transform(rot_matrix).represent_as(
# coord.SphericalRepresentation,
# differential_class=coord.SphericalCosLatDifferential,
# )
# rot_vel = rot_data.differentials["s"]
# get average velocity to determine whether need to rotate.
# TODO determine whether
avg = np.median(rot_data.pm_lon_coslat[subsel])
if avg < 0: # need to flip
rotation = rotation + 180 * u.deg
return values
# /def
#######################################################
# Fitting
[docs] def residual(self, rotation, *, scalar: bool = False):
r"""How close phi2, the rotated latitude (dec), is to flat.
Parameters
----------
rotation : float
The final rotation of the frame about the ``origin``. The sign of
the rotation is the left-hand rule. That is, an object at a
particular position angle in the un-rotated system will be sent to
the positive latitude (z) direction in the final frame.
In degrees.
Returns
-------
res : float or Sequence
:math:`\rm{lat} - 0`.
If `scalar` is True, then sum array_like to return float.
Other Parameters
----------------
scalar : bool (optional, keyword-only)
Whether to sum `res` into a float.
Note that if `res` is also a float, it is unaffected.
"""
variables = (
rotation,
self.origin.ra.to_value(u.deg),
self.origin.dec.to_value(u.deg),
)
return residual(variables, self.data.cartesian, scalar=scalar)
# /def
def _fit_representation_scipy(
self,
data: coord.CartesianRepresentation,
x0: T.Sequence[float],
*,
bounds: np.ndarray,
fix_origin: bool,
use_leastsquares: bool,
**kw,
):
if fix_origin:
bounds[1, :] = np.average(bounds[1, :])
bounds[2, :] = np.average(bounds[2, :])
raise NotImplementedError("TODO")
if use_leastsquares:
method = kw.pop("method", "trf")
res = opt.least_squares(
residual,
x0=x0,
args=(data, False),
method=method,
bounds=bounds.T,
**kw,
)
else:
method = kw.pop("method", "slsqp")
res = opt.minimize(
residual,
x0=x0,
args=(data, True),
method=method,
bounds=bounds,
**kw,
)
values = res.x * u.deg
return res, values
# /def
def _fit_representation_lmfit(
self,
data: coord.CartesianRepresentation,
x0: T.Sequence[float],
*,
bounds: np.ndarray,
fix_origin: bool,
**kw,
):
if np.shape(bounds) == (2,):
rot_bnd = lon_bnd = lat_bnd = bounds
elif np.shape(bounds) == (3, 2):
rot_bnd, lon_bnd, lat_bnd = bounds
params = lf.Parameters()
params.add_many(
("rotation", x0[0], True, rot_bnd[0], rot_bnd[1]),
("lon", x0[1], not fix_origin, lon_bnd[0], lon_bnd[1]),
("lat", x0[2], not fix_origin, lat_bnd[0], lat_bnd[1]),
)
method = kw.pop("method", "powell")
res = lf.minimize(
residual.lmfit,
params,
kws=dict(data=data, scalar=False),
method=method,
calc_covar=True,
**kw,
)
values = np.array(tuple(res.params.valuesdict().values())) * u.deg
return res, values
# /def
# @u.quantity_input(rot0=u.deg)
[docs] def fit(
self,
rot0: T.Optional[u.Quantity] = None,
bounds: T.Optional[T.Sequence] = None,
*,
fix_origin: T.Optional[bool] = None,
use_lmfit: T.Optional[bool] = None,
leastsquares: T.Optional[bool] = None,
align_v: T.Optional[bool] = None,
**kwargs,
):
"""Find Best-Fit Rotated Frame.
Parameters
----------
rot0 : |Quantity|, optional
Initial guess for rotation
bounds : array-like, optional
Parameter bounds.
::
[[rot_low, rot_up],
[lon_low, lon_up],
[lat_low, lat_up]]
Returns
-------
res : Any
The result of the minimization. Depends on arguments.
Dict[str, Any]
Has fields "rotation" and "origin".
Other Parameters
----------------
fix_origin : bool (optional, keyword-only)
Whether to fix the origin.
use_lmfit : bool (optional, keyword-only)
Whether to use ``lmfit`` package
leastsquares : bool (optional, keyword-only)
If `use_lmfit` is False, whether to to use
:func:`~scipy.optimize.least_square` or
:func:`~scipy.optimize.minimize` (default)
align_v : bool (optional, keyword-only)
Whether to align velocity to be in positive direction
fit_kwargs:
Into whatever minimization package / function is used.
Raises
------
ImportError
If ``use_lmfit`` and :mod:`lmfit` is not installed.
"""
# -----------------------------
# Prepare
if rot0 is None:
rot0 = self.fitter_kwargs.get("rot0", None)
if rot0 is None:
raise ValueError("no prespecified `rot0`; Need to provide one.")
if bounds is None:
bounds = self.bounds
if fix_origin is None:
fix_origin = self._default_options["fix_origin"]
if use_lmfit is None:
fix_origin = self._default_options["use_lmfit"]
if use_lmfit is None: # still None
use_lmfit = conf.use_lmfit
if leastsquares is None:
leastsquares = self._default_options["leastsquares"]
if align_v is None:
align_v = self._default_options["align_v"]
# kwargs, preferring newer
kwargs = {**self.fitter_kwargs, **kwargs}
# -----------------------------
# Origin
# We work with a SphericalRepresentation, but
origin_frame = self.origin.__class__
origin = self.origin.represent_as(coord.SphericalRepresentation)
x0 = u.Quantity([rot0, origin.lon, origin.lat]).to_value(u.deg)
subsel = kwargs.pop("subsel", Ellipsis)
if use_lmfit: # lmfit
if not HAS_LMFIT:
raise ImportError("`lmfit` package not available.")
fit_result, values = self._fit_representation_lmfit(
self.data.cartesian,
x0=x0,
bounds=bounds,
fix_origin=fix_origin,
**kwargs,
)
else: # scipy
fit_result, values = self._fit_representation_scipy(
self.data.cartesian,
x0=x0,
bounds=bounds,
fix_origin=fix_origin,
use_leastsquares=leastsquares,
**kwargs,
)
# /def
# -----------------------------
best_rot = values[0]
best_origin = coord.UnitSphericalRepresentation(
lon=values[1],
lat=values[2], # TODO re-add distance
)
best_origin = origin_frame(best_origin)
values = dict(rotation=best_rot, origin=best_origin)
if align_v:
values = self.align_v_positive_lon(values, subsel=subsel)
return FitResult(self.data, fitresult=fit_result, **values)
# /def
#######################################################
# Plot
[docs] def plot_data(self):
# THIRD PARTY
import matplotlib.pyplot as plt
plt.scatter(self.data.ra, self.data.dec)
# plt.ylim(-90, 90)
# return fig
# /def
[docs] def plot_residual(
self,
fitresult=None,
num_rots: int = 3600,
scalar: bool = True,
):
"""Plot Residual as a function of rotation angle."""
# LOCAL
from .plot import plot_rotation_frame_residual
fig = plot_rotation_frame_residual(
self.data,
self.origin,
num_rots=num_rots,
scalar=scalar,
)
if fitresult is not None:
fitresult.plot_on_residual(scalar=scalar)
return fig
# /def
# /class
# -------------------------------------------------------------------
class FitResult:
"""Result of Fit.
Parameters
----------
data : |Frame|
In ICRS coordinates.
fit_values : Dict[str, Any]
Has keys "rotation" and "origin".
fitresult : Any, optional
Attributes
----------
data : |Frame|
Transformed to |SkyOffsetFrame|
fitresult : Any, optional
fit_values : MappingProxy
Has keys "rotation" and "origin".
frame
residual
residual_scalar
Methods
-------
plot_data
plot_on_residual
"""
def __init__(self, data, origin, rotation, fitresult=None):
self._origin = origin
self._rotation = rotation
self.fitresult = fitresult
self.data = data.transform_to(self.frame)
# /def
@property
def origin(self):
return self._origin
# /def
@property
def rotation(self):
return self._rotation
# /def
@lazyproperty
def fit_values(self):
return MappingProxyType(dict(origin=self.origin, rotation=self.rotation))
# /def
@lazyproperty
def frame(self):
"""`~astropy.coordinates.SkyOffsetFrame`."""
# make frame # TODO ensure same as `make_frame`
frame = coord.SkyOffsetFrame(**self.fit_values)
frame.differential_type = coord.SphericalCosLatDifferential
return frame
# /def
@lazyproperty
def residual(self):
"""Fit result residual."""
return np.abs(self.data.lat - 0.0)
# /def
@property
def residual_scalar(self):
return np.sum(self.residual)
# /def
@lazyproperty
def lon_order(self):
"""Order data by longitude.
Returns
-------
order : ndarray
"""
orderer = np.argsort(self.data.lon)
return orderer
# /def
# ---------------------
def __repr__(self):
return f"FitResult({self.fit_values})"
# /def
# ---------------------
def plot_data(self):
# THIRD PARTY
import matplotlib.pyplot as plt
plt.scatter(self.data.lon, self.data.lat)
plt.ylim(-90, 90)
# return fig
# /def
def plot_on_residual(self, scalar: bool = True):
# THIRD PARTY
import matplotlib.pyplot as plt
if scalar:
theta = self.fit_values["rotation"]
# plt.axvline(theta)
plt.scatter(theta, self.residual_scalar, c="r")
else:
raise NotImplementedError
# /def
# /class
##############################################################################
# END