Skip to content

Conversation

@johnnynunez
Copy link
Contributor

@johnnynunez johnnynunez commented Sep 23, 2025

The errors in hopper/flash_api.cpp are C++11 narrowing warnings (treated as errors in strict builds) when initializing at::cuda::CUDAGuard with a non-constant char cast to c10::DeviceIndex (signed char).

cc @ko3n1g @tridao

… builds) when initializing at::cuda::CUDAGuard with a non-constant char cast to c10::DeviceIndex (signed char).
… builds) when initializing at::cuda::CUDAGuard with a non-constant char cast to c10::DeviceIndex (signed char).
@johnnynunez
Copy link
Contributor Author

cc @janeyx99 can you review now?

… builds) when initializing at::cuda::CUDAGuard with a non-constant char cast to c10::DeviceIndex (signed char).
… builds) when initializing at::cuda::CUDAGuard with a non-constant char cast to c10::DeviceIndex (signed char).
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks okay tho I’d feel better if you’re able to download a torch nightly after September and verify that fa3 still builds

@johnnynunez
Copy link
Contributor Author

cc @ko3n1g @tridao

@johnnynunez
Copy link
Contributor Author

@tridao @janeyx99 it is working in my GH200
image

docker run -t --rm --network=host --privileged --shm-size=8g --runtime=nvidia \
  --volume /home/ubuntu/Projects/jetson-containers/packages/attention/flash-attention:/test \
  --volume /home/ubuntu/Projects/jetson-containers/data:/data \
  flash-attention:r38.2.aarch64-cu130-24.04-flash-attention \
    /bin/bash -c 'python3 /test/test.py' \
2>&1 | tee /home/ubuntu/Projects/jetson-containers/logs/20250924_092313/test/25-1_flash-attention_r38.2.aarch64-cu130-24.04-flash-attention_test.py.txt; exit ${PIPESTATUS[0]}

@johnnynunez
Copy link
Contributor Author

Screenshot 2025-09-24 at 13 10 40

@janeyx99
Copy link
Contributor

The screenshots looks like FA2, but you’re modifying FA3 right? So should be building from the hopper directory. Also can you verify that the torch version is greater than the 08.30 nightly?

@johnnynunez
Copy link
Contributor Author

johnnynunez commented Sep 24, 2025

The screenshots looks like FA2, but you’re modifying FA3 right? So should be building from the hopper directory. Also can you verify that the torch version is greater than the 08.30 nightly?

i can run the test for FA3, let me do it

I'm using RC1: https://dev-discuss.pytorch.org/t/pytorch-2-9-rc1-produced-for-pytorch/3230

@johnnynunez
Copy link
Contributor Author

The screenshots looks like FA2, but you’re modifying FA3 right? So should be building from the hopper directory. Also can you verify that the torch version is greater than the 08.30 nightly?

>>> import time
>>> import torch
>>> from flash_attn import flash_attn_varlen_func
>>> # from flash_attn_interface import flash_attn_varlen_func
>>> 
>>> 
>>> def call_fa(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
...     return flash_attn_varlen_func(q, k, v,
...                 cu_seqlens_q=cu_seqlens_q,
...                 cu_seqlens_k=cu_seqlens_k,
...                 max_seqlen_q=max_seqlen_q,
...                 max_seqlen_k=max_seqlen_k,
...             )
... 
>>> def bench_fa(num_head: int = 24, head_dim: int = 128, seq_len: int = 6480):
...     q = torch.rand([seq_len, num_head, head_dim], dtype=torch.bfloat16, device="cuda")
...     k = q.clone()
...     v = q.clone()
...     cu_seqlens_q = torch.tensor([    0, seq_len], dtype=torch.int32, device="cuda")
...     max_seqlen_q = seq_len
...     for _ in range(100):
...         call_fa(q, k, v, cu_seqlens_q, cu_seqlens_q, max_seqlen_q, max_seqlen_q)
...     torch.cuda.synchronize()
...     bench_iter = 500
...     st = time.time()
...     for _ in range(bench_iter):
...         call_fa(q, k, v, cu_seqlens_q, cu_seqlens_q, max_seqlen_q, max_seqlen_q)
...     torch.cuda.synchronize()
...     print(f"avg time: {(time.time() - st) * 1000 / bench_iter :.4f} ms")
...     
... 
>>> def test():
...     # bench_fa(num_head=24, head_dim=128, seq_len=6480)
...     bench_fa(num_head=3, head_dim=128, seq_len=50400)
...     
... 
>>> test()
avg time: 10.0027 ms
>>> print('testing PyTorch...')
testing PyTorch...
>>> 
>>> import torch
>>> 
>>> print(f'PyTorch version: {torch.__version__}')
PyTorch version: 2.9.0
>>> print(f'CUDA available:  {torch.cuda.is_available()}')
CUDA available:  True
>>> print(f'cuDNN version:   {torch.backends.cudnn.version()}\n')
cuDNN version:   91200

>>> 
>>> print(torch.__config__.show())
PyTorch built with:
  - GCC 13.3
  - C++ Version: 201703
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: DEFAULT
  - CUDA Runtime 13.0
  - NVCC architecture flags: -gencode;arch=compute_87,code=sm_87;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_100,code=sm_100;-gencode;arch=compute_103,code=sm_103;-gencode;arch=compute_110,code=sm_110;-gencode;arch=compute_120,code=sm_120;-gencode;arch=compute_121,code=sm_121
  - CuDNN 91.2
  - Build settings: BLAS_INFO=nvpl, BUILD_TYPE=Release, COMMIT_SHA=36d207fcaaede0d1e58a5168084c307b32b6fd8b, CUDA_VERSION=13.0, CUDNN_VERSION=9.12.0, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS=-ffunction-sections -fdata-sections -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_PYTORCH_QNNPACK -DAT_BUILD_ARM_VEC256_WITH_SLEEF -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -DC10_NODEPRECATED -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -faligned-new -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-dangling-reference -Wno-error=dangling-reference -Wno-stringop-overflow, LAPACK_INFO=nvpl, TORCH_VERSION=2.9.0, USE_CUDA=ON, USE_CUDNN=1, USE_CUSPARSELT=OFF, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=1, USE_MKLDNN=OFF, USE_MPI=0, USE_NCCL=1, USE_NNPACK=1, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, USE_XCCL=OFF, USE_XPU=OFF, 

>>> 
>>> # fail if CUDA isn't available
>>> assert(torch.cuda.is_available())
>>> 
>>> print(f'\nPyTorch {torch.__version__}\n')

PyTorch 2.9.0

>>> 
>>> try:
...     print(f'  * CUDA device     {torch.cuda.get_device_name()}')
...     print(f'  * CUDA version    {torch.version.cuda}')
...     print(f'  * CUDA cuDNN      {torch.backends.cudnn.version()}')
...     print(f'  * CUDA BLAS       {torch.backends.cuda.preferred_blas_library()}')
...     print(f'  * CUDA linalg     {torch.backends.cuda.preferred_blas_library()}')
...     print(f'  * CUDA flash_attn {torch.backends.cuda.is_flash_attention_available()}')
...     print(f'  * CUDA flash_sdp  {torch.backends.cuda.flash_sdp_enabled()}')
...     print(f'  * CUDA cudnn_sdp  {torch.backends.cuda.cudnn_sdp_enabled()}')
...     print(f'  * CUDA math_sdp   {torch.backends.cuda.math_sdp_enabled()}')
...     print(f'  * CUDA mem_efficient_sdp_enabled    {torch.backends.cuda.mem_efficient_sdp_enabled()}')
...     print(f'  * CUDA fp16_bf16_reduction_math_sdp {torch.backends.cuda.fp16_bf16_reduction_math_sdp_allowed()}')
... except Exception as error:
...     print(f'Exception trying to read PyTorch {torch.__version__} CUDA versions (this may be expected on older versions of PyTorch)\n{error}')
... 
  * CUDA device     NVIDIA GH200 480GB
  * CUDA version    13.0
  * CUDA cuDNN      91200
  * CUDA BLAS       _BlasBackend.Cublas
  * CUDA linalg     _BlasBackend.Cublas
  * CUDA flash_attn True
  * CUDA flash_sdp  True
  * CUDA cudnn_sdp  True
  * CUDA math_sdp   True
  * CUDA mem_efficient_sdp_enabled    True
  * CUDA fp16_bf16_reduction_math_sdp False
>>> 

@janeyx99
Copy link
Contributor

Epic, thank u!

@johnnynunez
Copy link
Contributor Author

johnnynunez commented Sep 25, 2025

cc @tridao @ko3n1g ready to merge

@tridao tridao merged commit add1756 into Dao-AILab:main Sep 25, 2025
@tridao
Copy link
Member

tridao commented Sep 25, 2025

Thank you both!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants