Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions torchtrain/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,14 @@
import os
import torch

try:
import tomllib
except ModuleNotFoundError:
import tomli as tomllib

from torchtrain.logging_utils import rank0_log

_config_file = "./torchtrain/train_config.toml"


def get_config_from_toml(config_path: str = _config_file) -> dict:
"""
Reads a config file in TOML format and returns a dictionary.
"""
with open(config_path, "rb") as f:
config = tomllib.load(f)
return config
from torchtrain.tt_config.config_utils import get_config


@contextlib.contextmanager
def maybe_run_profiler(*pos_args, **kwargs):
config = get_config_from_toml()
config = get_config()

# get user defined profiler settings
run_profiler = config["profiling"].get("run_profiler", False)
Expand Down
1 change: 1 addition & 0 deletions torchtrain/tt_config/__init_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# __init__.py
13 changes: 13 additions & 0 deletions torchtrain/tt_config/config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
try:
import tomllib
except ModuleNotFoundError:
import tomli as tomllib


def get_config(config_path: str = "./torchtrain/tt_config/train_config.toml") -> dict:
"""
Reads a config file in TOML format and returns a dictionary.
"""
with open(config_path, "rb") as f:
config = tomllib.load(f)
return config
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
[global]
dump_folder = "./torchtrain/outputs"

[compile]
use_compile = true

[profiling]
run_profiler = true
save_traces_folder = "profiling/traces"
Expand Down
17 changes: 14 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
import os
from dataclasses import dataclass, field
from typing import List
import logging
from logging import getLogger
import sys # for logging

# torch imports
import torch
Expand All @@ -24,6 +21,8 @@
pad_batch_to_longest_seq,
)

from torchtrain.tt_config.config_utils import get_config


@dataclass
class TrainState:
Expand Down Expand Up @@ -56,6 +55,10 @@ def main(args):
device_type, (dp_degree, args.tp_degree), mesh_dim_names=("dp", "tp")
)

# load config
tt_config = get_config()
_use_compile = tt_config["compile"]["use_compile"]

model_name = args.model
# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
Expand Down Expand Up @@ -85,6 +88,14 @@ def main(args):

# TODO: apply parallelisms, e.g. fsdp/tp
# TODO: add metrics

# torch.compile model for improved performance
if _use_compile:
rank0_log(f"Compiling model {model_name} with torch.compile...")
torch.compile(
model,
)

train_state = TrainState()

# train loop
Expand Down