-
Notifications
You must be signed in to change notification settings - Fork 2.3k
C++11 fix warnings #1904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
C++11 fix warnings #1904
Conversation
… builds) when initializing at::cuda::CUDAGuard with a non-constant char cast to c10::DeviceIndex (signed char).
… builds) when initializing at::cuda::CUDAGuard with a non-constant char cast to c10::DeviceIndex (signed char).
|
cc @janeyx99 can you review now? |
… builds) when initializing at::cuda::CUDAGuard with a non-constant char cast to c10::DeviceIndex (signed char).
… builds) when initializing at::cuda::CUDAGuard with a non-constant char cast to c10::DeviceIndex (signed char).
janeyx99
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks okay tho I’d feel better if you’re able to download a torch nightly after September and verify that fa3 still builds
|
@tridao @janeyx99 it is working in my GH200 docker run -t --rm --network=host --privileged --shm-size=8g --runtime=nvidia \
--volume /home/ubuntu/Projects/jetson-containers/packages/attention/flash-attention:/test \
--volume /home/ubuntu/Projects/jetson-containers/data:/data \
flash-attention:r38.2.aarch64-cu130-24.04-flash-attention \
/bin/bash -c 'python3 /test/test.py' \
2>&1 | tee /home/ubuntu/Projects/jetson-containers/logs/20250924_092313/test/25-1_flash-attention_r38.2.aarch64-cu130-24.04-flash-attention_test.py.txt; exit ${PIPESTATUS[0]}
|
|
The screenshots looks like FA2, but you’re modifying FA3 right? So should be building from the hopper directory. Also can you verify that the torch version is greater than the 08.30 nightly? |
i can run the test for FA3, let me do it I'm using RC1: https://dev-discuss.pytorch.org/t/pytorch-2-9-rc1-produced-for-pytorch/3230 |
>>> import time
>>> import torch
>>> from flash_attn import flash_attn_varlen_func
>>> # from flash_attn_interface import flash_attn_varlen_func
>>>
>>>
>>> def call_fa(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
... return flash_attn_varlen_func(q, k, v,
... cu_seqlens_q=cu_seqlens_q,
... cu_seqlens_k=cu_seqlens_k,
... max_seqlen_q=max_seqlen_q,
... max_seqlen_k=max_seqlen_k,
... )
...
>>> def bench_fa(num_head: int = 24, head_dim: int = 128, seq_len: int = 6480):
... q = torch.rand([seq_len, num_head, head_dim], dtype=torch.bfloat16, device="cuda")
... k = q.clone()
... v = q.clone()
... cu_seqlens_q = torch.tensor([ 0, seq_len], dtype=torch.int32, device="cuda")
... max_seqlen_q = seq_len
... for _ in range(100):
... call_fa(q, k, v, cu_seqlens_q, cu_seqlens_q, max_seqlen_q, max_seqlen_q)
... torch.cuda.synchronize()
... bench_iter = 500
... st = time.time()
... for _ in range(bench_iter):
... call_fa(q, k, v, cu_seqlens_q, cu_seqlens_q, max_seqlen_q, max_seqlen_q)
... torch.cuda.synchronize()
... print(f"avg time: {(time.time() - st) * 1000 / bench_iter :.4f} ms")
...
...
>>> def test():
... # bench_fa(num_head=24, head_dim=128, seq_len=6480)
... bench_fa(num_head=3, head_dim=128, seq_len=50400)
...
...
>>> test()
avg time: 10.0027 ms
>>> print('testing PyTorch...')
testing PyTorch...
>>>
>>> import torch
>>>
>>> print(f'PyTorch version: {torch.__version__}')
PyTorch version: 2.9.0
>>> print(f'CUDA available: {torch.cuda.is_available()}')
CUDA available: True
>>> print(f'cuDNN version: {torch.backends.cudnn.version()}\n')
cuDNN version: 91200
>>>
>>> print(torch.__config__.show())
PyTorch built with:
- GCC 13.3
- C++ Version: 201703
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: DEFAULT
- CUDA Runtime 13.0
- NVCC architecture flags: -gencode;arch=compute_87,code=sm_87;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_100,code=sm_100;-gencode;arch=compute_103,code=sm_103;-gencode;arch=compute_110,code=sm_110;-gencode;arch=compute_120,code=sm_120;-gencode;arch=compute_121,code=sm_121
- CuDNN 91.2
- Build settings: BLAS_INFO=nvpl, BUILD_TYPE=Release, COMMIT_SHA=36d207fcaaede0d1e58a5168084c307b32b6fd8b, CUDA_VERSION=13.0, CUDNN_VERSION=9.12.0, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS=-ffunction-sections -fdata-sections -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_PYTORCH_QNNPACK -DAT_BUILD_ARM_VEC256_WITH_SLEEF -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -DC10_NODEPRECATED -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -faligned-new -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-dangling-reference -Wno-error=dangling-reference -Wno-stringop-overflow, LAPACK_INFO=nvpl, TORCH_VERSION=2.9.0, USE_CUDA=ON, USE_CUDNN=1, USE_CUSPARSELT=OFF, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=1, USE_MKLDNN=OFF, USE_MPI=0, USE_NCCL=1, USE_NNPACK=1, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, USE_XCCL=OFF, USE_XPU=OFF,
>>>
>>> # fail if CUDA isn't available
>>> assert(torch.cuda.is_available())
>>>
>>> print(f'\nPyTorch {torch.__version__}\n')
PyTorch 2.9.0
>>>
>>> try:
... print(f' * CUDA device {torch.cuda.get_device_name()}')
... print(f' * CUDA version {torch.version.cuda}')
... print(f' * CUDA cuDNN {torch.backends.cudnn.version()}')
... print(f' * CUDA BLAS {torch.backends.cuda.preferred_blas_library()}')
... print(f' * CUDA linalg {torch.backends.cuda.preferred_blas_library()}')
... print(f' * CUDA flash_attn {torch.backends.cuda.is_flash_attention_available()}')
... print(f' * CUDA flash_sdp {torch.backends.cuda.flash_sdp_enabled()}')
... print(f' * CUDA cudnn_sdp {torch.backends.cuda.cudnn_sdp_enabled()}')
... print(f' * CUDA math_sdp {torch.backends.cuda.math_sdp_enabled()}')
... print(f' * CUDA mem_efficient_sdp_enabled {torch.backends.cuda.mem_efficient_sdp_enabled()}')
... print(f' * CUDA fp16_bf16_reduction_math_sdp {torch.backends.cuda.fp16_bf16_reduction_math_sdp_allowed()}')
... except Exception as error:
... print(f'Exception trying to read PyTorch {torch.__version__} CUDA versions (this may be expected on older versions of PyTorch)\n{error}')
...
* CUDA device NVIDIA GH200 480GB
* CUDA version 13.0
* CUDA cuDNN 91200
* CUDA BLAS _BlasBackend.Cublas
* CUDA linalg _BlasBackend.Cublas
* CUDA flash_attn True
* CUDA flash_sdp True
* CUDA cudnn_sdp True
* CUDA math_sdp True
* CUDA mem_efficient_sdp_enabled True
* CUDA fp16_bf16_reduction_math_sdp False
>>> |
|
Epic, thank u! |
|
Thank you both! |


The errors in hopper/flash_api.cpp are C++11 narrowing warnings (treated as errors in strict builds) when initializing at::cuda::CUDAGuard with a non-constant char cast to c10::DeviceIndex (signed char).
cc @ko3n1g @tridao