Skip to content

Fix paddle.incubate.nn.functional.fused_rms_norm big Tensor #74055

Merged
wanghuancoder merged 5 commits intoPaddlePaddle:developfrom
xingmingyyj:fix_rms_norm_kernel
Jul 21, 2025
Merged

Fix paddle.incubate.nn.functional.fused_rms_norm big Tensor #74055
wanghuancoder merged 5 commits intoPaddlePaddle:developfrom
xingmingyyj:fix_rms_norm_kernel

Conversation

@xingmingyyj
Copy link
Contributor

@xingmingyyj xingmingyyj commented Jul 15, 2025

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

  • 修复fused_rms_norm 大Tensor 访存越界问题
  • fused_rms_norm的kernel计算流程是将每一行加载到shared mem中,所以当col的取值过大时会导致kernel launch失败,在kernel中未做强制检查,导致kernel未launch就直接退出,输出结果变为全0。这里补充检查。
  • 当输入数据类型为float16时,fused_rms_norm中会将数据cast成float32参与norm计算,以提升精度。在float16下,可以和下面的torch实现对齐精度。
class RMSNormFunction(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        x: torch.Tensor,
        norm_weight: torch.Tensor,
        norm_bias: Optional[torch.Tensor],
        epsilon: float,
        begin_norm_axis: int,
        bias: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        quant_scale: float = -1.0,
        quant_round_type: int = 0,
        quant_max_bound: float = 0.0,
        quant_min_bound: float = 0.0,
    ) -> torch.Tensor:
        """Forward pass of RMSNorm."""
        def _flatten_from_axis(tensor: torch.Tensor, axis: int) -> torch.Tensor:
            """Flatten tensor starting from given axis."""
            rows = torch.prod(torch.tensor(tensor.shape[:axis])).item()
            cols = torch.prod(torch.tensor(tensor.shape[axis:])).item()
            return tensor.reshape(rows, cols)

        # Save original dtype and shape for later
        origin_dtype = x.dtype
        origin_shape = x.shape

        # Convert inputs to float32 if needed
        x = x.float() if x.dtype == torch.float16 else x
        norm_weight = norm_weight.float() if norm_weight.dtype == torch.float16 else norm_weight
        residual = residual.float() if residual is not None and residual.dtype == torch.float16 else residual
        bias = bias.float() if bias is not None and bias.dtype == torch.float16 else bias
        norm_bias = norm_bias.float() if norm_bias is not None and norm_bias.dtype == torch.float16 else norm_bias

        # Apply residual and bias if provided
        output = x
        if residual is not None:
            output = output + residual
        if bias is not None:
            output = output + bias

        # Normalization
        output = _flatten_from_axis(output, begin_norm_axis)
        output_sq = output.pow(2)
        mean_output_sq = output_sq.mean(dim=-1, keepdim=True)
        rms = torch.sqrt(mean_output_sq + epsilon)
        invvar = 1.0 / rms
        output_norm = output * invvar
        output = output_norm * norm_weight

        # Add norm_bias if provided
        if norm_bias is not None:
            output = output + _flatten_from_axis(norm_bias, begin_norm_axis)

        # Quantization if enabled
        if quant_scale > 0:
            output = output / quant_scale
            if quant_round_type == 0:
                output = torch.round(output)
            elif quant_round_type == 1:
                output = torch.where(
                    output >= 0,
                    torch.ceil(output - 0.5),
                    torch.floor(output + 0.5),
                )
            else:
                raise ValueError(f"Unsupported quant_round_type: {quant_round_type}")
            output = output * quant_scale
            output = torch.clamp(output, min=quant_min_bound, max=quant_max_bound)

        # Convert back to original dtype if no quantization
        if origin_dtype == torch.float16 and quant_scale <= 0:
            output = output.to(origin_dtype)
            norm_weight = norm_weight.to(origin_dtype)

        # Save tensors and metadata for backward
        ctx.save_for_backward(x, norm_weight, invvar)
        ctx.epsilon = epsilon
        ctx.exist_residual = residual is not None
        ctx.exist_bias = bias is not None
        ctx.exist_norm_bias = norm_bias is not None
        ctx.quant_scale = quant_scale
        ctx.begin_norm_axis = begin_norm_axis
        ctx.origin_shape = origin_shape

        return output

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
        """Backward pass of RMSNorm."""
        def _flatten_from_axis(tensor: torch.Tensor, axis: int) -> torch.Tensor:
            """Flatten tensor starting from given axis."""
            rows = torch.prod(torch.tensor(tensor.shape[:axis])).item()
            cols = torch.prod(torch.tensor(tensor.shape[axis:])).item()
            return tensor.reshape(rows, cols)

        exist_bias = ctx.exist_bias
        exist_residual = ctx.exist_residual
        exist_norm_bias = ctx.exist_norm_bias
        quant_scale = ctx.quant_scale
    
        if quant_scale > 0 or exist_norm_bias or exist_bias or exist_residual:
            raise NotImplementedError

        # Retrieve saved tensors and metadata
        x, weight, invvar = ctx.saved_tensors
        origin_shape = ctx.origin_shape
        origin_dtype = grad_output.dtype

        # Flatten tensors for computation
        grad_output = _flatten_from_axis(grad_output.float(), ctx.begin_norm_axis)
        x = _flatten_from_axis(x.float(), ctx.begin_norm_axis)
        weight = weight.float()

        # Gradient w.r.t. weight (gamma)
        x_norm = x * invvar
        grad_weight = (grad_output * x_norm).sum(dim=tuple(range(grad_output.dim() - 1)), keepdim=False)
        grad_weight = grad_weight.to(origin_dtype)

        # Gradient w.r.t. input (x)
        D = x.size(-1)
        S = (grad_output * weight * x * invvar).sum(dim=1, keepdim=True)
        term1 = invvar / D
        grad_x = (D * grad_output * weight - x * invvar * S) * term1
        grad_x = grad_x.to(origin_dtype).reshape(origin_shape)

        # Return gradients (order matches forward inputs)
        return (
            grad_x,                   # x
            grad_weight,              # norm_weight
            None,                     # norm_bias
            None,                     # epsilon
            None,                     # begin_norm_axis
            None,                     # bias
            None,                     # residual
            None,                     # quant_scale
            None,                     # quant_round_type
            None,                     # quant_max_bound
            None,                     # quant_min_bound
        )

Pcard-73263

@paddle-bot
Copy link

paddle-bot bot commented Jul 15, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@xingmingyyj xingmingyyj force-pushed the fix_rms_norm_kernel branch 2 times, most recently from 23c2faa to 046eee7 Compare July 15, 2025 13:56
@xingmingyyj xingmingyyj force-pushed the fix_rms_norm_kernel branch from 046eee7 to b99d0af Compare July 15, 2025 14:14
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
const int64_t read_idx = threadIdx.y * blockDim.x + threadIdx.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const int64_t read_idx = static_cast<int64_t>(threadIdx.y) * blockDim.x + threadIdx.x;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@xingmingyyj
Copy link
Contributor Author

/re-run all-failed

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wanghuancoder wanghuancoder merged commit 971eac1 into PaddlePaddle:develop Jul 21, 2025
72 of 73 checks passed
@xingmingyyj xingmingyyj deleted the fix_rms_norm_kernel branch July 30, 2025 02:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants