Skip to content

Commit 991d1d1

Browse files
committed
add functional test
Signed-off-by: ruit <[email protected]>
1 parent bdefc4e commit 991d1d1

2 files changed

Lines changed: 50 additions & 0 deletions

File tree

tests/functional/L1_Functional_Tests_GPU.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ time uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh
3232
time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh
3333
time uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh
3434
time uv run --no-sync bash ./tests/functional/grpo_sglang.sh
35+
time uv run --no-sync bash ./tests/functional/dpo_automodel_lora.sh
36+
time uv run --no-sync bash ./tests/functional/dpo_megatron.sh
3537
time uv run --no-sync bash ./tests/functional/dpo.sh
3638
time uv run --no-sync bash ./tests/functional/rm.sh
3739
time uv run --no-sync bash ./tests/functional/eval.sh
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/bin/bash
2+
3+
# clean up checkpoint directory on exit
4+
trap "rm -rf /tmp/lora_dpo_checkpoints" EXIT
5+
6+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
7+
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
8+
# Mark the current repo as safe, since wandb fetches metadata about the repo
9+
git config --global --add safe.directory $PROJECT_ROOT
10+
11+
set -eou pipefail
12+
13+
EXP_NAME=$(basename $0 .sh)
14+
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
15+
LOG_DIR=$EXP_DIR/logs
16+
JSON_METRICS=$EXP_DIR/metrics.json
17+
RUN_LOG=$EXP_DIR/run.log
18+
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}
19+
20+
rm -rf $EXP_DIR $LOG_DIR
21+
mkdir -p $EXP_DIR $LOG_DIR
22+
23+
cd $PROJECT_ROOT
24+
uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
25+
$PROJECT_ROOT/examples/run_dpo.py \
26+
policy.model_name=Qwen/Qwen3-0.6B \
27+
cluster.gpus_per_node=2 \
28+
dpo.max_num_steps=3 \
29+
dpo.val_batches=1 \
30+
dpo.val_global_batch_size=8 \
31+
++policy.dtensor_cfg._v2=true \
32+
policy.train_global_batch_size=8 \
33+
policy.dtensor_cfg.lora_cfg.enabled=true \
34+
logger.tensorboard_enabled=true \
35+
logger.log_dir=$LOG_DIR \
36+
logger.wandb_enabled=false \
37+
logger.monitor_gpus=true \
38+
checkpointing.enabled=true \
39+
checkpointing.save_period=3 \
40+
checkpointing.checkpoint_dir=/tmp/lora_dpo_checkpoints \
41+
"$@" \
42+
2>&1 | tee $RUN_LOG
43+
44+
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
45+
46+
uv run tests/check_metrics.py $JSON_METRICS \
47+
'data["train/loss"]["3"] < 0.8'
48+

0 commit comments

Comments
 (0)