Skip to content

Commit 2cf3ab3

Browse files
geetu040kmehant
authored andcommitted
🚨🚨🚨 Fix sdpa in SAM and refactor relative position embeddings (huggingface#36422)
* fall back to eager if output_attentions * improve relative position embeddings * run modular on got_ocr2 * run-slow: sam * fix run-length encoding * fix tf processor errors * update tf_sam * fix compile error * re-run tests Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 9e94801 commit 2cf3ab3

File tree

3 files changed

+90
-4
lines changed

3 files changed

+90
-4
lines changed

src/transformers/integrations/tensor_parallel.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,76 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
417417
parameter = DTensor.from_local(parameter, device_mesh, [Shard(-1)], run_check=False)
418418
return nn.Parameter(parameter)
419419

420+
class ReplicateParallel(TensorParallelLayer):
421+
"""
422+
Replicate a nn.Module.
423+
Users can compose it together with other parallel styles like RowwiseParallel to achieve a fully distributed model.
424+
Fully distributed model is needed for gradient clipping.
425+
426+
Keyword Args:
427+
input_layouts (Placement, optional):
428+
The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
429+
become a DTensor. If not specified, we assume the input tensor to be replicated.
430+
output_layouts (Placement, optional):
431+
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
432+
with the user desired layout. If not specified, we assume the output tensor to be replicated.
433+
use_local_output (bool, optional):
434+
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
435+
Returns:
436+
A :class:`ParallelStyle` object that represents replication of nn.Module.
437+
438+
Example::
439+
>>> # xdoctest: +SKIP(failing)
440+
>>> from torch.distributed.tensor.parallel import parallelize_module, ReplicateParallel
441+
>>> from torch.distributed.device_mesh import init_device_mesh
442+
>>> ...
443+
>>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule
444+
>>> tp_mesh = init_device_mesh("cuda", (8,))
445+
>>>
446+
>>> # By default, the input and output of the "w1" Linear will be converted to Replicated DTensor
447+
>>>
448+
>>> replicated_mod = parallelize_module(m, tp_mesh, {"w1": ReplicateParallel()})
449+
>>> ...
450+
451+
"""
452+
453+
454+
def __init__(
455+
self,
456+
*,
457+
input_layouts: Optional[Placement] = None,
458+
output_layouts: Optional[Placement] = None,
459+
use_local_output: bool = True,
460+
use_dtensor=True,
461+
):
462+
463+
super().__init__()
464+
self.input_layouts = (input_layouts or Replicate(),)
465+
self.output_layouts = (output_layouts or Replicate(),)
466+
self.desired_input_layouts = (Replicate(),)
467+
self.use_local_output = use_local_output
468+
self.use_dtensor = use_dtensor
469+
470+
@staticmethod
471+
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
472+
# since nn.Linear and nn.Embedding have single input
473+
# we may extend support to other modules since its replicate.
474+
input_tensor = inputs[0]
475+
if isinstance(input_tensor, torch.distributed._functional_collectives.AsyncCollectiveTensor):
476+
input_tensor = input_tensor.trigger_wait()
477+
if not isinstance(input_tensor, DTensor):
478+
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
479+
480+
if input_layouts != desired_input_layouts:
481+
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
482+
return input_tensor
483+
484+
485+
@staticmethod
486+
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
487+
if outputs.placements != output_layouts:
488+
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
489+
return outputs.to_local() if use_local_output else outputs
420490

421491
SUPPORTED_TP_STYLES = {
422492
"colwise",
@@ -428,6 +498,8 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
428498
"local",
429499
"gather",
430500
"local_packed_rowwise",
501+
"replicate",
502+
"replicate_output_dtensor"
431503
}
432504

433505

@@ -459,6 +531,10 @@ def translate_to_torch_parallel_style(style: str):
459531
return GatherParallel()
460532
elif style == "local_packed_rowwise":
461533
return PackedRowwiseParallel(use_dtensor=False)
534+
elif style == "replicate":
535+
return ReplicateParallel()
536+
elif style == "replicate_output_dtensor":
537+
return ReplicateParallel(use_local_output=False)
462538
else:
463539
raise ValueError(f"Unsupported parallel style value: {style}")
464540

src/transformers/models/granite/configuration_granite.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,14 @@ class GraniteConfig(PretrainedConfig):
117117
"layers.*.self_attn.q_proj": "colwise",
118118
"layers.*.self_attn.k_proj": "colwise",
119119
"layers.*.self_attn.v_proj": "colwise",
120-
"layers.*.self_attn.o_proj": "rowwise",
120+
"layers.*.self_attn.o_proj": "rowwise_output_dtensor",
121121
"layers.*.mlp.gate_proj": "colwise",
122122
"layers.*.mlp.up_proj": "colwise",
123-
"layers.*.mlp.down_proj": "rowwise",
123+
"layers.*.mlp.down_proj": "rowwise_output_dtensor",
124+
"embed_tokens": "replicate_output_dtensor",
125+
"layers.*.post_attention_layernorm": "replicate_output_dtensor",
126+
"layers.*.input_layernorm": "replicate_output_dtensor",
127+
"norm": "replicate_output_dtensor",
124128
}
125129
base_model_pp_plan = {
126130
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

src/transformers/trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@
235235
AutocastKwargs,
236236
DistributedDataParallelKwargs,
237237
DistributedType,
238+
TorchTensorParallelPlugin,
238239
load_fsdp_model,
239240
load_fsdp_optimizer,
240241
save_fsdp_model,
@@ -2317,7 +2318,9 @@ def _inner_training_loop(
23172318
else:
23182319
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
23192320

2320-
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
2321+
delay_optimizer_creation = (
2322+
is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled or self.is_tp_enabled
2323+
)
23212324

23222325
# We need to reset the scheduler, as its parameters may be different on subsequent calls
23232326
if self._created_lr_scheduler:
@@ -2372,7 +2375,10 @@ def _inner_training_loop(
23722375
if self.use_apex:
23732376
model = self.accelerator.prepare(self.model)
23742377
else:
2375-
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
2378+
if delay_optimizer_creation:
2379+
self.optimizer = self.accelerator.prepare(self.optimizer)
2380+
else:
2381+
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
23762382
else:
23772383
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
23782384
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(

0 commit comments

Comments
 (0)