Skip to content

Commit

Permalink
Make StokesFrame use new StokesCoord from Astropy (#452)
Browse files Browse the repository at this point in the history
  • Loading branch information
nden authored Jun 11, 2023
2 parents dd12637 + d737e7b commit e7056a9
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 80 deletions.
75 changes: 11 additions & 64 deletions gwcs/coordinate_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from astropy.wcs.wcsapi.low_level_api import (validate_physical_types,
VALID_UCDS)
from astropy.wcs.wcsapi.fitswcs import CTYPE_TO_UCD1
from astropy.coordinates import StokesCoord

__all__ = ['Frame2D', 'CelestialFrame', 'SpectralFrame', 'CompositeFrame',
'CoordinateFrame', 'TemporalFrame']
'CoordinateFrame', 'TemporalFrame', 'StokesFrame']


def _ucd1_to_ctype_name_mapping(ctype_to_ucd, allowed_ucd_duplicates):
Expand Down Expand Up @@ -703,68 +704,16 @@ def _world_axis_object_classes(self):
return dict(self._wao_renamed_classes_iter)


class StokesProfile(str):
# This list of profiles in Table 7 in Greisen & Calabretta (2002)
# modified to be 0 indexed
profiles = {
'I': 0,
'Q': 1,
'U': 2,
'V': 3,
'RR': -1,
'LL': -2,
'RL': -3,
'LR': -4,
'XX': -5,
'YY': -6,
'XY': -7,
'YX': -8,
}

@classmethod
def from_index(cls, indexes):
"""
Construct a StokesProfile object from a numerical index.
Parameters
----------
indexes : `int`, `numpy.ndarray`
An index or array of indices to construct StokesProfile objects from.
"""

nans = np.isnan(indexes)
indexes = np.asarray(indexes, dtype=int)
out = np.empty_like(indexes, dtype=object)

for profile, index in cls.profiles.items():
out[indexes == index] = cls(profile)

out[nans] = np.nan

if out.size == 1 and not nans:
return StokesProfile(out.item())
elif nans.all():
return np.array(out, dtype=float)
return out

def __new__(cls, content):
content = str(content)
if content not in cls.profiles.keys():
raise ValueError(f"The profile name must be one of {cls.profiles.keys()} not {content}")
return str.__new__(cls, content)

def value(self):
return self.profiles[self]


class StokesFrame(CoordinateFrame):
"""
A coordinate frame for representing stokes polarisation states
A coordinate frame for representing Stokes polarisation states.
Parameters
----------
name : str
Name of this frame.
axes_order : tuple
A dimension in the data that corresponds to this axis.
"""

def __init__(self, axes_order=(0,), name=None):
Expand All @@ -775,10 +724,10 @@ def __init__(self, axes_order=(0,), name=None):
@property
def _world_axis_object_classes(self):
return {'stokes': (
StokesProfile,
StokesCoord,
(),
{},
StokesProfile.from_index)}
)}

@property
def _world_axis_object_components(self):
Expand All @@ -790,14 +739,12 @@ def coordinates(self, *args):
else:
arg = args[0]

return StokesProfile.from_index(arg)
return StokesCoord(arg)

def coordinate_to_quantity(self, *coords):
if isinstance(coords[0], str):
if coords[0] in StokesProfile.profiles.keys():
return StokesProfile.profiles[coords[0]] * u.one
else:
return coords[0]
if isinstance(coords[0], StokesCoord):
return coords[0].value << u.one
return coords[0]


class Frame2D(CoordinateFrame):
Expand Down
2 changes: 1 addition & 1 deletion gwcs/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def gwcs_simple_imaging_units():

@pytest.fixture
def gwcs_stokes_lookup():
transform = models.Tabular1D([0, 1, 2, 3] * u.pix, [0, 1, 2, 3] * u.one,
transform = models.Tabular1D([0, 1, 2, 3] * u.pix, [1, 2, 3, 4] * u.one,
method="nearest", fill_value=np.nan, bounds_error=False)
frame = cf.StokesFrame()

Expand Down
1 change: 1 addition & 0 deletions gwcs/tests/data/stokes.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
WCSAXES = 4 / Number of coordinate axes CRPIX1 = 126.0 / Pixel coordinate of reference point CRPIX2 = 126.0 / Pixel coordinate of reference point CRPIX3 = 1.0 / Pixel coordinate of reference point CRPIX4 = 1.0 / Pixel coordinate of reference point CDELT1 = -2.777777777778E-05 / [deg] Coordinate increment at reference point CDELT2 = 2.777777777778E-05 / [deg] Coordinate increment at reference point CDELT3 = 20000205938.09 / [Hz] Coordinate increment at reference point CDELT4 = 1.0 / Coordinate increment at reference point CUNIT1 = 'deg' / Units of coordinate increment and value CUNIT2 = 'deg' / Units of coordinate increment and value CUNIT3 = 'Hz' / Units of coordinate increment and value CTYPE1 = 'RA---SIN' / Right ascension, orthographic/synthesis projectCTYPE2 = 'DEC--SIN' / Declination, orthographic/synthesis projection CTYPE3 = 'FREQ' / Frequency (linear) CTYPE4 = 'STOKES' / Coordinate type code CRVAL1 = 202.78453375 / [deg] Coordinate value at reference point CRVAL2 = 30.50915555556 / [deg] Coordinate value at reference point CRVAL3 = 233000102969.0 / [Hz] Coordinate value at reference point CRVAL4 = 1.0 / Coordinate value at reference point PV2_1 = 0.0 / SIN projection parameter PV2_2 = 0.0 / SIN projection parameter LONPOLE = 180.0 / [deg] Native longitude of celestial pole LATPOLE = 30.50915555556 / [deg] Native latitude of celestial pole RESTFRQ = 224000000000.1 / [Hz] Line rest frequency TIMESYS = 'UTC' / Time scale MJDREF = 0.0 / [d] MJD of fiducial time DATE-OBS= '2014-07-01T21:36:05.280000' / ISO-8601 time of observation MJD-OBS = 56839.900061111 / [d] MJD of observation OBSGEO-X= 2225142.180269 / [m] observatory X-coordinate OBSGEO-Y= -5440307.370349 / [m] observatory Y-coordinate OBSGEO-Z= -2481029.851874 / [m] observatory Z-coordinate RADESYS = 'FK5' / Equatorial coordinate system EQUINOX = 2000.0 / [yr] Equinox of equatorial coordinates SPECSYS = 'TOPOCENT' / Reference frame of spectral coordinates END
18 changes: 11 additions & 7 deletions gwcs/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ def _compare_frame_output(wc1, wc2):
elif isinstance(wc1, str):
assert wc1 == wc2

elif isinstance(wc1, coord.StokesCoord):
assert wc1 == wc2

else:
assert False, f"Can't Compare {type(wc1)}"

Expand Down Expand Up @@ -294,29 +297,30 @@ def test_stokes_wrapper(gwcs_stokes_lookup):

out = hlvl.pixel_to_world(pixel_input*u.pix)

expected = np.array([['I', 'Q', 'U', 'V'],
['I', 'Q', 'U', 'V'],
['I', 'Q', 'U', 'V'],
['I', 'Q', 'U', 'V']], dtype=object)
expected = coord.StokesCoord([['I', 'Q', 'U', 'V'],
['I', 'Q', 'U', 'V'],
['I', 'Q', 'U', 'V'],
['I', 'Q', 'U', 'V']])

assert (out == expected).all()

pixel_input = [-1, 4]

out = hlvl.pixel_to_world(pixel_input*u.pix)

assert np.isnan(out).all()
assert np.isnan(out.value).all()

pixel_input = [[-1, 4],
[1, 2]]

out = hlvl.pixel_to_world(pixel_input*u.pix)

assert np.isnan(np.array(out[0], dtype=float)).all()
assert (out[1] == np.array(['Q', 'U'], dtype=object)).all()
assert np.isnan(out[0].value).all()
assert (out[1] == ['Q', 'U']).all()

out = hlvl.pixel_to_world(1*u.pix)

assert isinstance(out, coord.StokesCoord)
assert out == 'Q'


Expand Down
12 changes: 5 additions & 7 deletions gwcs/tests/test_coordinate_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from astropy.tests.helper import assert_quantity_allclose
from astropy.modeling import models as m
from astropy.wcs.wcsapi.fitswcs import CTYPE_TO_UCD1
from astropy.coordinates import StokesCoord

from .. import WCS
from .. import coordinate_frames as cf
Expand Down Expand Up @@ -281,13 +282,10 @@ def test_coordinate_to_quantity_composite(inp):
def test_stokes_frame():
sf = cf.StokesFrame()

assert sf.coordinates(0) == 'I'
assert sf.coordinates(0 * u.pix) == 'I'
assert sf.coordinate_to_quantity('I') == 0 * u.one
assert sf.coordinate_to_quantity(0) == 0

def test_stokes_profile():
assert (cf.StokesProfile.from_index(np.arange(-8, 4) * u.one) == np.array(['YX', 'XY', 'YY,', 'XX', 'LR', 'RL', 'LL', 'RR', 'I', 'Q', 'U', 'V'], dtype="U2")).all()
assert sf.coordinates(1) == 'I'
assert sf.coordinates(1 * u.pix) == 'I'
assert sf.coordinate_to_quantity(StokesCoord('I')) == 1 * u.one
assert sf.coordinate_to_quantity(1) == 1


@pytest.mark.parametrize('inp', [
Expand Down
46 changes: 46 additions & 0 deletions gwcs/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
focal = cf.Frame2D(name='focal', axes_order=(0, 1), unit=(u.m, u.m))
spec = cf.SpectralFrame(name='wave', unit=[u.m, ], axes_order=(2, ), axes_names=('lambda', ))
time = cf.TemporalFrame(name='time', unit=[u.s, ], axes_order=(3, ), axes_names=('time', ), reference_frame=Time("2020-01-01"))
stokes = cf.StokesFrame(axes_order=(2,))

pipe = [wcs.Step(detector, m1),
wcs.Step(focal, m2),
Expand Down Expand Up @@ -223,6 +224,15 @@ def test_return_coordinates():
output_quant = w.output_frame.coordinate_to_quantity(*num_plus_output)
assert_allclose(utils.get_values(w.unit, *output_quant), numerical_result)

# CompositeFrame - [celestial, Stokes]
output_frame = cf.CompositeFrame(frames=[icrs, stokes])
transform = m1 & models.Identity(1)
w = wcs.WCS(forward_transform=transform, output_frame=output_frame)
numerical_result = transform(x, y, y)
num_plus_output = w(x, y, y, with_units=True)
output_quant = w.output_frame.coordinate_to_quantity(*num_plus_output)
assert_allclose(utils.get_values(w.unit, *output_quant), numerical_result)


def test_from_fiducial_sky():
sky = coord.SkyCoord(1.63 * u.radian, -72.4 * u.deg, frame='fk5')
Expand Down Expand Up @@ -1270,3 +1280,39 @@ def test_sip_roundtrip():
atol=0.0,
rtol=1.0e-8 * 10**(i + j)
)


def test_spatial_spectral_stokes():
""" Converts a FITS WCS to GWCS and compares results."""
hdr = fits.Header.fromfile(get_pkg_data_filename("data/stokes.txt"))
aw = astwcs.WCS(hdr)
crpix = aw.wcs.crpix
crval = aw.wcs.crval
cdelt = aw.wcs.cdelt

fk5 = cf.CelestialFrame(reference_frame=coord.FK5(), name='FK5')
detector = cf.Frame2D(name='detector', axes_order=(0, 1))
spec = cf.SpectralFrame(name='FREQ', unit=[u.Hz, ], axes_order=(2, ), axes_names=('freq', ))
stokes = cf.StokesFrame(axes_order=(3,))
world = cf.CompositeFrame(frames=[fk5, spec, stokes])

det2sky = (models.Shift(-crpix[0]) & models.Shift(-crpix[1]) |
models.Scale(cdelt[0]) & models.Scale(cdelt[1]) |
models.Pix2Sky_SIN() | models.RotateNative2Celestial(crval[0], crval[1], 180))
det2freq = models.Shift(-crpix[2]) | models.Scale(cdelt[2]) | models.Shift(crval[2])
det2stokes = models.Shift(-crpix[3]) | models.Scale(cdelt[3]) | models.Shift(crval[3])

gw = wcs.WCS([wcs.Step(detector, det2sky & det2freq & det2stokes),
wcs.Step(world, None)]
)

x1 = np.array([0, 0, 0, 0, 0])
x2 = np.array([0, 1, 2, 3, 4])

gw_sky, gw_spec, gw_stokes = gw.pixel_to_world(x1+1, x1+1, x1+1, x2+1)
aw_sky, aw_spec, aw_stokes = aw.pixel_to_world(x1, x1, x1, x2)

assert_allclose(gw_sky.data.lon, aw_sky.data.lon)
assert_allclose(gw_sky.data.lat, aw_sky.data.lat)
assert_allclose(gw_spec.value, aw_spec.value)
assert_allclose(gw_stokes.value, aw_stokes.value)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = [
]
dependencies = [
"asdf >= 2.8.1",
"astropy >= 5.1",
"astropy >= 5.3",
"numpy",
"scipy",
"asdf_wcs_schemas",
Expand Down

0 comments on commit e7056a9

Please sign in to comment.