Skip to content

[PHI] Align paddle.inner with torch in matmul logic#72843

Merged
lshpku merged 1 commit intoPaddlePaddle:developfrom
lshpku:fix-inner-op
May 22, 2025
Merged

[PHI] Align paddle.inner with torch in matmul logic#72843
lshpku merged 1 commit intoPaddlePaddle:developfrom
lshpku:fix-inner-op

Conversation

@lshpku
Copy link
Contributor

@lshpku lshpku commented May 21, 2025

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

对齐paddle.inner(x, y)与torch在matmul上的调用逻辑

原理

paddle动态图原来是先调用 y.transpose(),再调用 matmul,这会导致inner的执行变成2个kernel:transpose(y) + matmul(x, yT):

void phi::funcs::TilingSwapDim1And2<phi::dtype::float16, (int)256, (int)32, (int)32, long>
void cutlass::Kernel<cutlass_75_wmma_tensorop_s161616gemm_f16_32x32_128x2_nn_align1>

本PR将其改为直接调用 matmul(x, y, transpose_y=True),这样底层的分发逻辑就会只执行1个kernel(外加一个小reduce
kernel):

void cutlass::Kernel<cutlass_75_wmma_tensorop_s161616gemm_f16_32x32_128x2_tn_align1>
void splitKreduce_kernel<(int)32, (int)16, int, float, __half, float, __half, (bool)1, (bool)0, ...>

注意到,前一个cutlass的签名是nn,后一个是tn

参考torch

  1. torch.inner(x, y) 先调用了 tensordot(x, y, -1, -1):https://github.com/pytorch/pytorch/blob/afd7a13bca2ce518b7c32f868c8dba610a538e22/aten/src/ATen/native/LinearAlgebra.cpp#L1308
  2. 然后 tensordot(x, y) 调用 y.permute(),再进一步调用 mm(x, yT):https://github.com/pytorch/pytorch/blob/afd7a13bca2ce518b7c32f868c8dba610a538e22/aten/src/ATen/native/Linear.cpp#L764
  3. tensordot虽然也对y进行permute了,但torch的permute是lazy的,实际上调用到cutlass层是只调用了一个matmul kernel

对齐情况

本PR可以做到与torch完全对齐(结果完全相等,nsys trace完全相同)
测试范围:x[M, C],y[N, C];其中 1 <= N, M <= 228,1 <= C <= 230(在显存能装下的范围内)

性能比原先的 transpose + matmul 版本提升一倍左右,因为省了transpose的成本

另外,静态图无需修改,因为静态图会自动把 matmul_v2(x, y.T) 变成 matmul(x, y, transpose_y: true)


Pcard-85711

@paddle-bot
Copy link

paddle-bot bot commented May 21, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Please upload report for BASE (develop@7869937). Learn more about missing BASE report.

Additional details and impacted files
@@             Coverage Diff             @@
##             develop    #72843   +/-   ##
===========================================
  Coverage           ?   100.00%           
===========================================
  Files              ?         1           
  Lines              ?         1           
  Branches           ?         0           
===========================================
  Hits               ?         1           
  Misses             ?         0           
  Partials           ?         0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@lshpku lshpku merged commit 1d581f2 into PaddlePaddle:develop May 22, 2025
55 of 56 checks passed
wanghuancoder pushed a commit to wanghuancoder/Paddle that referenced this pull request May 27, 2025
wanghuancoder added a commit that referenced this pull request Jun 3, 2025
* refine forrange (#72360)

* refine forrange

* refine forrange

* reduce support big tensor (#71970)

* reduce support big tensor

* [PHI] Fix gridDim limit for reduce kernel (#72507)

* [API] isclose support bigtensor (#72516)

* isclose support bigtensor

* refine

* [API] isnan isinf isfinite support bigtensor (#72517)

* isnan isinf isfinite support bigtensor

* refine

* [PHI] Fix cum kernel for big tensor (#72562)

* [PHI] Preliminary fix for elementwise broadcast int32 shape overflow (#72584)

* [PHI] Align linalg.solve kernel with torch (#72608)

* Update strided copy kernel (#72662)

* [PHI] Fix grid sample kernel for big tensor (#72628)

* [PHI] Fix argsort big tensor bug (#72712)

* [PHI] Fixed argsort big tensor bug

* [PHI] Fixed shape mismatch problem.

* [PHI] Fix contiguous kernel for big tensor (#72705)

* [PHI] Fix flatten and split kernel for big tensor (#72634)

* [PHI] Fix out-of-bound issue of paddle.take_along_axis (#72757)

* [PHI] fix paddle.diag with big tensor (#72638)

* [API] fix paddle.cross with big tensor (#72652)

* [PHI] Fix paddle.where api for big tensor (#72717)

* [PHI] Fix bincount kernel for big tensor (#72706)

* fix bincount kernel for big tensor

* use HostAlloc to alloc memory

* add cpu test case

* [PHI] Fix full_like kernel for big tensor (#72831)

* [API] Fix int overflow and float16 support for paddle.frac (#72815)

* [PHI] Align paddle.inner with torch in matmul logic (#72843)

* [PHI] Fix paddle.var & paddle.std float16 overflow (#72650)

* [PHI] Fix logsumexp precision problem (#72681)

* [PHI] Debug for logsumexp, bug source found

* [PHI] Removed GetNumBlocks func to get correct logsumexp

* [PHI] Removed redundant debug VLOG

* [PHI] Elegant grid bounded solution

* [Accuracy diff No.55-56、76-77] Fix accuracy diff for var&std API (#72879)

* [Accuracy diff No.21] Fix accuracy diff for heaviside API (#72894)

---------

Co-authored-by: Shuhao Liang <[email protected]>
Co-authored-by: Qianyue He <[email protected]>
Co-authored-by: Lei Ding <[email protected]>
Co-authored-by: ggggxm <[email protected]>
Co-authored-by: xkkkkkk23 <[email protected]>
Co-authored-by: Zx <[email protected]>
Co-authored-by: huangjiyi <[email protected]>
Co-authored-by: ooo oo <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants