Skip to content

Commit f583d30

Browse files
inkcherrygyou2021
authored andcommitted
autotp training(fix dco) (deepspeedai#7004)
Same as [this PR](deepspeedai#6922). [affeb88](deepspeedai@affeb88) I noticed the CI updated the DCO check recently. Using the suggested rebase method for sign-off would reintroduce many conflicts, so I opted for a squash merge with sign-off instead. thanks: ) Signed-off-by: inkcherry <[email protected]> Signed-off-by: gyou2021 <[email protected]>
1 parent a531625 commit f583d30

File tree

17 files changed

+1662
-164
lines changed

17 files changed

+1662
-164
lines changed

deepspeed/__init__.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError
3838
from .runtime.activation_checkpointing import checkpointing
3939
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
40-
from .module_inject import replace_transformer_layer, revert_transformer_layer
40+
from .module_inject import replace_transformer_layer, revert_transformer_layer, set_autotp_mode
4141

4242
from .utils import log_dist, OnDevice, logger
4343
from .comm.comm import init_distributed
@@ -364,3 +364,34 @@ def init_inference(model, config=None, **kwargs):
364364
engine = InferenceEngine(model, config=ds_inference_config)
365365

366366
return engine
367+
368+
369+
def tp_model_init(model, tp_size, dtype):
370+
"""
371+
Initialize the model for tensor parallelism.
372+
373+
Args:
374+
model (torch.nn.Module): The model to be initialized.
375+
tp_size (int): The tensor parallelism size.
376+
dtype (torch.dtype): The data type to be used for the model.
377+
378+
Returns:
379+
torch.nn.Module: The initialized model with tensor parallelism.
380+
"""
381+
# avoid re-entry
382+
assert not hasattr(
383+
model, 'ds_autotp_parsed'), "ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed."
384+
385+
set_autotp_mode(training=True)
386+
387+
from deepspeed.runtime.tensor_parallel import TpTrainingManager
388+
# The expected usage here is for it to be invoked by transformers package.
389+
390+
#TODO: We should provide a custom TP mapping solution without using autoTP
391+
#as modifying the autoTP logic may be more difficult for users compared to configuring it
392+
393+
model = TpTrainingManager(model=model, tp_size=tp_size, dtype=dtype).module
394+
395+
setattr(model, 'ds_autotp_parsed', True)
396+
397+
return model

deepspeed/comm/comm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,12 @@ def broadcast(tensor, src, group=None, async_op=False, prof=False, log_name='bro
224224
return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
225225

226226

227+
@timed_op
228+
def broadcast_object_list(object_list, src, group=None, device=None):
229+
global cdb
230+
return cdb.broadcast_object_list(object_list=object_list, src=src, group=group, device=device)
231+
232+
227233
@timed_op
228234
def all_gather(tensor_list,
229235
tensor,

deepspeed/comm/torch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,10 @@ def broadcast(self, tensor, src, group=None, async_op=False):
205205
else:
206206
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
207207

208+
@disable_compiler_collective
209+
def broadcast_object_list(self, object_list, src, group=None, device=None):
210+
return torch.distributed.broadcast_object_list(object_list=object_list, src=src, group=group, device=device)
211+
208212
@disable_compiler_collective
209213
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
210214
if DS_COMM_ALL_GATHER_OFF:

deepspeed/inference/engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
1616
from deepspeed.utils.timer import SynchronizedWallClockTimer
1717
from deepspeed.runtime.compiler import is_compile_supported
18-
1918
from ..runtime.state_dict_factory import SDLoaderFactory
2019
from ..runtime.weight_quantizer import WeightQuantization
2120
from ..module_inject import replace_transformer_layer, generic_injection

deepspeed/module_inject/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection
77
from .module_quantize import quantize_transformer_layer
88
from .replace_policy import HFBertLayerPolicy
9-
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize
9+
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode
1010
from .policy import DSPolicy

deepspeed/module_inject/auto_tp.py

Lines changed: 30 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
from typing import Optional
1212
import torch
1313
from deepspeed import comm as dist
14-
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
14+
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearAllreduce, Yuan_LinearLayer, GateUpPack_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer
1515
from deepspeed.accelerator import get_accelerator
16-
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp
16+
from .fusedqkv_utils import require_tp_fused_qkvw
1717
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
1818
import os
1919
import ast
20+
from deepspeed.utils import groups
21+
from deepspeed.module_inject.layers import is_autotp_training_mode
2022

2123

2224
def move(tensor, device, copy=True):
@@ -367,10 +369,18 @@ def tp_parser(model):
367369
return policy_list
368370

369371
def set_tensor_parallel_config(self, mp_size, mp_group):
372+
373+
if is_autotp_training_mode():
374+
self.mp_group = groups.get_tensor_model_parallel_group()
375+
self.mp_size = groups.get_tensor_model_parallel_world_size()
376+
return
377+
370378
self.mp_size = mp_size
371379
self.mp_group = mp_group
372380

373381
def _replace(self, child, name, conv_linear_layer):
382+
# This function should clearly define the routing rules for specific layers
383+
# and avoid any complex shard-related logic.
374384
if getattr(child, "replaced", False) == True:
375385
return
376386
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
@@ -415,80 +425,41 @@ def _replace(self, child, name, conv_linear_layer):
415425
# For Yuan model
416426
if 'Yuan' in str(self.module):
417427
if 'v_proj' in name:
418-
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
419-
dist.get_world_size(), True)
420-
return LinearLayer(weight=weight, bias=bias)
428+
return Yuan_LinearLayer(child, self.mp_group)
429+
421430
elif 'o_proj' in name:
422-
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
423-
dist.get_world_size(), False)
424-
return LinearAllreduce(weight, bias, self.mp_group)
425-
# For Arctic model, bypass to all_reduce replacement for w2 weights
431+
return Yuan_LinearAllreduce(child, self.mp_group)
432+
433+
# For MLP including chunk layer.
434+
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
435+
return GateUpPack_LinearLayer(child, self.mp_group)
436+
# For Arctic model, bypass to all_reduce replacement for w2 weights
426437
arctic_w2_all_reduce_linear = False
427438
if 'Arctic' in str(self.module) and 'w2' in name:
428439
arctic_w2_all_reduce_linear = True
429440
# For MoE MLP model, e.g., deepseek and jamba
430441
down_proj = False
431442
if 'down_proj' in name:
432443
down_proj = True
433-
# For MLP including chunk layer.
434-
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
435-
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
436-
return LinearLayer(weight=weight, bias=bias)
437444
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj:
438-
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
439-
# else [weight_shape[0], weight_shape[1] // mp_size]
440445

446+
setattr(child, "replaced", True)
441447
if self.conv_linear_layer:
442-
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
443-
data = child.weight.data.split(get_shard_size_list(
444-
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
445-
dim=1)
446-
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
447-
del data
448+
return Conv_LinearALlreduce(child, self.mp_group, name=name)
449+
elif name == "lm_head" or name == 'embed_out':
450+
return LmHeadLinearAllreduce(child, self.mp_group)
448451

449-
setattr(child, "replaced", True)
450-
if name == "lm_head" or name == 'embed_out':
451-
return LmHeadLinearAllreduce(
452-
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
453-
child.bias if child.bias is None else torch.nn.parameter.Parameter(
454-
move(child.bias, device_name, return_new_copy)), self.mp_group)
455-
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
456-
torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group)
452+
return LinearAllreduce(child, self.mp_group, name=name)
457453
else:
458454

459-
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
460-
# else [weight_shape[0] // mp_size, weight_shape[1]]
455+
setattr(child, "replaced", True)
461456
if self.conv_linear_layer:
462-
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
463-
464-
if require_tp_fused_qkvw(name, self.mp_size):
457+
conv_LinearLayer(child, self.mp_group)
458+
elif require_tp_fused_qkvw(name, self.mp_size):
465459
#Check and handle fused qkv for TP
466-
#The copy is a regular copy, The shape of dst and src is the same
467-
data_dc = move(
468-
prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index),
469-
device_name, return_new_copy)
470-
471-
bias_data_dc = None if child.bias is None else move(
472-
prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index),
473-
device_name, return_new_copy)
474-
else:
475-
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
476-
dim=1 if self.conv_linear_layer else 0)
477-
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
478-
del data
479-
480-
if child.bias is not None:
481-
bias_data = child.bias.data.split(get_shard_size_list(
482-
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
483-
dim=0)
484-
bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy)
485-
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
486-
del bias_data
487-
else:
488-
bias_data_dc = None
460+
return fused_LinearLayer(child, self.mp_group, fused_module=self.module)
489461

490-
setattr(child, "replaced", True)
491-
return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc, requires_grad=False), bias=bias_data_dc)
462+
return LinearLayer(child, self.mp_group, name=name)
492463

493464
def _slice_embedding(self, child, name, conv_linear_layer):
494465
if getattr(child, "replaced", False) == True:

0 commit comments

Comments
 (0)