-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Open
Description
Do all the input dimensions must be greater than 16 for tl.dot() operation?
I'm trying to do tl.dot(a,b) between a 2D matrix (shape: [16,16]) and 1D vector (shape[16,1]). I'm getting a compilation error as
AssertionError('All values in both first input shape ([constexpr[16], constexpr[16]]) and second input shape ([constexpr[16], constexpr[1]]) must be >= 16!')
I'm not sure whether I'm doing it wrong or Triton doesn't allow dot on 1D vectors.
Below is the code for reference.
@triton.jit
def kernel_opt(A_in,
A_out,
kernel, kernel_size_sq:tl.constexpr,
BLOCK_SIZE: tl.constexpr):
""" A_in (shape: [M,N]), A_out (shape: [M',1]), kernel (shape: [N,1]) """
pid = tl.program_id(0)
stride_x = kernel_size_sq
offsets_x = tl.arange(0,kernel_size_sq)
offsets_y = tl.arange(0, BLOCK_SIZE)[:,None] * stride_x
data_pointers = A_in + pid* kernel_size_sq*BLOCK_SIZE + offsets_y + offsets_x
kernel_pointers = kernel_pointer + offsets_x[None,:]
# load
load_mat = tl.load(data_pointers)
load_kernel = tl.load(kernel_pointers)
# compute
result = tl.dot(load_mat, tl.trans(load_kernel)) # result should be column vector
# store result
out_pointers = A_out + pid*BLOCK_SIZE + offsets_x[:,None]
tl.store(out_pointers, result)
Metadata
Metadata
Assignees
Labels
No labels