Source code for trackstream.tests.test_stream
# -*- coding: utf-8 -*-
"""Testing :mod:`~trackstream.stream`."""
__all__ = [
"TestStream",
]
##############################################################################
# IMPORTS
# THIRD PARTY
import astropy.coordinates as coord
import astropy.table as table
import numpy as np
import pytest
# LOCAL
from trackstream.example_data import get_example_pal5
from trackstream.stream import Stream
##############################################################################
# TESTS
##############################################################################
[docs]class TestStream:
"""Test :class:`trackstream.stream.Stream`."""
[docs] @classmethod
def setup_class(self):
"""Setup fixtures for testing."""
self.stream_cls = Stream
self.data = get_example_pal5()
self.origin = self.data.meta["origin"]
self.data_err = None # TODO?
self.frame = None
[docs] @pytest.fixture
def stream_cls(self):
"""Stream class."""
return self.stream_cls
[docs] @pytest.fixture(params=[None, True])
def stream(self, stream_cls, request):
"""Stream instance."""
frame = self.frame if request.param is True else request.param
return stream_cls(self.data, self.origin, data_err=self.data_err, frame=frame)
# ===============================================================
[docs] def test_init_fail_numargs(self, stream_cls):
"""Test with wrong number of arguments. A reminder to include ``origin``."""
with pytest.raises(TypeError, match="origin"):
stream_cls(self.data)
[docs] def test_init(self, stream_cls):
"""Test initialization."""
stream = stream_cls(self.data, self.origin, data_err=self.data_err, frame=self.frame)
# origin
assert isinstance(stream.origin, coord.SkyCoord)
assert stream.origin == self.origin
# frame
assert stream._system_frame is self.frame
# cache
assert stream._cache == dict()
# data
# assert stream._original_data is None # NOT! it's processed
assert isinstance(stream.data, table.Table)
# -------------------------------------------
# arm1
[docs] def test_arm1_index(self, stream):
expected = stream.data["tail"] == "arm1"
got = stream.arm1.index
assert all(got == expected)
[docs] def test_arm1_has_data(self, stream):
expected = any(stream.data["tail"] == "arm1")
got = stream.arm1.has_data
assert got == expected
[docs] def test_arm1_data(self, stream):
if not stream.arm1.has_data:
with pytest.raises(Exception, match="no arm 1"):
stream.arm1.data
else:
index = stream.data["tail"] == "arm1"
expected = stream.data[index]
got = stream.arm1.data
assert all(got == expected)
[docs] def test_arm1_coords(self, stream):
index = stream.data["tail"] == "arm1"
expected = stream.coords[index]
got = stream.arm1.coords
assert all(got == expected)
# -------------------------------------------
# arm2
[docs] def test_arm2_index(self, stream):
expected = stream.data["tail"] == "arm2"
got = stream.arm2.index
assert all(got == expected)
[docs] def test_arm2_has_data(self, stream):
expected = any(stream.data["tail"] == "arm2")
got = stream.arm2.has_data
assert got == expected
[docs] def test_arm2_data(self, stream):
if not stream.arm2.has_data:
with pytest.raises(Exception, match="no arm 1"):
stream.arm2.data
else:
index = stream.data["tail"] == "arm2"
expected = stream.data[index]
got = stream.arm2.data
assert all(got == expected)
[docs] def test_arm2_coords(self, stream):
index = stream.data["tail"] == "arm2"
expected = stream.coords[index]
got = stream.arm2.coords
assert all(got == expected)
# -------------------------------------------
[docs] def test_system_frame(self, stream):
"""Test system-centric frame."""
# if passed a frame at initialization
if stream._system_frame is not None: # if passed frame at init
assert stream.system_frame is self._system_frame
else:
assert stream.system_frame is None # == stream._cache.get("frame", None)
# TODO! have test for fit streams, where it isn't None
[docs] def test_frame(self, stream):
"""Test attribute ``frame``."""
assert stream.frame is stream.system_frame
[docs] def test_number_of_tails(self, stream):
expect = 2 if (stream.arm1.has_data and stream.arm2.has_data) else 1
assert stream.number_of_tails == expect
[docs] def test_coords(self, stream):
assert isinstance(stream.coords, coord.SkyCoord)
frame = stream.system_frame if stream.system_frame is not None else stream.data_frame
assert np.all(stream.coords == stream.data_coords.transform_to(frame))
# -------------------------------------------
[docs] def test_data_coords(self, stream):
assert np.all(stream.data_coords == stream.data["coord"])
[docs] def test_data_frame(self, stream):
assert np.all(stream.data_frame == stream.data_coords.frame.replicate_without_data())
# -------------------------------------------
[docs] @pytest.mark.skip("TODO!")
def test_normalize_data(self, stream_cls):
assert False
# -------------------------------------------
[docs] def test_track(self, stream):
# Different test if not already fit a track
if "track" not in stream._cache:
with pytest.raises(ValueError, match="need to fit track"):
stream.track
stream.fit_track()
# now guaranteed to have a working track
# assert isinstance(stream.track, StreamTrack)
# TODO! more tests
[docs] def test_fit_track(self, stream):
"""Test fit stream track."""
if "track" not in stream._cache:
stream.fit_track(force=True)
# a track is already fit
# not force
with pytest.raises(Exception, match="already fit"):
stream.fit_track()
# testing a forced fit
track = stream.fit_track(force=True)
assert track is stream._cache["track"]
# TODO! more tests
# ===============================================================
# Test Usage
[docs] @pytest.mark.skip("TODO!")
def test_loading_pal5(self, stream_cls):
data = get_example_pal5()
stream = Stream(data, data.meta["origin"])
assert stream
assert False
# /class
##############################################################################
# END