Skip to content

Conversation

@LCStayingdullCircuit
Copy link
Contributor

@LCStayingdullCircuit LCStayingdullCircuit commented Jul 23, 2025

PR Category

Execute Infrastructure

PR Types

Bug fixes

Description


修复 p_norm 内核及相关 API 的多项问题

一、 概述

本次修复主要针对 p_norm kernel在以下四个方面的问题:

  1. 无穷范数 (p=inf):反向梯度分配策略与 PyTorch 不一致。
  2. L2范数 (p=2):在 FP16 精度下,中间结果累加可能导致上溢出变为 inf
  3. 多种范数:计算过程中存在 int32 溢出风险,可能导致 error 700 或精度错误。(rebase以后发现已经被修改)
  4. 负数范数 (p=-1):反向传播的计算逻辑与 PyTorch 未对齐,导致精度问题。

下面将对每个问题的具体成因和修复方案进行详细说明。


二、 问题详情与修复方案

1. 无穷范数 (p=inf) 的反向梯度分配策略
  • 问题描述
    在计算无穷范数的反向梯度时,当输入张量中存在多个绝对值相等的最大值时,PaddlePaddle 与 PyTorch 的梯度分配策略存在差异。

    • PaddlePaddle (修复前): 将梯度 1.0 赋给所有绝对值最大的元素。
    • PyTorch (对齐目标): 将梯度 1.0 在所有绝对值最大的元素之间进行平均分配
  • 修复方案
    p_norm_grad_kernel.cu 中,参考了 amax kernel 实现。该 kernel 会统计出绝对值最大元素的个数,并在反向传播时将梯度进行平均分配,从而与 PyTorch 的行为保持一致。

2. L2范数 (p=2) 在 FP16 精度下的溢出问题
  • 问题描述
    当使用 FP16 数据类型计算 L2 范数时,ReduceAnyKernelReduceHigherDimKernel 中的累加操作使用了 FP32 的累加器 reduce_var 以保证精度。但在计算结束后,通过 Ty result = static_cast<Ty>(reduce_var); 将结果转换回 FP16 (Ty 此时为 half)。如果 reduce_var 的值超过了 FP16 的最大表示范围 (65504),result 就会上溢出为 inf

  • 修复方案
    考虑到 Reduce* kernel 的通用性,为避免影响其他模块,选择在调用层进行处理。在 p_norm_kernel.cu 中,对调用 Reduce* kernel 的模板参数进行了修改,强制要求返回类型 TyFP32,从而避免了从 FP32FP16 的溢出转换,同时对计算结果直接进行开方和强转,避免后续的问题。

3. 多种范数下的整数溢出风险
  • 问题描述
    reduce_grad_functions.h 的实现中,部分用于索引计算或计数的变量使用了 int (32位整型)。当处理超大规模的张量时,这些变量可能发生整数溢出,进而导致 error 700 或计算结果的精度错误。

  • 修复方案
    reduce_grad_functions.h 中存在溢出风险的 int 类型变量统一调整为 int64_t,确保在处理大规模数据时能够正确计算。

4. 负数范数 (p=-1) 的反向传播逻辑
  • 问题描述
    p = -1 时,p_norm 反向传播的计算逻辑与 PyTorch 未完全对齐,导致在特定场景下出现精度差异。

  • 修复方案
    p_norm_grad_kernel.cuPNormGradFunctor 函数中,对 p < 0 分支下的反向计算公式进行了修正,使其与 PyTorch 的实现逻辑对齐。

@paddle-bot
Copy link

paddle-bot bot commented Jul 23, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

lshpku
lshpku previously approved these changes Jul 23, 2025
wanghuancoder
wanghuancoder previously approved these changes Jul 23, 2025
Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@LCStayingdullCircuit
Copy link
Contributor Author

/re-run all-failed

zyfncg
zyfncg previously approved these changes Jul 25, 2025
@LCStayingdullCircuit
Copy link
Contributor Author

/re-run all-failed

@paddle-ci-bot
Copy link

paddle-ci-bot bot commented Aug 2, 2025

Sorry to inform you that 8ebf85f's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@LCStayingdullCircuit
Copy link
Contributor Author

/re-run all-failed

@LCStayingdullCircuit
Copy link
Contributor Author

/re-run all-failed

2 similar comments
@LCStayingdullCircuit
Copy link
Contributor Author

/re-run all-failed

@LCStayingdullCircuit
Copy link
Contributor Author

/re-run all-failed

@LCStayingdullCircuit
Copy link
Contributor Author

/re-run all-failed

lshpku
lshpku previously approved these changes Aug 21, 2025
@LCStayingdullCircuit
Copy link
Contributor Author

/re-run all-failed

1 similar comment
@LCStayingdullCircuit
Copy link
Contributor Author

/re-run all-failed

wanghuancoder
wanghuancoder previously approved these changes Aug 25, 2025
Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/gpu/reduce.h"

#include "paddle/fluid/framework/tensor_util.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

phi目录下不能引入fluid目录的文件,这里是否是多余引用

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

@zrr1999 zrr1999 dismissed stale reviews from wanghuancoder and lshpku via be1019d August 26, 2025 02:19
@zrr1999 zrr1999 force-pushed the bugfix/vector_norm branch from 38cc403 to be1019d Compare August 26, 2025 02:19
@swgu98 swgu98 merged commit a184716 into PaddlePaddle:develop Aug 26, 2025
136 of 147 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants