diff --git a/CMakeLists.txt b/CMakeLists.txt index b48ad6cad751..a6cabb9feecd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,8 +3,15 @@ cmake_minimum_required(VERSION 3.13.4) include(CheckCXXSourceCompiles) set(POLYGEIST_ENABLE_CUDA 0 CACHE BOOL "Enable CUDA frontend and backend") +set(POLYGEIST_ENABLE_CUDA_SYNTAX_ONLY 0 CACHE BOOL "Enable CUDA syntax parsing without requiring CUDA toolkit") set(POLYGEIST_ENABLE_ROCM 0 CACHE BOOL "Enable ROCM backend") +# If CUDA_SYNTAX_ONLY is enabled, set flag for frontend-only mode +if(POLYGEIST_ENABLE_CUDA_SYNTAX_ONLY) + set(POLYGEIST_CUDA_FRONTEND_ONLY 1) + message(STATUS "CUDA syntax-only mode enabled (no CUDA toolkit required)") +endif() + set(POLYGEIST_ENABLE_POLYMER 0 CACHE BOOL "Enable Polymer") set(POLYGEIST_POLYMER_ENABLE_ISL 0 CACHE BOOL "Enable Polymer isl") set(POLYGEIST_POLYMER_ENABLE_PLUTO 0 CACHE BOOL "Enable Polymer pluto") diff --git a/include/polygeist/Passes/Passes.h b/include/polygeist/Passes/Passes.h index fd4ab057da3d..f48541eb3a60 100644 --- a/include/polygeist/Passes/Passes.h +++ b/include/polygeist/Passes/Passes.h @@ -79,6 +79,9 @@ createGpuSerializeToHsacoPass(StringRef arch, StringRef features, void registerGpuSerializeToCubinPass(); void registerGpuSerializeToHsacoPass(); +std::unique_ptr createConvertGPUToVortexPass(); +std::unique_ptr createGenerateVortexMainPass(); + void populateForBreakToWhilePatterns(RewritePatternSet &patterns); } // namespace polygeist } // namespace mlir diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index 70081f5502a0..8b6b714e48f1 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -300,4 +300,52 @@ def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> ]; } +def ConvertGPUToVortex : Pass<"convert-gpu-to-vortex", "ModuleOp"> { + let summary = "Lower GPU dialect operations to Vortex RISC-V intrinsics"; + let description = [{ + This pass converts GPU dialect operations to LLVM dialect with Vortex-specific + intrinsics. It lowers operations like gpu.thread_id, gpu.block_id to RISC-V + CSR reads via inline assembly, preparing the code for the Vortex GPGPU backend. + + Example: + %tid = gpu.thread_id x + becomes: + %tid = llvm.inline_asm "csrr $0, 0xCC0" : () -> i32 + + The pass automatically emits JSON metadata files describing kernel argument + layouts for runtime argument marshaling. Files are written to the current + working directory as .meta.json. + }]; + let constructor = "mlir::polygeist::createConvertGPUToVortexPass()"; + let dependentDialects = [ + "LLVM::LLVMDialect", + "gpu::GPUDialect", + ]; +} + +def GenerateVortexMain : Pass<"generate-vortex-main", "ModuleOp"> { + let summary = "Generate Vortex main() wrapper for kernel execution"; + let description = [{ + This pass generates the Vortex-specific main() entry point and kernel_body + wrapper function. It should run AFTER gpu-to-llvm lowering has converted + gpu.func to llvm.func. + + The generated code: + 1. main() function that: + - Reads kernel args from VX_CSR_MSCRATCH (0x340) via inline assembly + - Calls vx_spawn_threads() with grid dimensions and kernel callback + 2. kernel_body() wrapper that: + - Takes void* args pointer + - Unpacks individual arguments from the struct + - Calls the original lowered kernel function + + This matches the Vortex kernel execution model where kernels are launched + via vx_spawn_threads() with a callback function. + }]; + let constructor = "mlir::polygeist::createGenerateVortexMainPass()"; + let dependentDialects = [ + "LLVM::LLVMDialect", + ]; +} + #endif // POLYGEIST_PASSES diff --git a/lib/polygeist/ExecutionEngine/CMakeLists.txt b/lib/polygeist/ExecutionEngine/CMakeLists.txt index 3049f2fb3e54..b322448f5a2e 100644 --- a/lib/polygeist/ExecutionEngine/CMakeLists.txt +++ b/lib/polygeist/ExecutionEngine/CMakeLists.txt @@ -1,5 +1,11 @@ # TODO we do not support cross compilation currently +# Skip execution engine entirely if syntax-only mode +if(POLYGEIST_CUDA_FRONTEND_ONLY) + message(STATUS "Skipping CUDA execution engine (syntax-only mode)") + return() +endif() + if(POLYGEIST_ENABLE_CUDA) find_package(CUDA) enable_language(CUDA) diff --git a/lib/polygeist/Passes/CMakeLists.txt b/lib/polygeist/Passes/CMakeLists.txt index c385d548a428..37ebd1b184b2 100644 --- a/lib/polygeist/Passes/CMakeLists.txt +++ b/lib/polygeist/Passes/CMakeLists.txt @@ -18,6 +18,8 @@ add_mlir_dialect_library(MLIRPolygeistTransforms InnerSerialization.cpp ForBreakToWhile.cpp ConvertParallelToGPU.cpp + ConvertGPUToVortex.cpp + GenerateVortexMain.cpp SerializeToCubin.cpp SerializeToHsaco.cpp ParallelLoopUnroll.cpp @@ -80,7 +82,8 @@ target_compile_definitions(obj.MLIRPolygeistTransforms POLYGEIST_PGO_DATA_DIR_ENV_VAR="${POLYGEIST_PGO_DATA_DIR_ENV_VAR}" ) -if(POLYGEIST_ENABLE_CUDA) +# Only require CUDA toolkit if full CUDA support (not syntax-only mode) +if(POLYGEIST_ENABLE_CUDA AND NOT POLYGEIST_CUDA_FRONTEND_ONLY) find_package(CUDA) enable_language(CUDA) diff --git a/lib/polygeist/Passes/ConvertGPUToVortex.cpp b/lib/polygeist/Passes/ConvertGPUToVortex.cpp new file mode 100644 index 000000000000..cf558ba8f5e1 --- /dev/null +++ b/lib/polygeist/Passes/ConvertGPUToVortex.cpp @@ -0,0 +1,1231 @@ +//===- ConvertGPUToVortex.cpp - Lower GPU dialect to Vortex intrinsics ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass that lowers GPU dialect operations to LLVM +// dialect with Vortex-specific intrinsics (CSR reads, custom instructions). +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "polygeist/Passes/Passes.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/raw_ostream.h" +#include +#include + +using namespace mlir; +using namespace mlir::gpu; + +namespace { + +//===----------------------------------------------------------------------===// +// Vortex CSR Addresses (from vx_intrinsics.h) +//===----------------------------------------------------------------------===// + +constexpr uint32_t VX_CSR_THREAD_ID = 0xCC0; +constexpr uint32_t VX_CSR_WARP_ID = 0xCC1; +constexpr uint32_t VX_CSR_CORE_ID = 0xCC2; +constexpr uint32_t VX_CSR_NUM_THREADS = 0xFC0; +constexpr uint32_t VX_CSR_NUM_WARPS = 0xFC1; +constexpr uint32_t VX_CSR_NUM_CORES = 0xFC2; +constexpr uint32_t VX_CSR_LOCAL_MEM_BASE = 0xFC3; + +//===----------------------------------------------------------------------===// +// Preprocessing: Consolidate Polygeist Alternatives +//===----------------------------------------------------------------------===// + +/// Extract base kernel name by removing Polygeist variant suffix +/// Example: _Z12launch_basicPiS_ji_kernel94565344022848 -> _Z12launch_basicPiS_ji +/// Example: __polygeist_launch_vecadd_kernel_kernel94... -> __polygeist_launch_vecadd_kernel +static StringRef extractBaseKernelName(StringRef mangledName) { + // Search from the end for "_kernel" followed by digits + // This handles cases like "vecadd_kernel_kernel94..." where the kernel name + // itself contains "_kernel" + size_t searchStart = 0; + size_t lastValidPos = StringRef::npos; + + while (true) { + size_t pos = mangledName.find("_kernel", searchStart); + if (pos == StringRef::npos) + break; + + size_t suffixStart = pos + 7; // Length of "_kernel" + if (suffixStart < mangledName.size() && + std::isdigit(mangledName[suffixStart])) { + // Found "_kernel" followed by digit - this is a potential suffix + lastValidPos = pos; + } + searchStart = pos + 1; + } + + if (lastValidPos != StringRef::npos) { + return mangledName.substr(0, lastValidPos); + } + return mangledName; +} + +/// Consolidate polygeist.alternatives to first variant only +/// This preprocessing step simplifies downstream processing by: +/// 1. Replacing polygeist.alternatives with content of first alternative +/// 2. Ensuring single canonical launch configuration for Vortex +static void consolidatePolygeistAlternatives(ModuleOp module) { + SmallVector altOps; + + // Collect all polygeist.alternatives operations + module.walk([&](Operation *op) { + if (op->getName().getStringRef() == "polygeist.alternatives") { + altOps.push_back(op); + } + }); + + // Replace each alternatives op with content of its first region + for (Operation *altOp : altOps) { + if (altOp->getNumRegions() == 0 || altOp->getRegion(0).empty()) + continue; + + OpBuilder builder(altOp); + Region &firstRegion = altOp->getRegion(0); + Block &firstBlock = firstRegion.front(); + + // Move all operations from first region to parent block (before alternatives op) + // This inlines the first alternative's content + auto &ops = firstBlock.getOperations(); + for (Operation &innerOp : llvm::make_early_inc_range(ops)) { + // Skip the terminator (polygeist.polygeist_yield) + if (innerOp.getName().getStringRef() == "polygeist.polygeist_yield") + continue; + innerOp.moveBefore(altOp); + } + + // Erase the now-empty alternatives operation + altOp->erase(); + } +} + +/// Remove duplicate GPU kernel functions, keeping only the first variant +/// After Polygeist auto-tuning, multiple kernel variants exist but only +/// the first one is referenced after consolidating alternatives. +static void removeDuplicateKernels(ModuleOp module) { + // Track seen kernel base names + llvm::StringMap seenKernels; + SmallVector toErase; + + // Walk all GPU modules + module.walk([&](gpu::GPUModuleOp gpuModule) { + // Collect all kernel functions + for (auto gpuFunc : gpuModule.getOps()) { + if (!gpuFunc.isKernel()) + continue; + + StringRef funcName = gpuFunc.getName(); + StringRef baseName = extractBaseKernelName(funcName); + + // Check if we've seen this kernel base name before + auto it = seenKernels.find(baseName); + if (it != seenKernels.end()) { + // Duplicate found - mark for deletion + toErase.push_back(gpuFunc); + } else { + // First occurrence - keep it + seenKernels[baseName] = gpuFunc; + } + } + }); + + // Erase duplicate kernels + for (auto func : toErase) { + func.erase(); + } +} + +//===----------------------------------------------------------------------===// +// Helper Functions +//===----------------------------------------------------------------------===// + +/// Declare an external function to access TLS dim3_t variables +/// For thread-local variables like blockIdx/threadIdx, we generate helper +/// functions that return pointers to the TLS variables +/// Returns an LLVM function declaration +/// The function is declared within the gpu.module where it's being used +static LLVM::LLVMFuncOp getOrCreateDim3TLSAccessor(Operation *op, + OpBuilder &builder, + StringRef varName) { + // Find the gpu.module containing this operation + auto gpuModule = op->getParentOfType(); + MLIRContext *context = gpuModule.getContext(); + + // Create function name: e.g., "vx_get_blockIdx" + std::string funcName = ("vx_get_" + varName).str(); + + // Check if function already exists in gpu.module + if (auto existing = gpuModule.lookupSymbol(funcName)) { + return existing; + } + + // Create function type: () -> !llvm.ptr (returns pointer to dim3_t) + auto ptrType = LLVM::LLVMPointerType::get(context); + auto funcType = LLVM::LLVMFunctionType::get(ptrType, {}, /*isVarArg=*/false); + + // Declare external function within gpu.module + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(gpuModule.getBody()); + + return builder.create( + gpuModule.getLoc(), + funcName, + funcType, + LLVM::Linkage::External); +} + +/// Access a field of a TLS dim3_t variable (threadIdx or blockIdx) +/// dimension: gpu::Dimension::x (0), y (1), or z (2) +static Value createDim3TLSAccess(Operation *op, + ConversionPatternRewriter &rewriter, + Location loc, + StringRef varName, + gpu::Dimension dimension) { + auto module = op->getParentOfType(); + MLIRContext *context = module.getContext(); + + // Get or create the TLS accessor function + auto accessorFunc = getOrCreateDim3TLSAccessor(op, rewriter, varName); + + // Call the accessor function to get pointer to TLS variable + auto ptrType = LLVM::LLVMPointerType::get(context); + auto callResult = rewriter.create( + loc, accessorFunc, ValueRange{}); + Value dim3Ptr = callResult.getResult(); + + // Create GEP to access the specific field (x=0, y=1, z=2) + auto i32Type = rewriter.getI32Type(); + auto dim3Type = LLVM::LLVMStructType::getLiteral( + context, {i32Type, i32Type, i32Type}); + + // GEP indices: [0, dimension] + // First 0 is to dereference the pointer + // Second index selects the struct field + SmallVector indices; + indices.push_back(0); // Base index + indices.push_back(static_cast(dimension)); // Field index (0=x, 1=y, 2=z) + + auto gep = rewriter.create( + loc, ptrType, dim3Type, dim3Ptr, indices); + + // Load the value from the computed address + auto result = rewriter.create(loc, i32Type, gep); + + return result.getResult(); +} + +//===----------------------------------------------------------------------===// +// Conversion Patterns +//===----------------------------------------------------------------------===// + +/// Lower gpu.thread_id to TLS variable access +/// Accesses the threadIdx TLS variable set by vx_spawn_threads() +struct ThreadIdOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(ThreadIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // Get the dimension (X, Y, or Z) + auto dimension = op.getDimension(); + + // Access threadIdx.{x,y,z} from TLS + auto result = createDim3TLSAccess(op, rewriter, loc, + "threadIdx", dimension); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +/// Lower gpu.block_id to TLS variable access +/// Accesses the blockIdx TLS variable set by vx_spawn_threads() +struct BlockIdOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(BlockIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // Get the dimension (X, Y, or Z) + auto dimension = op.getDimension(); + + // Access blockIdx.{x,y,z} from TLS + auto result = createDim3TLSAccess(op, rewriter, loc, + "blockIdx", dimension); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +/// Lower gpu.block_dim to TLS variable access +/// Accesses the blockDim global variable set by vx_spawn_threads() +struct BlockDimOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // Get the dimension (X, Y, or Z) + auto dimension = op.getDimension(); + + // Access blockDim.{x,y,z} from global variable + // Note: blockDim is NOT thread-local, it's a regular global + auto result = createDim3TLSAccess(op, rewriter, loc, + "blockDim", dimension); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +/// Lower gpu.grid_dim to TLS variable access +/// Accesses the gridDim global variable set by vx_spawn_threads() +struct GridDimOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::GridDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // Get the dimension (X, Y, or Z) + auto dimension = op.getDimension(); + + // Access gridDim.{x,y,z} from global variable + // Note: gridDim is NOT thread-local, it's a regular global + auto result = createDim3TLSAccess(op, rewriter, loc, + "gridDim", dimension); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +/// Lower gpu.barrier to Vortex vx_barrier call +/// Synchronizes all threads in a block using Vortex hardware barriers +struct BarrierOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + // Declare functions in gpu.module (not top-level module) so they're visible + auto gpuModule = op->getParentOfType(); + if (!gpuModule) + return failure(); + MLIRContext *context = gpuModule.getContext(); + + // Allocate barrier ID (simple counter for now) + // TODO: Proper barrier ID allocation to avoid conflicts + static int barrierIdCounter = 0; + int barrierId = barrierIdCounter++; + + // Create barrier ID constant + auto i32Type = rewriter.getI32Type(); + auto barIdConstant = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(barrierId)); + + // Declare vx_num_warps_abi function in gpu.module if not already declared + // Using _abi suffix to call the non-inline wrapper in vx_intrinsics_abi.c + auto vxNumWarpsFunc = gpuModule.lookupSymbol("vx_num_warps_abi"); + if (!vxNumWarpsFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(gpuModule.getBody()); + + auto funcType = LLVM::LLVMFunctionType::get( + i32Type, {}, /*isVarArg=*/false); + + vxNumWarpsFunc = rewriter.create( + gpuModule.getLoc(), "vx_num_warps_abi", funcType); + } + + // Call vx_num_warps_abi() to get number of warps + auto numWarps = rewriter.create( + loc, vxNumWarpsFunc, ValueRange{}); + + // Declare vx_barrier_abi function in gpu.module if not already declared + // Using _abi suffix to call the non-inline wrapper in vx_intrinsics_abi.c + auto vxBarrierFunc = gpuModule.lookupSymbol("vx_barrier_abi"); + if (!vxBarrierFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(gpuModule.getBody()); + + auto funcType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(context), + {i32Type, i32Type}, + /*isVarArg=*/false); + + vxBarrierFunc = rewriter.create( + gpuModule.getLoc(), "vx_barrier_abi", funcType); + } + + // Call vx_barrier_abi(barrier_id, num_warps) + SmallVector args; + args.push_back(barIdConstant.getResult()); + args.push_back(numWarps.getResult()); + + rewriter.replaceOpWithNewOp( + op, vxBarrierFunc, args); + + return success(); + } +}; + +/// Lower printf calls to vx_printf +/// Matches: llvm.call @printf(format, args...) +/// Replaces with: llvm.call @vx_printf(format, args...) +/// vx_printf has the same signature as standard printf +struct PrintfOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::CallOp callOp, + PatternRewriter &rewriter) const override { + // Only match calls to 'printf' + auto callee = callOp.getCalleeAttr(); + if (!callee) + return failure(); + + auto flatSymbolRef = callee.dyn_cast(); + if (!flatSymbolRef || flatSymbolRef.getValue() != "printf") + return failure(); + + // Only lower printf calls inside GPU modules + auto gpuModule = callOp->getParentOfType(); + if (!gpuModule) + return failure(); + + Location loc = callOp.getLoc(); + MLIRContext *context = gpuModule.getContext(); + auto i32Type = rewriter.getI32Type(); + + // Declare vx_printf function in gpu.module if not already declared + auto vxPrintfFunc = gpuModule.lookupSymbol("vx_printf"); + if (!vxPrintfFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(gpuModule.getBody()); + + auto ptrType = LLVM::LLVMPointerType::get(context); + auto funcType = LLVM::LLVMFunctionType::get(i32Type, {ptrType}, /*isVarArg=*/true); + vxPrintfFunc = rewriter.create( + gpuModule.getLoc(), "vx_printf", funcType); + } + + // Build argument list: pass all original arguments unchanged + SmallVector newArgs; + for (unsigned i = 0; i < callOp.getNumOperands(); ++i) { + newArgs.push_back(callOp.getOperand(i)); + } + + // Replace with call to vx_printf + rewriter.replaceOpWithNewOp( + callOp, vxPrintfFunc, newArgs); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Shared Memory Lowering (Address Space 3) +//===----------------------------------------------------------------------===// + +/// Track shared memory allocations for computing offsets +/// Maps global name to (offset, size) within the shared memory region +static llvm::StringMap> sharedMemoryLayout; +static unsigned totalSharedMemorySize = 0; + +/// Get or declare the __local_group_id TLS variable accessor +/// Returns a function that provides access to the per-warp group ID +static LLVM::LLVMFuncOp getOrCreateLocalGroupIdAccessor(Operation *op, + OpBuilder &builder) { + // Find the parent module (works for both gpu.module and regular module) + Operation *symbolTableOp = op->getParentOfType(); + if (!symbolTableOp) + symbolTableOp = op->getParentOfType(); + if (!symbolTableOp) + return nullptr; + + std::string funcName = "vx_get_local_group_id"; + + // Check if function already exists + if (auto existing = SymbolTable::lookupSymbolIn(symbolTableOp, + builder.getStringAttr(funcName))) { + return cast(existing); + } + + // Create function type: () -> i32 (returns the local group ID) + auto i32Type = builder.getI32Type(); + auto funcType = LLVM::LLVMFunctionType::get(i32Type, {}, /*isVarArg=*/false); + + // Declare external function + OpBuilder::InsertionGuard guard(builder); + if (auto gpuModule = dyn_cast(symbolTableOp)) { + builder.setInsertionPointToStart(gpuModule.getBody()); + } else if (auto module = dyn_cast(symbolTableOp)) { + builder.setInsertionPointToStart(module.getBody()); + } + + return builder.create( + symbolTableOp->getLoc(), + funcName, + funcType, + LLVM::Linkage::External); +} + +/// Lower memref.global with address space 3 (shared memory) +/// These become placeholders - the actual allocation is done by vx_spawn_threads +/// We record the size and assign offsets for memref.get_global to use +struct SharedMemoryGlobalOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::GlobalOp globalOp, + PatternRewriter &rewriter) const override { + // Only handle address space 3 (shared memory) + auto memrefType = globalOp.getType(); + if (memrefType.getMemorySpaceAsInt() != 3) + return failure(); + + // Skip if already processed + if (globalOp->hasAttr("vortex.shared_memory_offset")) + return failure(); + + // Calculate the size of this shared memory allocation + unsigned elementSize = 4; // Default to 4 bytes + Type elemType = memrefType.getElementType(); + if (elemType.isF32() || elemType.isInteger(32)) + elementSize = 4; + else if (elemType.isF64() || elemType.isInteger(64)) + elementSize = 8; + else if (elemType.isInteger(8)) + elementSize = 1; + else if (elemType.isInteger(16) || elemType.isF16()) + elementSize = 2; + + unsigned numElements = 1; + for (int64_t dim : memrefType.getShape()) { + if (dim == ShapedType::kDynamic) { + // Dynamic shared memory - can't compute static offset + globalOp.emitWarning("Dynamic shared memory size not supported"); + return failure(); + } + numElements *= dim; + } + unsigned totalBytes = numElements * elementSize; + + // Assign offset in shared memory layout + unsigned offset = totalSharedMemorySize; + sharedMemoryLayout[globalOp.getSymName()] = {offset, totalBytes}; + totalSharedMemorySize += totalBytes; + + // Mark as processed with offset attribute + rewriter.startRootUpdate(globalOp); + globalOp->setAttr("vortex.shared_memory_offset", + rewriter.getI32IntegerAttr(offset)); + globalOp->setAttr("vortex.shared_memory_size", + rewriter.getI32IntegerAttr(totalBytes)); + rewriter.finalizeRootUpdate(globalOp); + + return success(); + } +}; + +/// Lower memref.get_global with address space 3 to Vortex local memory access +/// Generates: (int8_t*)csr_read(VX_CSR_LOCAL_MEM_BASE) + __local_group_id * total_size + offset +/// Returns a proper memref descriptor for use with memref load/store operations +struct SharedMemoryGetGlobalOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::GetGlobalOp getGlobalOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only handle address space 3 (shared memory) + auto memrefType = getGlobalOp.getType(); + if (memrefType.getMemorySpaceAsInt() != 3) + return failure(); + + // Only handle static shapes (required for MemRefDescriptor::fromStaticShape) + if (!memrefType.hasStaticShape()) { + getGlobalOp.emitError("Dynamic shared memory shapes not supported"); + return failure(); + } + + Location loc = getGlobalOp.getLoc(); + MLIRContext *context = getGlobalOp.getContext(); + + // Look up the offset for this global + StringRef globalName = getGlobalOp.getName(); + auto it = sharedMemoryLayout.find(globalName); + if (it == sharedMemoryLayout.end()) { + // Not found - this can happen if GlobalOp lowering hasn't run yet + // Try to find the GlobalOp and compute offset + auto module = getGlobalOp->getParentOfType(); + auto gpuModule = getGlobalOp->getParentOfType(); + Operation *symbolTable = gpuModule ? (Operation*)gpuModule : (Operation*)module; + + if (auto globalOp = SymbolTable::lookupSymbolIn(symbolTable, + getGlobalOp.getNameAttr())) { + if (auto memGlobalOp = dyn_cast(globalOp)) { + if (auto offsetAttr = memGlobalOp->getAttrOfType( + "vortex.shared_memory_offset")) { + unsigned offset = offsetAttr.getInt(); + unsigned size = 0; + if (auto sizeAttr = memGlobalOp->getAttrOfType( + "vortex.shared_memory_size")) { + size = sizeAttr.getInt(); + } + sharedMemoryLayout[globalName] = {offset, size}; + it = sharedMemoryLayout.find(globalName); + } + } + } + + if (it == sharedMemoryLayout.end()) { + getGlobalOp.emitError("Shared memory global not found in layout: ") + << globalName; + return failure(); + } + } + + unsigned offset = it->second.first; + auto i32Type = rewriter.getI32Type(); + + // Get pointer type with address space 3 for shared memory + unsigned addressSpace = memrefType.getMemorySpaceAsInt(); + Type elementType = getTypeConverter()->convertType(memrefType.getElementType()); + auto ptrType = getTypeConverter()->getPointerType(elementType, addressSpace); + + // Generate CSR read for local memory base + // csrr %0, 0xFC3 (VX_CSR_LOCAL_MEM_BASE) + std::string csrAsmStr = "csrr $0, " + std::to_string(VX_CSR_LOCAL_MEM_BASE); + + auto csrRead = rewriter.create( + loc, + i32Type, // result type (single i32, not a struct) + ValueRange{}, // operands + csrAsmStr, // asm string + "=r", // constraints: output to register + /*has_side_effects=*/false, + /*is_align_stack=*/false, + /*asm_dialect=*/nullptr, + /*operand_attrs=*/nullptr); + + // The CSR read returns the base address directly + Value baseAddr = csrRead.getResult(0); + + // Get __local_group_id via external function call + auto localGroupIdFunc = getOrCreateLocalGroupIdAccessor( + getGlobalOp, rewriter); + if (!localGroupIdFunc) + return failure(); + + auto localGroupIdCall = rewriter.create( + loc, localGroupIdFunc, ValueRange{}); + Value localGroupId = localGroupIdCall.getResult(); + + // Calculate final address: + // base + localGroupId * totalSharedMemorySize + offset + // + // Note: For now, we use totalSharedMemorySize computed so far. + // A more robust approach would compute this in a second pass. + Value totalSizeVal = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(totalSharedMemorySize)); + Value groupOffset = rewriter.create( + loc, i32Type, localGroupId, totalSizeVal); + Value baseWithGroup = rewriter.create( + loc, i32Type, baseAddr, groupOffset); + + // Add the static offset for this specific global + Value offsetVal = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(offset)); + Value finalAddr = rewriter.create( + loc, i32Type, baseWithGroup, offsetVal); + + // Convert to pointer + Value ptr = rewriter.create(loc, ptrType, finalAddr); + + // Create a memref descriptor from the computed pointer + // This creates a proper LLVM struct with allocated_ptr, aligned_ptr, offset, sizes, strides + Value descr = MemRefDescriptor::fromStaticShape( + rewriter, loc, *getTypeConverter(), memrefType, ptr); + + rewriter.replaceOp(getGlobalOp, descr); + + return success(); + } +}; + +/// Extract metadata from gpu.launch_func for Vortex kernel argument struct +/// For RV32, all arguments are 4 bytes (scalars and pointers) +struct LaunchFuncMetadataExtraction : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(gpu::LaunchFuncOp launchOp, + PatternRewriter &rewriter) const override { + // Skip if metadata already exists (avoid infinite loop in greedy rewriter) + if (launchOp->hasAttr("vortex.kernel_metadata")) + return failure(); + + Location loc = launchOp.getLoc(); + + // Get kernel name + StringRef kernelName = launchOp.getKernelName().getValue(); + + // Get kernel arguments + auto kernelOperands = launchOp.getKernelOperands(); + unsigned numArgs = kernelOperands.size(); + + // For RV32: all arguments are 4 bytes (scalars and pointers) + // Calculate total struct size: numArgs * 4 + unsigned totalSize = numArgs * 4; + + // Build metadata string for debugging/documentation + std::string metadataStr = "Kernel: " + kernelName.str() + + "\nNum args: " + std::to_string(numArgs) + + "\nTotal size (RV32): " + std::to_string(totalSize) + " bytes\nArguments:\n"; + + unsigned offset = 0; + for (auto [idx, arg] : llvm::enumerate(kernelOperands)) { + Type argType = arg.getType(); + bool isPointer = argType.isa(); + + metadataStr += " [" + std::to_string(idx) + "] offset=" + std::to_string(offset) + + ", size=4, type=" + (isPointer ? "pointer" : "scalar") + "\n"; + offset += 4; + } + + // Emit metadata as a comment for now (can be enhanced to create LLVM metadata) + rewriter.startRootUpdate(launchOp); + launchOp->setAttr("vortex.kernel_metadata", + rewriter.getStringAttr(metadataStr)); + rewriter.finalizeRootUpdate(launchOp); + + // Note: We don't replace the op, just annotate it with metadata + // The actual launch lowering will be handled separately + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Kernel Metadata JSON Emission +//===----------------------------------------------------------------------===// + +/// Structure to hold kernel argument metadata +struct KernelArgInfo { + std::string name; + std::string type; // "ptr", "i32", "u32", "f32", "f64", etc. + unsigned size; // Size in bytes + unsigned offset; // Offset in args struct + bool isPointer; +}; + +/// Structure to hold complete kernel metadata +struct KernelMetadata { + std::string kernelName; + std::vector arguments; + unsigned totalArgsSize; +}; + +/// Convert MLIR type to metadata type string +static std::string getMetadataTypeString(Type type) { + if (type.isa() || type.isa()) + return "ptr"; + if (type.isInteger(32)) + return "i32"; + if (type.isInteger(64)) + return "i64"; + if (type.isF32()) + return "f32"; + if (type.isF64()) + return "f64"; + if (type.isIndex()) + return "i64"; // Index maps to i64 by default in LLVM lowering + return "unknown"; +} + +/// Get size in bytes for a type on RV32 +static unsigned getTypeSizeRV32(Type type) { + // On RV32 Vortex, pointers are 4 bytes + if (type.isa() || type.isa()) + return 4; + if (type.isInteger(32) || type.isF32()) + return 4; + if (type.isInteger(64) || type.isF64() || type.isIndex()) + return 8; // Index maps to i64 in LLVM lowering + return 4; // Default +} + +/// Convert metadata type string to C type +static std::string getCTypeString(const std::string &metaType) { + if (metaType == "ptr") return "uint32_t"; // RV32 pointer = 32-bit device address + if (metaType == "i32") return "int32_t"; + if (metaType == "u32") return "uint32_t"; + if (metaType == "i64") return "int64_t"; + if (metaType == "u64") return "uint64_t"; + if (metaType == "f32") return "float"; + if (metaType == "f64") return "double"; + return "uint32_t"; // Default +} + +/// Generate C header string for kernel args struct (Vortex-compatible) +static std::string generateKernelArgsHeader(const KernelMetadata &meta) { + std::ostringstream header; + + // Generate include guard + std::string guardName = meta.kernelName; + std::transform(guardName.begin(), guardName.end(), guardName.begin(), ::toupper); + std::replace(guardName.begin(), guardName.end(), '-', '_'); + + header << "// Auto-generated kernel argument structure for " << meta.kernelName << "\n"; + header << "// Generated by Polygeist ConvertGPUToVortex pass\n"; + header << "#ifndef " << guardName << "_ARGS_H\n"; + header << "#define " << guardName << "_ARGS_H\n\n"; + header << "#include \n\n"; + + header << "typedef struct {\n"; + for (const auto &arg : meta.arguments) { + std::string cType = getCTypeString(arg.type); + header << " " << cType << " " << arg.name << ";"; + header << " // offset=" << arg.offset << ", size=" << arg.size; + if (arg.isPointer) header << ", device pointer"; + header << "\n"; + } + header << "} " << meta.kernelName << "_args_t;\n\n"; + + header << "#define " << guardName << "_ARGS_SIZE " << meta.totalArgsSize << "\n\n"; + header << "#endif // " << guardName << "_ARGS_H\n"; + + return header.str(); +} + +/// Generate JSON string for kernel metadata (for runtime dynamic loading) +static std::string generateMetadataJSON(const KernelMetadata &meta, + const std::vector &originalOrder = {}) { + std::ostringstream json; + json << "{\n"; + json << " \"kernel_name\": \"" << meta.kernelName << "\",\n"; + json << " \"arguments\": [\n"; + + for (size_t i = 0; i < meta.arguments.size(); ++i) { + const auto &arg = meta.arguments[i]; + json << " {\n"; + json << " \"name\": \"" << arg.name << "\",\n"; + json << " \"type\": \"" << arg.type << "\",\n"; + json << " \"size\": " << arg.size << ",\n"; + json << " \"offset\": " << arg.offset << ",\n"; + json << " \"is_pointer\": " << (arg.isPointer ? "true" : "false") << "\n"; + json << " }"; + if (i < meta.arguments.size() - 1) + json << ","; + json << "\n"; + } + + json << " ],\n"; + json << " \"total_args_size\": " << meta.totalArgsSize << ",\n"; + + // Include original argument order mapping if available + // This maps from original (hipLaunchKernelGGL) order to device order + if (!originalOrder.empty()) { + json << " \"original_arg_order\": ["; + for (size_t i = 0; i < originalOrder.size(); ++i) { + json << originalOrder[i]; + if (i < originalOrder.size() - 1) + json << ", "; + } + json << "],\n"; + } + + json << " \"architecture\": \"rv32\"\n"; + json << "}\n"; + + return json.str(); +} + +/// Extract metadata from a GPU function and write metadata files +/// Generates both .meta.json (for runtime) and _args.h (for compile-time) +/// If outputDir is empty, uses current working directory +/// Uses pre-built originalArgIsPointer map for computing argument order mapping +static void emitKernelMetadata(gpu::GPUFuncOp funcOp, + StringRef outputDir, + const llvm::StringMap> &originalArgIsPointer) { + if (!funcOp.isKernel()) + return; + + KernelMetadata meta; + meta.kernelName = funcOp.getName().str(); + + // Extract base kernel name (remove Polygeist suffix if present) + StringRef baseName = extractBaseKernelName(funcOp.getName()); + meta.kernelName = baseName.str(); + + // Count leading scalar args before any memrefs/pointers + // These are typically derived from block_dim (e.g., block_dim.x) + // and should be skipped in user arg metadata since kernel_body + // derives them from the block_dim header, not user args. + auto argTypes = funcOp.getArgumentTypes(); + unsigned numLeadingScalars = 0; + for (auto argType : argTypes) { + if (argType.isa() || argType.isa()) { + break; + } + ++numLeadingScalars; + } + + // Skip first 2 leading scalars (derived from block_dim[0]) + // These are: arg0 = block_dim.x as index, arg1 = block_dim.x as i32 + unsigned argsToSkip = (numLeadingScalars >= 2) ? 2 : 0; + + unsigned offset = 0; + unsigned argIndex = 0; + + for (auto argType : argTypes) { + if (argIndex < argsToSkip) { + // Skip this arg - it comes from block_dim header, not user args + argIndex++; + continue; + } + + KernelArgInfo argInfo; + argInfo.name = "arg" + std::to_string(argIndex); + argInfo.type = getMetadataTypeString(argType); + argInfo.size = getTypeSizeRV32(argType); + argInfo.offset = offset; + argInfo.isPointer = argType.isa() || + argType.isa(); + + meta.arguments.push_back(argInfo); + offset += argInfo.size; + argIndex++; + } + + meta.totalArgsSize = offset; + + // Look up pre-computed original argument types from host wrapper + // Base name should match the host wrapper function name + std::vector originalOrder; + + auto it = originalArgIsPointer.find(baseName); + if (it != originalArgIsPointer.end()) { + const std::vector &hostIsPointer = it->second; + + if (hostIsPointer.size() == meta.arguments.size()) { + // Build mapping from original order to device order + // Device order: scalars first, then pointers (preserving relative order) + // Original order: as declared in kernel signature + + // Count scalars in host (original) order + unsigned numScalars = 0; + for (bool isPtr : hostIsPointer) { + if (!isPtr) numScalars++; + } + + // Build the mapping: original_arg_order[device_idx] = original_idx + originalOrder.resize(hostIsPointer.size()); + unsigned deviceScalarIdx = 0; + unsigned devicePtrIdx = numScalars; + + for (unsigned origIdx = 0; origIdx < hostIsPointer.size(); ++origIdx) { + if (!hostIsPointer[origIdx]) { + // Scalar - goes to front of device args + originalOrder[deviceScalarIdx++] = origIdx; + } else { + // Pointer - goes to back of device args + originalOrder[devicePtrIdx++] = origIdx; + } + } + } + } + + // Determine output directory + SmallString<256> outDir; + if (outputDir.empty()) { + llvm::sys::fs::current_path(outDir); + } else { + outDir = outputDir; + } + + // Write JSON metadata file (with original order mapping if available) + { + SmallString<256> jsonPath(outDir); + llvm::sys::path::append(jsonPath, meta.kernelName + ".meta.json"); + + std::error_code ec; + llvm::raw_fd_ostream outFile(jsonPath, ec); + if (ec) { + llvm::errs() << "Error writing metadata file " << jsonPath << ": " + << ec.message() << "\n"; + } else { + outFile << generateMetadataJSON(meta, originalOrder); + outFile.close(); + llvm::outs() << "Wrote kernel metadata: " << jsonPath << "\n"; + } + } + + // Write C header file + { + SmallString<256> headerPath(outDir); + llvm::sys::path::append(headerPath, meta.kernelName + "_args.h"); + + std::error_code ec; + llvm::raw_fd_ostream outFile(headerPath, ec); + if (ec) { + llvm::errs() << "Error writing header file " << headerPath << ": " + << ec.message() << "\n"; + } else { + outFile << generateKernelArgsHeader(meta); + outFile.close(); + llvm::outs() << "Wrote kernel args header: " << headerPath << "\n"; + } + } +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +// Use the tablegen-generated base class which handles the pass options correctly +#define GEN_PASS_DECL_CONVERTGPUTOVORTEX +#define GEN_PASS_DEF_CONVERTGPUTOVORTEX +#include "polygeist/Passes/Passes.h.inc" + +struct ConvertGPUToVortexPass + : public impl::ConvertGPUToVortexBase { + + ConvertGPUToVortexPass() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + // FIRST: Build argument order map from host wrapper functions BEFORE any changes + // This maps kernel base name -> list of (isPointer, type) for original args + llvm::StringMap> originalArgIsPointer; + + // Find host wrapper functions (func.func @__polygeist_launch_) + for (auto funcOp : module.getOps()) { + StringRef funcName = funcOp.getName(); + if (!funcName.startswith("__polygeist_launch_")) + continue; + + // Host wrapper args: user args... + blocks + threads (last 2 are launch params) + auto hostArgTypes = funcOp.getArgumentTypes(); + unsigned numHostUserArgs = hostArgTypes.size() > 2 ? hostArgTypes.size() - 2 : 0; + + std::vector isPointerVec; + for (unsigned i = 0; i < numHostUserArgs; ++i) { + isPointerVec.push_back(hostArgTypes[i].isa() || + hostArgTypes[i].isa()); + } + originalArgIsPointer[funcName] = std::move(isPointerVec); + } + + // PREPROCESSING: Consolidate Polygeist auto-tuning artifacts + // This must happen before any conversion patterns are applied + consolidatePolygeistAlternatives(module); + removeDuplicateKernels(module); + + // Always emit kernel metadata for each kernel + // Files are written to current working directory: + // - .meta.json (for runtime dynamic loading) + // - _args.h (for compile-time type-safe usage) + // Pass pre-built argument order map for original argument positions + module.walk([&](gpu::GPUModuleOp gpuModule) { + for (auto gpuFunc : gpuModule.getOps()) { + if (gpuFunc.isKernel()) { + emitKernelMetadata(gpuFunc, "" /* use current directory */, originalArgIsPointer); + } + } + }); + + // Set up type converter for GPU to LLVM types + LLVMTypeConverter typeConverter(context); + + // Set up conversion target + // Mark only the Vortex-specific GPU operations as illegal + // All other operations (including GPU structural ops) remain legal + // A subsequent --gpu-to-llvm pass will handle gpu.module/gpu.func conversion + ConversionTarget target(*context); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + target.addIllegalOp(); + + // Set up rewrite patterns + RewritePatternSet patterns(context); + patterns.add(typeConverter); + + // Apply the conversion + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } + + // Apply metadata extraction, printf lowering, and shared memory global annotation + // as separate greedy rewrites (these patterns don't replace ops, just annotate) + RewritePatternSet metadataPatterns(context); + metadataPatterns.add(context); + if (failed(applyPatternsAndFoldGreedily(module, std::move(metadataPatterns)))) { + signalPassFailure(); + } + + // Lower memref.get_global for address space 3 (shared memory) to Vortex intrinsics + // This must run after SharedMemoryGlobalOpLowering has annotated the globals + { + ConversionTarget sharedMemTarget(*context); + sharedMemTarget.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + sharedMemTarget.addDynamicallyLegalOp( + [](memref::GetGlobalOp op) { + // Only make shared memory (address space 3) get_global ops illegal + return op.getType().getMemorySpaceAsInt() != 3; + }); + + RewritePatternSet sharedMemPatterns(context); + sharedMemPatterns.add(typeConverter); + + if (failed(applyPartialConversion(module, sharedMemTarget, + std::move(sharedMemPatterns)))) { + signalPassFailure(); + } + } + + // Remove gpu.launch_func operations - they were needed for Polygeist + // to generate proper MLIR but are not needed for Vortex kernel compilation. + // The host code handles kernel launching through the Vortex runtime separately. + SmallVector launchOps; + module.walk([&](gpu::LaunchFuncOp launchOp) { + launchOps.push_back(launchOp); + }); + for (auto launchOp : launchOps) { + launchOp.erase(); + } + + // Remove host-side functions (those outside gpu.module) + // Keep only the kernel code inside gpu.module for kernel binary compilation + SmallVector hostFuncs; + for (auto funcOp : module.getOps()) { + hostFuncs.push_back(funcOp); + } + for (auto funcOp : hostFuncs) { + funcOp.erase(); + } + + // Extract kernel functions from gpu.module and convert to func.func + // This allows standard MLIR lowering passes to work on the kernel code + OpBuilder builder(context); + SmallVector gpuModulesToErase; + + module.walk([&](gpu::GPUModuleOp gpuModule) { + // Clone kernel functions as func.func at module level + for (auto gpuFunc : gpuModule.getOps()) { + // Create func.func with same name and type + builder.setInsertionPointToEnd(module.getBody()); + + auto funcOp = builder.create( + gpuFunc.getLoc(), + gpuFunc.getName(), + gpuFunc.getFunctionType()); + + // Don't copy GPU-specific attributes - they're not relevant for Vortex + // Skipped attributes: gpu.kernel, gpu.known_block_size, nvvm.*, rocdl.* + // The kernel will use Vortex runtime conventions instead + + // Clone the function body + IRMapping mapping; + gpuFunc.getBody().cloneInto(&funcOp.getBody(), mapping); + + // Replace gpu.return with func.return in the cloned body + funcOp.walk([&](gpu::ReturnOp returnOp) { + OpBuilder returnBuilder(returnOp); + returnBuilder.create(returnOp.getLoc(), + returnOp.getOperands()); + returnOp.erase(); + }); + + // Replace unrealized_conversion_cast (i32 -> index) with arith.index_cast + // These come from Polygeist's type conversions and can't be reconciled as-is + SmallVector castsToReplace; + funcOp.walk([&](UnrealizedConversionCastOp castOp) { + // Only replace i32 -> index casts + if (castOp.getNumOperands() == 1 && castOp.getNumResults() == 1) { + auto srcType = castOp.getOperand(0).getType(); + auto dstType = castOp.getResult(0).getType(); + if (srcType.isInteger(32) && dstType.isIndex()) { + castsToReplace.push_back(castOp); + } + } + }); + for (auto castOp : castsToReplace) { + OpBuilder castBuilder(castOp); + auto indexCast = castBuilder.create( + castOp.getLoc(), castOp.getResult(0).getType(), castOp.getOperand(0)); + castOp.getResult(0).replaceAllUsesWith(indexCast.getResult()); + castOp.erase(); + } + } + + // Also clone any llvm.func declarations (like vx_get_threadIdx) + for (auto llvmFunc : gpuModule.getOps()) { + builder.setInsertionPointToEnd(module.getBody()); + // Check if already exists at module level + if (!module.lookupSymbol(llvmFunc.getName())) { + llvmFunc.clone(); + builder.clone(*llvmFunc.getOperation()); + } + } + + gpuModulesToErase.push_back(gpuModule); + }); + + // Erase the gpu.module after extracting all functions + for (auto gpuModule : gpuModulesToErase) { + gpuModule.erase(); + } + + // Remove gpu.container_module attribute since we no longer have gpu.module + module->removeAttr("gpu.container_module"); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Pass Registration +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace polygeist { + +std::unique_ptr createConvertGPUToVortexPass() { + return std::make_unique(); +} + +} // namespace polygeist +} // namespace mlir diff --git a/lib/polygeist/Passes/GenerateVortexMain.cpp b/lib/polygeist/Passes/GenerateVortexMain.cpp new file mode 100644 index 000000000000..0e3781afd30d --- /dev/null +++ b/lib/polygeist/Passes/GenerateVortexMain.cpp @@ -0,0 +1,428 @@ +//===- GenerateVortexMain.cpp - Generate Vortex main() wrapper ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass generates the Vortex-specific main() entry point and kernel_body +// wrapper function. It should run AFTER gpu-to-llvm lowering has converted +// gpu.func to llvm.func. +// +// The Vortex execution model requires: +// 1. A main() function that reads args from VX_CSR_MSCRATCH and calls +// vx_spawn_threads() with the kernel callback +// 2. A kernel_body() wrapper that unpacks arguments and calls the kernel +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "polygeist/Passes/Passes.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { + +//===----------------------------------------------------------------------===// +// Vortex CSR Address +//===----------------------------------------------------------------------===// + +// VX_CSR_MSCRATCH - Machine scratch register used to pass kernel arguments +// From vortex/hw/rtl/VX_types.vh: `define VX_CSR_MSCRATCH 12'h340 +constexpr uint32_t VX_CSR_MSCRATCH = 0x340; + +//===----------------------------------------------------------------------===// +// Helper Functions +//===----------------------------------------------------------------------===// + +/// Find the kernel function in the module (lowered from gpu.func) +/// After gpu-to-llvm, the kernel is an llvm.func with a mangled name +static LLVM::LLVMFuncOp findKernelFunction(ModuleOp module) { + LLVM::LLVMFuncOp kernelFunc = nullptr; + + module.walk([&](LLVM::LLVMFuncOp func) { + StringRef name = func.getName(); + // Look for functions with "_kernel" in the name (Polygeist naming convention) + // or functions that were marked as kernels + if (name.contains("_kernel") && !name.startswith("kernel_body")) { + // Prefer the first kernel found + if (!kernelFunc) { + kernelFunc = func; + } + } + }); + + return kernelFunc; +} + +/// Declare vx_spawn_threads external function +/// Signature: int vx_spawn_threads(uint32_t dimension, const uint32_t* grid_dim, +/// const uint32_t* block_dim, +/// vx_kernel_func_cb kernel_func, const void* arg) +static LLVM::LLVMFuncOp +getOrDeclareVxSpawnThreads(ModuleOp module, OpBuilder &builder) { + MLIRContext *ctx = module.getContext(); + + if (auto existing = module.lookupSymbol("vx_spawn_threads")) + return existing; + + auto i32Type = IntegerType::get(ctx, 32); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + // int vx_spawn_threads(uint32_t, uint32_t*, uint32_t*, void(*)(void*), void*) + auto funcType = LLVM::LLVMFunctionType::get( + i32Type, {i32Type, ptrType, ptrType, ptrType, ptrType}, + /*isVarArg=*/false); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + return builder.create(module.getLoc(), "vx_spawn_threads", + funcType, LLVM::Linkage::External); +} + +/// Check if a sequence of parameter types represents a memref descriptor +/// LLVM memref descriptor format: { ptr, ptr, i64, [1 x i64], [1 x i64] } +/// After flattening, this becomes: ptr, ptr, i64, i64, i64 (for 1D memref) +static bool isMemrefDescriptorStart(LLVM::LLVMFunctionType funcType, + unsigned startIdx) { + unsigned numParams = funcType.getNumParams(); + // Need at least 5 params remaining for a 1D memref descriptor + if (startIdx + 5 > numParams) + return false; + + Type t0 = funcType.getParamType(startIdx); + Type t1 = funcType.getParamType(startIdx + 1); + Type t2 = funcType.getParamType(startIdx + 2); + Type t3 = funcType.getParamType(startIdx + 3); + Type t4 = funcType.getParamType(startIdx + 4); + + // Check pattern: ptr, ptr, i64, i64, i64 + return t0.isa() && + t1.isa() && t2.isInteger(64) && + t3.isInteger(64) && t4.isInteger(64); +} + +/// Generate kernel_body wrapper function +/// This function unpacks arguments from the void* args pointer and calls +/// the original kernel function +/// +/// IMPORTANT: Polygeist transforms kernel signatures to add computed arguments +/// from the launch configuration. The first few kernel args may come from +/// block_dim (e.g., block_dim.x as both index and i32), not from user args. +/// +/// This handles the memref descriptor expansion that happens during +/// LLVM lowering. Each memref in the original kernel becomes 5 params +/// (ptr, ptr, i64, i64, i64) after lowering. The host passes simple device +/// pointers, so we must construct the full descriptor from each pointer. +static LLVM::LLVMFuncOp +generateKernelBodyWrapper(ModuleOp module, LLVM::LLVMFuncOp kernelFunc, + OpBuilder &builder) { + MLIRContext *ctx = module.getContext(); + Location loc = kernelFunc.getLoc(); + + auto ptrType = LLVM::LLVMPointerType::get(ctx); + auto voidType = LLVM::LLVMVoidType::get(ctx); + auto i32Type = IntegerType::get(ctx, 32); + auto i64Type = IntegerType::get(ctx, 64); + + // Create function: void kernel_body(void* args) + auto funcType = + LLVM::LLVMFunctionType::get(voidType, {ptrType}, /*isVarArg=*/false); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(kernelFunc); + + auto bodyFunc = builder.create(loc, "kernel_body", funcType, + LLVM::Linkage::External); + + // Create entry block + Block *entryBlock = bodyFunc.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + Value argsPtr = bodyFunc.getArgument(0); + + // Get kernel argument types + auto kernelFuncType = kernelFunc.getFunctionType(); + unsigned numArgs = kernelFuncType.getNumParams(); + + // Standard Vortex args layout: + // uint32_t grid_dim[3]; // 12 bytes (offsets 0, 4, 8) + // uint32_t block_dim[3]; // 12 bytes (offsets 12, 16, 20) + // // starting at offset 24 + // + // IMPORTANT: Polygeist transforms kernel args during GPU lowering. The first + // two kernel args are typically derived from block_dim.x (threads_per_block): + // arg0 = block_dim.x as i64 (index type) + // arg1 = block_dim.x as i32 + // These should be loaded from block_dim[0] in the header, not from user args. + // The remaining args (arg2+) come from the user args buffer. + constexpr unsigned BLOCK_DIM_OFFSET = 12; + constexpr unsigned USER_ARGS_OFFSET = 24; + + SmallVector unpackedArgs; + unsigned currentOffset = USER_ARGS_OFFSET; + auto i8Type = IntegerType::get(ctx, 8); + + // Count leading scalar args before any memrefs/pointers + // These are typically derived from launch config (block_dim) + unsigned numLeadingScalars = 0; + for (unsigned i = 0; i < numArgs; ++i) { + Type argType = kernelFuncType.getParamType(i); + if (isMemrefDescriptorStart(kernelFuncType, i) || + argType.isa()) { + break; + } + ++numLeadingScalars; + } + + // Pre-load block_dim[0] for the first two args (if needed) + Value blockDimX_i32 = nullptr; + Value blockDimX_i64 = nullptr; + if (numLeadingScalars >= 2) { + SmallVector gepIndices; + gepIndices.push_back(static_cast(BLOCK_DIM_OFFSET)); + auto blockDimPtr = + builder.create(loc, ptrType, i8Type, argsPtr, gepIndices); + blockDimX_i32 = builder.create(loc, i32Type, blockDimPtr); + blockDimX_i64 = builder.create(loc, i64Type, blockDimX_i32); + } + + for (unsigned i = 0; i < numArgs; ) { + Type argType = kernelFuncType.getParamType(i); + + // Check if this is the start of a memref descriptor (5 consecutive params) + if (isMemrefDescriptorStart(kernelFuncType, i)) { + // This is a memref descriptor - load single pointer from args and expand + SmallVector gepIndices; + gepIndices.push_back(static_cast(currentOffset)); + auto argBytePtr = builder.create(loc, ptrType, i8Type, + argsPtr, gepIndices); + + // Load the device pointer (4 bytes on RV32) + auto rawPtr = builder.create(loc, i32Type, argBytePtr); + auto devicePtr = builder.create(loc, ptrType, rawPtr); + + // Construct memref descriptor values: + // param 0: allocated pointer (same as device ptr) + // param 1: aligned pointer (same as device ptr) + // param 2: offset (0) + // param 3: size (use large value, kernel will bounds check) + // param 4: stride (1 for contiguous) + unpackedArgs.push_back(devicePtr); // allocated ptr + unpackedArgs.push_back(devicePtr); // aligned ptr + + auto zeroI64 = builder.create(loc, i64Type, 0); + auto maxI64 = builder.create( + loc, i64Type, std::numeric_limits::max()); + auto oneI64 = builder.create(loc, i64Type, 1); + + unpackedArgs.push_back(zeroI64); // offset = 0 + unpackedArgs.push_back(maxI64); // size = MAX (kernel has bounds check) + unpackedArgs.push_back(oneI64); // stride = 1 + + currentOffset += 4; // Single pointer in args buffer + i += 5; // Skip 5 params (the whole memref descriptor) + continue; + } + + // Scalar argument handling + Value argVal; + + // First two leading scalars come from block_dim[0], not user args + // arg0 = block_dim.x as i64 (for index type in LLVM lowering) + // arg1 = block_dim.x as i32 + if (i < 2 && numLeadingScalars >= 2) { + if (argType.isInteger(64)) { + argVal = blockDimX_i64; + } else { + argVal = blockDimX_i32; + } + } else { + // Regular user argument from user args buffer + SmallVector gepIndices; + gepIndices.push_back(static_cast(currentOffset)); + auto argBytePtr = + builder.create(loc, ptrType, i8Type, argsPtr, gepIndices); + + if (argType.isa()) { + // For pointers: load as i32 (RV32 pointer), then inttoptr + auto rawPtr = builder.create(loc, i32Type, argBytePtr); + argVal = builder.create(loc, ptrType, rawPtr); + currentOffset += 4; + } else if (argType.isInteger(32)) { + argVal = builder.create(loc, i32Type, argBytePtr); + currentOffset += 4; + } else if (argType.isInteger(64)) { + argVal = builder.create(loc, i64Type, argBytePtr); + currentOffset += 8; + } else if (argType.isF32()) { + auto f32Type = Float32Type::get(ctx); + argVal = builder.create(loc, f32Type, argBytePtr); + currentOffset += 4; + } else if (argType.isF64()) { + auto f64Type = Float64Type::get(ctx); + argVal = builder.create(loc, f64Type, argBytePtr); + currentOffset += 8; + } else { + // Default: treat as 4-byte value + argVal = builder.create(loc, i32Type, argBytePtr); + currentOffset += 4; + } + } + + unpackedArgs.push_back(argVal); + ++i; + } + + // Call the original kernel function + builder.create(loc, kernelFunc, unpackedArgs); + + // Return void + builder.create(loc, ValueRange{}); + + return bodyFunc; +} + +/// Generate main() entry point function +/// This function: +/// 1. Reads args from VX_CSR_MSCRATCH via inline assembly +/// 2. Extracts grid_dim pointer from args struct +/// 3. Calls vx_spawn_threads() with kernel_body callback +static LLVM::LLVMFuncOp generateMainFunction(ModuleOp module, + LLVM::LLVMFuncOp bodyFunc, + LLVM::LLVMFuncOp spawnFunc, + OpBuilder &builder) { + MLIRContext *ctx = module.getContext(); + Location loc = module.getLoc(); + + auto i32Type = IntegerType::get(ctx, 32); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + // Create function: int main() + auto funcType = + LLVM::LLVMFunctionType::get(i32Type, {}, /*isVarArg=*/false); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToEnd(module.getBody()); + + auto mainFunc = + builder.create(loc, "main", funcType, LLVM::Linkage::External); + + // Create entry block + Block *entryBlock = mainFunc.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + // 1. Read args from VX_CSR_MSCRATCH using inline assembly + // csrr rd, 0x340 + // Note: LLVM 18+ requires direct result type, not struct-wrapped + auto inlineAsm = builder.create( + loc, + /*resultTypes=*/i32Type, + /*operands=*/ValueRange{}, + /*asm_string=*/"csrr $0, 0x340", + /*constraints=*/"=r", + /*has_side_effects=*/true, + /*is_align_stack=*/false, + /*asm_dialect=*/LLVM::AsmDialectAttr{}, + /*operand_attrs=*/ArrayAttr{}); + + // Get the result directly (no struct extraction needed for single output) + auto argsRaw = inlineAsm.getRes(); + + // Convert to pointer + auto argsPtr = builder.create(loc, ptrType, argsRaw); + + // 2. Get grid_dim pointer (first field of args struct, offset 0) + // The args struct layout is: grid_dim[3], block_dim[3], user_args... + // grid_dim is at offset 0 - for offset 0, argsPtr itself is the grid_dim pointer + Value gridDimPtr = argsPtr; + + // 3. Get block_dim pointer (offset 12 = 3 * sizeof(uint32_t)) + auto i8Type = IntegerType::get(ctx, 8); + SmallVector blockDimIndices; + blockDimIndices.push_back(12); + auto blockDimPtr = + builder.create(loc, ptrType, i8Type, argsPtr, blockDimIndices); + + // 4. Get kernel_body function pointer + auto kernelPtr = builder.create(loc, ptrType, bodyFunc.getName()); + + // 5. Call vx_spawn_threads(dimension=1, grid_dim, block_dim, kernel_body, args) + // dimension=1 for 1D grid (most common case) + // TODO: Support multi-dimensional grids by reading dimension from metadata + auto dim = builder.create(loc, i32Type, 1); + + SmallVector spawnArgs = {dim, gridDimPtr, blockDimPtr, kernelPtr, + argsPtr}; + auto result = builder.create(loc, spawnFunc, spawnArgs); + + // 6. Return the result from vx_spawn_threads + builder.create(loc, result.getResult()); + + return mainFunc; +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +#define GEN_PASS_DECL_GENERATEVORTEXMAIN +#define GEN_PASS_DEF_GENERATEVORTEXMAIN +#include "polygeist/Passes/Passes.h.inc" + +struct GenerateVortexMainPass + : public impl::GenerateVortexMainBase { + + GenerateVortexMainPass() = default; + + void runOnOperation() override { + ModuleOp module = getOperation(); + OpBuilder builder(module.getContext()); + + // 1. Find the kernel function + LLVM::LLVMFuncOp kernelFunc = findKernelFunction(module); + if (!kernelFunc) { + // No kernel found - this might be a host-only module, skip silently + return; + } + + // 2. Check if main() already exists + if (module.lookupSymbol("main")) { + // main() already exists, skip generation + return; + } + + // 3. Declare vx_spawn_threads + auto spawnFunc = getOrDeclareVxSpawnThreads(module, builder); + + // 4. Generate kernel_body wrapper + auto bodyFunc = generateKernelBodyWrapper(module, kernelFunc, builder); + + // 5. Generate main() entry point + generateMainFunction(module, bodyFunc, spawnFunc, builder); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Pass Registration +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace polygeist { + +std::unique_ptr createGenerateVortexMainPass() { + return std::make_unique(); +} + +} // namespace polygeist +} // namespace mlir diff --git a/lib/polygeist/Passes/SerializeToCubin.cpp b/lib/polygeist/Passes/SerializeToCubin.cpp index abb6dbee1c1e..278a83f2a0f8 100644 --- a/lib/polygeist/Passes/SerializeToCubin.cpp +++ b/lib/polygeist/Passes/SerializeToCubin.cpp @@ -412,5 +412,11 @@ std::unique_ptr createGpuSerializeToCubinPass( #else namespace mlir::polygeist { void registerGpuSerializeToCubinPass() {} +std::unique_ptr createGpuSerializeToCubinPass( + StringRef arch, StringRef features, int llvmOptLevel, int ptxasOptLevel, + std::string ptxasPath, std::string libDevicePath, bool outputIntermediate) { + llvm::errs() << "error: CUDA toolkit support not enabled in this build\n"; + return nullptr; +} } // namespace mlir::polygeist #endif diff --git a/tools/cgeist/CMakeLists.txt b/tools/cgeist/CMakeLists.txt index 84b4a739e9df..feca781d530a 100644 --- a/tools/cgeist/CMakeLists.txt +++ b/tools/cgeist/CMakeLists.txt @@ -31,11 +31,19 @@ add_clang_executable(cgeist Lib/TypeUtils.cc Lib/CGCall.cc ) -if(POLYGEIST_ENABLE_CUDA) +if(POLYGEIST_ENABLE_CUDA OR POLYGEIST_CUDA_FRONTEND_ONLY) target_compile_definitions(cgeist PRIVATE POLYGEIST_ENABLE_CUDA=1 ) +endif() + +# Define POLYGEIST_CUDA_FULL only if full CUDA support (execution engine available) +if(POLYGEIST_ENABLE_CUDA AND NOT POLYGEIST_CUDA_FRONTEND_ONLY) + target_compile_definitions(cgeist + PRIVATE + POLYGEIST_CUDA_FULL=1 + ) add_dependencies(cgeist execution_engine_cuda_wrapper_binary_include) endif() if(POLYGEIST_ENABLE_ROCM) diff --git a/tools/cgeist/Test/Verification/basic_hip_kernel.hip b/tools/cgeist/Test/Verification/basic_hip_kernel.hip new file mode 100644 index 000000000000..5945c9060a9d --- /dev/null +++ b/tools/cgeist/Test/Verification/basic_hip_kernel.hip @@ -0,0 +1,15 @@ +// Basic HIP kernel test +#include "Inputs/cuda.h" +#include "__clang_cuda_builtin_vars.h" + +__global__ void basic_kernel(int32_t* src, int32_t* dst, uint32_t count) { + uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < count) { + dst[tid] = src[tid]; + } +} + +void launch_basic(int32_t* d_src, int32_t* d_dst, uint32_t count, int threads_per_block) { + int num_blocks = (count + threads_per_block - 1) / threads_per_block; + basic_kernel<<>>(d_src, d_dst, count); +} diff --git a/tools/cgeist/Test/Verification/gpu_to_vortex_basic.mlir b/tools/cgeist/Test/Verification/gpu_to_vortex_basic.mlir new file mode 100644 index 000000000000..e566b8c74c5e --- /dev/null +++ b/tools/cgeist/Test/Verification/gpu_to_vortex_basic.mlir @@ -0,0 +1,81 @@ +// RUN: mlir-opt %s -convert-gpu-to-vortex | FileCheck %s + +// Test basic gpu.thread_id and gpu.block_id lowering to Vortex TLS access + +module { + // CHECK-LABEL: func @test_thread_id_x + func.func @test_thread_id_x() -> index { + // CHECK: llvm.mlir.global external thread_local @threadIdx + // CHECK: llvm.mlir.addressof @threadIdx + // CHECK: llvm.getelementptr + // CHECK: llvm.load + %tid = gpu.thread_id x + // CHECK: builtin.unrealized_conversion_cast + // CHECK: return + return %tid : index + } + + // CHECK-LABEL: func @test_thread_id_y + func.func @test_thread_id_y() -> index { + // CHECK: llvm.mlir.addressof @threadIdx + // CHECK: llvm.getelementptr + // CHECK-SAME: 1 + // CHECK: llvm.load + %tid = gpu.thread_id y + // CHECK: builtin.unrealized_conversion_cast + // CHECK: return + return %tid : index + } + + // CHECK-LABEL: func @test_thread_id_z + func.func @test_thread_id_z() -> index { + // CHECK: llvm.mlir.addressof @threadIdx + // CHECK: llvm.getelementptr + // CHECK-SAME: 2 + // CHECK: llvm.load + %tid = gpu.thread_id z + // CHECK: builtin.unrealized_conversion_cast + // CHECK: return + return %tid : index + } + + // CHECK-LABEL: func @test_block_id_x + func.func @test_block_id_x() -> index { + // CHECK: llvm.mlir.global external thread_local @blockIdx + // CHECK: llvm.mlir.addressof @blockIdx + // CHECK: llvm.getelementptr + // CHECK: llvm.load + %bid = gpu.block_id x + // CHECK: builtin.unrealized_conversion_cast + // CHECK: return + return %bid : index + } + + // CHECK-LABEL: func @test_block_id_y + func.func @test_block_id_y() -> index { + // CHECK: llvm.mlir.addressof @blockIdx + // CHECK: llvm.getelementptr + // CHECK-SAME: 1 + // CHECK: llvm.load + %bid = gpu.block_id y + // CHECK: builtin.unrealized_conversion_cast + // CHECK: return + return %bid : index + } + + // CHECK-LABEL: func @test_combined + func.func @test_combined() -> index { + // CHECK: llvm.mlir.addressof @threadIdx + // CHECK: llvm.getelementptr + // CHECK: llvm.load + %tid = gpu.thread_id x + // CHECK: llvm.mlir.addressof @blockIdx + // CHECK: llvm.getelementptr + // CHECK: llvm.load + %bid = gpu.block_id x + // CHECK: arith.addi + %sum = arith.addi %tid, %bid : index + // CHECK: return + return %sum : index + } +} diff --git a/tools/cgeist/Test/Verification/gpu_to_vortex_thread_model.mlir b/tools/cgeist/Test/Verification/gpu_to_vortex_thread_model.mlir new file mode 100644 index 000000000000..2b3864c0355b --- /dev/null +++ b/tools/cgeist/Test/Verification/gpu_to_vortex_thread_model.mlir @@ -0,0 +1,185 @@ +// RUN: polygeist-opt %s -convert-gpu-to-vortex | FileCheck %s + +// Test Developer A: Thread Model & Synchronization operations +// Tests for blockDim, gridDim, and gpu.barrier operations + +module { + //===----------------------------------------------------------------------===// + // Block Dimension Tests (blockDim.x, blockDim.y, blockDim.z) + //===----------------------------------------------------------------------===// + + // CHECK-LABEL: func @test_block_dim_x + func.func @test_block_dim_x() -> index { + // CHECK: llvm.mlir.addressof @blockDim + // CHECK: llvm.getelementptr {{.*}}[0, 0] + // CHECK: llvm.load + %bdim = gpu.block_dim x + // CHECK: builtin.unrealized_conversion_cast + // CHECK: return + return %bdim : index + } + + // CHECK-LABEL: func @test_block_dim_y + func.func @test_block_dim_y() -> index { + // CHECK: llvm.mlir.addressof @blockDim + // CHECK: llvm.getelementptr {{.*}}[0, 1] + // CHECK: llvm.load + %bdim = gpu.block_dim y + // CHECK: builtin.unrealized_conversion_cast + // CHECK: return + return %bdim : index + } + + // CHECK-LABEL: func @test_block_dim_z + func.func @test_block_dim_z() -> index { + // CHECK: llvm.mlir.addressof @blockDim + // CHECK: llvm.getelementptr {{.*}}[0, 2] + // CHECK: llvm.load + %bdim = gpu.block_dim z + // CHECK: builtin.unrealized_conversion_cast + // CHECK: return + return %bdim : index + } + + //===----------------------------------------------------------------------===// + // Grid Dimension Tests (gridDim.x, gridDim.y, gridDim.z) + //===----------------------------------------------------------------------===// + + // CHECK-LABEL: func @test_grid_dim_x + func.func @test_grid_dim_x() -> index { + // CHECK: llvm.mlir.addressof @gridDim + // CHECK: llvm.getelementptr {{.*}}[0, 0] + // CHECK: llvm.load + %gdim = gpu.grid_dim x + // CHECK: builtin.unrealized_conversion_cast + // CHECK: return + return %gdim : index + } + + // CHECK-LABEL: func @test_grid_dim_y + func.func @test_grid_dim_y() -> index { + // CHECK: llvm.mlir.addressof @gridDim + // CHECK: llvm.getelementptr {{.*}}[0, 1] + // CHECK: llvm.load + %gdim = gpu.grid_dim y + // CHECK: builtin.unrealized_conversion_cast + // CHECK: return + return %gdim : index + } + + // CHECK-LABEL: func @test_grid_dim_z + func.func @test_grid_dim_z() -> index { + // CHECK: llvm.mlir.addressof @gridDim + // CHECK: llvm.getelementptr {{.*}}[0, 2] + // CHECK: llvm.load + %gdim = gpu.grid_dim z + // CHECK: builtin.unrealized_conversion_cast + // CHECK: return + return %gdim : index + } + + //===----------------------------------------------------------------------===// + // Barrier Synchronization Tests + //===----------------------------------------------------------------------===// + + // CHECK-LABEL: func @test_simple_barrier + func.func @test_simple_barrier() { + // CHECK: %[[BAR_ID:.*]] = llvm.mlir.constant({{[0-9]+}} : i32) + // CHECK: llvm.mlir.addressof @blockDim + // CHECK: llvm.getelementptr {{.*}}[0, 0] + // CHECK: llvm.load + // CHECK: llvm.mlir.addressof @blockDim + // CHECK: llvm.getelementptr {{.*}}[0, 1] + // CHECK: llvm.load + // CHECK: llvm.mlir.addressof @blockDim + // CHECK: llvm.getelementptr {{.*}}[0, 2] + // CHECK: llvm.load + // CHECK: llvm.mul + // CHECK: %[[NUM_THREADS:.*]] = llvm.mul + // CHECK: llvm.call @vx_barrier(%[[BAR_ID]], %[[NUM_THREADS]]) + gpu.barrier + // CHECK: return + return + } + + // CHECK-LABEL: func @test_multiple_barriers + func.func @test_multiple_barriers() { + // First barrier + // CHECK: %[[BAR_ID_0:.*]] = llvm.mlir.constant({{[0-9]+}} : i32) + // CHECK: llvm.call @vx_barrier(%[[BAR_ID_0]] + gpu.barrier + + // Second barrier - ID should be different from first + // CHECK: %[[BAR_ID_1:.*]] = llvm.mlir.constant({{[0-9]+}} : i32) + // CHECK: llvm.call @vx_barrier(%[[BAR_ID_1]] + gpu.barrier + + // CHECK: return + return + } + + //===----------------------------------------------------------------------===// + // Combined Test: Global ID Computation Pattern + //===----------------------------------------------------------------------===// + + // CHECK-LABEL: func @test_global_id_pattern + func.func @test_global_id_pattern() -> index { + // Get threadIdx.x + // CHECK: llvm.mlir.addressof @threadIdx + // CHECK: llvm.getelementptr + // CHECK: llvm.load + %tid = gpu.thread_id x + + // Get blockIdx.x + // CHECK: llvm.mlir.addressof @blockIdx + // CHECK: llvm.getelementptr + // CHECK: llvm.load + %bid = gpu.block_id x + + // Get blockDim.x + // CHECK: llvm.mlir.addressof @blockDim + // CHECK: llvm.getelementptr + // CHECK: llvm.load + %bdim = gpu.block_dim x + + // Compute: blockIdx.x * blockDim.x + threadIdx.x + // CHECK: arith.muli + %temp = arith.muli %bid, %bdim : index + // CHECK: arith.addi + %gid = arith.addi %temp, %tid : index + + // CHECK: return + return %gid : index + } + + //===----------------------------------------------------------------------===// + // Realistic Kernel Pattern with Barrier + //===----------------------------------------------------------------------===// + + // CHECK-LABEL: func @test_kernel_with_barrier + func.func @test_kernel_with_barrier() -> index { + // Compute global ID + // CHECK: llvm.mlir.addressof @threadIdx + %tid = gpu.thread_id x + + // CHECK: llvm.mlir.addressof @blockIdx + %bid = gpu.block_id x + + // CHECK: llvm.mlir.addressof @blockDim + %bdim = gpu.block_dim x + + // CHECK: arith.muli + %temp = arith.muli %bid, %bdim : index + // CHECK: arith.addi + %gid = arith.addi %temp, %tid : index + + // Synchronize threads + // CHECK: llvm.mlir.constant({{[0-9]+}} : i32) + // CHECK: llvm.call @vx_barrier + gpu.barrier + + // CHECK: return + return %gid : index + } + +} diff --git a/tools/cgeist/driver.cc b/tools/cgeist/driver.cc index b49eecdcfec5..ff451007a79a 100644 --- a/tools/cgeist/driver.cc +++ b/tools/cgeist/driver.cc @@ -1025,7 +1025,7 @@ int main(int argc, char **argv) { CudaInstallationDetector detector(*driver, triple, argList); if (EmitCUDA) { -#if POLYGEIST_ENABLE_CUDA +#if POLYGEIST_CUDA_FULL std::string arch = CUDAGPUArch; if (arch == "") arch = "sm_60"; @@ -1104,7 +1104,7 @@ int main(int argc, char **argv) { llvm::errs() << "Failed to emit LLVM IR\n"; return -1; } -#if POLYGEIST_ENABLE_CUDA +#if POLYGEIST_CUDA_FULL if (EmitCUDA) { // This header defines: // unsigned char CudaRuntimeWrappers_cpp_bc[]