Skip to content

Commit c09d5dc

Browse files
sanjaydasguptasd-buildwhizpre-commit-ci[bot]
authored
Save ludwig-config with model-weights in output directory (#3965)
Co-authored-by: Sanjay <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 606c732 commit c09d5dc

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

ludwig/utils/upload_utils.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from abc import ABC, abstractmethod
66

77
from huggingface_hub import HfApi, login
8+
from huggingface_hub.hf_api import CommitInfo
9+
10+
from ludwig.globals import MODEL_HYPERPARAMETERS_FILE_NAME
811

912
logger = logging.getLogger(__name__)
1013

@@ -193,17 +196,20 @@ def _validate_upload_parameters(
193196
</Alex(12/10/2023): TODO>
194197
"""
195198
files = set(os.listdir(trained_model_artifacts_path))
196-
acceptable_model_artifact_file_nanes: set[str] = {
199+
acceptable_model_artifact_file_names: set[str] = {
197200
"pytorch_model.bin",
198201
"adapter_model.bin", # Delete per formal deprecation policy TBD (per above comment).
199202
"adapter_model.safetensors", # New format as of PEFT version "0.7.0" (per above comment).
200203
}
201-
if not (files & acceptable_model_artifact_file_nanes):
204+
if not (files & acceptable_model_artifact_file_names):
202205
raise ValueError(
203206
f"Can't find model weights at {trained_model_artifacts_path}. Trained model weights should "
204207
"either be saved as `pytorch_model.bin` for regular model training, or have `adapter_model.bin`"
205208
"or `adapter_model.safetensors` if using parameter efficient fine-tuning methods like LoRA."
206209
)
210+
model_hyperparameters_path: str = os.path.join(model_path, "model")
211+
if MODEL_HYPERPARAMETERS_FILE_NAME not in os.listdir(model_hyperparameters_path):
212+
raise ValueError(f"Can't find '{MODEL_HYPERPARAMETERS_FILE_NAME}' at {model_hyperparameters_path}.")
207213

208214
def upload(
209215
self,
@@ -256,17 +262,37 @@ def upload(
256262
)
257263

258264
# Upload all artifacts in model weights folder
259-
upload_path = self.api.upload_folder(
265+
commit_message_weights: str | None = f"{commit_message} (weights)" if commit_message else commit_message
266+
commit_description_weights: str | None = (
267+
f"{commit_description} (weights)" if commit_description else commit_description
268+
)
269+
upload_path_weights: CommitInfo = self.api.upload_folder(
260270
folder_path=os.path.join(model_path, "model", "model_weights"),
261271
repo_id=repo_id,
262272
repo_type=repo_type,
263-
commit_message=commit_message,
264-
commit_description=commit_description,
273+
commit_message=commit_message_weights,
274+
commit_description=commit_description_weights,
265275
)
266276

267-
if upload_path:
268-
logger.info(f"Model uploaded to `{upload_path}` with repository name `{repo_id}`")
269-
return True
277+
if upload_path_weights:
278+
logger.info(f"Model weights uploaded to `{upload_path_weights}` with repository name `{repo_id}`")
279+
# Upload the ludwig configuration file
280+
commit_message_config: str | None = f"{commit_message} (config)" if commit_message else commit_message
281+
commit_description_config: str | None = (
282+
f"{commit_description} (config)" if commit_description else commit_description
283+
)
284+
upload_path_config: CommitInfo = self.api.upload_file(
285+
path_or_fileobj=os.path.join(model_path, "model", MODEL_HYPERPARAMETERS_FILE_NAME),
286+
path_in_repo="ludwig_config.json",
287+
repo_id=repo_id,
288+
repo_type=repo_type,
289+
commit_message=commit_message_config,
290+
commit_description=commit_description_config,
291+
)
292+
293+
if upload_path_config:
294+
logger.info(f"Model config uploaded to `{upload_path_config}` with repository name `{repo_id}`")
295+
return True
270296

271297
return False
272298

tests/ludwig/utils/test_upload_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytest
88

9+
from ludwig.globals import MODEL_HYPERPARAMETERS_FILE_NAME
910
from ludwig.utils.upload_utils import HuggingFaceHub
1011

1112
logger = logging.getLogger(__name__)
@@ -26,14 +27,15 @@ def _build_fake_model_repo(
2627
names must be leaf file names, not paths).
2728
"""
2829
# Create a temporary folder designating training output directory.
29-
model_directory: str = pathlib.Path(destination_directory) / experiment_name / model_directory_name
30-
model_weights_directory: str = model_directory / model_weights_directory_name
30+
model_directory: pathlib.Path = pathlib.Path(destination_directory) / experiment_name / model_directory_name
31+
model_weights_directory: pathlib.Path = model_directory / model_weights_directory_name
3132
model_weights_directory.mkdir(parents=True, exist_ok=True)
3233

3334
# Create files within the "model_weights" subdirectory.
3435
file_name: str
3536
for file_name in file_names:
3637
pathlib.Path(model_weights_directory / file_name).touch()
38+
pathlib.Path(model_directory / MODEL_HYPERPARAMETERS_FILE_NAME).touch()
3739

3840

3941
@pytest.fixture

0 commit comments

Comments
 (0)