Source code for trackstream.tests.test_core

# -*- coding: utf-8 -*-

"""Testing :class:`trackstream.core.TrackStream`."""

__all__ = ["Test_TrackStream"]


##############################################################################
# IMPORTS

# THIRD PARTY
import astropy.coordinates as coord
import astropy.units as u
import numpy as np
import pytest
from astropy.tests.helper import assert_quantity_allclose

# LOCAL
from trackstream.core import StreamTrack, TrackStream
from trackstream.example_data import get_example_pal5
from trackstream.preprocess.som import SelfOrganizingMap1D
from trackstream.stream import Stream

##############################################################################
# TESTS
##############################################################################


[docs]class Test_TrackStream: """Test :class:`~trackstream.core.TrackStream`."""
[docs] @classmethod def setup_class(self): """Setup fixtures for testing.""" self.stream_cls = Stream self.data = get_example_pal5() self.origin = self.data.meta["origin"] self.data_err = None # TODO? self.frame = None
[docs] @pytest.fixture(params=[(None, None), (True, None), (None, True), (True, True)]) def tracker(self, request): arm1SOM, arm2SOM = request.param if arm1SOM is True: # TODO! have a test for a pre-made SOM arm1SOM = None if arm2SOM is True: # TODO! have a test for a pre-made SOM arm2SOM = None return TrackStream(arm1SOM=arm1SOM, arm2SOM=arm2SOM)
[docs] @pytest.fixture def tracker_cls(self): return TrackStream
[docs] @pytest.fixture def stream_cls(self): """Stream class.""" return self.stream_cls
[docs] @pytest.fixture def stream(self, stream_cls): """Stream instance.""" frame = self.frame return stream_cls(self.data, self.origin, data_err=self.data_err, frame=frame)
# =============================================================== # Method tests
[docs] def test_init(self, tracker): """Test instantiation.""" assert isinstance(tracker._cache, dict) assert isinstance(tracker._arm1_SOM, (type(None), SelfOrganizingMap1D)) assert isinstance(tracker._arm2_SOM, (type(None), SelfOrganizingMap1D))
# -------------------------------
[docs] def test_fit(self, tracker, stream): """Test method ``fit``.""" track = tracker.fit(stream) assert isinstance(track, StreamTrack)
# TODO! a lot more tests
[docs] def test_predict(self, tracker): """Test method ``predict``.""" arclength = np.linspace(0, 1) with pytest.raises(AttributeError): # can't call what don't have tracker.predict(arclength)
[docs] def test_fit_predict(self, tracker, stream): """Test method ``fit_predict``.""" arclength = np.linspace(0, 1) with pytest.raises(AttributeError): # can't call what don't have tracker.fit_predict(stream, arclength)
# /class # ------------------------------------------------------------------- class Test_StreamTrack: """Test :class:`~trackstream.core.StreamTrack`.""" # @classmethod # def setup_class(self): # """Setup fixtures for testing.""" # # TODO! move to # num = 40 # self.arclength = np.linspace(0, 10, num=num) * u.deg # # lon = np.linspace(0, 25, num=num) * u.deg # lat = np.linspace(-10, 10, num=num) * u.deg # distance = np.linspace(8, 15, num=num) * u.kpc # # self.data = coord.ICRS( # coord.SphericalRepresentation(lon=lon, lat=lat, distance=distance), # ) # self.interps = dict( # lon=IUSU(self.arclength, lon), # lat=IUSU(self.arclength, lat), # distance=IUSU(self.arclength, distance), # ) # # # origin # i = num // 2 # self.origin = coord.ICRS(ra=lon[i], dec=lat[i]) @pytest.fixture def origin(self, scrd, num): i = num // 2 o = scrd[i] return o @pytest.fixture def track_cls(self): return StreamTrack @pytest.fixture def track(self, track_cls, path, scrd, origin, frame): """path and stream_data don't have to match up this nicely.""" track = track_cls(path, stream_data=scrd, origin=origin, frame=frame) return track # =============================================================== # Method tests # def test_init(self, track_cls, frame): # """Test instantiation.""" # track = StreamTrack(self.interps, stream_data=self.data, # origin=self.origin, frame=frame) # # assert hasattr(track, "_data") # assert hasattr(track, "_track") # assert hasattr(track, "origin") # # # -------------- # # Different argument types # # # The data is an ICRS object # # we must also test passing in a BaseRepresentation # rep = self.data.represent_as(coord.SphericalRepresentation) # # track = track_cls(self.interps, stream_data=rep, origin=self.origin) # assert isinstance(track._data_frame, coord.BaseCoordinateFrame) # assert track._data_rep == self.data.representation_type # # # and a failed input type # with pytest.raises(TypeError) as e: # track_cls(None, None, None) # # assert f"`stream_data` type <{type(None)}> is wrong." in str(e.value) def test_path(self, track): assert track.path is track._path def test_track(self, track): assert track.track is track.path.data def test_affine(self, track): assert track.affine is track.path.affine def test_stream_data(self, track): assert track.stream_data is track._stream_data def test_origin(self, track): assert track.origin is track._origin def test_frame(self, track): assert track.frame is track._path.frame def test_frame_fit(self, track): if "__attributes__" in track.meta and "frame_fit" in track.meta["__attributes__"]: assert track.frame_fit is track.meta["__attributes__"]["frame_fit"] else: assert track.frame_fit is None def test_visit_order(self, track): if "__attributes__" in track.meta and "visit_order" in track.meta["__attributes__"]: assert track.visit_order is track.meta["__attributes__"]["visit_order"] else: assert track.visit_order is None def test_som(self, track): if "__attributes__" in track.meta and "som" in track.meta["__attributes__"]: assert track.som is track.meta["__attributes__"]["som"] else: assert track.som is None def test_kalman(self, track): if "__attributes__" in track.meta and "som" in track.meta["__attributes__"]: assert track.kalman is track.meta["__attributes__"]["kalman"] else: assert track.kalman is None def test_call(self, track, scrd, affine): """Test call method.""" mean, width = track(affine) assert isinstance(mean.frame, coord.ICRS) assert mean.representation_type == coord.SphericalRepresentation assert_quantity_allclose(mean.ra, scrd.ra, atol=1e-15 * u.deg) assert_quantity_allclose(mean.dec, scrd.dec, atol=1e-15 * u.deg) assert_quantity_allclose(mean.distance, scrd.distance, atol=1e-15 * u.kpc) def test_repr(self, track): """Test that the modified __repr__ method works.""" s = track.__repr__() frame_name = track.frame.__class__.__name__ rep_name = track.track.representation_type.__name__ assert f"StreamTrack ({frame_name}|{rep_name})" in s ############################################################################## # END