diff --git a/examples/gptoss/convert_mcore_bf16_checkpoint_from_hf.py b/examples/gptoss/convert_mcore_bf16_checkpoint_from_hf.py new file mode 100644 index 0000000000..0d1b20eaf9 --- /dev/null +++ b/examples/gptoss/convert_mcore_bf16_checkpoint_from_hf.py @@ -0,0 +1,278 @@ +from argparse import ArgumentParser + +from megatron.bridge import AutoBridge +from megatron.bridge.utils.common_utils import get_last_rank, print_rank_0 +from megatron.bridge.training.model_load_save import load_megatron_model, save_megatron_model, load_tokenizer + +from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer + +from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + +import os + +from megatron.core import parallel_state +from megatron.core import parallel_state as mpu +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed + +import torch +import torch.distributed as dist +from transformers import AutoModelForCausalLM, AutoTokenizer + +# load pretrain/SFT model info, only bf16 supported for the moment +MODEL="gpt-oss-20b-BF16" + +# create soft links to /workspace/models +MODEL_DIR="/workspace/models" + +HF_MODEL_DIR=f"{MODEL_DIR}/{MODEL}" + +# Specify model partitions, we use parallel folding strategy to separate EP for MLP from pp-tp-cp-dp for Attention +TP=int(os.environ.get("TP", 8)) +PP=int(os.environ.get("PP", 1)) +CP=int(os.environ.get("CP", 1)) + +# Assume a single node setup in this script +EP=int(os.environ.get("EDP", 8 // PP)) # distributed evenly among all gpu cards +# ETP can only be 1 for GptOSS for the moment with Mcore backend +ETP=1 + +SAVER="mcore_bridge" + +SEED=42 + +# adpated from megatron bridge examples/ +class SingleBatchIterator: + """Iterator that yields a single batch of data for text generation. + Required by the forward_backward_func function. + + This class creates an iterator that yields exactly one batch containing + input tokens, position IDs, and attention mask, then raises StopIteration. + Used for single-step inference in the forward pass. + """ + + def __init__(self, input_ids, position_ids, attention_mask): + self.batch = dict( + tokens=input_ids, + position_ids=position_ids, + # attention_mask=attention_mask, + ) + self._yielded = False + + def __iter__(self): + return self + + def __next__(self): + if self._yielded: + raise StopIteration + self._yielded = True + return self.batch + + +def text_forward_step(data_iterator, model, **kwargs) -> torch.Tensor: + """Forward step function for text generation. + Required by the forward_backward_func function. + + Extracts a batch from the data iterator and runs the model forward pass + with the provided input tokens, position IDs, and attention mask. + + Args: + data_iterator: Iterator providing batches of input data + model: The Megatron model to run forward pass on + **kwargs: Additional keyword arguments (unused) + + Returns: + Tuple of (model_output, loss_function) + """ + batch = next(data_iterator) + forward_args = { + "input_ids": batch["tokens"], + "position_ids": batch["position_ids"], + "attention_mask": batch.get("attention_mask", None), + } + + def loss_func(x, **kwargs): + return x + + return model(**forward_args), loss_func + + +def export(checkpoint=True): + # gptoss bf16 recipe for post training + dtype="bf16" + + # using Megatron Bridge provider API + bridge = AutoBridge.from_hf_pretrained(f"{HF_MODEL_DIR}", trust_remote_code=True) + + provider = bridge.to_megatron_provider() + + provider.tensor_model_parallel_size = TP + provider.pipeline_model_parallel_size = PP + provider.context_parallel_size = CP + + # sparse model + provider.expert_model_parallel_size = EP + provider.expert_tensor_parallel_size = ETP + + provider.finalize() + + model = provider.provide_distributed_model(wrap_with_ddp=False) + + # output info + OUTPUT=f"{MODEL_DIR}/{MODEL}-to-{SAVER}-tp{TP}-pp{PP}-cp{CP}-ep{EP}-{dtype}" + + if not checkpoint: + # to huggingface + bridge.save_hf_pretrained(model, f"{OUTPUT}") + else: + # to megatron checkpoint + save_megatron_model(model, f"{OUTPUT}", hf_tokenizer_path=f"{HF_MODEL_DIR}") + OUTPUT = f"{OUTPUT}/iter_0000000" + + return model, OUTPUT + + +def _verify_tokenizer_and_hfmodel(hf_tokenizer, model): + texts = ["Once upon the time",] + messages = [ + {"role": "user", "content": text} for text in texts + ] + + prompts = hf_tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True) + + model_inputs = hf_tokenizer([prompts], return_tensors="pt").to(model.device) + + outputs_ids = model.generate(**model_inputs, max_new_tokens=16) + + outputs_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, outputs_ids) + ] + + response = hf_tokenizer.batch_decode(outputs_ids, skip_special_tokens=True)[0] + print(f"[Rank#{torch.distributed.get_rank()}] response : {response}") + +def verify_tokenizer_and_hfmodel(hf_tokenizer_path, model): + hf_tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_path) + + _verify_tokenizer_and_hfmodel(hf_tokenizer, model) + +def verify_megatron_fwd(tokenizer_path, model, max_length=16): + tokenizer = load_tokenizer(tokenizer_path) + + assert isinstance(tokenizer, _HuggingFaceTokenizer), "update script to adapt to mcore tokenizer (I am using legacy huggingface tokenizer)" + + model = [m.cuda() for m in model] + for m in model: + m.eval() + + prompt = "Once upon the time" + token_ids = tokenizer.tokenize(prompt) + + with torch.no_grad(): + input_batch = torch.tensor([token_ids]).cuda() + + output_ids = input_batch.clone() + + fwd_bwd_function = get_forward_backward_func() + + for i in range(max_length - len(token_ids)): + position_ids = torch.arange(output_ids.size(1), dtype=torch.long, device=output_ids.device) + attention_mask = torch.ones_like(output_ids, dtype=torch.bool) + + data_iterator = SingleBatchIterator(output_ids, position_ids, attention_mask) + + output = fwd_bwd_function( + forward_step_func=text_forward_step, + data_iterator=data_iterator, + model=model, + num_microbatches=1, + forward_only=True, + seq_length=input_batch.size(1), + micro_batch_size=1, + collect_non_loss_data=True, + ) + + if isinstance(output, list) and len(output) > 0: + output = output[0] + + if parallel_state.is_pipeline_last_stage(): + world_size = parallel_state.get_tensor_model_parallel_world_size() + gathered_tensors = [torch.zeros_like(output) for _ in range(world_size)] + + dist.all_gather(gathered_tensors, output, group=parallel_state.get_tensor_model_parallel_group()) + + logits = torch.cat(gathered_tensors, dim=2) + next_token_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1) + else: + next_token_id = torch.ones((1, 1), device=output_ids.device, dtype=output_ids.dtype) + + torch.distributed.broadcast(next_token_id, get_last_rank()) + output_ids = torch.cat([output_ids, next_token_id], dim=1) + + if next_token_id.item() == tokenizer._tokenizer.eos_token_id: + break + + response = tokenizer._tokenizer.decode(output_ids[0].cpu().numpy(), skip_special_tokens=True) + print_rank_0(f"Rank#{torch.distributed.get_rank()} Response : {response}") + + +if __name__ == "__main__": + parser = ArgumentParser() + + parser.add_argument( + "--source_model", default=None, type=str, required=False, help="source model." + ) + parser.add_argument( + "--output_hf_dir", default=None, type=str, required=False, help="Where to save the converted model." + ) + parser.add_argument( + "--output_ckpt_dir", default=None, type=str, required=False, help="Where to save the converted model." + ) + args = parser.parse_args() + + if args.source_model: + MODEL_DIR = args.source_model + HF_MODEL_DIR=f"{MODEL_DIR}/{MODEL}" + + if args.output_hf_dir: + OUTPUT_DIR = args.output_hf_dir + + model = AutoModelForCausalLM.from_pretrained(OUTPUT_DIR, + torch_dtype="auto", + trust_remote_code=True) + + verify_tokenizer_and_hfmodel(OUTPUT_DIR, model) + elif args.output_ckpt_dir: + OUTPUT_DIR = f"{args.output_ckpt_dir}/iter_0000000" + + bridge = AutoBridge.from_hf_pretrained(f"{HF_MODEL_DIR}", trust_remote_code=True) + + provider = bridge.to_megatron_provider() + + provider.tensor_model_parallel_size = TP + provider.pipeline_model_parallel_size = PP + provider.context_parallel_size = CP + + # sparse model + provider.expert_model_parallel_size = EP + + # provider.sequence_parallel = True + + provider.finalize() + provider.initialize_model_parallel(seed=SEED) + + model = load_megatron_model(OUTPUT_DIR) + + verify_megatron_fwd(OUTPUT_DIR, model) + + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + else: + model, OUTPUT_DIR = export() + + verify_megatron_fwd(OUTPUT_DIR, model) + + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() \ No newline at end of file diff --git a/examples/gptoss/training_gptoss_20b_h100_bf16_fp8.sh b/examples/gptoss/training_gptoss_20b_h100_bf16_fp8.sh new file mode 100644 index 0000000000..1eac4e5edd --- /dev/null +++ b/examples/gptoss/training_gptoss_20b_h100_bf16_fp8.sh @@ -0,0 +1,286 @@ +#!/usr/bin/bash +ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/../.." && pwd )" + +set -e +set -x + +## NCCL config +export NCCL_IB_HCA=mlx5_0,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_9,mlx5_10,mlx5_11 +# export NCCL_IB_HCA=mlx5 +# export NCCL_TOPO_DUMP_FILE=topo.xml +# traffic class for QoS tunning +export NCCL_IB_TC=136 +# service level that maps virtual lane +export NCCL_IB_SL=5 +export NCCL_IB_GID_INDEX=3 +export NCCL_IB_CUDA_SUPPORT=1 +export NCCL_IB_TIMEOUT=22 +# for HKUST supper pod, and this sets the TCP/IP-based interface for fallback of socket-based NCCL communication +# 'ibp154s0'(tcp), 'ibp170s0f0'(tcp), 'ibp192s0'(tcp), 'ibp206s0'(tcp), 'ibp220s0'(tcp), 'ibp24s0'(tcp), 'ibp41s0f0'(tcp), 'ibp64s0'(tcp), 'ibp79s0'(tcp), 'ibp94s0'(tcp) +# NOTE(yiakwy) : see ib device and roce device mapping via ibdev2netdev +export NCCL_SOCKET_IFNAME=ibp24s0,ibp41s0f0,ibp64s0,ibp79s0,ibp94s0,ibp154s0,ibp170s0f0,ibp192s0 +# export UCX_NET_DEVICES=$NCCL_SOCKET_IFNAME +export NCCL_SOCKET_IFNAME=ibp #/*NOTE*/ +export NCCL_DEBUG=DEBUG + +export SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 +export NCCL_COLLNET_ENABLE=1 + +export NCCL_IB_DISABLE=1 + +DIST_ENV=${DIST_ENV:-dsw} +echo "DIST_ENV : $DIST_ENV" + + +MEGATRON_PATH=$ROOT # ${MEGATRON_PATCH_PATH} + + +# add python path +export PYTHONPATH=${MEGATRON_PATH}:$PYTHONPATH + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +if [ $DIST_ENV = dsw ]; then + +MASTER_ADDR=localhost +MASTER_PORT=$(shuf -n 1 -i 10000-65535) +NNODES=1 +NODE_RANK=0 +GPUS_PER_NODE=8 + +elif [ $DIST_ENV = dlc ]; then + +NNODES=${WORLD_SIZE} +NODE_RANK=${RANK} +GPUS_PER_NODE=${KUBERNETES_CONTAINER_RESOURCE_GPU} + +fi + +# megatron world size def +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +## Model Parallel +TP=8 +PP=1 +CP=1 +EDP=8 +ETP=1 + +# TODO (yiakwy) : rename to DIST_OPT, --use-distributed-optimizer +DO=${DIST_OPT:-true} + +TE=${TE:-true} # ${TE:-false} + +PRETRAIN_CHECKPOINT_PATH="/raid/gpt-oss-20b-BF16-to-mcore_bridge-tp8-pp1-cp1-ep8-bf16/iter_0000000" + +# PRETRAIN_CHECKPOINT_PATH="/raid/gpt-oss-20b-BF16-to-mcore_bridge-tp1-pp1-cp1-ep8-bf16/iter_0000000" +# PRETRAIN_CHECKPOINT_PATH="/raid/gpt-oss-20b-BF16-to-mcore_bridge-tp1-pp2-cp1-ep4-bf16/iter_0000000" + +OUTPUT_BASEPATH=output + + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ +" + +EXTRA_ARGS=" + --finetune \ + --no-load-optim \ + --no-load-rng +" + +load_options=" + --load $PRETRAIN_CHECKPOINT_PATH" + +DATASET_PATH=${DATASET_PATH} +echo "DATASET_PATH : $DATASET_PATH" +if [ -z "$DATASET_PATH" ];then + echo "WARN : DATASET_PATH is not set, using mocked dataset" + load_dataset="--mock-data" +else + load_dataset="--train-data-path ${DATASET_PATH}" +fi + + +if [ $DO = true ]; then + do_options=" \ + --use-distributed-optimizer" +elif [ $DO = false ]; then + do_options=" \ + " +fi + + +if [ $TE = true ]; then + te_options=" \ + --transformer-impl transformer_engine" +elif [ $TE = false ]; then + te_options=" \ + --transformer-impl local" +fi + + +# --fp8-param-gather \ +pr_options=" \ + --bf16 + --fp8-format hybrid \ + --fp8-param-gather \ + --fp8-amax-compute-algo max \ + --fp8-amax-history-len 1024" + + +# pr_options=" \ +# --bf16 +# --fp8-format hybrid" + + +# NOTE (yiakwy) : wierd options, see https://github.com/NVIDIA/Megatron-LM/commit/a2d8c806b35bc708b13e6c069e19e5dfb49b8481#r171154625 +# --seq-length 4096 \ +# --max-position-embeddings 40960 \ + +# --seq-length 131072 \ +GPT_OSS_SFT_MODEL_ARGS=" \ + --no-masked-softmax-fusion \ + --untie-embeddings-and-output-weights \ + --no-rope-fusion \ + --normalization RMSNorm \ + --num-layers 24 \ + --hidden-size 2880 \ + --ffn-hidden-size 2880 \ + --num-attention-heads 64 \ + --group-query-attention \ + --num-query-groups 8 \ + --kv-channels 64 \ + --seq-length 4096 \ + --max-position-embeddings 4096 \ + --make-vocab-size-divisible-by 128 \ + --use-mcore-models \ + --rotary-percent 1.0 \ + --rotary-base 150000 \ + --no-bias-gelu-fusion \ + --sequence-parallel \ + --export-force-local-attention \ + --no-bias-dropout-fusion \ + --padded-vocab-size 201088 \ + --quick-geglu \ + --glu-linear-offset 1.0 \ + --softmax-type learnable \ + --window-attn-skip-freq 2 \ + --activation-func-clamp-value 7.0 \ + --window-size 128,0 \ + --enable-gpt-oss \ +" + +MODEL_ARGS=( + --init-method-std 0.01 + --hidden-dropout 0.0 + --attention-dropout 0.0 + --rope-type yarn + --position-embedding-type yarn +) + +# --moe-aux-loss-coeff 1e-2 +# --moe-router-pre-softmax +MOE_ARGS=( + --num-experts 32 + --moe-router-load-balancing-type none # options: aux_loss, sinkhorn, None. Default is aux_loss. + --moe-router-topk 4 + --moe-aux-loss-coeff 0.0 + --moe-grouped-gemm + --moe-permute-fusion + --moe-ffn-hidden-size 2880 + --moe-router-dtype fp32 + --moe-token-dispatcher-type alltoall + --moe-router-score-function softmax +) + +mkdir -p $OUTPUT_BASEPATH/index_mapping + +DATA_ARGS=( + $load_dataset + --data-cache-path $OUTPUT_BASEPATH/index_mapping + --num-workers 5 + --tokenizer-type HuggingFaceTokenizer + --tokenizer-model "$PRETRAIN_CHECKPOINT_PATH/tokenizer" +) + +TRAINING_ARGS=( + --micro-batch-size 1 + --global-batch-size 128 + --lr 1e-4 + --train-iters 500000 + --lr-decay-iters 320000 + --lr-decay-style cosine + --min-lr 1.0e-5 + --weight-decay 0.1 + --lr-warmup-iters 500 + --clip-grad 1.0 + --overlap-grad-reduce + --overlap-param-gather +) + +MODEL_PARALLEL_ARGS=( + --tensor-model-parallel-size $TP + --pipeline-model-parallel-size $PP + --context-parallel-size $CP + --expert-model-parallel-size $EDP + --expert-tensor-parallel-size $ETP + --sequence-parallel + --use-distributed-optimizer + --disable-gloo-process-groups + --enable-gpt-oss # see options added by modelopts +) + +mkdir -p "${OUTPUT_BASEPATH}/checkpoint/" +mkdir -p "${OUTPUT_BASEPATH}/tensorboard/" +mkdir -p "${OUTPUT_BASEPATH}/log/" +current_time=$(date "+%Y.%m.%d-%H.%M.%S") + +OUTPUT_ARGS=" + --log-interval 1 \ + --eval-interval 1000 \ + --eval-iters 0 \ + --save-interval ${SAVE_INTERVAL:-50} + --save "${OUTPUT_BASEPATH}/checkpoint" \ +" + +LOGGING_ARGS=" + --timing-log-level 2 \ + --log-throughput + --log-timers-to-tensorboard \ + --log-validation-ppl-to-tensorboard \ + --tensorboard-queue-size 1 \ + --tensorboard-dir "${OUTPUT_BASEPATH}/tensorboard" \ +" + +if [ -n "${WANDB_API_KEY}" ]; then + LOGGING_ARGS+=( + --wandb-project ${WANDB_PROJECT:-"GptOSS-20b-bf16-SFT"} + --wandb-exp-name ${WANDB_NAME:-"gptoss-20b-bf16"} + ) +fi + +export PYTORCH_ALLOC_CONF=expandable_segments:True +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# TODO (yiakwy) : use modelopt finetune +torchrun ${DISTRIBUTED_ARGS[@]} $MEGATRON_PATH/pretrain_gpt.py \ + $load_options \ + $GPT_OSS_SFT_MODEL_ARGS \ + ${MODEL_ARGS[@]} \ + $do_options \ + $te_options \ + ${MOE_ARGS[@]} \ + ${DATA_ARGS[@]} \ + ${TRAINING_ARGS[@]} \ + --recompute-activations \ + $pr_options \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${OUTPUT_ARGS} \ + ${LOGGING_ARGS} 2>&1 | tee ${OUTPUT_BASEPATH}/log/megatron_trainning.Rank_${RANK}.log diff --git a/examples/llama/convert_megatron_bf16_checkpoint_from_hf.sh b/examples/llama/convert_megatron_bf16_checkpoint_from_hf.sh new file mode 100644 index 0000000000..3faf966440 --- /dev/null +++ b/examples/llama/convert_megatron_bf16_checkpoint_from_hf.sh @@ -0,0 +1,43 @@ +#/usr/bin/bash +SCRIPT_ROOT=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +PROJECT_ROOT=$SCRIPT_ROOT/../.. + +export PYTHONPATH=$PROJECT_ROOT:$PYTHONPATH + +TP=${TP:-1} +PP=${PP:-1} + +MODEL=Llama-2-7b-hf +MODEL_TYPE=llama2-7B + +# create soft links to /workspace/models +MODEL_DIR=/workspace/models + +SAVER=megatron + +dtype="bf16" + +HF_MODEL_DIR=$MODEL_DIR/$MODEL +OUTPUT=$MODEL_DIR/$MODEL-to-$SAVER-tp$TP-pp$PP-$dtype + +# old transformer use this path +# TOKENIZER_MODEL=$HF_MODEL_DIR/tokenizer.model + +TOKENIZER_MODEL=$HF_MODEL_DIR + +dtype_opt=( + --$dtype +) + +python $PROJECT_ROOT/tools/checkpoint/convert.py \ + --model-type GPT \ + --loader llama_mistral \ + --model-size $MODEL_TYPE \ + --saver $SAVER \ + --target-tensor-parallel-size ${TP} \ + --target-pipeline-parallel-size ${PP} \ + --checkpoint-type hf \ + --load-dir $HF_MODEL_DIR \ + --save-dir $OUTPUT \ + --tokenizer-model ${TOKENIZER_MODEL} \ + ${dtype_opt[@]} \ No newline at end of file diff --git a/examples/llama/train_llama2_7b_h100_fp16_fp8.sh b/examples/llama/train_llama2_7b_h100_fp16_fp8.sh new file mode 100644 index 0000000000..fd79771211 --- /dev/null +++ b/examples/llama/train_llama2_7b_h100_fp16_fp8.sh @@ -0,0 +1,445 @@ +#!/usr/bin/bash +ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/../.." && pwd )" + +set -e +set -x + +# TOOD (yiakwy) : add NCCL args +export NCCL_IB_HCA=mlx5 +export NCCL_IB_TC=136 +export NCCL_IB_SL=5 +export NCCL_IB_GID_INDEX=3 +export NCCL_IB_TIMEOUT=22 +export NCCL_IB_SPLIT_DATA_ON_QPS=1 +export NCCL_IB_QPS_PER_CONNECTION=8 +export NCCL_IB_RETRY_CNT=13 +export NCCL_SOCKET_IFNAME=ens #/*NOTE*/ +export NCCL_DEBUG=INFO + + +# export NCCL_IB_HCA=ibp +# export UCX_NET_DEVICES=ibp0:1,ibp1:1,ibp2:1,ibp3:1,ibp4:1,ibp5:1,ibp6:1,ibp7:1 +export SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 +export NCCL_COLLNET_ENABLE=0 + + +# This is needed to avoid NCCL to use ifiniband, which the cluster is not ready +export NCCL_IB_DISABLE=1 + + +DIST_ENV=${DIST_ENV:-dsw} +echo "DIST_ENV : $DIST_ENV" + + +MEGATRON_PATCH_PATH=$MEGATRON_PATCH_PATH +echo "MEGATRON_PATCH_PATH : $MEGATRON_PATCH_PATH" +if [ -z "$MEGATRON_PATCH_PATH" ]; then + echo "Error : MEGATRON_PATCH_PATH is not set" + MEGATRON_PATCH_PATH=$ROOT + echo "setting MEGATRON_PATCH_PATH to $MEGATRON_PATCH_PATH" +else + echo "Ok" +fi + + +MEGATRON_PATH=${MEGATRON_PATCH_PATH} + + +# add python path +export PYTHONPATH=${MEGATRON_PATH}:${MEGATRON_PATCH_PATH}:$PYTHONPATH + + +# TODO (yaikwy) : this is required by sequence parallelism , consider to deprecate the option in favor of context parallel +export CUDA_DEVICE_MAX_CONNECTIONS=1 + + +if [ $DIST_ENV = dsw ]; then +# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export CUDA_VISIBLE_DEVICES=4,5,6,7 +MASTER_ADDR=localhost +MASTER_PORT=$(shuf -n 1 -i 10000-65535) +NNODES=1 +NODE_RANK=0 +# GPUS_PER_NODE=8 +GPUS_PER_NODE=4 + + +elif [ $DIST_ENV = dlc ]; then + + +NNODES=${WORLD_SIZE} +NODE_RANK=${RANK} +GPUS_PER_NODE=${KUBERNETES_CONTAINER_RESOURCE_GPU} + + +fi + + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + + +# TODO (yiakwy) : decoupled from llama-2-7b model, used to define GPT control args : +# --num-layers, +# --hidden-size, +# --num-attention-heads, +# --ffn-hidden-size +MODEL_SIZE=${MODEL_SIZE:-7B} +echo "MODEL_SIZE : $MODEL_SIZE" +if [ -z "$MODEL_SIZE" ]; then + echo "Error : MODEL_SIZE is not set" + exit 1 +fi + + +# TODO (yiakwy) : renamed to MICRO_BSZ, this is micro batch size NOT batch size +BATCH_SIZE=${MICRO_BATCH_SIZE:-1} +echo "MICRO BATCH SIZE: $BATCH_SIZE" + + +# TODO (yiakwy) : renamed to GLOBAL_BSZ +GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE:-2048} +echo "GLOBAL BATCH SIZE: $GLOBAL_BATCH_SIZE" + + +# peak lr +LR=${LR:-9e-7} +# minimum lr +MIN_LR=${MIN_LR:-9e-8} + + +# TODO (yiakwy) : finetune adam beta +B1=0.9 +B2=0.95 + + +# 2K +SEQ_LEN=${SEQ_LEN:-2048} +echo "SEQ_LEN : $SEQ_LEN" + + +# TODO (yiakwy) : add new tokens to vocabulary to scale up, standard llama has vocabulary size of 32k +EXTRA_VOCAB_SIZE=${EXTRA_VOCAB_SIZE:-0} + + +# precision +PR=${PR:-bf16} + + +TP=${TP:-1} +echo "TP : $TP" + + +PP=${PP:-1} +echo "PP : $PP" + + +# TODO (yiakwy) : add DP, ACC +DP=$((NNODES * GPUS_PER_NODE / TP / PP)) +ACC=$((GLOBAL_BATCH_SIZE / DP / BATCH_SIZE)) +echo "DP : $DP" +echo "ACC : $ACC" + + +# TODO (yiakwy) : rename to AUTO_RECOMPUTE_OPT +AC=${AUTO_RECOMPUTE_OPT:-sel} +echo "AUTO_RECOMPUTE_OPT : $AC" + + +# TODO (yiakwy) : rename to DIST_OPT, --use-distributed-optimizer +DO=${DIST_OPT:-true} + + +# TODO (yiakwy) : rename to USE_FLASH_ATTN, --use-flash-attn +FL=${USE_FLASH_ATTN:-true} + + +# TODO (yiakwy) : disable SP, in favor of CP +SP=${SP:-false} + + +# TODO (yiakwy) : add support TE equivalent functions (FP8) in ROCm +TE=${TE:-false} + + +SAVE_INTERVAL=${SAVE_INTERVAL:-50} + + +# megatron dataset path +DATASET_PATH=${DATASET_PATH} +echo "DATASET_PATH : $DATASET_PATH" +if [ -z "$DATASET_PATH" ];then + echo "WARN : DATASET_PATH is not set, using mocked dataset" + load_dataset="--mock-data" +else + load_dataset="--train-data-path ${DATASET_PATH}" +fi + + +# TODO (yiakwy) : add control, this is only used in SFT task +PRETRAIN_CHECKPOINT_PATH=${PRETRAIN_CHECKPOINT_PATH} +echo "PRETRAIN_CHECKPOINT_PATH : $PRETRAIN_CHECKPOINT_PATH" +if [ -z "$PRETRAIN_CHECKPOINT_PATH" ];then + echo "NOTE : PRETRAIN_CHECKPOINT_PATH is not set, switch to pretrain mode" + EXTRA_ARGS="" + + load_options="" +else + EXTRA_ARGS=" + --finetune \ + --no-load-optim + " + + load_options=" + --load $PRETRAIN_CHECKPOINT_PATH" +fi + + +# TODO (yiakwy) : add support blending dataset to scale up to 1 trillion tokens, in this test alibaba recommends 1/10 B tokens +TRAIN_TOKENS=${TRAIN_TOKENS:-10000000} +echo "TRAIN_TOKENS : $TRAIN_TOKENS" + + +WARMUP_TOKENS=${WARMUP_TOKENS:-1000} + + +DEFAULT_OUTPUT_BASEPATH="/workspace/logs/llama-2-7b_tp${TP}_pp${PP}_dp${DP}_MBSZ${BATCH_SIZE}_ACC${ACC}-profiling" + + +OUTPUT_BASEPATH=$DEFAULT_OUTPUT_BASEPATH + + +if [ $MODEL_SIZE = 7B ]; then + + +NUM_LAYERS=32 +HIDDEN_SIZE=4096 +NUM_ATTN_HEADS=32 +INTERMEDIATE_SIZE=11008 +MAX_POSITION_EMBEDDINGS=4096 + + +gqa_options="" + + +elif [ $MODEL_SIZE = 13B ]; then + + +NUM_LAYERS=40 +HIDDEN_SIZE=5120 +NUM_ATTN_HEADS=40 +INTERMEDIATE_SIZE=13824 +MAX_POSITION_EMBEDDINGS=4096 + + +gqa_options="" + + +elif [ $MODEL_SIZE = 70B ]; then + + +NUM_LAYERS=80 +HIDDEN_SIZE=8192 +NUM_ATTN_HEADS=64 +INTERMEDIATE_SIZE=28672 +MAX_POSITION_EMBEDDINGS=4096 + + +gqa_options=" \ + --group-query-attention \ + --num-query-groups 8" +fi + + +if [ $AC = full ]; then + activation_checkpoint_options=" \ + --recompute-method uniform \ + --recompute-granularity full" +elif [ $AC = sel ]; then + activation_checkpoint_options=" \ + --recompute-activations" +elif [ $AC = none ]; then + activation_checkpoint_options=" \ + " +fi + + +if [ $PR = fp16 ]; then + pr_options=" \ + --fp16" +elif [ $PR = bf16 ]; then + pr_options=" \ + --bf16" +elif [ $PR = fp8 ]; then + pr_options=" \ + --bf16 + --fp8-hybrid \ + --fp8-amax-compute-algo max \ + --fp8-amax-history-len 1024 \ + --transformer-impl transformer_engine" +fi + + +if [ $DO = true ]; then + do_options=" \ + --use-distributed-optimizer" + + +elif [ $DO = false ]; then + do_options=" \ + " +fi + + +if [ $FL = true ]; then + flash_options=" \ + --use-flash-attn" + + +elif [ $FL = false ]; then + flash_options=" \ + " +fi + + +if [ $TE = true ]; then + te_options=" \ + --transformer-impl transformer_engine" + + +elif [ $TE = false ]; then + te_options=" \ + --transformer-impl local" +fi + + +if [ $SP = true ] && [ $TP -gt 1 ]; then + sp_options=" \ + --sequence-parallel" + + +elif [ $SP = false ]; then + sp_options=" \ + " +fi + + +mkdir -p $OUTPUT_BASEPATH/index_mapping + + +DATA_ARGS=" + $load_dataset \ + --data-cache-path $OUTPUT_BASEPATH/index_mapping \ + --num-workers 5 +" + +MODEL=Llama-2-7b-hf +TOK=Llama2Tokenizer +TOKENIZER_MODEL=/workspace/models/$MODEL/tokenizer.model + + +TRAIN_ITERS=$(( ${TRAIN_TOKENS} / ${GLOBAL_BATCH_SIZE} / ${SEQ_LEN} )) +LR_WARMUP_ITERS=$(( ${WARMUP_TOKENS} / ${GLOBAL_BATCH_SIZE} / ${SEQ_LEN} )) +LR_DECAY_ITERS=$(( ${TRAIN_TOKENS} / ${GLOBAL_BATCH_SIZE} / ${SEQ_LEN} )) + + +NAME="${DIST_ENV}-pretrain-megatron-llama-2-7b-${MODEL_SIZE}-lr-${LR}-bs-${BATCH_SIZE}-seqlen-${SEQ_LEN}-pr-${PR}-tp-${TP}-pp-${PP}-ac-${AC}-do-${DO}-sp-${SP}-tt-${TRAIN_TOKENS}-wt-${WARMUP_TOKENS}" +mkdir -p "${OUTPUT_BASEPATH}/tensorboard/" +mkdir -p "${OUTPUT_BASEPATH}/checkpoint/" +mkdir -p "${OUTPUT_BASEPATH}/log/" +current_time=$(date "+%Y.%m.%d-%H.%M.%S") +TENSORBOARD_DIR="${OUTPUT_BASEPATH}/tensorboard/${NAME}_${current_time}" +mkdir -p ${TENSORBOARD_DIR} + + +SAVED_PRETRAIN_CHECKPOINT_PATH="${OUTPUT_BASEPATH}/checkpoint/${NAME}" + + +megatron_options=" \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings ${MAX_POSITION_EMBEDDINGS} \ + --num-layers ${NUM_LAYERS} \ + --hidden-size ${HIDDEN_SIZE} \ + --num-attention-heads ${NUM_ATTN_HEADS} \ + --ffn-hidden-size ${INTERMEDIATE_SIZE} \ + --swiglu \ + --normalization RMSNorm \ + --optimizer adam \ + --micro-batch-size ${BATCH_SIZE} \ + --global-batch-size ${GLOBAL_BATCH_SIZE} \ + --lr ${LR} \ + --min-lr ${MIN_LR} \ + --lr-decay-style linear \ + --adam-beta1 $B1 \ + --adam-beta2 $B2 \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --init-method-std 0.01 \ + --lr-decay-iters ${LR_DECAY_ITERS} \ + --lr-warmup-iters ${LR_WARMUP_ITERS} \ + --train-iters ${TRAIN_ITERS} \ + --seed 1234 \ + --tokenizer-type $TOK \ + --tokenizer-model $TOKENIZER_MODEL \ + --use-rotary-position-embeddings \ + --rotary-percent 1.0 \ + --no-load-rng \ + --no-masked-softmax-fusion \ + --position-embedding-type rope \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --initial-loss-scale 4096 \ + --attention-softmax-in-fp32 \ + --use-legacy-models + " + + +OUTPUT_ARGS=" + --log-interval 1 \ + --eval-interval 10000 \ + --eval-iters 0 \ + --save-interval ${SAVE_INTERVAL} +" + +# --log-batch-size-to-tensorboard \ +LOGGING_ARGS=" + --timing-log-level 2 \ + --log-throughput + --log-timers-to-tensorboard \ + --log-validation-ppl-to-tensorboard \ + --tensorboard-queue-size 1 \ + --tensorboard-dir ${TENSORBOARD_DIR} +" + + +cat $0 > ${OUTPUT_BASEPATH}/launch_script.sh +git config --global --add safe.directory $ROOT +echo "COMMIT_ID=$(git rev-parse HEAD)" >> ${OUTPUT_BASEPATH}/commit_id.txt + + +torchrun $DISTRIBUTED_ARGS $MEGATRON_PATH/pretrain_gpt.py \ + ${megatron_options} \ + ${pr_options} \ + ${load_options} \ + ${te_options} \ + ${activation_checkpoint_options} \ + ${do_options} \ + ${flash_options} \ + ${sp_options} \ + ${gqa_options} \ + ${EXTRA_ARGS} \ + ${DATA_ARGS} \ + ${OUTPUT_ARGS} \ + ${LOGGING_ARGS} \ + &> ${OUTPUT_BASEPATH}/log/${NODE_RANK}.log + + +set +e +set +x \ No newline at end of file diff --git a/gpt_builders.py b/gpt_builders.py index 9fa1aff72c..c2c948dbdd 100644 --- a/gpt_builders.py +++ b/gpt_builders.py @@ -28,6 +28,22 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None): config = core_transformer_config_from_yaml(args, "language_model") else: config = core_transformer_config_from_args(args) + + # Handle GPT-OSS mode with YaRN RoPE configuration + if hasattr(args, 'enable_gpt_oss') and args.enable_gpt_oss: + print_rank_0("GPT-OSS mode enabled: Configuring YaRN RoPE parameters") + + # Set GPT-OSS YaRN values directly on the config + # These defaults are based on Huggingface GPT-OSS configurations + config.position_embedding_type = "yarn" + config.yarn_rotary_scaling_factor = 32.0 + config.yarn_original_max_position_embeddings = 131072 + config.yarn_beta_fast = 32.0 + config.yarn_beta_slow = 1.0 + config.yarn_mscale = 1.0 + config.yarn_mscale_all_dim = 0.0 + config.yarn_correction_range_round_to_int = False + if args.use_legacy_models: model = megatron.legacy.model.GPTModel( config, diff --git a/megatron/core/jit.py b/megatron/core/jit.py index b67810f2e3..cb989c35b7 100644 --- a/megatron/core/jit.py +++ b/megatron/core/jit.py @@ -1,5 +1,6 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import os import torch from megatron.core.utils import is_torch_min_version @@ -7,6 +8,7 @@ jit_fuser = torch.jit.script # nvFuser is deprecated in PyTorch JIT starting from 2.2 +use_noop_decoator = os.environ.get('TORCHINDUCTOR_DISABLE', '0') == '1' def noop_decorator(func): '''No-op decorator''' @@ -16,12 +18,15 @@ def noop_decorator(func): def enable_jit_fuser(): '''Enable the JIT fuser''' global jit_fuser - try: - if is_torch_min_version("2.2.0a0"): - jit_fuser = torch.compile - except ImportError: + if use_noop_decoator: jit_fuser = noop_decorator + else: + try: + if is_torch_min_version("2.2.0a0"): + jit_fuser = torch.compile + except ImportError: + jit_fuser = noop_decorator def disable_jit_fuser(): diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index e31fcd2577..745bcc8fd3 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -146,8 +146,8 @@ class ModelParallelConfig: gradient_accumulation_fusion: bool = False """If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install - APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" - --global-option=\"--cuda_ext\" ". Note that the extension requires CUDA>=11. Otherwise, you + APEX with \"--cpp_ext\" and \"--cuda_ext\". For example: "pip install --global-option=\"--cpp_ext\" + --global-option=\"--cuda_ext\" . ". Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion. """ diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 1e41bf9d8c..8266b41fd2 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -527,7 +527,7 @@ def initialize_model_parallel( order: str = "tp-cp-ep-dp-pp", get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, - create_gloo_process_groups: bool = True, + create_gloo_process_groups: bool = False, high_priority_stream_groups: Optional[List[str]] = None, sharp_enabled_group: Optional[str] = None, ) -> None: diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 221f3327e5..83a203f56d 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -678,8 +678,8 @@ def linear_with_grad_accumulation_and_async_allreduce( accumulation fusion, requires the custom CUDA extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with - --cpp_ext and --cuda_ext. For example: "pip install - --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" + \"--cpp_ext\" and \"--cuda_ext\". For example: "pip install + --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" . " Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion." @@ -938,9 +938,9 @@ def __init__( "ColumnParallelLinear was called with gradient_accumulation_fusion set " "to True but the custom CUDA extension fused_weight_gradient_mlp_cuda " "module is not found. To use gradient_accumulation_fusion you must " - "install APEX with --cpp_ext and --cuda_ext. For example: " - 'pip install --global-option="--cpp_ext" --global-option="--cuda_ext ." ' - "Note that the extension requires CUDA>=11. Otherwise, you must turn off " + "install APEX with --cpp_ext and --cuda_ext. For example: \'" + 'pip install --global-option="--cpp_ext" --global-option="--cuda_ext" . ' + "\'. Note that the extension requires CUDA>=11. Otherwise, you must turn off " "gradient accumulation fusion." ) self.gradient_accumulation_fusion = config.gradient_accumulation_fusion diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 7391bcaf12..515f7e26e4 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -857,8 +857,11 @@ def forward( output (torch.Tensor): The output of the local experts. """ tokens_per_expert = tokens_per_expert.tolist() + + actual_tokens_per_expert = tokens_per_expert + acutal_permuted_probs = permuted_probs.unsqueeze(-1) + if self.config.fp8 or self.config.fp4: - actual_tokens_per_expert = tokens_per_expert permuted_local_hidden_states, tokens_per_expert = self.quantization_padding( permuted_local_hidden_states, tokens_per_expert ) @@ -977,7 +980,7 @@ def glu(x): if self.config.fp8 or self.config.fp4: output = self.quantization_unpadding(output, actual_tokens_per_expert) - output = self._apply_bias(output, output_bias, tokens_per_expert, permuted_probs) + output = self._apply_bias(output, output_bias, actual_tokens_per_expert, acutal_permuted_probs) output_bias = None return output, output_bias diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index bb1b17e9ba..993951bd41 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1045,7 +1045,7 @@ def validate_args(args, defaults={}): # Would just need to add 'NoPE' as a position_embedding_type to support this, but for now # don't allow it to keep things simple - if not args.add_position_embedding and args.position_embedding_type != 'rope': + if not args.add_position_embedding and args.position_embedding_type not in ('rope', 'yarn'): raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type') # Relative position embeddings arguments @@ -1140,7 +1140,7 @@ def validate_args(args, defaults={}): if args.dist_ckpt_optim_fully_reshardable: assert not args.distrib_optim_fully_reshardable_mem_efficient, \ - '--distrib-optim-fully-reshardable-mem-efficient requires -enable-gloo-process-groups' + '--distrib-optim-fully-reshardable-mem-efficient requires --enable-gloo-process-groups' if args.fake_process_group: # Disable nan check for fake process group @@ -1347,8 +1347,9 @@ def core_transformer_config_from_args(args, config_class=None): # Pop 'rope_type' to let the config class use the default value. kw_args.pop('rope_type', None) else: - assert (args.multi_latent_attention or args.rope_type == 'rope'), ( - f'Common attention only support rope_type="rope", but got {args.rope_type}.' + # GptOSS defaults to yarn, but it possible to use many others and yarn has already been supported in Mcore + assert (not args.multi_latent_attention or args.rope_type != 'yarn'), ( + f'If using MLA attention (sft deepseek V3), rope_type is expected to be "yarn", but got {args.rope_type}.' ) if len(args.cp_comm_type) == 1: @@ -1677,7 +1678,7 @@ def _add_network_size_args(parser): help='Maximum number of position embeddings to use. ' 'This is the size of position embedding.') group.add_argument('--position-embedding-type', type=str, default='learned_absolute', - choices=['learned_absolute', 'rope', 'mrope', 'relative', 'none'], + choices=['learned_absolute', 'rope', 'yarn', 'mrope', 'relative', 'none'], help='Position embedding type.') group.add_argument('--relative-attention-num-buckets', type=int, default=32, help='Number of buckets for relative position embeddings.') diff --git a/tools/checkpoint/convert.py b/tools/checkpoint/convert.py index e9eb7e99b6..7293500ecb 100644 --- a/tools/checkpoint/convert.py +++ b/tools/checkpoint/convert.py @@ -107,7 +107,6 @@ def load_plugin(plugin_type, name): return plugin def main(): - import argparse parser = argparse.ArgumentParser(description="Megatron Checkpoint Converter Arguments", allow_abbrev=False, conflict_handler='resolve')