Skip to content

tl.dot() input dimensions #2266

@var-nan

Description

@var-nan

Do all the input dimensions must be greater than 16 for tl.dot() operation?

I'm trying to do tl.dot(a,b) between a 2D matrix (shape: [16,16]) and 1D vector (shape[16,1]). I'm getting a compilation error as
AssertionError('All values in both first input shape ([constexpr[16], constexpr[16]]) and second input shape ([constexpr[16], constexpr[1]]) must be >= 16!')

I'm not sure whether I'm doing it wrong or Triton doesn't allow dot on 1D vectors.

Below is the code for reference.

@triton.jit
def kernel_opt(A_in, 
                    A_out, 
                    kernel, kernel_size_sq:tl.constexpr, 
                    BLOCK_SIZE: tl.constexpr):
    
    """ A_in (shape: [M,N]), A_out (shape: [M',1]), kernel (shape: [N,1]) """

    pid = tl.program_id(0)
    
    stride_x = kernel_size_sq
    offsets_x = tl.arange(0,kernel_size_sq)
    offsets_y = tl.arange(0, BLOCK_SIZE)[:,None] * stride_x
    
    data_pointers = A_in + pid* kernel_size_sq*BLOCK_SIZE + offsets_y + offsets_x
    kernel_pointers = kernel_pointer + offsets_x[None,:]
    
    # load
    load_mat = tl.load(data_pointers)
    load_kernel = tl.load(kernel_pointers)

    # compute
    result = tl.dot(load_mat, tl.trans(load_kernel)) # result should be column vector
    
    # store result
    out_pointers = A_out  + pid*BLOCK_SIZE + offsets_x[:,None]
    tl.store(out_pointers, result)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions