Skip to content

Commit 29a4231

Browse files
committed
Fix example based on triton-lang#2701.
1 parent d04f288 commit 29a4231

1 file changed

Lines changed: 8 additions & 1 deletion

File tree

python/examples/copy_strided.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import triton
22
import triton.language as tl
3+
import triton.compiler as tc
34

45

56
# triton kernel
@@ -14,5 +15,11 @@ def kernel(X, stride_xm, #
1415
tl.store(Zs, tl.load(Xs))
1516

1617

17-
ret = triton.compile(kernel, signature="*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
18+
src = tc.ASTSource(
19+
fn=kernel,
20+
constants={"BLOCK_M": 64, "BLOCK_N": 64},
21+
signature="*fp32,i32,*fp32,i32",
22+
)
23+
24+
ret = triton.compile(src)
1825
print(ret.asm["ttgir"])

0 commit comments

Comments
 (0)