1313
1414Usage:
1515 sync_to_s3.py [-v... | --verbose...] [-r | --recursive] [-d | --delete]
16- [-e <extension> | --extension <extension>]
16+ [-e <extension> | --extension <extension>]...
1717 [--upload-with-metadata <key>=<value>]...
1818 [--]
1919 SOURCE_PATH DESTINATION_S3_URL
109109
110110import os
111111import sys
112- from typing import Mapping , TypedDict , Optional
112+ from typing import Mapping , TypedDict
113113from pathlib import Path
114114from urllib .parse import urlparse
115115import hashlib
@@ -145,7 +145,7 @@ class LocalFileData(GenericFileData):
145145
146146def get_local_files (
147147 local_path : str ,
148- file_extension : str ,
148+ file_extensions : [ str ] ,
149149 recursive : bool ,
150150) -> Mapping [str , LocalFileData ]:
151151 """
@@ -157,7 +157,8 @@ def get_local_files(
157157 Args:
158158 local_path (str): The local path to search in/lookup.
159159
160- file_extension (Optional[str]): The file_extension to search for.
160+ file_extensions ([str]): The file_extensions to search for, or empty
161+ list if this filter should not be applied.
161162
162163 recursive (bool): Whether to search recursively or not.
163164
@@ -172,7 +173,7 @@ def get_local_files(
172173 if recursive :
173174 return _get_recursive_local_files (
174175 local_path ,
175- f"**/* { file_extension } " ,
176+ file_extensions ,
176177 )
177178 return _get_single_local_file (
178179 local_path ,
@@ -181,7 +182,7 @@ def get_local_files(
181182
182183def _get_recursive_local_files (
183184 local_path : str ,
184- glob : Optional [str ],
185+ file_extensions : [str ],
185186) -> Mapping [str , LocalFileData ]:
186187 """
187188 Retrieve the files that are in the relative path local_path.
@@ -191,9 +192,11 @@ def _get_recursive_local_files(
191192 Args:
192193 local_path (str): The local files to search in.
193194
194- glob (Optional[str]): The glob to search for files inside a specific
195- path. For example, "**/*.yml" will return all YAML files,
196- including those in sub directories.
195+ file_extensions ([str]): The file_extensions to search for, or empty
196+ list if this filter should not be applied. This will be converted
197+ to a glob search, where the extension ".yml" will match files with
198+ the glob search "**/*.yml", returning any YAML file that ends with
199+ .yml. Including those in subdirectories.
197200
198201 Returns:
199202 Mapping[str, LocalFileData]: The map of the Local File Data objects
@@ -206,12 +209,22 @@ def _get_recursive_local_files(
206209 LOGGER .debug (
207210 "Searching for local files in %s matching %s" ,
208211 str (path ),
209- glob ,
212+ file_extensions ,
210213 )
211214 local_files = {}
212- for file_path in path .glob (glob ):
213- local_file_data = _get_local_file_data (file_path , path )
214- local_files [local_file_data ['key' ]] = local_file_data
215+ globs_to_match = [
216+ f"**/*{ ext } "
217+ for ext in (
218+ # File extensions or a list of an empty string, so it either
219+ # generates "**/*{ext}" for each extension in file_extensions
220+ # or it generates "**/*"
221+ file_extensions or ["" ]
222+ )
223+ ]
224+ for glob in globs_to_match :
225+ for file_path in path .glob (glob ):
226+ local_file_data = _get_local_file_data (file_path , path )
227+ local_files [local_file_data ['key' ]] = local_file_data
215228
216229 LOGGER .debug (
217230 "Found %d local files: %s" ,
@@ -299,7 +312,7 @@ def get_s3_objects(
299312 s3_client : any ,
300313 s3_bucket : str ,
301314 s3_prefix : str ,
302- file_extension : str ,
315+ file_extensions : [ str ] ,
303316 recursive : bool ,
304317):
305318 """
@@ -314,7 +327,8 @@ def get_s3_objects(
314327 s3_bucket (str): The bucket name.
315328 s3_prefix (str): The prefix under which the objects are stored in
316329 the bucket.
317- file_extension (str): The file extension of objects that would match.
330+ file_extensions ([str]): The file extensions of objects that would
331+ match.
318332 recursive (bool): Whether to search recursively or not.
319333
320334 Returns:
@@ -326,7 +340,7 @@ def get_s3_objects(
326340 s3_client ,
327341 s3_bucket ,
328342 s3_prefix ,
329- file_extension ,
343+ file_extensions ,
330344 )
331345
332346 return _get_single_s3_object (
@@ -340,7 +354,7 @@ def _get_recursive_s3_objects(
340354 s3_client : any ,
341355 s3_bucket : str ,
342356 s3_prefix : str ,
343- file_extension : str ,
357+ file_extensions : [ str ] ,
344358) -> Mapping [str , GenericFileData ]:
345359 """
346360 Retrieve the objects that are stored inside the S3 bucket, which keys
@@ -352,7 +366,8 @@ def _get_recursive_s3_objects(
352366 s3_bucket (str): The bucket name.
353367 s3_prefix (str): The prefix under which the objects are stored in
354368 the bucket.
355- file_extension (str): The file extension of objects that would match.
369+ file_extensions ([str]): The file extension of objects that would
370+ match.
356371
357372 Returns:
358373 Mapping[str, GenericFileData]: The map of the S3 objects that were
@@ -374,7 +389,16 @@ def _get_recursive_s3_objects(
374389 s3_objects = {}
375390 for response_data in s3_object_iterator :
376391 for obj in response_data .get ("Contents" , []):
377- if not obj .get ("Key" ).endswith (file_extension ):
392+ matched_extensions = list (
393+ # The filter matches its Key against the file_extensions
394+ # to see if it ends with that specific extension.
395+ # This will return an empty list if it did not match or
396+ # if the file_extensions is empty.
397+ filter (obj .get ("Key" ).endswith , file_extensions )
398+ )
399+ if file_extensions and not matched_extensions :
400+ # If we should filter on extensions and we did not match
401+ # with any, we should skip this object.
378402 continue
379403 index_key = convert_to_local_key (obj .get ("Key" ), s3_prefix )
380404 s3_objects [index_key ] = _get_s3_object_data (
@@ -647,7 +671,7 @@ def convert_to_local_key(s3_key, s3_prefix):
647671
648672def ensure_valid_input (
649673 local_path : str ,
650- file_extension : Optional [str ],
674+ file_extensions : [str ],
651675 s3_url : str ,
652676 s3_bucket : str ,
653677 s3_prefix : str ,
@@ -695,7 +719,7 @@ def ensure_valid_input(
695719 )
696720 sys .exit (5 )
697721
698- if file_extension and not recursive :
722+ if file_extensions and not recursive :
699723 LOGGER .warning ("Input warning: Ignoring file_extension filter." )
700724 LOGGER .warning (
701725 "Input warning: The file_extension filter is not applied "
@@ -709,7 +733,7 @@ def ensure_valid_input(
709733def sync_files (
710734 s3_client : any ,
711735 local_path : str ,
712- file_extension : str ,
736+ file_extensions : [ str ] ,
713737 s3_url : str ,
714738 recursive : bool ,
715739 delete : bool ,
@@ -722,9 +746,9 @@ def sync_files(
722746 Args:
723747 s3_client (Boto3.Client): The Boto3 S3 Client to interact with when
724748 a file needs to be deleted.
725- file_extension ( str): The extension to search with for files inside a
726- specific path. For example, ".yml" will return all YAML files,
727- including those in sub directories.
749+ file_extensions ([ str] ): The extensions to search for files inside a
750+ specific path. For example, [ ".yml", ".yaml"] will return all
751+ YAML files, including those in sub directories.
728752 s3_url (str): The S3 URL to use, for example
729753 S3://bucket/specific/prefix.
730754 """
@@ -734,20 +758,20 @@ def sync_files(
734758
735759 ensure_valid_input (
736760 local_path ,
737- file_extension ,
761+ file_extensions ,
738762 s3_url ,
739763 s3_bucket ,
740764 s3_prefix ,
741765 recursive ,
742766 )
743767
744- local_files = get_local_files (local_path , file_extension , recursive )
768+ local_files = get_local_files (local_path , file_extensions , recursive )
745769
746770 s3_objects = get_s3_objects (
747771 s3_client ,
748772 s3_bucket ,
749773 s3_prefix ,
750- file_extension ,
774+ file_extensions ,
751775 recursive ,
752776 )
753777
@@ -784,7 +808,8 @@ def main(): # pylint: disable=R0915
784808 LOGGER .debug ("Input arguments: %s" , options )
785809
786810 local_path = options .get ('SOURCE_PATH' )
787- file_extension = options .get ('--extension' ) or ""
811+ # Remove duplicates from file extension list if there are any
812+ file_extensions = list (set (options .get ('--extension' )))
788813 s3_url = options .get ('DESTINATION_S3_URL' )
789814 recursive = options .get ('--recursive' , False )
790815 delete = options .get ('--delete' , False )
@@ -802,7 +827,7 @@ def main(): # pylint: disable=R0915
802827 sync_files (
803828 s3_client ,
804829 local_path ,
805- file_extension ,
830+ file_extensions ,
806831 s3_url ,
807832 recursive ,
808833 delete ,
0 commit comments