@@ -94,6 +94,9 @@ def _prepare_weights(
9494 load_format = self .load_config .load_format
9595 use_safetensors = False
9696 index_file = DIFFUSION_MODEL_WEIGHTS_INDEX
97+ index_file_with_subfolder = (
98+ f"{ subfolder } /{ index_file } " if subfolder else index_file
99+ )
97100
98101 # only hf is supported currently
99102 if load_format == "auto" :
@@ -129,8 +132,8 @@ def _prepare_weights(
129132 for pattern in allow_patterns :
130133 hf_weights_files += glob .glob (os .path .join (hf_folder , pattern ))
131134 if len (hf_weights_files ) > 0 :
132- if pattern == "*.safetensors" :
133- use_safetensors = True
135+ # Decide by actual files rather than pattern name (patterns may include subfolders).
136+ use_safetensors = any ( f . endswith ( ".safetensors" ) for f in hf_weights_files )
134137 break
135138
136139 if use_safetensors :
@@ -142,11 +145,15 @@ def _prepare_weights(
142145 if not is_local :
143146 download_safetensors_index_file_from_hf (
144147 model_name_or_path ,
145- index_file ,
148+ index_file_with_subfolder ,
146149 self .load_config .download_dir ,
147150 revision ,
148151 )
149- hf_weights_files = filter_duplicate_safetensors_files (hf_weights_files , hf_folder , index_file )
152+ hf_weights_files = filter_duplicate_safetensors_files (
153+ hf_weights_files ,
154+ hf_folder ,
155+ index_file_with_subfolder ,
156+ )
150157 else :
151158 hf_weights_files = filter_files_not_needed_for_inference (hf_weights_files )
152159
@@ -188,8 +195,9 @@ def get_all_weights(
188195
189196 def download_model (self , model_config : ModelConfig ) -> None :
190197 self ._prepare_weights (
191- model_config .model ,
192- model_config .revision ,
198+ model_name_or_path = model_config .model ,
199+ subfolder = None ,
200+ revision = model_config .revision ,
193201 fall_back_to_pt = True ,
194202 allow_patterns_overrides = None ,
195203 )
0 commit comments