@@ -111,6 +111,10 @@ def list_dirs(self, fs_path):
111111 def touch (self , fs_path , exist_ok = True ):
112112 raise NotImplementedError
113113
114+ @abc .abstractmethod
115+ def cat (self , fs_path = None ):
116+ raise NotImplementedError
117+
114118
115119class LocalFS (FS ):
116120 """
@@ -676,14 +680,35 @@ def is_exist(self, fs_path):
676680
677681 return True
678682
683+ def upload_dir (self , local_dir , dest_dir , overwrite = False ):
684+ """
685+ upload dir to hdfs
686+ Args:
687+ local_dir(str): local dir
688+ dest_dir(str): hdfs dest dir
689+ overwrite(bool): is overwrite
690+ Returns:
691+ return code
692+ """
693+ local_dir = local_dir .rstrip ("/" )
694+ dest_dir = dest_dir .rstrip ("/" )
695+ local_basename = os .path .basename (local_dir )
696+ if self .is_exist (dest_dir + "/" + local_basename ) and overwrite :
697+ self .delete (dest_dir + "/" + local_basename )
698+ if not self .is_exist (dest_dir ):
699+ self .mkdirs (dest_dir )
700+ self ._try_upload (local_dir , dest_dir )
701+
679702 # can't retry
680- def upload (self , local_path , fs_path ):
703+ def upload (self , local_path , fs_path , multi_processes = 1 , overwrite = False ):
681704 """
682705 Upload the local path to remote HDFS.
683706
684707 Args:
685708 local_path(str): The local path.
686709 fs_path(str): The HDFS path.
710+ multi_processes(int|1): the upload data process at the same time, default=5
711+ overwrite(bool|False): will overwrite file on HDFS or not
687712
688713 Examples:
689714
@@ -700,35 +725,83 @@ def upload(self, local_path, fs_path):
700725 client = HDFSClient(hadoop_home, configs)
701726 client.upload("test_hdfs_client", "hdfs:/test_hdfs_client")
702727 """
703- if self .is_exist (fs_path ):
704- raise FSFileExistsError ("{} exists" .format (fs_path ))
728+
729+ def __subprocess_upload (hdfs_path_single , datas ):
730+ for data in datas :
731+ self ._try_upload (data , hdfs_path_single )
732+
733+ def get_local_files (path ):
734+ """
735+ get local files
736+ Args:
737+ path(str): local path
738+ Returns:
739+ list of local files
740+ """
741+ rlist = []
742+
743+ if not os .path .exists (path ):
744+ return rlist
745+
746+ if os .path .isdir (path ):
747+ for file in os .listdir (path ):
748+ t = os .path .join (path , file )
749+ rlist .append (t )
750+ else :
751+ rlist .append (path )
752+ return rlist
705753
706754 local = LocalFS ()
707755 if not local .is_exist (local_path ):
708756 raise FSFileNotExistsError ("{} not exists" .format (local_path ))
757+ # upload_dir
758+ if local .is_dir (local_path ):
759+ self .upload_dir (local_path , fs_path , overwrite = overwrite )
760+ return
761+ # upload files
762+ all_files = get_local_files (local_path )
763+ if not all_files :
764+ print ("there are nothing need to upload, function exit" )
765+ return
766+
767+ if self .is_exist (fs_path ) and overwrite :
768+ self .delete (fs_path )
769+ self .mkdirs (fs_path )
770+
771+ procs = []
772+ for i in range (multi_processes ):
773+ process_datas = self ._split_files (all_files , i , multi_processes )
774+ p = multiprocessing .Process (
775+ target = __subprocess_upload , args = (fs_path , process_datas ))
776+ procs .append (p )
777+ p .start ()
709778
710- return self ._try_upload (local_path , fs_path )
779+ # complete the processes
780+ for proc in procs :
781+ proc .join ()
711782
712783 @_handle_errors ()
713784 def _try_upload (self , local_path , fs_path ):
714785 cmd = "put {} {}" .format (local_path , fs_path )
715786 ret = 0
716787 try :
717- ret , lines = self ._run_cmd (cmd )
788+ ret , _ = self ._run_cmd (cmd )
718789 if ret != 0 :
719790 raise ExecuteError (cmd )
720791 except Exception as e :
721792 self .delete (fs_path )
722793 raise e
723794
724795 # can't retry
725- def download (self , fs_path , local_path ):
796+ def download (self , fs_path , local_path , multi_processes = 1 , overwrite = False ):
726797 """
727798 Download remote HDFS path to the local.
728799
729800 Args:
730801 fs_path(str): The HDFS path.
731802 local_path(str): The local path.
803+ multi_processes(int|1): the download data process at the same time, default=1
804+ overwrite(bool): is overwrite
732805
733806 Examples:
734807
@@ -745,17 +818,43 @@ def download(self, fs_path, local_path):
745818 client = HDFSClient(hadoop_home, configs)
746819 client.download("hdfs:/test_hdfs_client", "./")
747820 """
821+
822+ def __subprocess_download (local_path , datas ):
823+ """
824+ download file from HDFS
825+ Args:
826+ local_path(str): the local file path
827+ datas(str): the hdfs file path list
828+ """
829+ for data in datas :
830+ self ._try_download (data , local_path )
831+
748832 if not self .is_exist (fs_path ):
749833 raise FSFileNotExistsError ("{} not exits" .format (fs_path ))
750-
751- return self ._try_download (fs_path , local_path )
834+ # download file
835+ if self .is_file (fs_path ):
836+ return self ._try_download (fs_path , local_path )
837+ # download dir
838+ _ , all_files = self .ls_dir (fs_path )
839+
840+ procs = []
841+ for i in range (multi_processes ):
842+ process_datas = self ._split_files (all_files , i , multi_processes )
843+ p = multiprocessing .Process (
844+ target = __subprocess_download , args = (local_path , process_datas ))
845+ procs .append (p )
846+ p .start ()
847+
848+ # complete the processes
849+ for proc in procs :
850+ proc .join ()
752851
753852 @_handle_errors ()
754853 def _try_download (self , fs_path , local_path ):
755854 cmd = "get {} {}" .format (fs_path , local_path )
756855 ret = 0
757856 try :
758- ret , lines = self ._run_cmd (cmd )
857+ ret , _ = self ._run_cmd (cmd )
759858 if ret != 0 :
760859 raise ExecuteError (cmd )
761860 except Exception as e :
@@ -803,7 +902,7 @@ def mkdirs(self, fs_path):
803902
804903 if out_hdfs and not self .is_exist (fs_path ):
805904 cmd = "mkdir -p {}" .format (fs_path )
806- ret , lines = self ._run_cmd (cmd )
905+ ret , _ = self ._run_cmd (cmd )
807906 if ret != 0 :
808907 raise ExecuteError (cmd )
809908
@@ -939,7 +1038,71 @@ def _touchz(self, fs_path):
9391038 cmd = "touchz {}" .format (fs_path )
9401039 ret , _ = self ._run_cmd (cmd )
9411040 if ret != 0 :
942- raise ExecuteError
1041+ raise ExecuteError ( cmd )
9431042
9441043 def need_upload_download (self ):
9451044 return True
1045+
1046+ def cat (self , fs_path = None ):
1047+ """
1048+ Cat a remote HDFS file.
1049+
1050+ Args:
1051+ fs_path(str): The HDFS file path.
1052+
1053+ Returns:
1054+ file content
1055+
1056+ Examples:
1057+
1058+ .. code-block:: text
1059+
1060+ from paddle.distributed.fleet.utils import HDFSClient
1061+
1062+ hadoop_home = "/home/client/hadoop-client/hadoop/"
1063+ configs = {
1064+ "fs.default.name": "hdfs://xxx.hadoop.com:54310",
1065+ "hadoop.job.ugi": "hello,hello123"
1066+ }
1067+
1068+ client = HDFSClient(hadoop_home, configs)
1069+ client.cat("hdfs:/test_hdfs_client")
1070+ """
1071+ if self .is_file (fs_path ):
1072+ output = self ._try_cat (fs_path )
1073+ return "\n " .join (output )
1074+ else :
1075+ return ""
1076+
1077+ @_handle_errors ()
1078+ def _try_cat (self , fs_path ):
1079+ cmd = "cat {}" .format (fs_path )
1080+ ret , output = self ._run_cmd (cmd )
1081+ if ret != 0 :
1082+ raise ExecuteError (cmd )
1083+ return output
1084+
1085+ def _split_files (self , files , trainer_id , trainers ):
1086+ """
1087+ split file list
1088+ Args:
1089+ files(list): file list
1090+ trainer_id(int): trainer mpi rank id
1091+ trainers(int): all trainers num
1092+ Returns:
1093+ fileist(list): file list of current trainer
1094+ """
1095+ remainder = len (files ) % trainers
1096+ blocksize = len (files ) // trainers
1097+
1098+ blocks = [blocksize ] * trainers
1099+ for i in range (remainder ):
1100+ blocks [i ] += 1
1101+
1102+ trainer_files = [[]] * trainers
1103+ begin = 0
1104+ for i in range (trainers ):
1105+ trainer_files [i ] = files [begin :begin + blocks [i ]]
1106+ begin += blocks [i ]
1107+
1108+ return trainer_files [trainer_id ]
0 commit comments