Source code for trackstream.stream

# -*- coding: utf-8 -*-

"""Core Functions."""

__all__ = ["Stream"]


##############################################################################
# IMPORTS

# STDLIB
import typing as T
import weakref

# THIRD PARTY
import astropy.coordinates as coord
import astropy.units as u
import numpy as np
from astropy.table import Column, QTable, Table
from astropy.utils.decorators import lazyproperty

# LOCAL
from trackstream._type_hints import CoordinateType, FrameType
from trackstream.preprocess.som import SelfOrganizingMap1D
from trackstream.utils.path import path_moments

##############################################################################
# CODE
##############################################################################


class StreamArmDescriptor:
    def __init__(self) -> None:
        # references to parent class and instance
        self._parent_attr = None  # set in __set_name__
        self._parent_cls = None
        self._parent_ref = None

    @property
    def _parent(self):
        """Parent instance Cosmology."""
        return self._parent_ref() if self._parent_ref is not None else self._parent_cls

    # ------------------------------------

    def __set_name__(self, objcls, name):
        self._parent_attr = name

    def __get__(self, obj, objcls):
        # accessed from a class
        if obj is None:
            self._parent_cls = objcls
            return self

        # accessed from an obj
        equivs = obj.__dict__.get(self._parent_attr)  # get from obj
        if equivs is None:  # hasn't been created on the obj
            descriptor = self.__class__()
            descriptor._parent_cls = obj.__class__
            descriptor._parent_attr = self._parent_attr
            obj.__dict__[self._parent_attr] = descriptor

        # We set `_parent_ref` on every call, since if one makes copies of objs,
        # 'descriptor' will be copied as well, which will lose the reference.
        descriptor._parent_ref = weakref.ref(obj)
        return descriptor

    # ------------------------------------

    @property
    def index(self) -> Column:
        return self._parent.data["tail"] == self._parent_attr

    @property
    def has_data(self):
        return any(self.index)

    @property
    def data(self) -> Column:
        if not self.has_data:
            raise Exception("no arm 1")  # TODO! specific exception
        return self._parent.data[self.index]

    @property
    def coords(self) -> coord.SkyCoord:
        """The coordinates of the arm."""
        arm: coord.SkyCoord
        arm = self._parent.coords[self.index]
        return arm


# /class


[docs]class Stream: """A Stellar Stream. Parameters ---------- data : `~astropy.table.Table` origin : `~astropy.coordinates.ICRS` The origin point of the rotated reference frame. data_err : `~astropy.table.QTable` (optional) The data_err must have (at least) column names ["x_err", "y_err", "z_err"] frame : `~astropy.coordinates.BaseCoordinateFrame` or None (optional, keyword-only) The stream frame. Locally linearizes the data. If not None, need to fit for the frame (default). """ arm1 = StreamArmDescriptor() arm2 = StreamArmDescriptor() # =============================================================== def __init__( self, data: QTable, origin: FrameType, data_err: T.Optional[Table] = None, *, frame: T.Optional[CoordinateType] = None, ): # system attributes self.origin: coord.SkyCoord = coord.SkyCoord(origin, copy=False) self._system_frame: T.Optional[FrameType] = frame self._cache = dict() # TODO! improve # ---------- # process the data # seed values set in functions self._original_data: coord.SkyCoord = None # processed data -> QTable self.data: QTable = self._normalize_data(data) self._data_max_lines = 10 # ----------------------------------------------------- @property def system_frame(self) -> T.Optional[coord.BaseCoordinateFrame]: """A system-centric frame. Determined from the argument ``frame`` at initialization. If None (default) and the method ``fit`` has been called, then a system frame has been found and cached. """ frame: T.Optional[coord.BaseCoordinateFrame] if self._system_frame is not None: frame = self._system_frame else: frame = self._cache.get("frame") return frame @property def frame(self) -> coord.BaseCoordinateFrame: """Alias for ``system_frame``.""" return self.system_frame @lazyproperty def number_of_tails(self) -> int: """Number of tidal tails. Returns ------- number_of_tails : int There can only be 1, or 2 tidal tails. """ return 2 if (self.arm1.has_data and self.arm2.has_data) else 1 @property # TODO! make lazy def coords(self) -> coord.SkyCoord: """Coordinates.""" frame: coord.SkyCoord if self.system_frame is not None: frame = self.system_frame else: frame = self.data_frame return self.data_coords.transform_to(frame) # =============================================================== @property def data_coords(self) -> coord.SkyCoord: """Get ``coord`` from data table.""" return self.data["coord"] @property def data_frame(self) -> FrameType: """The frame of the data.""" return self.data_coords.frame.replicate_without_data() # =============================================================== # Data normalization def _normalize_data(self, original: Table) -> QTable: """Normalize data table. Just calls other functions. Parameters ---------- original : :class:`~astropy.table.Table` Returns ------- data : :class:`~astropy.table.QTable` """ data = QTable() # going to be assigned in-place # 1) data probability self._normalize_data_probability(original, data, default_weight=1) # 2) coordinates. `data` modded in-place self._normalize_data_coordinates(original, data) # 3) SOM ordering self._normalize_data_arm_index(original, data) # Metadata # TODO? selective, or just copy over? data.meta = original.meta.copy() # TODO? deepcopy? return data def _normalize_data_probability( self, original: Table, data: QTable, default_weight: T.Union[float, u.Quantity] = 1.0, ) -> None: """Data probability. Units of percent. Default is 100%. Parameters ---------- original : |Table| The original data. data : |QTable| The normalized data. default_weight : float The default membership probability. If float, then range 0-1 maps to 0-100%. If has unit of percent, then unchanged """ colns = [n.lower() for n in original.colnames] if "pmemb" in colns: Pmemb = original[original.colnames[colns.index("pmemb")]] else: Pmemb = np.ones(len(original)) * default_weight # non-scalar data["Pmemb"] = u.Quantity(Pmemb).to(u.percent) # in % def _normalize_data_coordinates( self, original: Table, data: QTable, ) -> None: """Parse the data table. - the frame is stored in ``_data_frame`` - the representation is stored in ``_data_rep`` - the original data representation is in ``_data`` Parameters ---------- data : |Table| The stream data. data_err : |Table| or |Frame| or instance It must be in Cartesian coordinates! Raises ------ TypeError if `data` is not |Table| or |Frame| """ # TODO!!! better data["tail"] = original["tail"] # ---------- # 1) the data # First look for a column "coord" if "coord" in original.colnames: sc = osc = coord.SkyCoord(original["coord"], copy=False) else: sc = osc = coord.SkyCoord.guess_from_table(original) self._original_data = osc # Convert frame and representation type frame = ( self.system_frame if self.system_frame is not None else osc.frame.replicate_without_data() ) sc = sc.transform_to(frame) sc.representation_type = coord.CartesianRepresentation # it's now clean and can be added data["coord"] = sc # also want to store the components, for plotting compnames = sc.representation_component_names.keys() for n in compnames: data[n] = getattr(sc, n) # ---------- # 2) the error # TODO! want errors in frame of the data err_cols = ["x_err", "y_err", "z_err"] for n in err_cols: data[n] = original[n] # transfer # ---------- data = data.group_by("tail") data.add_index("tail") def _normalize_data_arm_index( self, original: Table, data: QTable, ) -> None: """Data probability. Units of percent. Default is 100%. Parameters ---------- original : |Table| The original data. data : |QTable| The normalized data. """ if "SOM" in original.colnames: data["SOM"] = original["SOM"] else: data["SOM"] = None # =============================================================== # Fitting @property def track(self) -> "StreamTrack": # noqa: F821 """Stream track. Raises ------ ValueError If track is not fit. """ track = self._cache.get("track") if track is None: raise ValueError("need to fit track.") return track
[docs] def fit_track( self, arm1SOM: T.Optional[SelfOrganizingMap1D] = None, arm2SOM: T.Optional[SelfOrganizingMap1D] = None, *, force: bool = False, **kwargs, ) -> "StreamTrack": # noqa: F821 """Make a stream track. Parameters ---------- arm1SOM, arm2SOM : `~trackstream.preprocess.SelfOrganizingMap` (optional, keyword-only) Fiducial SOMs for stream arms 1 and 2, respectively. force : bool Whether to force a fit, even if already fit. **kwargs Passed to :meth:`trackstream.TrackStream.fit`. Returns ------- `trackstream.StreamTrack` """ if not force and "tracker" in self._cache: raise Exception("already fit. use ``force`` to re-fit.") # LOCAL from trackstream.core import StreamTrack, TrackStream self._cache["tracker"] = tracker = TrackStream(arm1SOM=arm1SOM, arm2SOM=arm2SOM) track: StreamTrack = tracker.fit(self, **kwargs) self._cache["track"] = track # Add SOM ordering to data self.data["SOM"] = np.empty(len(self.data), dtype=int) self.data["SOM"][self.arm1.index] = tracker._cache["arm1_visit_order"] self.data["SOM"][self.arm2.index] = tracker._cache["arm2_visit_order"] return track
# =============================================================== # Math on the Track (requires fitting track)
[docs] def predict_track( self, affine: T.Optional[u.Quantity] = None, angular: bool = False, ) -> path_moments: return self.track()
# =============================================================== # Misc def _base_repr_(self, max_lines=None): """mirroring implementation in astropy Table.""" header: str = super().__repr__() frame: str = repr(self.frame) datarep: str = self.data._base_repr_( html=False, max_width=None, max_lines=self._data_max_lines, ) table: str = "\n\t".join(datarep.split("\n")[1:]) return header + "\n Frame:\n\t" + frame + "\n Data:\n\t" + table def __repr__(self) -> str: return self._base_repr_(max_lines=self._data_max_lines)
############################################################################## # END