|
11 | 11 | from typing import Optional |
12 | 12 | import torch |
13 | 13 | from deepspeed import comm as dist |
14 | | -from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce |
| 14 | +from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearAllreduce, Yuan_LinearLayer, GateUpPack_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer |
15 | 15 | from deepspeed.accelerator import get_accelerator |
16 | | -from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp |
| 16 | +from .fusedqkv_utils import require_tp_fused_qkvw |
17 | 17 | from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list |
18 | 18 | import os |
19 | 19 | import ast |
| 20 | +from deepspeed.utils import groups |
| 21 | +from deepspeed.module_inject.layers import is_autotp_training_mode |
20 | 22 |
|
21 | 23 |
|
22 | 24 | def move(tensor, device, copy=True): |
@@ -367,10 +369,18 @@ def tp_parser(model): |
367 | 369 | return policy_list |
368 | 370 |
|
369 | 371 | def set_tensor_parallel_config(self, mp_size, mp_group): |
| 372 | + |
| 373 | + if is_autotp_training_mode(): |
| 374 | + self.mp_group = groups.get_tensor_model_parallel_group() |
| 375 | + self.mp_size = groups.get_tensor_model_parallel_world_size() |
| 376 | + return |
| 377 | + |
370 | 378 | self.mp_size = mp_size |
371 | 379 | self.mp_group = mp_group |
372 | 380 |
|
373 | 381 | def _replace(self, child, name, conv_linear_layer): |
| 382 | + # This function should clearly define the routing rules for specific layers |
| 383 | + # and avoid any complex shard-related logic. |
374 | 384 | if getattr(child, "replaced", False) == True: |
375 | 385 | return |
376 | 386 | device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name() |
@@ -415,80 +425,41 @@ def _replace(self, child, name, conv_linear_layer): |
415 | 425 | # For Yuan model |
416 | 426 | if 'Yuan' in str(self.module): |
417 | 427 | if 'v_proj' in name: |
418 | | - weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(), |
419 | | - dist.get_world_size(), True) |
420 | | - return LinearLayer(weight=weight, bias=bias) |
| 428 | + return Yuan_LinearLayer(child, self.mp_group) |
| 429 | + |
421 | 430 | elif 'o_proj' in name: |
422 | | - weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(), |
423 | | - dist.get_world_size(), False) |
424 | | - return LinearAllreduce(weight, bias, self.mp_group) |
425 | | - # For Arctic model, bypass to all_reduce replacement for w2 weights |
| 431 | + return Yuan_LinearAllreduce(child, self.mp_group) |
| 432 | + |
| 433 | + # For MLP including chunk layer. |
| 434 | + if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)): |
| 435 | + return GateUpPack_LinearLayer(child, self.mp_group) |
| 436 | + # For Arctic model, bypass to all_reduce replacement for w2 weights |
426 | 437 | arctic_w2_all_reduce_linear = False |
427 | 438 | if 'Arctic' in str(self.module) and 'w2' in name: |
428 | 439 | arctic_w2_all_reduce_linear = True |
429 | 440 | # For MoE MLP model, e.g., deepseek and jamba |
430 | 441 | down_proj = False |
431 | 442 | if 'down_proj' in name: |
432 | 443 | down_proj = True |
433 | | - # For MLP including chunk layer. |
434 | | - if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)): |
435 | | - weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size()) |
436 | | - return LinearLayer(weight=weight, bias=bias) |
437 | 444 | if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj: |
438 | | - # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] |
439 | | - # else [weight_shape[0], weight_shape[1] // mp_size] |
440 | 445 |
|
| 446 | + setattr(child, "replaced", True) |
441 | 447 | if self.conv_linear_layer: |
442 | | - child.weight.data = child.weight.data.transpose(-1, -2).contiguous() |
443 | | - data = child.weight.data.split(get_shard_size_list( |
444 | | - weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name), |
445 | | - dim=1) |
446 | | - data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach() |
447 | | - del data |
| 448 | + return Conv_LinearALlreduce(child, self.mp_group, name=name) |
| 449 | + elif name == "lm_head" or name == 'embed_out': |
| 450 | + return LmHeadLinearAllreduce(child, self.mp_group) |
448 | 451 |
|
449 | | - setattr(child, "replaced", True) |
450 | | - if name == "lm_head" or name == 'embed_out': |
451 | | - return LmHeadLinearAllreduce( |
452 | | - torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(), |
453 | | - child.bias if child.bias is None else torch.nn.parameter.Parameter( |
454 | | - move(child.bias, device_name, return_new_copy)), self.mp_group) |
455 | | - return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \ |
456 | | - torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group) |
| 452 | + return LinearAllreduce(child, self.mp_group, name=name) |
457 | 453 | else: |
458 | 454 |
|
459 | | - # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] |
460 | | - # else [weight_shape[0] // mp_size, weight_shape[1]] |
| 455 | + setattr(child, "replaced", True) |
461 | 456 | if self.conv_linear_layer: |
462 | | - child.weight.data = child.weight.data.transpose(-1, -2).contiguous() |
463 | | - |
464 | | - if require_tp_fused_qkvw(name, self.mp_size): |
| 457 | + conv_LinearLayer(child, self.mp_group) |
| 458 | + elif require_tp_fused_qkvw(name, self.mp_size): |
465 | 459 | #Check and handle fused qkv for TP |
466 | | - #The copy is a regular copy, The shape of dst and src is the same |
467 | | - data_dc = move( |
468 | | - prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index), |
469 | | - device_name, return_new_copy) |
470 | | - |
471 | | - bias_data_dc = None if child.bias is None else move( |
472 | | - prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index), |
473 | | - device_name, return_new_copy) |
474 | | - else: |
475 | | - data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name), |
476 | | - dim=1 if self.conv_linear_layer else 0) |
477 | | - data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach() |
478 | | - del data |
479 | | - |
480 | | - if child.bias is not None: |
481 | | - bias_data = child.bias.data.split(get_shard_size_list( |
482 | | - weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name), |
483 | | - dim=0) |
484 | | - bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy) |
485 | | - bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False) |
486 | | - del bias_data |
487 | | - else: |
488 | | - bias_data_dc = None |
| 460 | + return fused_LinearLayer(child, self.mp_group, fused_module=self.module) |
489 | 461 |
|
490 | | - setattr(child, "replaced", True) |
491 | | - return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc, requires_grad=False), bias=bias_data_dc) |
| 462 | + return LinearLayer(child, self.mp_group, name=name) |
492 | 463 |
|
493 | 464 | def _slice_embedding(self, child, name, conv_linear_layer): |
494 | 465 | if getattr(child, "replaced", False) == True: |
|
0 commit comments