Skip to content

Commit 9004459

Browse files
authored
[SDK] Fix trainer error: Update the version of base image and add "num_labels" for downloading pretrained models (kubeflow/trainer#2230)
* fix trainer error Signed-off-by: helenxie-bit <helenxiehz@gmail.com> * rerun tests Signed-off-by: helenxie-bit <helenxiehz@gmail.com> * update the process of num_labels in trainer Signed-off-by: helenxie-bit <helenxiehz@gmail.com> * rerun tests Signed-off-by: helenxie-bit <helenxiehz@gmail.com> * adjust the default value of 'num_labels' Signed-off-by: helenxie-bit <helenxiehz@gmail.com> --------- Signed-off-by: helenxie-bit <helenxiehz@gmail.com>
1 parent b48147e commit 9004459

4 files changed

Lines changed: 22 additions & 9 deletions

File tree

python/kubeflow/storage_initializer/hugging_face.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class HuggingFaceModelParams:
3737
model_uri: str
3838
transformer_type: TRANSFORMER_TYPES
3939
access_token: str = None
40+
num_labels: Optional[int] = None
4041

4142
def __post_init__(self):
4243
# Custom checks or validations can be added here

python/kubeflow/trainer/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Use an official Pytorch runtime as a parent image
2-
FROM nvcr.io/nvidia/pytorch:23.10-py3
2+
FROM nvcr.io/nvidia/pytorch:24.06-py3
33

44
# Set the working directory in the container
55
WORKDIR /app

python/kubeflow/trainer/hf_llm_training.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,26 @@
2828
logger.setLevel(logging.INFO)
2929

3030

31-
def setup_model_and_tokenizer(model_uri, transformer_type, model_dir):
31+
def setup_model_and_tokenizer(model_uri, transformer_type, model_dir, num_labels):
3232
# Set up the model and tokenizer
3333
parsed_uri = urlparse(model_uri)
3434
model_name = parsed_uri.netloc + parsed_uri.path
3535

36-
model = transformer_type.from_pretrained(
37-
pretrained_model_name_or_path=model_name,
38-
cache_dir=model_dir,
39-
local_files_only=True,
40-
trust_remote_code=True,
41-
)
36+
if num_labels != "None":
37+
model = transformer_type.from_pretrained(
38+
pretrained_model_name_or_path=model_name,
39+
cache_dir=model_dir,
40+
local_files_only=True,
41+
trust_remote_code=True,
42+
num_labels=int(num_labels),
43+
)
44+
else:
45+
model = transformer_type.from_pretrained(
46+
pretrained_model_name_or_path=model_name,
47+
cache_dir=model_dir,
48+
local_files_only=True,
49+
trust_remote_code=True,
50+
)
4251

4352
tokenizer = AutoTokenizer.from_pretrained(
4453
pretrained_model_name_or_path=model_name,
@@ -151,6 +160,7 @@ def parse_arguments():
151160

152161
parser.add_argument("--model_uri", help="model uri")
153162
parser.add_argument("--transformer_type", help="model transformer type")
163+
parser.add_argument("--num_labels", default="None", help="number of classes")
154164
parser.add_argument("--model_dir", help="directory containing model")
155165
parser.add_argument("--dataset_dir", help="directory containing dataset")
156166
parser.add_argument("--lora_config", help="lora_config")
@@ -178,7 +188,7 @@ def parse_arguments():
178188

179189
logger.info("Setup model and tokenizer")
180190
model, tokenizer = setup_model_and_tokenizer(
181-
args.model_uri, transformer_type, args.model_dir
191+
args.model_uri, transformer_type, args.model_dir, args.num_labels
182192
)
183193

184194
logger.info("Preprocess dataset")

python/kubeflow/training/api/training_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@ def train(
265265
model_provider_parameters.model_uri,
266266
"--transformer_type",
267267
model_provider_parameters.transformer_type.__name__,
268+
"--num_labels",
269+
str(model_provider_parameters.num_labels),
268270
"--model_dir",
269271
VOLUME_PATH_MODEL,
270272
"--dataset_dir",

0 commit comments

Comments
 (0)