diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index ba2f4fb2cc016d..3c79671e6d9a77 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -200,9 +200,12 @@ def _share_param_buffer(self): stop_gradient = self._param.stop_gradient self._param.stop_gradient = True self._param.flatten_() - self._param_buffer[ - self._index : self._index + self._param._numel() - ] = self._param + paddle.assign( + self._param, + self._param_buffer._slice( + self._index, self._index + self._param._numel() + ), + ) self._param.get_tensor()._set_dims(param_shape) self._param.stop_gradient = stop_gradient self._param_buffer._slice(