|
5 | 5 | from abc import ABC, abstractmethod |
6 | 6 |
|
7 | 7 | from huggingface_hub import HfApi, login |
| 8 | +from huggingface_hub.hf_api import CommitInfo |
| 9 | + |
| 10 | +from ludwig.globals import MODEL_HYPERPARAMETERS_FILE_NAME |
8 | 11 |
|
9 | 12 | logger = logging.getLogger(__name__) |
10 | 13 |
|
@@ -193,17 +196,20 @@ def _validate_upload_parameters( |
193 | 196 | </Alex(12/10/2023): TODO> |
194 | 197 | """ |
195 | 198 | files = set(os.listdir(trained_model_artifacts_path)) |
196 | | - acceptable_model_artifact_file_nanes: set[str] = { |
| 199 | + acceptable_model_artifact_file_names: set[str] = { |
197 | 200 | "pytorch_model.bin", |
198 | 201 | "adapter_model.bin", # Delete per formal deprecation policy TBD (per above comment). |
199 | 202 | "adapter_model.safetensors", # New format as of PEFT version "0.7.0" (per above comment). |
200 | 203 | } |
201 | | - if not (files & acceptable_model_artifact_file_nanes): |
| 204 | + if not (files & acceptable_model_artifact_file_names): |
202 | 205 | raise ValueError( |
203 | 206 | f"Can't find model weights at {trained_model_artifacts_path}. Trained model weights should " |
204 | 207 | "either be saved as `pytorch_model.bin` for regular model training, or have `adapter_model.bin`" |
205 | 208 | "or `adapter_model.safetensors` if using parameter efficient fine-tuning methods like LoRA." |
206 | 209 | ) |
| 210 | + model_hyperparameters_path: str = os.path.join(model_path, "model") |
| 211 | + if MODEL_HYPERPARAMETERS_FILE_NAME not in os.listdir(model_hyperparameters_path): |
| 212 | + raise ValueError(f"Can't find '{MODEL_HYPERPARAMETERS_FILE_NAME}' at {model_hyperparameters_path}.") |
207 | 213 |
|
208 | 214 | def upload( |
209 | 215 | self, |
@@ -256,17 +262,37 @@ def upload( |
256 | 262 | ) |
257 | 263 |
|
258 | 264 | # Upload all artifacts in model weights folder |
259 | | - upload_path = self.api.upload_folder( |
| 265 | + commit_message_weights: str | None = f"{commit_message} (weights)" if commit_message else commit_message |
| 266 | + commit_description_weights: str | None = ( |
| 267 | + f"{commit_description} (weights)" if commit_description else commit_description |
| 268 | + ) |
| 269 | + upload_path_weights: CommitInfo = self.api.upload_folder( |
260 | 270 | folder_path=os.path.join(model_path, "model", "model_weights"), |
261 | 271 | repo_id=repo_id, |
262 | 272 | repo_type=repo_type, |
263 | | - commit_message=commit_message, |
264 | | - commit_description=commit_description, |
| 273 | + commit_message=commit_message_weights, |
| 274 | + commit_description=commit_description_weights, |
265 | 275 | ) |
266 | 276 |
|
267 | | - if upload_path: |
268 | | - logger.info(f"Model uploaded to `{upload_path}` with repository name `{repo_id}`") |
269 | | - return True |
| 277 | + if upload_path_weights: |
| 278 | + logger.info(f"Model weights uploaded to `{upload_path_weights}` with repository name `{repo_id}`") |
| 279 | + # Upload the ludwig configuration file |
| 280 | + commit_message_config: str | None = f"{commit_message} (config)" if commit_message else commit_message |
| 281 | + commit_description_config: str | None = ( |
| 282 | + f"{commit_description} (config)" if commit_description else commit_description |
| 283 | + ) |
| 284 | + upload_path_config: CommitInfo = self.api.upload_file( |
| 285 | + path_or_fileobj=os.path.join(model_path, "model", MODEL_HYPERPARAMETERS_FILE_NAME), |
| 286 | + path_in_repo="ludwig_config.json", |
| 287 | + repo_id=repo_id, |
| 288 | + repo_type=repo_type, |
| 289 | + commit_message=commit_message_config, |
| 290 | + commit_description=commit_description_config, |
| 291 | + ) |
| 292 | + |
| 293 | + if upload_path_config: |
| 294 | + logger.info(f"Model config uploaded to `{upload_path_config}` with repository name `{repo_id}`") |
| 295 | + return True |
270 | 296 |
|
271 | 297 | return False |
272 | 298 |
|
|
0 commit comments