Skip to content

Commit c6734a9

Browse files
Update get_tensor_shapes function whose signature was refactored (#14594)
* Update get_tensor_shapes function whose signature changed and wasn't refactored Signed-off-by: Asha Anoosheh <[email protected]> * Bump Mcore commit to latest on 0.14.0 branch Signed-off-by: Charlie Truong <[email protected]> * Bump Mcore Signed-off-by: Charlie Truong <[email protected]> * Set flux fsdp test to optional Signed-off-by: Charlie Truong <[email protected]> * Fix flux test to skip Signed-off-by: Charlie Truong <[email protected]> --------- Signed-off-by: Asha Anoosheh <[email protected]> Signed-off-by: Charlie Truong <[email protected]> Co-authored-by: Charlie Truong <[email protected]> Signed-off-by: Charlie Truong <[email protected]>
1 parent ee0cce8 commit c6734a9

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

.github/workflows/cicd-main-nemo2.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ jobs:
281281
script: L2_NeMo_2_Flux_ControlNet_Training_DDP_Test
282282
- runner: self-hosted-azure
283283
script: L2_NeMo_2_Flux_ControlNet_Training_FSDP_Test
284+
is-optional: true
284285

285286

286287
needs: [build]

nemo/collections/llm/modelopt/distill/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from megatron.core.dist_checkpointing.validation import StrictHandling, parse_strict_flag
2424
from megatron.core.pipeline_parallel.schedules import get_tensor_shapes
2525
from megatron.core.transformer import TransformerLayer
26-
from megatron.core.utils import get_model_config, get_model_type
26+
from megatron.core.utils import get_model_config
2727

2828
from nemo import lightning as nl
2929
from nemo.collections import llm
@@ -259,25 +259,25 @@ def get_tensor_shapes_adjust_fn_for_distillation(
259259
return None
260260

261261
def adjust_tensor_shapes(recv_tensor_shapes: List[Tuple[int, ...]], send_tensor_shapes: List[Tuple[int, ...]]):
262-
rank = parallel_state.get_pipeline_model_parallel_rank()
263262
teacher_config = get_model_config(model.teacher_model)
264-
teacher_model_type = get_model_type(model.teacher_model)
263+
tp_group = parallel_state.get_tensor_model_parallel_group()
264+
cp_group = parallel_state.get_context_parallel_group()
265265

266266
teacher_recv_tensor_shapes = get_tensor_shapes(
267-
rank=rank - 1,
268-
model_type=teacher_model_type,
269267
seq_length=seq_length,
270268
micro_batch_size=micro_batch_size,
271269
decoder_seq_length=decoder_seq_length,
272270
config=teacher_config,
271+
tp_group=tp_group,
272+
cp_group=cp_group,
273273
)
274274
teacher_send_tensor_shapes = get_tensor_shapes(
275-
rank=rank,
276-
model_type=teacher_model_type,
277275
seq_length=seq_length,
278276
micro_batch_size=micro_batch_size,
279277
decoder_seq_length=decoder_seq_length,
280278
config=teacher_config,
279+
tp_group=tp_group,
280+
cp_group=cp_group,
281281
)
282282
model.set_student_input_tensor_shape(recv_tensor_shapes)
283283

requirements/manifest.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
},
1212
"megatron-lm": {
1313
"repo": "https://github.com/NVIDIA/Megatron-LM",
14-
"ref": "7f7439f543288f50f134e44832069192a3e1d98e"
14+
"ref": "53cad7137aacf56ffc44a8672b9340f560ec6572"
1515
},
1616
"trt-llm": {
1717
"repo": "https://github.com/NVIDIA/TensorRT-LLM.git",

0 commit comments

Comments
 (0)