Skip to content

Conversation

@zrr1999
Copy link
Member

@zrr1999 zrr1999 commented Sep 11, 2025

PR Category

Operator Mechanism

PR Types

Performance

Description

import torch
import paddle

dtype = torch.float32
start = torch.tensor(1.0)
stop = torch.tensor(-1.0)

# paddle
dx = ((stop - start).to(dtype=torch.float64) / 28)
(start + dx*6).to(dtype).tolist()

# torch
dx = ((stop.to(dtype=dtype) - stop.to(dtype=dtype)) / 28)
(start + dx.to(dtype)*6).tolist()

受影响的 API:

  • paddle.linspace 对齐全部case(共 85 个)

修改内容

  • 计算 step 时,始终将 start和end转换为 double 再进行计算,以确保获取到最高精度的step。
  • 改为先对stop和start进行cast在计算,可以避免溢出。
  • 当 tensor的dtype为整数时且为GPU kernel时,将 step type 转换为float (原本为double)。
  • 当 tensor的dtype为浮点数时,将 step type 转换为 tensor 的 dtype进行后续计算。

其他说明

GPU Kernel 在T为整数时,使用float比double性能更好,可以对齐PyTorch。

  • paddle float32 forward 5.1293089389801025
  • paddle float64 forward 8.464366674423218

性能测试代码

import paddle
import time

test_loop = 2**10

paddle.linspace(
    -1,
    1,
    2 ** 30,
    dtype="float32",
)
paddle.linspace(
    -1,
    1,
    2 ** 30,
    dtype="float64",
)
with paddle.no_grad():
    paddle.base.core._cuda_synchronize(paddle.CUDAPlace(0))
    start = time.time()
    for i in range(test_loop):
        paddle.linspace(
            -1,
            1,
            2 ** 30,
            dtype="float32",
        )
    start2 = time.time()
    paddle.base.core._cuda_synchronize(paddle.CUDAPlace(0))
    end = time.time()
    timeused = end - start
    print("paddle float32 forward", timeused)

with paddle.no_grad():
    paddle.base.core._cuda_synchronize(paddle.CUDAPlace(0))
    start = time.time()
    for i in range(test_loop):
        paddle.linspace(
            -1,
            1,
            2** 30,
            dtype="float64",
        )
    start2 = time.time()
    paddle.base.core._cuda_synchronize(paddle.CUDAPlace(0))
    end = time.time()
    timeused = end - start
    print("paddle float64 forward", timeused)
    print("end - start2", end - start2)

TODO:

  • GPU kernel 有性能提升空间,一个线程可以计算多个值,降低调度开销。

Pcard-67164

@paddle-bot
Copy link

paddle-bot bot commented Sep 11, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@codecov-commenter
Copy link

codecov-commenter commented Sep 11, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@7cc27eb). Learn more about missing BASE report.

Additional details and impacted files
@@             Coverage Diff             @@
##             develop    #75238   +/-   ##
===========================================
  Coverage           ?   100.00%           
===========================================
  Files              ?         2           
  Lines              ?         5           
  Branches           ?         0           
===========================================
  Hits               ?         5           
  Misses             ?         0           
  Partials           ?         0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@zrr1999 zrr1999 changed the title align LinspaceKernel LinspaceKernel uses the dtype of 'self' as the type of 'step' when tensor is floating Sep 12, 2025
Copy link
Contributor

@A-nnonymous A-nnonymous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need some modifications

double step = (static_cast<double>(stop_data - start_data)) / (num - 1);
// step should be of StepT type
StepT step =
(static_cast<StepT>(stop_data) - static_cast<StepT>(start_data)) /
Copy link
Contributor

@A-nnonymous A-nnonymous Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在StepT为整形时,我们是不是应该添加一个检查来确认:stop-start后的值之间,是否有num-1个有效整数值?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的StepT不应该为整数,如果T是整数,应该通过 using StepT = std::conditional_t<std::is_integral_v, double, T>;转为double

}

template <typename T, typename StepT>
__global__ void LinspaceKernelInnerForInt(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不建议使用两个kernel名字来管理这个linspace功能,可以考虑使用C++的模版特化,使用统一的名称来管理两个kernel

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

bool isIntegral =
(t == DataType::UINT8 || t == DataType::INT8 || t == DataType::UINT16 ||
t == DataType::INT16 || t == DataType::UINT32 || t == DataType::INT32 ||
t == DataType::UINT64 || t == DataType::INT64);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认一下paddle是否支持超长整型,比如INT128等

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前kernel还没有找到int128相关的内容,后面如果添加了再修改这里

@zrr1999 zrr1999 requested a review from A-nnonymous September 15, 2025 03:15
Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@A-nnonymous A-nnonymous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, except performance for now

@zrr1999 zrr1999 merged commit 8e3d549 into PaddlePaddle:develop Sep 17, 2025
147 of 154 checks passed
@zrr1999 zrr1999 deleted the acc/linspace branch September 17, 2025 09:10
zhengshengning pushed a commit to zhengshengning/Paddle that referenced this pull request Oct 24, 2025
…nsor is floating (PaddlePaddle#75238)

* align LinspaceKernel

* update meta

* update gpu kernel

* fix LinspaceKernelInner

* improve kernel
zhengshengning pushed a commit to zhengshengning/Paddle that referenced this pull request Oct 24, 2025
…nsor is floating (PaddlePaddle#75238)

* align LinspaceKernel

* update meta

* update gpu kernel

* fix LinspaceKernelInner

* improve kernel
zhengshengning added a commit that referenced this pull request Oct 27, 2025
* CallScalarFunction uses the dtype of 'self' as the type of 'other' when opotype is 'div'(#75237)

* LinspaceKernel uses the dtype of 'self' as the type of 'step' when tensor is floating (#75238)

* align LinspaceKernel

* update meta

* update gpu kernel

* fix LinspaceKernelInner

* improve kernel

* fix CudaSigmoidGradFunctor and CudaSiluGradFunctor (#75341)

* Softplus accuracy and torch alignment 1 (#75363)

* [Precision Depth Alignment] paddle.tan reverse calculation: dx = dout *(1 + tan(x)^2) (#75335)

* Tan reverse calculation: dx = dout *(1 + tan(x)^2)

* [Precision Depth Alignment] Add support for CUDNN to paddle.nn.functional.grid_sample to align with torch accuracy.  (#75355)

* accuracy_stable_grid_sample

* fix

* correlation supports big tensor (#75383)

* fix

* fix test

* fix

* paddle.tanh Grad and torch alignment (float16) (#75454)

* [Precision Depth Alignment] paddle.sin and paddle.cos aligns with torch precision. (#75503)

* accuracy_stable_sin

* accuracy_stable_cos

* [深度对齐]Divide (#75379)

* fix

* fix

* fix

* fix

* fix

* [Precision Depth Alignment] fix precision for float16 of paddle.tan backward (#75525)

* fix precision for float16 of paddle.tan backward

* fix else branch of CudaTanGradFunctor

* [Precision Depth Alignment] fix precision for  paddle.expm1 (#75549)

* accuracy_stable_expm1

* fix

* Bigtensor排查修复[Paddle/paddle/phi/kernels/funcs] (#75523)

* fix

* fix

* [Precision Depth Alignment]  fix beta and threshold of paddle.nn.functional.softplus  to double (#75426)

* fix beta and threshold of Softplus to double

* fix test_softplus_activation_fuse_pass v1

* fix test_activation_zero

* fix flaot of SoftplusDoubleGradKernel to double

* add op_patches for softplus

* add yaml for ops/yaml/legacy

* fix infershape/operator for FLOAT64

* fix

* add SoftPlusOpTranscriber

* fix

* fix

* fix1

* fix2

* fix coverage

* fix coverage2

* fix (#75605)

* [深度对齐] dot (#75717)

* fix

* fix

* fix dcu

* [Precision Depth Alignment]  paddle.log aligns with torch precision (#75799)

* accuracy_stable_log

* accuracy_stable_log

* fix

* fix

* fix

* fix

* fix5

* [Precision Depth Alignment] fix eps of paddle.logit from float to double (#75816)

* accuracy_stable_logit

* add LogitOpTranscriber

* fix coverage

* fix 0yaml

* [Precision Depth Alignment] paddle.log_sigmoid (#75898)

* accuracy_stable_log_sigmoid

* fix test_activation_stride_op.py

* [Precision Depth Alignment] Modify the negative_slope parameter of the paddle.nn.functional.leaky_relu API to double (#75547)

* [big tensor] Paddle/paddle/phi/kernels/funcs gpuBigtensor (#75856)

* fix funcs

* gpu

* fix

* fix

* 修改PADDLE_ENFORCE信息

* fix cpu error

* fix dcu

* fix dcu

* fix

* [Fix] log sigmoid complex (#75953)

* feature: Add specialized LogSigmoidFunctor and CudaLogSigmoidFunctor for complex numbers

This commit introduces specialized implementations of LogSigmoidFunctor and CudaLogSigmoidFunctor to handle complex number inputs. The new implementations utilize direct formulas for improved accuracy and stability in calculations involving complex types.

* refactor: Optimize LogSigmoidFunctor and CudaLogSigmoidFunctor for complex types by caching exp(-x) to reduce redundant computations. This change enhances performance while maintaining accuracy in calculations.

* refactor: modified the formula in LogSigmoidFunctor to make it numerical stable

---------

Co-authored-by: Zhan Rongrui <[email protected]>
Co-authored-by: 正在学习 <[email protected]>
Co-authored-by: Bvicii <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants