# -*- coding: utf-8 -*-
"""Core Functions."""
__all__ = [
"TrackStream",
"StreamTrack",
]
##############################################################################
# IMPORTS
# STDLIB
import typing as T
# THIRD PARTY
import astropy.coordinates as coord
import astropy.units as u
import numpy as np
from astropy.table import Table
from astropy.utils.metadata import MetaAttribute, MetaData
from astropy.utils.misc import indent
from scipy.linalg import block_diag
# LOCAL
from . import _type_hints as TH
from .stream import Stream
from trackstream.preprocess.rotated_frame import FitResult, RotatedFrameFitter
from trackstream.preprocess.som import SelfOrganizingMap1D, order_data, reorder_visits
from trackstream.process.kalman import KalmanFilter, kalman_output
from trackstream.process.utils import make_dts, make_F, make_H, make_Q, make_R
from trackstream.utils.misc import intermix_arrays
from trackstream.utils.path import Path, path_moments
##############################################################################
# CODE
##############################################################################
[docs]class TrackStream:
"""Track a Stream in ICRS coordinates.
When run, produces a StreamTrack.
Parameters
----------
arm1SOM, arm2SOM : `~trackstream.preprocess.SelfOrganizingMap` or None (optional, keyword-only)
Fiducial SOMs for stream arms 1 and 2, respectively.
"""
def __init__(self, *, arm1SOM=None, arm2SOM=None):
self._cache: T.Dict[str, object] = {}
# SOM
self._arm1_SOM = arm1SOM
self._arm2_SOM = arm2SOM
# ===============================================================
# Fit
def _fit_rotated_frame(
self,
stream: Stream,
rot0: T.Optional[u.Quantity] = 0 * u.deg,
bounds: T.Optional[T.Sequence] = None,
**kwargs,
):
"""Fit a rotated frame in ICRS coordinates.
Parameters
----------
rot0 : |Quantity| or None.
Initial guess for rotation.
bounds : array-like or None, optional
Parameter bounds. If None, these are automatically constructed.
::
[[rot_low, rot_up],
[lon_low, lon_up],
[lat_low, lat_up]]
Other Parameters
----------------
rot_lower, rot_upper : |Quantity|, (optional, keyword-only)
The lower and upper bounds in degrees.
Default is (-180, 180] degree.
origin_lim : |Quantity|, (optional, keyword-only)
The symmetric lower and upper bounds on origin in degrees.
Default is 0.005 degree.
fix_origin : bool or None (optional, keyword-only)
Whether to fix the origin point. Default is False.
use_lmfit : bool or None (optional, keyword-only)
Whether to use ``lmfit`` package.
None (default) falls back to config file.
leastsquares : bool or None (optional, keyword-only)
If `use_lmfit` is False, whether to to use
:func:`~scipy.optimize.least_square` or
:func:`~scipy.optimize.minimize`
Default is False
align_v : bool or None (optional, keyword-only)
Whether to align velocity to be in positive direction
Raises
------
TypeError
If ``_data_frame`` is None
"""
fitter = RotatedFrameFitter(
data=stream.data_coords,
origin=stream.origin,
**kwargs,
)
fitted = fitter.fit(rot0=rot0, bounds=bounds)
return fitted.frame, fitted
# -------------------------------------------
def _fit_SOM(
self,
arm,
som=None,
*,
learning_rate: float = 0.1,
sigma: float = 1.0,
iterations: int = 10_000,
random_seed: T.Optional[int] = None,
reorder: T.Optional[int] = None,
progress: bool = False,
nlattice: T.Optional[int] = None,
**kwargs,
):
"""Reorder data by SOM.
.. todo::
- iterative training
Parameters
----------
arm : SkyCoord
som : object or None (optional)
The self-organizing map. If None, will be constructed.
learning_rate : float (optional, keyword-only)
sigma : float (optional, keyword-only)
iterations : int (optional, keyword-only)
random_seed : int or None (optional, keyword-only)
reorder : int or None (optional, keyword-only)
progress : bool (optional, keyword-only)
Whether to show progress bar.
"""
# # TODO! iterative training
rep = arm.represent_as(coord.SphericalRepresentation)
data = rep._values.view("f8").reshape(-1, len(rep.components))
data[:, :2] *= u.rad.to(u.deg) # rad -> deg
if som is None:
data_len, nfeature = data.shape
if nlattice is None:
nlattice = data_len // 10 # allows to be variable
if nlattice == 0:
raise ValueError
som = SelfOrganizingMap1D(
nlattice,
nfeature,
sigma=sigma,
learning_rate=learning_rate,
# decay_function=None,
neighborhood_function="gaussian",
activation_distance="euclidean",
random_seed=random_seed,
)
# call method to initialize SOM weights
weight_init_method = kwargs.get("weight_init_method", "binned_weights_init")
getattr(som, weight_init_method)(data, **kwargs)
som.train(
data,
iterations,
verbose=False,
random_order=False,
progress=progress,
)
# get the ordering by "vote" of the Prototypes
visit_order = order_data(som, data)
# Reorder
if reorder is not None:
visit_order = reorder_visits(rep, visit_order, start_ind=reorder)
# ----------------------------
# TODO! transition matrix
return visit_order, som
def _fit_kalman_filter(
self,
data: coord.SkyCoord,
w0: T.Optional[np.ndarray] = None,
) -> T.Union[kalman_output, KalmanFilter, np.ndarray]:
"""Fit data with Kalman filter.
Parameters
----------
stream : Stream
w0 : array or None (optional)
The starting point of the Kalman filter.
Returns
-------
mean_path
kalman_filter
"""
arr = data.cartesian.xyz.T.value
dts = make_dts(arr, dt0=0.5, N=6, axis=1, plot=False)
# starting point
if w0 is None: # need to determine a good starting point
# Instead of choosing the first point as the starting point,
# since the stream is in its frame, instead choose the locus of
# points near the origin.
x = arr[:3].mean(axis=0) # fist point
v = [0, 0, 0] # guess for "velocity"
w0 = intermix_arrays(x, v)
# TODO! as options
p = np.array([[0.0001, 0], [0, 1]])
P0 = block_diag(p, p, p)
H0 = make_H()
R0 = make_R([0.05, 0.05, 0.003])[0] # TODO! actual errors
kf = KalmanFilter(
w0,
P0,
F0=make_F,
Q0=make_Q,
H0=H0,
R0=R0,
q_kw=dict(var=0.01, n_dims=3), # TODO! as options
)
smooth_mean_path = kf.fit(
arr,
dts,
method="stepupdate",
use_filterpy=None,
)
return smooth_mean_path, kf, dts
# -------------------------------------------
[docs] def fit(
self,
stream,
*,
fit_frame_if_needed: bool = True,
rotated_frame_fit_kw: T.Optional[dict] = None,
fit_SOM_if_needed: bool = True,
som_fit_kw: T.Optional[dict] = None,
kalman_fit_kw: T.Optional[dict] = None,
):
"""Fit a data to the data.
Parameters
----------
fit_frame : bool
Only fits frame if ``self.frame`` is None
The fit frame is ICRS always.
.. todo::
make fitting work in the frame of the data
Returns
-------
StreamTrack instance
Also stores as ``.track``
"""
# -------------------
# Fit Rotated Frame
# this step applies to all arms. In fact, it will perform better if both
# arms are present, limiting the influence of the tails on the frame
# orientation.
# 1) Already provided or in cache.
# Either way, don't need to repeat the process.
frame: T.Optional[coord.BaseCoordinateFrame] = stream._system_frame
frame_fit: T.Optional[FitResult] = self._cache.get("frame_fit", None)
# 2) Fit (& cache), if still None.
# This can be turned off using `fit_frame_if_needed`, but is
# probably more important than the following step of the SOM.
if frame is None and fit_frame_if_needed:
kw: dict = rotated_frame_fit_kw or {}
frame, frame_fit = self._fit_rotated_frame(stream, **kw)
# 3) if it's still None, give up
if frame is None:
frame: coord.BaseCoordinateFrame = stream.data_coord.frame.replicate_without_data()
frame_fit = None
# Cache the fit frame on the stream. This is used for transforming the
# coordinates into the system frame (if that wasn't provided to the
# stream on initialization).
self._cache["frame"] = frame # SkyOffsetICRS
self._cache["frame_fit"] = frame_fit
stream._cache["frame"] = frame
# get arms, in frame
# do this after caching b/c coords can use the cache
arm1: coord.SkyCoord = stream.arm1.coords
arm2: coord.SkyCoord = stream.arm2.coords
# -------------------
# Self-Organizing Map
# Unlike the previous step, we must do this for both arms.
# -----
# Arm 1
som = self._arm1_SOM
visit_order = None
# 1) try to get from cache
if som is None:
visit_order = self._cache.get("arm1_visit_order", None)
som = self._cache.get("arm1_SOM", None)
# 2) fit, if still None
if visit_order is None and fit_SOM_if_needed:
som_fit_kw = som_fit_kw or {}
visit_order, som = self._fit_SOM(arm1, som=som, **som_fit_kw)
# 3) if it's still None, give up
if visit_order is None:
visit_order = np.argsort(arm1.lon)
# now rearrange the data
visit_order = np.array(visit_order, dtype=int)
# the visit order can be backward so need to detect proximity to origin
# TODO! more careful if closest point not end point. & adjust SOM!
arm1ep = arm1[visit_order[[0, -1]]] # end points
if np.argmin(arm1ep.separation_3d(stream.origin)) == 1:
visit_order = visit_order[::-1]
arm1 = arm1[visit_order]
# cache (even if None)
self._cache["arm1_visit_order"] = visit_order
self._cache["arm1_SOM"] = som
# -----
# Arm 2 (if not None)
som = self._arm2_SOM
visit_order = None
if arm2 is not None:
# 1) try to get from cache
if som is None:
visit_order = self._cache.get("arm2_visit_order", None)
som = self._cache.get("arm2_SOM", None)
# 2) fit, if still None
if visit_order is None and fit_SOM_if_needed:
som_fit_kw = som_fit_kw or {}
visit_order, som = self._fit_SOM(arm2, **som_fit_kw)
# 3) if it's still None, give up
if visit_order is None:
visit_order = np.argsort(arm2.lon)
visit_order = np.array(visit_order, dtype=int)
# now rearrange the data
# the visit order can be backward so need to detect proximity to origin
# TODO! more careful if closest point not end point. & adjust SOM!
arm2ep = arm2[visit_order[[0, -1]]] # end points
if np.argmin(arm2ep.separation_3d(stream.origin)) == 1:
visit_order = visit_order[::-1]
arm2 = arm2[visit_order]
# cache (even if None)
self._cache["arm2_visit_order"] = visit_order
self._cache["arm2_SOM"] = som
# -------------------
# Kalman Filter
# both arms start at 0 displacement wrt themselves, but not each other.
# e.g. the progenitor is cut out. To address this the start of affine
# is offset by epsilon = min(1e-10, 1e-10 * dp2p[0])
# Arm 1 (never None)
# -----
kalman_fit_kw = kalman_fit_kw or {}
mean1, kf1, dts1 = self._fit_kalman_filter(arm1, **kalman_fit_kw)
# cache
self._cache["arm1_mean_path"] = mean1
self._cache["arm1_kalman"] = kf1
# TODO! make sure get the frame and units right
r1 = coord.CartesianRepresentation(mean1.Xs[:, ::2].T, unit=u.kpc)
c1 = frame.realize_frame(r1) # (not interpolated)
sp2p1 = c1[:-1].separation(c1[1:]) # point-2-point sep
affine1 = np.concatenate(([min(1e-10 * sp2p1.unit, 1e-10 * sp2p1[0])], sp2p1.cumsum()))
# covariance matrix. select only the phase-space positions
# everything is Gaussian so there are no off-diagonal elements,
# so the 1-sigma error is quite easy.
cov = mean1.Ps[:, ::2, ::2]
var = np.diagonal(cov, axis1=1, axis2=2)
sigma1 = np.sqrt(np.sum(np.square(var), axis=-1)) * u.kpc
# Arm 2
# -----
if arm2 is None:
mean2 = kf2 = dts2 = None
else:
mean2, kf2, dts2 = self._fit_kalman_filter(arm2, **kalman_fit_kw)
# TODO! make sure get the frame and units right
r2 = coord.CartesianRepresentation(mean2.Xs[:, ::2].T, unit=u.kpc)
c2 = frame.realize_frame(r2) # (not interpolated)
sp2p2 = c2[:-1].separation(c2[1:]) # point-2-point sep
affine2 = np.concatenate(([min(1e-10 * sp2p2.unit, 1e-10 * sp2p2[0])], sp2p2.cumsum()))
cov = mean2.Ps[:, ::2, ::2]
var = np.diagonal(cov, axis1=1, axis2=2)
sigma2 = np.sqrt(np.sum(np.square(var), axis=-1)) * u.kpc
# cache (even if None)
self._cache["arm2_mean_path"] = mean2
self._cache["arm2_kalman"] = kf2
# -------------------
# Combine together into a single path
# Need to reverse order of one arm to be indexed toward origin, not away
if arm2 is None:
affine, c, sigma = affine1, c1, sigma1
else:
affine = np.concatenate((-affine2[::-1], affine1))
c = coord.concatenate((c2[::-1], c1))
sigma = np.concatenate((sigma2[::-1], sigma1))
path = Path(
path=c,
width=sigma,
affine=affine,
frame=frame,
)
# construct interpolation
track = StreamTrack(
path,
stream_data=stream.data,
origin=stream.origin,
# frame=frame,
# metadata
frame_fit=frame_fit,
# visit_order=visit_order, # TODO! not combined
som=dict(
arm1=self._cache.get("arm1_SOM", None), # TODO! fix ordering
arm2=self._cache.get("arm2_SOM", None),
),
kalman=dict(arm1=kf1, arm2=kf2),
)
return track
# ===============================================================
[docs] def predict(self, affine):
"""Predict from a fit.
Returns
-------
StreamTrack instance
"""
return self.track(affine)
[docs] def fit_predict(self, stream, affine, **fit_kwargs):
"""Fit and Predict."""
self.fit(stream, **fit_kwargs)
return self.predict(affine)
##############################################################################
[docs]class StreamTrack:
"""A stream track interpolation as function of arc length.
The track is Callable, returning a Frame.
Parameters
----------
path : `~trackstream.utils.path.Path`
stream_data
Original stream data
origin
of the coordinate system (often the progenitor)
"""
meta = MetaData()
frame_fit = MetaAttribute()
visit_order = MetaAttribute()
som = MetaAttribute()
kalman = MetaAttribute()
def __init__(
self,
path: Path,
stream_data: T.Union[Table, TH.CoordinateType, None],
origin: TH.CoordinateType,
# frame: T.Optional[TH.FrameLikeType] = None,
**meta,
):
# validation of types
if not isinstance(path, Path):
raise TypeError("`path` must be <Path>.")
elif not isinstance(origin, (coord.SkyCoord, coord.BaseCoordinateFrame)):
raise TypeError("`origin` must be <|SkyCoord|, |Frame|>.")
# assign
self._path: Path = path
self._origin = origin
# self._frame = resolve_framelike(frame)
self._stream_data = stream_data
# set the MetaAttribute(s)
for attr in list(meta):
descr = getattr(self.__class__, attr, None)
if isinstance(descr, MetaAttribute):
setattr(self, attr, meta.pop(attr))
# and the meta
self.meta.update(meta)
@property
def path(self):
return self._path
@property
def track(self):
"""The path's central track."""
return self._path.data
@property
def affine(self):
return self._path.affine
@property
def stream_data(self):
return self._stream_data
@property
def origin(self):
return self._origin
@property
def frame(self):
return self._path.frame
#######################################################
# Math on the Track
[docs] def __call__(
self,
affine: T.Optional[u.Quantity] = None,
*,
angular: bool = False,
) -> path_moments:
"""Get discrete points along interpolated stream track.
Parameters
----------
affine : `~astropy.units.Quantity` array-like or None, optional
The affine interpolation parameter. If None (default), return
path moments evaluated at all "tick" interpolation points.
angular : bool, optional keyword-only
Whether to compute on-sky or real-space.
Returns
-------
`trackstream.utils.path.path_moments`
Realized from the ``.path`` attribute.
"""
return self.path(affine=affine, angular=angular)
[docs] def probability(
self,
point: coord.SkyCoord,
background_model=None,
*,
angular: bool = False,
affine: T.Optional[u.Quantity] = None,
):
"""Probability point is part of the stream.
.. todo:: angular probability
"""
# # Background probability
# Pb = background_model(point) if background_model is not None else 0.0
#
# #
# angular = False # TODO: angular probability
# afn = self._path.closest_affine_to_point(point, angular=False, affine=affine)
# pt_w = getattr(self._path, "width_angular" if angular else "width")(afn)
# sep = getattr(self._path, "separation" if angular else "separation_3d")(
# point,
# interpolate=False,
# affine=afn,
# )
# stats.norm.pdf(ps.separation_3d(point)) # FIXME! dimensionality
raise NotImplementedError("TODO!")
#######################################################
# misc
def __repr__(self):
"""String representation."""
s = super().__repr__()
frame_name = self.frame.__class__.__name__
rep_name = self.track.representation_type.__name__
s = s.replace("StreamTrack", f"StreamTrack ({frame_name}|{rep_name})")
s += "\n" + indent(repr(self._stream_data)[1:-1])
return s
##############################################################################
# END