|
44 | 44 | # used to compute the scaling factor for quantization. |
45 | 45 | torch.ops.aten.max.default, |
46 | 46 | torch._higher_order_ops.flex_attention, |
| 47 | + torch._higher_order_ops.inductor_compiled_code, |
47 | 48 | } |
48 | | -# Add optional ops if available (requires newer PyTorch) |
49 | | -try: |
50 | | - _op_sac_save_list.add(torch._higher_order_ops.inductor_compiled_code) |
51 | | -except AttributeError: |
52 | | - pass |
53 | | - |
54 | | -# Add DeepEP custom ops to SAC save list |
55 | | -try: |
56 | | - import torchtitan.distributed.deepep.deepep # noqa: F401 |
57 | | - _op_sac_save_list.add(torch.ops.deepep.dispatch.default) |
58 | | - _op_sac_save_list.add(torch.ops.deepep.combine.default) |
59 | | -except (ImportError, AttributeError): |
60 | | - pass |
61 | | - |
62 | 49 |
|
63 | 50 | # Adapted from llama4/infra/parallelize.py |
64 | 51 | def parallelize_deepseekv3( |
@@ -115,11 +102,17 @@ def parallelize_deepseekv3( |
115 | 102 | job_config.parallelism.expert_parallel_comm_backend == "deepep" |
116 | 103 | and not parallel_dims.ep_enabled |
117 | 104 | ): |
| 105 | + use_deepep = False |
118 | 106 | logger.warning( |
119 | 107 | "expert_parallel_comm_backend='deepep' has no effect when EP=1. " |
120 | 108 | "Using standard communication." |
121 | 109 | ) |
122 | 110 |
|
| 111 | + if use_deepep: |
| 112 | + import torchtitan.distributed.deepep.deepep # noqa: F401 |
| 113 | + _op_sac_save_list.add(torch.ops.deepep.dispatch.default) |
| 114 | + _op_sac_save_list.add(torch.ops.deepep.combine.default) |
| 115 | + |
123 | 116 | # DeepEP + ETP is not supported yet |
124 | 117 | if use_deepep and parallel_dims.etp_enabled: |
125 | 118 | raise NotImplementedError( |
|
0 commit comments