🐛 Describe the bug
Bug
F.grid_sample backward on CUDA runs without raising RuntimeError when torch.use_deterministic_algorithms(True) is set, despite being documented as an operation that should throw. The backward then produces non-deterministic gradients (~2e-4 max abs diff) between identical runs.
Expected behavior
Per the docs, torch.nn.functional.grid_sample() when attempting to differentiate a CUDA tensor should throw RuntimeError under use_deterministic_algorithms(True):
The following normally-nondeterministic operations will throw a RuntimeError when mode=True:
...
torch.nn.functional.grid_sample() when attempting to differentiate a CUDA tensor
Instead, backward completes silently and produces different gradients across runs with identical seeds and inputs.
Minimal reproduction
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# Bug 1: Should throw RuntimeError but doesn't
x = torch.randn(4, 1, 16, 16, device='cuda', requires_grad=True)
grid = torch.rand(4, 1, 81, 2, device='cuda') * 2 - 1
out = F.grid_sample(x, grid, align_corners=True)
out.sum().backward() # No error raised
# Bug 2: Gradients differ between identical runs
grads = []
for run in range(2):
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
conv = nn.Conv2d(3, 64, 7, stride=8, padding=3).cuda()
x = torch.randn(2, 3, 256, 256, device='cuda')
feat = conv(x)
f2 = torch.randn(2, 64, 16, 16, device='cuda')
corr = torch.matmul(
feat.view(2, 64, -1).transpose(1, 2),
f2.view(2, 64, -1),
).view(2048, 1, 16, 16)
grid = torch.rand(2048, 1, 81, 2, device='cuda') * 2 - 1
F.grid_sample(corr, grid, align_corners=True).sum().backward()
grads.append(conv.weight.grad.clone())
diff = (grads[0] - grads[1]).abs().max().item()
print(f"Max gradient diff: {diff:.6e}") # ~2.4e-04, should be 0.0
Observed behavior
Max gradient diff: 2.441406e-04
No RuntimeError raised. Gradients differ by ~2e-4 between two identical runs in the same process, same GPU.
Note: Replacing the random grid with uniform coordinates (torch.zeros(...)) makes it deterministic — the non-determinism is triggered specifically by non-uniform grid coordinates across the batch dimension.
Environment
Tested across multiple versions — same result on all:
| PyTorch |
CUDA |
cuDNN |
GPU |
Result |
| 2.0.1 |
11.8 |
8.7.0 |
H200 |
Non-deterministic, no error |
| 2.1.2 |
12.1 |
8.9.02 |
H200 |
Non-deterministic, no error |
| 2.4.1 |
12.4 |
9.1.0 |
H200 |
Non-deterministic, no error |
| 2.11.0 |
12.6 |
9.10.02 |
H200 |
Non-deterministic, no error |
Not yet tested on A100/V100.
Impact
Any model that uses F.grid_sample in its forward pass (e.g., RAFT, optical flow networks, spatial transformers, homography estimation) cannot achieve deterministic training on CUDA, even with all deterministic flags properly configured. Users get no warning — the operation silently produces different results.
Additional context
grid_sample forward is deterministic — only backward is affected
- The non-determinism scales with gradient magnitude (~1e-5 relative error)
- Uniform grids (all batch elements sample the same coordinates) are deterministic; non-uniform grids (each batch element samples different coordinates) are not
- The effect is independent of spatial size, batch size, and number of sample points
Versions
Collecting environment information...
PyTorch version: 2.4.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.11.14 (main, Feb 3 2026, 22:51:56) [Clang 21.1.4 ] (64-bit runtime)
Python platform: Linux-5.15.0-139-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: 11.5.119
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration:
GPU 0: NVIDIA H200
GPU 1: NVIDIA H200
GPU 2: NVIDIA H200
GPU 3: NVIDIA H200
GPU 4: NVIDIA H200
GPU 5: NVIDIA H200
GPU 6: NVIDIA H200
GPU 7: NVIDIA H200
Nvidia driver version: 570.133.20
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 240
On-line CPU(s) list: 0-239
Vendor ID: GenuineIntel
Model name: INTEL(R) XEON(R) PLATINUM 8580
CPU family: 6
Model: 207
Thread(s) per core: 2
Core(s) per socket: 60
Socket(s): 2
Stepping: 2
CPU max MHz: 2000.0000
CPU min MHz: 800.0000
BogoMIPS: 4000.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 5.6 MiB (120 instances)
L1i cache: 3.8 MiB (120 instances)
L2 cache: 240 MiB (120 instances)
L3 cache: 600 MiB (2 instances)
NUMA node(s): 4
NUMA node0 CPU(s): 0-29,120-149
NUMA node1 CPU(s): 30-59,150-179
NUMA node2 CPU(s): 60-89,180-209
NUMA node3 CPU(s): 90-119,210-239
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect
cc @ezyang @gchanan @kadeng @msaroufim @albanD @gqchen @nikitaved @soulitzer @Varal7 @bobrenjc93 @mruberry @jbschlosser @walterddr @mikaylagawarecki @ptrblck @eqy @jerryzh168 @tinglvv @nWEIdia @kurtamohler
🐛 Describe the bug
Bug
F.grid_samplebackward on CUDA runs without raising RuntimeError whentorch.use_deterministic_algorithms(True)is set, despite being documented as an operation that should throw. The backward then produces non-deterministic gradients (~2e-4 max abs diff) between identical runs.Expected behavior
Per the docs,
torch.nn.functional.grid_sample()when attempting to differentiate a CUDA tensor should throwRuntimeErrorunderuse_deterministic_algorithms(True):Instead, backward completes silently and produces different gradients across runs with identical seeds and inputs.
Minimal reproduction
Observed behavior
No
RuntimeErrorraised. Gradients differ by ~2e-4 between two identical runs in the same process, same GPU.Note: Replacing the random grid with uniform coordinates (
torch.zeros(...)) makes it deterministic — the non-determinism is triggered specifically by non-uniform grid coordinates across the batch dimension.Environment
Tested across multiple versions — same result on all:
Not yet tested on A100/V100.
Impact
Any model that uses
F.grid_samplein its forward pass (e.g., RAFT, optical flow networks, spatial transformers, homography estimation) cannot achieve deterministic training on CUDA, even with all deterministic flags properly configured. Users get no warning — the operation silently produces different results.Additional context
grid_sampleforward is deterministic — only backward is affectedVersions
Collecting environment information...
PyTorch version: 2.4.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.11.14 (main, Feb 3 2026, 22:51:56) [Clang 21.1.4 ] (64-bit runtime)
Python platform: Linux-5.15.0-139-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: 11.5.119
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration:
GPU 0: NVIDIA H200
GPU 1: NVIDIA H200
GPU 2: NVIDIA H200
GPU 3: NVIDIA H200
GPU 4: NVIDIA H200
GPU 5: NVIDIA H200
GPU 6: NVIDIA H200
GPU 7: NVIDIA H200
Nvidia driver version: 570.133.20
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 240
On-line CPU(s) list: 0-239
Vendor ID: GenuineIntel
Model name: INTEL(R) XEON(R) PLATINUM 8580
CPU family: 6
Model: 207
Thread(s) per core: 2
Core(s) per socket: 60
Socket(s): 2
Stepping: 2
CPU max MHz: 2000.0000
CPU min MHz: 800.0000
BogoMIPS: 4000.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 5.6 MiB (120 instances)
L1i cache: 3.8 MiB (120 instances)
L2 cache: 240 MiB (120 instances)
L3 cache: 600 MiB (2 instances)
NUMA node(s): 4
NUMA node0 CPU(s): 0-29,120-149
NUMA node1 CPU(s): 30-59,150-179
NUMA node2 CPU(s): 60-89,180-209
NUMA node3 CPU(s): 90-119,210-239
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect
cc @ezyang @gchanan @kadeng @msaroufim @albanD @gqchen @nikitaved @soulitzer @Varal7 @bobrenjc93 @mruberry @jbschlosser @walterddr @mikaylagawarecki @ptrblck @eqy @jerryzh168 @tinglvv @nWEIdia @kurtamohler