Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ NGPU=8
MP=4

torchrun --nproc_per_node=${NGPU} \
train.py --steps 10
train.py --steps 10 --compile
2 changes: 1 addition & 1 deletion torchtrain/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from torchtrain.logging_utils import rank0_log

_config_file = "./torchtrain/train_config.toml"
_config_file = "./torchtrain/train_config/train_config.toml"


def get_config_from_toml(config_path: str = _config_file) -> dict:
Expand Down
1 change: 1 addition & 0 deletions torchtrain/train_config/__init_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
dump_folder = "./torchtrain/outputs"

[profiling]
run_profiler = true
run_profiler = false
save_traces_folder = "profiling/traces"
# profiling frequency - example: 10 means every 10th iter will be profiled
profile_every_x_iter = 10
13 changes: 10 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
import os
from dataclasses import dataclass, field
from typing import List
import logging
from logging import getLogger
import sys # for logging

# torch imports
import torch
Expand Down Expand Up @@ -56,6 +53,7 @@ def main(args):
device_type, (dp_degree, args.tp_degree), mesh_dim_names=("dp", "tp")
)


model_name = args.model
# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
Expand Down Expand Up @@ -85,6 +83,15 @@ def main(args):

# TODO: apply parallelisms, e.g. fsdp/tp
# TODO: add metrics

# torch.compile model for improved performance

if args.compile:
rank0_log(f"Compiling model {model_name} with torch.compile...")
model = torch.compile(
model,
)

train_state = TrainState()

# train loop
Expand Down