Skip to content

Commit c9f2538

Browse files
authored
Merge pull request #3488 from snbianco/get-cloud-missions
Get cloud missions dynamically and fix cloud download workflow
2 parents 2f2c903 + de342ea commit c9f2538

File tree

8 files changed

+691
-196
lines changed

8 files changed

+691
-196
lines changed

.readthedocs.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ python:
2020
path: .
2121
extra_requirements:
2222
- docs
23+
- all
2324

2425
# Don't build any extra formats
2526
formats: []

CHANGES.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ vizier
3131

3232
- Methods ``get_catalog``, ``get_catalog_async`` and ``query_*`` now always return UCD1+ instead of UCD1. [#3458]
3333

34+
mast
35+
^^^^
36+
- ``utils.mast_relative_path`` is now deprecated in favor of ``utils.get_cloud_paths``. [#3488]
37+
- When cloud access is enabled, ``Observations.download_file`` and ``Observations.download_products``
38+
now check all requested products against cloud storage. As a result, setting ``cloud_only=True`` will skip
39+
any products that are not available in the cloud, rather than falling back to on-prem downloads.
40+
3441
Service fixes and enhancements
3542
------------------------------
3643

@@ -83,6 +90,9 @@ mast
8390

8491
- Added full support for the International Ultraviolet Explorer (IUE) mission in ``MastMissions``. [#3517]
8592

93+
- Added a new ``Observations.list_cloud_datasets()`` method for querying cloud-supported MAST datasets, alongside
94+
improvements to cloud download handling. [#3488]
95+
8696
jplspec
8797
^^^^^^^
8898

astroquery/mast/cloud.py

Lines changed: 118 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
from astroquery import log
1414
from astropy.utils.console import ProgressBarOrSpinner
1515
from astropy.utils.exceptions import AstropyDeprecationWarning
16+
from botocore.exceptions import ClientError, BotoCoreError
1617

17-
from ..exceptions import NoResultsWarning
18+
from ..exceptions import RemoteServiceError, NoResultsWarning
1819

1920
from . import utils
2021

@@ -52,39 +53,75 @@ def __init__(self, provider="AWS", profile=None, verbose=False):
5253
import boto3
5354
import botocore
5455

55-
self.supported_missions = ["mast:hst/product", "mast:tess/product", "mast:kepler", "mast:galex", "mast:ps1",
56-
"mast:jwst/product"]
57-
5856
self.boto3 = boto3
5957
self.botocore = botocore
6058
self.config = botocore.client.Config(signature_version=botocore.UNSIGNED)
61-
6259
self.pubdata_bucket = "stpubdata"
60+
self.s3_client = self.boto3.client('s3', config=self.config)
61+
62+
# Cached list of datasets available in the cloud
63+
self._supported_datasets = self._fetch_supported_datasets()
6364

6465
if verbose:
6566
log.info("Using the S3 STScI public dataset")
6667

67-
def is_supported(self, data_product):
68+
def _fetch_supported_datasets(self):
6869
"""
69-
Given a data product, determines if it is in a mission available in the cloud.
70+
Returns the list of datasets that have data available in the cloud.
7071
71-
Parameters
72-
----------
73-
data_product : `~astropy.table.Row`
74-
Product to be validated.
72+
Returns
73+
-------
74+
response : list
75+
List of supported datasets.
76+
"""
77+
try:
78+
datasets = []
79+
80+
# Top-level prefixes in the bucket
81+
response = self.s3_client.list_objects_v2(
82+
Bucket=self.pubdata_bucket,
83+
Delimiter='/' # Use a delimiter to treat the S3 structure like folders
84+
)
85+
86+
for prefix_info in response.get('CommonPrefixes', []):
87+
prefix = prefix_info['Prefix'].rstrip('/')
88+
89+
if prefix == 'mast':
90+
# 'mast/' contains sub-prefixes for different high-level science products
91+
mast_response = self.s3_client.list_objects_v2(
92+
Bucket=self.pubdata_bucket,
93+
Prefix='mast/hlsp/',
94+
Delimiter='/',
95+
)
96+
datasets.extend(
97+
cp['Prefix'].rstrip('/')
98+
for cp in mast_response.get('CommonPrefixes', [])
99+
)
100+
else:
101+
datasets.append(prefix)
102+
103+
return datasets
104+
105+
except (ClientError, BotoCoreError) as e:
106+
log.error('Failed to retrieve supported datasets from S3 bucket %s: %s', self.pubdata_bucket, e)
107+
return []
108+
109+
def get_supported_datasets(self):
110+
"""
111+
Returns the list of datasets that have data available in the cloud.
75112
76113
Returns
77114
-------
78-
response : bool
79-
Is the product from a supported mission.
115+
response : list
116+
List of supported datasets.
80117
"""
81-
return any(data_product['dataURI'].lower().startswith(mission) for mission in self.supported_missions)
118+
return list(self._supported_datasets)
82119

83120
def get_cloud_uri(self, data_product, include_bucket=True, full_url=False):
84121
"""
85122
For a given data product, returns the associated cloud URI.
86-
If the product is from a mission that does not support cloud access an
87-
exception is raised. If the mission is supported but the product
123+
If the product is from a dataset that does not support cloud access an
124+
exception is raised. If the dataset is supported but the product
88125
cannot be found in the cloud, the returned path is None.
89126
90127
Parameters
@@ -112,7 +149,7 @@ def get_cloud_uri(self, data_product, include_bucket=True, full_url=False):
112149

113150
# Making sure we got at least 1 URI from the query above.
114151
if not uri_list or uri_list[0] is None:
115-
warnings.warn("Unable to locate file {}.".format(data_product), NoResultsWarning)
152+
return
116153
else:
117154
# Output from ``get_cloud_uri_list`` is always a list even when it's only 1 URI
118155
return uri_list[0]
@@ -141,11 +178,8 @@ def get_cloud_uri_list(self, data_products, *, include_bucket=True, full_url=Fal
141178
List of URIs generated from the data products, list way contain entries that are None
142179
if data_products includes products not found in the cloud.
143180
"""
144-
s3_client = self.boto3.client('s3', config=self.config)
145181
data_uris = data_products if isinstance(data_products, list) else data_products['dataURI']
146-
paths = utils.mast_relative_path(data_uris, verbose=verbose)
147-
if isinstance(paths, str): # Handle the case where only one product was requested
148-
paths = [paths]
182+
paths = utils.get_cloud_paths(data_uris, verbose=verbose)
149183

150184
uri_list = []
151185
for path in paths:
@@ -154,7 +188,7 @@ def get_cloud_uri_list(self, data_products, *, include_bucket=True, full_url=Fal
154188
else:
155189
try:
156190
# Use `head_object` to verify that the product is available on S3 (not all products are)
157-
s3_client.head_object(Bucket=self.pubdata_bucket, Key=path)
191+
self.s3_client.head_object(Bucket=self.pubdata_bucket, Key=path)
158192
if include_bucket:
159193
s3_path = "s3://{}/{}".format(self.pubdata_bucket, path)
160194
uri_list.append(s3_path)
@@ -167,76 +201,78 @@ def get_cloud_uri_list(self, data_products, *, include_bucket=True, full_url=Fal
167201
if e.response['Error']['Code'] != "404":
168202
raise
169203
if verbose:
170-
warnings.warn("Unable to locate file {}.".format(path), NoResultsWarning)
204+
warnings.warn(f"Failed to retrieve cloud path for {path}", NoResultsWarning)
171205
uri_list.append(None)
172206

173207
return uri_list
174208

175-
def download_file(self, data_product, local_path, cache=True, verbose=True):
209+
def download_file_from_cloud(self, data_product, local_path, cache=True, verbose=True):
176210
"""
177-
Takes a data product in the form of an `~astropy.table.Row` and downloads it from the cloud into
178-
the given directory.
211+
Download a data product from MAST cloud storage (S3) to a local file.
179212
180213
Parameters
181214
----------
182-
data_product : `~astropy.table.Row`
183-
Product to download.
215+
data_product : str
216+
MAST product URI (e.g. ``mast:JWST/product.fits``) or S3 URI (e.g. ``s3://<bucket>/path/to/product.fits``).
184217
local_path : str
185-
The local filename to which toe downloaded file will be saved.
186-
cache : bool
187-
Default is True. If file is found on disc it will not be downloaded again.
218+
Local filename where the downloaded file will be saved.
219+
cache : bool, optional
220+
Default is True. If True, and the file already exists locally with the expected size,
221+
the download is skipped.
188222
verbose : bool, optional
189223
Default is True. Whether to show download progress in the console.
190224
"""
191-
192-
s3 = self.boto3.resource('s3', config=self.config)
193-
s3_client = self.boto3.client('s3', config=self.config)
194-
bkt = s3.Bucket(self.pubdata_bucket)
195-
with warnings.catch_warnings():
196-
warnings.simplefilter("ignore")
197-
bucket_path = self.get_cloud_uri(data_product, False)
198-
if not bucket_path:
199-
raise Exception("Unable to locate file {}.".format(data_product['dataURI']))
200-
201-
# Ask the webserver (in this case S3) what the expected content length is and use that.
202-
info_lookup = s3_client.head_object(Bucket=self.pubdata_bucket, Key=bucket_path)
203-
length = info_lookup["ContentLength"]
204-
205-
if cache and os.path.exists(local_path):
206-
if length is not None:
207-
statinfo = os.stat(local_path)
208-
if statinfo.st_size != length:
209-
log.warning("Found cached file {0} with size {1} that is "
210-
"different from expected size {2}"
211-
.format(local_path,
212-
statinfo.st_size,
213-
length))
214-
else:
215-
log.info("Found cached file {0} with expected size {1}."
216-
.format(local_path, statinfo.st_size))
217-
return
218-
219-
if verbose:
220-
with ProgressBarOrSpinner(length, ('Downloading URL s3://{0}/{1} to {2} ...'.format(
221-
self.pubdata_bucket, bucket_path, local_path))) as pb:
222-
223-
# Bytes read tracks how much data has been received so far
224-
# This variable will be updated in multiple threads below
225-
global bytes_read
226-
bytes_read = 0
227-
228-
progress_lock = threading.Lock()
229-
230-
def progress_callback(numbytes):
231-
# Boto3 calls this from multiple threads pulling the data from S3
232-
global bytes_read
233-
234-
# This callback can be called in multiple threads
235-
# Access to updating the console needs to be locked
236-
with progress_lock:
237-
bytes_read += numbytes
238-
pb.update(bytes_read)
239-
240-
bkt.download_file(bucket_path, local_path, Callback=progress_callback)
225+
# TODO: Function that checks if a particular product, by dataURI, can be found in the cloud
226+
# Normalize to an S3 key (no bucket)
227+
if data_product.strip().startswith('s3://'):
228+
s3_key = data_product.replace(f's3://{self.pubdata_bucket}/', '', 1)
241229
else:
242-
bkt.download_file(bucket_path, local_path)
230+
s3_key = self.get_cloud_uri_list([data_product], include_bucket=False, verbose=False)[0]
231+
232+
# If s3_key is None, the product was not found in the cloud
233+
if s3_key is None:
234+
raise RemoteServiceError(f'The product {data_product} was not found in the cloud.')
235+
236+
# Query S3 for expected file size
237+
head = self.s3_client.head_object(Bucket=self.pubdata_bucket, Key=s3_key)
238+
expected_size = head.get('ContentLength')
239+
240+
# Cache check
241+
if cache and os.path.exists(local_path) and expected_size is not None:
242+
local_size = os.path.getsize(local_path)
243+
if local_size == expected_size:
244+
log.info("Using cached file {0} with expected size {1}."
245+
.format(local_path, local_size))
246+
return
247+
else:
248+
log.warning("Found cached file {0} with size {1} that is "
249+
"different from expected size {2}"
250+
.format(local_path,
251+
local_size,
252+
expected_size))
253+
254+
# Proceed with download
255+
bucket = self.boto3.resource('s3', config=self.config).Bucket(self.pubdata_bucket)
256+
if not verbose:
257+
bucket.download_file(s3_key, local_path)
258+
return
259+
260+
# Progress-aware download
261+
bytes_read = 0
262+
progress_lock = threading.Lock()
263+
264+
def progress_callback(numbytes):
265+
# Boto3 calls this from multiple threads pulling the data from S3
266+
nonlocal bytes_read
267+
268+
# This callback can be called in multiple threads
269+
# Access to updating the console needs to be locked
270+
with progress_lock:
271+
bytes_read += numbytes
272+
pb.update(bytes_read)
273+
274+
with ProgressBarOrSpinner(
275+
expected_size,
276+
f'Downloading s3://{self.pubdata_bucket}/{s3_key} to {local_path} ...'
277+
) as pb:
278+
bucket.download_file(s3_key, local_path, Callback=progress_callback)

0 commit comments

Comments
 (0)