1313from astroquery import log
1414from astropy .utils .console import ProgressBarOrSpinner
1515from astropy .utils .exceptions import AstropyDeprecationWarning
16+ from botocore .exceptions import ClientError , BotoCoreError
1617
17- from ..exceptions import NoResultsWarning
18+ from ..exceptions import RemoteServiceError , NoResultsWarning
1819
1920from . import utils
2021
@@ -52,39 +53,75 @@ def __init__(self, provider="AWS", profile=None, verbose=False):
5253 import boto3
5354 import botocore
5455
55- self .supported_missions = ["mast:hst/product" , "mast:tess/product" , "mast:kepler" , "mast:galex" , "mast:ps1" ,
56- "mast:jwst/product" ]
57-
5856 self .boto3 = boto3
5957 self .botocore = botocore
6058 self .config = botocore .client .Config (signature_version = botocore .UNSIGNED )
61-
6259 self .pubdata_bucket = "stpubdata"
60+ self .s3_client = self .boto3 .client ('s3' , config = self .config )
61+
62+ # Cached list of datasets available in the cloud
63+ self ._supported_datasets = self ._fetch_supported_datasets ()
6364
6465 if verbose :
6566 log .info ("Using the S3 STScI public dataset" )
6667
67- def is_supported (self , data_product ):
68+ def _fetch_supported_datasets (self ):
6869 """
69- Given a data product, determines if it is in a mission available in the cloud.
70+ Returns the list of datasets that have data available in the cloud.
7071
71- Parameters
72- ----------
73- data_product : `~astropy.table.Row`
74- Product to be validated.
72+ Returns
73+ -------
74+ response : list
75+ List of supported datasets.
76+ """
77+ try :
78+ datasets = []
79+
80+ # Top-level prefixes in the bucket
81+ response = self .s3_client .list_objects_v2 (
82+ Bucket = self .pubdata_bucket ,
83+ Delimiter = '/' # Use a delimiter to treat the S3 structure like folders
84+ )
85+
86+ for prefix_info in response .get ('CommonPrefixes' , []):
87+ prefix = prefix_info ['Prefix' ].rstrip ('/' )
88+
89+ if prefix == 'mast' :
90+ # 'mast/' contains sub-prefixes for different high-level science products
91+ mast_response = self .s3_client .list_objects_v2 (
92+ Bucket = self .pubdata_bucket ,
93+ Prefix = 'mast/hlsp/' ,
94+ Delimiter = '/' ,
95+ )
96+ datasets .extend (
97+ cp ['Prefix' ].rstrip ('/' )
98+ for cp in mast_response .get ('CommonPrefixes' , [])
99+ )
100+ else :
101+ datasets .append (prefix )
102+
103+ return datasets
104+
105+ except (ClientError , BotoCoreError ) as e :
106+ log .error ('Failed to retrieve supported datasets from S3 bucket %s: %s' , self .pubdata_bucket , e )
107+ return []
108+
109+ def get_supported_datasets (self ):
110+ """
111+ Returns the list of datasets that have data available in the cloud.
75112
76113 Returns
77114 -------
78- response : bool
79- Is the product from a supported mission .
115+ response : list
116+ List of supported datasets .
80117 """
81- return any ( data_product [ 'dataURI' ]. lower (). startswith ( mission ) for mission in self .supported_missions )
118+ return list ( self ._supported_datasets )
82119
83120 def get_cloud_uri (self , data_product , include_bucket = True , full_url = False ):
84121 """
85122 For a given data product, returns the associated cloud URI.
86- If the product is from a mission that does not support cloud access an
87- exception is raised. If the mission is supported but the product
123+ If the product is from a dataset that does not support cloud access an
124+ exception is raised. If the dataset is supported but the product
88125 cannot be found in the cloud, the returned path is None.
89126
90127 Parameters
@@ -112,7 +149,7 @@ def get_cloud_uri(self, data_product, include_bucket=True, full_url=False):
112149
113150 # Making sure we got at least 1 URI from the query above.
114151 if not uri_list or uri_list [0 ] is None :
115- warnings . warn ( "Unable to locate file {}." . format ( data_product ), NoResultsWarning )
152+ return
116153 else :
117154 # Output from ``get_cloud_uri_list`` is always a list even when it's only 1 URI
118155 return uri_list [0 ]
@@ -141,11 +178,8 @@ def get_cloud_uri_list(self, data_products, *, include_bucket=True, full_url=Fal
141178 List of URIs generated from the data products, list way contain entries that are None
142179 if data_products includes products not found in the cloud.
143180 """
144- s3_client = self .boto3 .client ('s3' , config = self .config )
145181 data_uris = data_products if isinstance (data_products , list ) else data_products ['dataURI' ]
146- paths = utils .mast_relative_path (data_uris , verbose = verbose )
147- if isinstance (paths , str ): # Handle the case where only one product was requested
148- paths = [paths ]
182+ paths = utils .get_cloud_paths (data_uris , verbose = verbose )
149183
150184 uri_list = []
151185 for path in paths :
@@ -154,7 +188,7 @@ def get_cloud_uri_list(self, data_products, *, include_bucket=True, full_url=Fal
154188 else :
155189 try :
156190 # Use `head_object` to verify that the product is available on S3 (not all products are)
157- s3_client .head_object (Bucket = self .pubdata_bucket , Key = path )
191+ self . s3_client .head_object (Bucket = self .pubdata_bucket , Key = path )
158192 if include_bucket :
159193 s3_path = "s3://{}/{}" .format (self .pubdata_bucket , path )
160194 uri_list .append (s3_path )
@@ -167,76 +201,78 @@ def get_cloud_uri_list(self, data_products, *, include_bucket=True, full_url=Fal
167201 if e .response ['Error' ]['Code' ] != "404" :
168202 raise
169203 if verbose :
170- warnings .warn ("Unable to locate file {}." . format ( path ) , NoResultsWarning )
204+ warnings .warn (f"Failed to retrieve cloud path for { path } " , NoResultsWarning )
171205 uri_list .append (None )
172206
173207 return uri_list
174208
175- def download_file (self , data_product , local_path , cache = True , verbose = True ):
209+ def download_file_from_cloud (self , data_product , local_path , cache = True , verbose = True ):
176210 """
177- Takes a data product in the form of an `~astropy.table.Row` and downloads it from the cloud into
178- the given directory.
211+ Download a data product from MAST cloud storage (S3) to a local file.
179212
180213 Parameters
181214 ----------
182- data_product : `~astropy.table.Row`
183- Product to download .
215+ data_product : str
216+ MAST product URI (e.g. ``mast:JWST/product.fits``) or S3 URI (e.g. ``s3://<bucket>/path/to/product.fits``) .
184217 local_path : str
185- The local filename to which toe downloaded file will be saved.
186- cache : bool
187- Default is True. If file is found on disc it will not be downloaded again.
218+ Local filename where the downloaded file will be saved.
219+ cache : bool, optional
220+ Default is True. If True, and the file already exists locally with the expected size,
221+ the download is skipped.
188222 verbose : bool, optional
189223 Default is True. Whether to show download progress in the console.
190224 """
191-
192- s3 = self .boto3 .resource ('s3' , config = self .config )
193- s3_client = self .boto3 .client ('s3' , config = self .config )
194- bkt = s3 .Bucket (self .pubdata_bucket )
195- with warnings .catch_warnings ():
196- warnings .simplefilter ("ignore" )
197- bucket_path = self .get_cloud_uri (data_product , False )
198- if not bucket_path :
199- raise Exception ("Unable to locate file {}." .format (data_product ['dataURI' ]))
200-
201- # Ask the webserver (in this case S3) what the expected content length is and use that.
202- info_lookup = s3_client .head_object (Bucket = self .pubdata_bucket , Key = bucket_path )
203- length = info_lookup ["ContentLength" ]
204-
205- if cache and os .path .exists (local_path ):
206- if length is not None :
207- statinfo = os .stat (local_path )
208- if statinfo .st_size != length :
209- log .warning ("Found cached file {0} with size {1} that is "
210- "different from expected size {2}"
211- .format (local_path ,
212- statinfo .st_size ,
213- length ))
214- else :
215- log .info ("Found cached file {0} with expected size {1}."
216- .format (local_path , statinfo .st_size ))
217- return
218-
219- if verbose :
220- with ProgressBarOrSpinner (length , ('Downloading URL s3://{0}/{1} to {2} ...' .format (
221- self .pubdata_bucket , bucket_path , local_path ))) as pb :
222-
223- # Bytes read tracks how much data has been received so far
224- # This variable will be updated in multiple threads below
225- global bytes_read
226- bytes_read = 0
227-
228- progress_lock = threading .Lock ()
229-
230- def progress_callback (numbytes ):
231- # Boto3 calls this from multiple threads pulling the data from S3
232- global bytes_read
233-
234- # This callback can be called in multiple threads
235- # Access to updating the console needs to be locked
236- with progress_lock :
237- bytes_read += numbytes
238- pb .update (bytes_read )
239-
240- bkt .download_file (bucket_path , local_path , Callback = progress_callback )
225+ # TODO: Function that checks if a particular product, by dataURI, can be found in the cloud
226+ # Normalize to an S3 key (no bucket)
227+ if data_product .strip ().startswith ('s3://' ):
228+ s3_key = data_product .replace (f's3://{ self .pubdata_bucket } /' , '' , 1 )
241229 else :
242- bkt .download_file (bucket_path , local_path )
230+ s3_key = self .get_cloud_uri_list ([data_product ], include_bucket = False , verbose = False )[0 ]
231+
232+ # If s3_key is None, the product was not found in the cloud
233+ if s3_key is None :
234+ raise RemoteServiceError (f'The product { data_product } was not found in the cloud.' )
235+
236+ # Query S3 for expected file size
237+ head = self .s3_client .head_object (Bucket = self .pubdata_bucket , Key = s3_key )
238+ expected_size = head .get ('ContentLength' )
239+
240+ # Cache check
241+ if cache and os .path .exists (local_path ) and expected_size is not None :
242+ local_size = os .path .getsize (local_path )
243+ if local_size == expected_size :
244+ log .info ("Using cached file {0} with expected size {1}."
245+ .format (local_path , local_size ))
246+ return
247+ else :
248+ log .warning ("Found cached file {0} with size {1} that is "
249+ "different from expected size {2}"
250+ .format (local_path ,
251+ local_size ,
252+ expected_size ))
253+
254+ # Proceed with download
255+ bucket = self .boto3 .resource ('s3' , config = self .config ).Bucket (self .pubdata_bucket )
256+ if not verbose :
257+ bucket .download_file (s3_key , local_path )
258+ return
259+
260+ # Progress-aware download
261+ bytes_read = 0
262+ progress_lock = threading .Lock ()
263+
264+ def progress_callback (numbytes ):
265+ # Boto3 calls this from multiple threads pulling the data from S3
266+ nonlocal bytes_read
267+
268+ # This callback can be called in multiple threads
269+ # Access to updating the console needs to be locked
270+ with progress_lock :
271+ bytes_read += numbytes
272+ pb .update (bytes_read )
273+
274+ with ProgressBarOrSpinner (
275+ expected_size ,
276+ f'Downloading s3://{ self .pubdata_bucket } /{ s3_key } to { local_path } ...'
277+ ) as pb :
278+ bucket .download_file (s3_key , local_path , Callback = progress_callback )
0 commit comments