Skip to content

Commit 6c55b12

Browse files
authored
Merge pull request #29 from huggingface/fix_ds
Fix deepspeed
2 parents e030193 + 6c0effa commit 6c55b12

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

src/transformers/integrations/mxfp4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, **
331331
else:
332332
setattr(module, param_name.rsplit(".", 1)[1], param_value)
333333
dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr))
334-
dequantized = dequantized.transpose(1, 2).to(target_device)
334+
dequantized = dequantized.transpose(1, 2).contiguous().to(target_device)
335335
setattr(module, proj, torch.nn.Parameter(dequantized))
336336
delattr(module, blocks_attr)
337337
delattr(module, scales_attr)

src/transformers/modeling_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,8 +887,12 @@ def _load_state_dict_into_meta_model(
887887
# and then cast it to CPU to avoid excessive memory usage on each GPU
888888
# in comparison to the sharded model across GPUs.
889889
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
890+
param_name = hf_quantizer.update_param_name(param_name)
890891
module, param_type = get_module_from_name(model, param_name)
891892
value = getattr(module, param_type)
893+
# special case for OpenAIMoeForCausalLM, we wait for the param to be leave the meta device before casting it to cpu
894+
if model.__class__.__name__ == "OpenAIMoeForCausalLM" and value.device.type == "meta":
895+
continue
892896
param_to = "cpu"
893897
if is_fsdp_enabled() and not is_local_dist_rank_0():
894898
param_to = "meta"
@@ -5124,8 +5128,8 @@ def _assign_original_dtype(module):
51245128
dispatch_model(model, **device_map_kwargs)
51255129

51265130
if hf_quantizer is not None:
5127-
hf_quantizer.postprocess_model(model, config=config)
51285131
model.hf_quantizer = hf_quantizer
5132+
hf_quantizer.postprocess_model(model, config=config)
51295133

51305134
if _adapter_model_path is not None:
51315135
adapter_kwargs["key_mapping"] = key_mapping

0 commit comments

Comments
 (0)