Skip to content

Commit 71aebc4

Browse files
authored
Merge pull request #3492 from snbianco/select-all-cols
More powerful column selection in `MastMissions`
2 parents 003f177 + 85e64c5 commit 71aebc4

6 files changed

Lines changed: 196 additions & 28 deletions

File tree

CHANGES.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ mast
7171

7272
- Improved robustness of PanSTARRS column metadata parsing. This prevents metadata-related query errors. [#3485]
7373

74+
- The ``select_cols`` parameter in ``MastMissions`` query functions now accepts an iterable of column names, a comma-delimited
75+
string of column names, or the special values 'all' or '*' to return all available columns. [#3492]
76+
7477
jplspec
7578
^^^^^^^
7679

astroquery/mast/missions.py

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import difflib
1010
import warnings
11+
from collections.abc import Iterable
1112
from json import JSONDecodeError
1213
from pathlib import Path
1314
from urllib.parse import quote
@@ -21,7 +22,7 @@
2122
from astroquery import log
2223
from astroquery.utils import commons, async_to_sync
2324
from astroquery.utils.class_or_instance import class_or_instance
24-
from astroquery.exceptions import InvalidQueryError, MaxResultsWarning, NoResultsWarning
25+
from astroquery.exceptions import InputWarning, InvalidQueryError, MaxResultsWarning, NoResultsWarning
2526

2627
from astroquery.mast import utils
2728
from astroquery.mast.core import MastQueryWithLogin
@@ -43,7 +44,7 @@ class MastMissionsClass(MastQueryWithLogin):
4344
_list_products = 'post_list_products'
4445

4546
# Workaround so that observation_id is returned in ULLYSES queries that do not specify columns
46-
_default_ullyses_cols = ['target_name_ulysses', 'target_classification', 'targ_ra', 'targ_dec', 'host_galaxy_name',
47+
_default_ullyses_cols = ['target_name_ullyses', 'target_classification', 'targ_ra', 'targ_dec', 'host_galaxy_name',
4748
'spectral_type', 'bmv0_mag', 'u_mag', 'b_mag', 'v_mag', 'gaia_g_mean_mag', 'star_mass',
4849
'instrument', 'grating', 'filter', 'observation_id']
4950

@@ -197,6 +198,71 @@ def _build_params_from_criteria(self, params, **criteria):
197198
value = [value]
198199
params[prop] = value
199200

201+
def _parse_select_cols(self, select_cols):
202+
"""
203+
Parse the select_cols parameter to ensure it is in the correct format.
204+
205+
Parameters
206+
----------
207+
select_cols : iterable or str or None
208+
The select_cols parameter to parse.
209+
210+
Returns
211+
-------
212+
list
213+
A list of column names to select.
214+
215+
Raises
216+
------
217+
InvalidQueryError
218+
If select_cols is not an iterable of strings, a comma-separated string, 'all', or '*'.
219+
If any individual column name is not a string.
220+
"""
221+
if select_cols is None:
222+
if self.mission == 'ullyses':
223+
select_cols = self._default_ullyses_cols
224+
return select_cols
225+
226+
# Handle special string cases first
227+
all_columns = self.get_column_list()['name'].value.tolist()
228+
if isinstance(select_cols, str):
229+
if (select_cols.lower() == 'all' or select_cols == '*'):
230+
return all_columns
231+
# Comma-separated string
232+
select_cols = select_cols.split(',')
233+
234+
# Handle an iterable
235+
elif isinstance(select_cols, Iterable):
236+
# Convert to list so we can iterate multiple times safely
237+
select_cols = list(select_cols)
238+
239+
else:
240+
raise InvalidQueryError(
241+
"`select_cols` must be an iterable of column names, a comma-separated string, "
242+
"'all', or '*'."
243+
)
244+
245+
# Validate the column names
246+
valid_select_cols = []
247+
for col in select_cols:
248+
if not isinstance(col, str):
249+
raise InvalidQueryError(
250+
"`select_cols` must contain only strings (column names)."
251+
)
252+
col = col.strip()
253+
if col not in all_columns:
254+
closest_match = difflib.get_close_matches(col, all_columns, n=1)
255+
suggestion = f' Did you mean "{closest_match[0]}"?' if closest_match else ''
256+
warnings.warn(f"Column '{col}' not found.{suggestion}", InputWarning)
257+
else:
258+
valid_select_cols.append(col)
259+
260+
# Dataset ID column should always be returned
261+
dataset_col = self.dataset_kwds.get(self.mission, None)
262+
if dataset_col and dataset_col not in valid_select_cols:
263+
valid_select_cols.append(dataset_col)
264+
return valid_select_cols
265+
200266
@class_or_instance
201267
def query_region_async(self, coordinates, *, radius=3*u.arcmin, limit=5000, offset=0,
202268
select_cols=None, **criteria):
@@ -217,9 +283,11 @@ def query_region_async(self, coordinates, *, radius=3*u.arcmin, limit=5000, offs
217283
Default is 5000. The maximum number of dataset IDs in the results.
218284
offset : int
219285
Default is 0. The number of records you wish to skip before selecting records.
220-
select_cols: list, optional
286+
select_cols: iterable or str or None, optional
221287
Default is None. Names of columns that will be included in the result table.
222288
If None, a default set of columns will be returned.
289+
Can either be an iterable of column names, a comma-separated string of column names,
290+
or 'all'/'*' to return all available columns.
223291
**criteria
224292
Other mission-specific criteria arguments.
225293
All valid filters can be found using `~astroquery.mast.missions.MastMissionsClass.get_column_list`
@@ -255,19 +323,13 @@ def query_region_async(self, coordinates, *, radius=3*u.arcmin, limit=5000, offs
255323
f"Query radius too large. Must be ≤{self._max_query_radius}, got {radius}."
256324
)
257325

258-
# Dataset ID column should always be returned
259-
if select_cols:
260-
select_cols.append(self.dataset_kwds.get(self.mission, None))
261-
elif self.mission == 'ullyses':
262-
select_cols = self._default_ullyses_cols
263-
264326
# Basic params
265327
params = {'target': [f"{coordinates.ra.deg} {coordinates.dec.deg}"],
266328
'radius': radius.arcsec,
267329
'radius_units': 'arcseconds',
268330
'limit': limit,
269331
'offset': offset,
270-
'select_cols': select_cols}
332+
'select_cols': self._parse_select_cols(select_cols)}
271333

272334
self._build_params_from_criteria(params, **criteria)
273335

@@ -295,9 +357,11 @@ def query_criteria_async(self, *, coordinates=None, objectname=None, radius=3*u.
295357
Default is 5000. The maximum number of dataset IDs in the results.
296358
offset : int
297359
Default is 0. The number of records you wish to skip before selecting records.
298-
select_cols: list, optional
360+
select_cols: iterable or str or None, optional
299361
Default is None. Names of columns that will be included in the result table.
300362
If None, a default set of columns will be returned.
363+
Can either be an iterable of column names, a comma-separated string of column names,
364+
or 'all'/'*' to return all available columns.
301365
resolver : str, optional
302366
Default is None. The resolver to use when resolving a named target into coordinates. Valid options are
303367
"SIMBAD" and "NED". If not specified, the default resolver order will be used. Please see the
@@ -344,14 +408,8 @@ def query_criteria_async(self, *, coordinates=None, objectname=None, radius=3*u.
344408
f"Query radius too large. Must be ≤{self._max_query_radius}, got {radius}."
345409
)
346410

347-
# Dataset ID column should always be returned
348-
if select_cols:
349-
select_cols.append(self.dataset_kwds.get(self.mission, None))
350-
elif self.mission == 'ullyses':
351-
select_cols = self._default_ullyses_cols
352-
353411
# build query
354-
params = {"limit": self.limit, "offset": offset, 'select_cols': select_cols}
412+
params = {"limit": self.limit, "offset": offset, 'select_cols': self._parse_select_cols(select_cols)}
355413
if coordinates:
356414
params["target"] = [f"{coordinates.ra.deg} {coordinates.dec.deg}"]
357415
params["radius"] = radius.arcsec
@@ -382,9 +440,11 @@ def query_object_async(self, objectname, *, radius=3*u.arcmin, limit=5000, offse
382440
Default is 5000. The maximum number of dataset IDs in the results.
383441
offset : int
384442
Default is 0. The number of records you wish to skip before selecting records.
385-
select_cols: list, optional
443+
select_cols: iterable or str or None, optional
386444
Default is None. Names of columns that will be included in the result table.
387445
If None, a default set of columns will be returned.
446+
Can either be an iterable of column names, a comma-separated string of column names,
447+
or 'all'/'*' to return all available columns.
388448
resolver : str, optional
389449
Default is None. The resolver to use when resolving a named target into coordinates. Valid options are
390450
"SIMBAD" and "NED". If not specified, the default resolver order will be used. Please see the

astroquery/mast/services.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ..query import BaseQuery
1919
from ..utils import async_to_sync
2020
from ..utils.class_or_instance import class_or_instance
21-
from ..exceptions import InvalidQueryError, TimeoutError, NoResultsWarning
21+
from ..exceptions import BlankResponseWarning, InvalidQueryError, TimeoutError, NoResultsWarning
2222

2323
from . import conf
2424

@@ -97,13 +97,43 @@ def _json_to_table(json_obj, data_key='data'):
9797
# no consistent way to make the mask because np.equal fails on ''
9898
# and array == value fails with None
9999
if col_type == 'str':
100-
col_mask = (col_data == ignore_value)
100+
ignore_mask = (col_data == ignore_value)
101101
else:
102-
col_mask = np.equal(col_data, ignore_value)
102+
ignore_mask = np.equal(col_data, ignore_value)
103103

104104
# add the column if it does not exist already
105105
if col_name not in data_table.colnames:
106-
data_table.add_column(MaskedColumn(col_data.astype(col_type), name=col_name, mask=col_mask))
106+
try:
107+
# Try to coerce entire column at once
108+
coerced = col_data.astype(col_type)
109+
data_table.add_column(
110+
MaskedColumn(coerced, name=col_name, mask=ignore_mask)
111+
)
112+
except (ValueError, TypeError):
113+
# Fallback to coercing values one by one
114+
out = np.empty(len(col_data), dtype=col_type)
115+
fail_mask = np.zeros(len(col_data), dtype=bool)
116+
for i, val in enumerate(col_data):
117+
if val == ignore_value:
118+
# Ignored values are already masked by ignore_mask
119+
continue
120+
121+
try:
122+
out[i] = col_type(val)
123+
except (ValueError, TypeError, OverflowError):
124+
# Could not coerce value, mask it
125+
fail_mask[i] = True
126+
127+
# Combined mask of ignored values + failed coercions
128+
combined_mask = ignore_mask | fail_mask
129+
if np.any(fail_mask):
130+
warnings.warn(
131+
f"Column '{col_name}': {np.sum(fail_mask)} values could not be coerced to {col_type} "
132+
"and were masked.", BlankResponseWarning
133+
)
134+
data_table.add_column(
135+
MaskedColumn(out, name=col_name, mask=combined_mask)
136+
)
107137

108138
return data_table
109139

astroquery/mast/tests/test_mast.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from astropy.utils.exceptions import AstropyDeprecationWarning
2121
from astroquery.mast.services import _json_to_table
2222
from astroquery.utils.mocks import MockResponse
23-
from astroquery.exceptions import (InvalidQueryError, InputWarning, MaxResultsWarning, NoResultsWarning,
24-
RemoteServiceError, ResolverError)
23+
from astroquery.exceptions import (BlankResponseWarning, InvalidQueryError, InputWarning, MaxResultsWarning,
24+
NoResultsWarning, RemoteServiceError, ResolverError)
2525

2626
from astroquery import mast
2727

@@ -302,6 +302,56 @@ def test_missions_query_criteria(patch_post):
302302
)
303303

304304

305+
def test_missions_parse_select_cols(patch_post):
306+
# Default columns
307+
cols = mast.MastMissions._parse_select_cols(None) # Default columns for HST
308+
assert cols is None
309+
310+
# All columns
311+
all_cols = mast.MastMissions._parse_select_cols('all')
312+
assert all_cols == mast.MastMissions.get_column_list()['name'].value.tolist()
313+
314+
# Comma-separated string
315+
string_cols = mast.MastMissions._parse_select_cols('sci_pep_id, sci_instrume')
316+
for col in ['sci_pep_id', 'sci_instrume', 'sci_data_set_name']:
317+
assert col in string_cols
318+
319+
# List of columns
320+
list_cols = mast.MastMissions._parse_select_cols(['sci_pep_id', 'sci_instrume'])
321+
for col in ['sci_pep_id', 'sci_instrume', 'sci_data_set_name']:
322+
assert col in list_cols
323+
324+
# Tuple of columns
325+
tuple_cols = mast.MastMissions._parse_select_cols(('sci_pep_id', 'sci_instrume'))
326+
for col in ['sci_pep_id', 'sci_instrume', 'sci_data_set_name']:
327+
assert col in tuple_cols
328+
329+
# Generator of columns
330+
gen_cols = mast.MastMissions._parse_select_cols(col for col in ['sci_pep_id', 'sci_instrume'])
331+
for col in ['sci_pep_id', 'sci_instrume', 'sci_data_set_name']:
332+
assert col in gen_cols
333+
334+
# Error if invalid type
335+
with pytest.raises(InvalidQueryError, match="`select_cols` must be an iterable of column names"):
336+
mast.MastMissions._parse_select_cols(123)
337+
338+
# Error if an individual column is not a string
339+
with pytest.raises(InvalidQueryError, match="`select_cols` must contain only strings"):
340+
mast.MastMissions._parse_select_cols(['sci_pep_id', 123])
341+
342+
# Warning for invalid column names
343+
with pytest.warns(InputWarning, match="Column 'invalid_column' not found."):
344+
valid_cols = mast.MastMissions._parse_select_cols(['sci_pep_id', 'invalid_column'])
345+
assert 'sci_pep_id' in valid_cols
346+
assert 'invalid_column' not in valid_cols
347+
348+
# Workaround for Ullyses mission default columns
349+
ullyses_mission = mast.MastMissions(mission='ullyses')
350+
ullyses_cols = ullyses_mission._parse_select_cols(None)
351+
for col in mast.MastMissions._default_ullyses_cols:
352+
assert col in ullyses_cols
353+
354+
305355
def test_missions_get_product_list_async(patch_post):
306356
# String input
307357
result = mast.MastMissions.get_product_list_async('Z14Z0104T')
@@ -1485,3 +1535,25 @@ def test_parse_input_location(patch_post):
14851535
with pytest.warns(InputWarning, match="Resolver is only used when resolving object names"):
14861536
loc = mast.utils.parse_input_location(coordinates=coord, resolver="SIMBAD")
14871537
assert isinstance(loc, SkyCoord)
1538+
1539+
1540+
def test_json_to_table_fallback_type_coercion(patch_post):
1541+
json_obj = {'info': [{'name': 'test_int', 'type': 'int'}],
1542+
'data': [['1'], ['2'], ['not_an_int'], ['3'], [-999]]}
1543+
1544+
with pytest.warns(BlankResponseWarning):
1545+
table = _json_to_table(json_obj)
1546+
1547+
# Column exists
1548+
assert 'test_int' in table.colnames
1549+
col = table['test_int']
1550+
assert col.dtype == np.int64
1551+
1552+
# Good values survived
1553+
assert col[0] == 1
1554+
assert col[1] == 2
1555+
assert col[3] == 3
1556+
1557+
# Bad + ignored values are masked
1558+
assert col.mask[2] # 'not_an_int'
1559+
assert col.mask[4] # ignore_value

astroquery/mast/tests/test_mast_remote.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,10 @@ def test_missions_query_criteria(self):
153153
result = MastMissions.query_criteria(objectname='NGC6121',
154154
radius=0.1,
155155
sci_start_time='<2012',
156-
sci_actual_duration='0..200'
157-
)
156+
sci_actual_duration='0..200',
157+
select_cols='*')
158158
assert len(result) == 3
159+
assert result.colnames == MastMissions.get_column_list()['name'].value.tolist()
159160
assert (result['ang_sep'].data.data.astype('float') < 0.1).all()
160161
assert (result['sci_start_time'] < '2012').all()
161162
assert ((result['sci_actual_duration'] >= 0) & (result['sci_actual_duration'] <= 200)).all()

docs/mast/mast_missions.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ Keyword arguments can also be used to refine results further. The following para
7070
- ``sort_desc``: A boolean or list of booleans (one for each field specified in ``sort_by``),
7171
describing if each field should be sorted in descending order (``True``) or ascending order (``False``).
7272

73-
- ``select_cols``: A list of columns to be returned in the response.
73+
- ``select_cols``: Columns to include in the result table. If not specified, a default set of columns
74+
is returned. This parameter may be given as an iterable of column names, a comma-separated string, or the special
75+
values ``'all'`` or ``'*'`` to return all available columns.
7476

7577

7678
Mission Positional Queries

0 commit comments

Comments
 (0)