Skip to content

Commit b3f5970

Browse files
authored
Fix AMD GPU alloca address space errors (#433)
* Fix AMD GPU alloca address space errors AMD GPUs require stack allocations (alloca instructions) to be in address space 5 (private/local memory), not address space 0 (generic memory) * Optimize AMDGPU allocas by keeping AS5 pointers throughout * Fix AMDGPU allocas to use address space 5 in MLIR lowering
1 parent b458795 commit b3f5970

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

xla/codegen/emitters/transforms/lower_to_llvm.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,29 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
140140
std::move(mathPatterns)))) {
141141
signalPassFailure();
142142
}
143+
144+
// For AMDGPU, fix allocas to use address space 5 (private)
145+
// AMDGPU requires allocas in AS5, but MLIR lowering creates them in AS0
146+
if (device_spec_.IsAmdGpu()) {
147+
getOperation()->walk([](mlir::LLVM::AllocaOp alloca) {
148+
auto ptr_type = mlir::cast<mlir::LLVM::LLVMPointerType>(alloca.getResult().getType());
149+
// Check if address space is 0 (default/generic)
150+
if (ptr_type.getAddressSpace() == 0) {
151+
mlir::OpBuilder builder(alloca);
152+
// Create new alloca in address space 5
153+
auto new_ptr_type = mlir::LLVM::LLVMPointerType::get(builder.getContext(), 5);
154+
auto new_alloca = builder.create<mlir::LLVM::AllocaOp>(
155+
alloca.getLoc(),
156+
new_ptr_type,
157+
alloca.getElemType(),
158+
alloca.getArraySize(),
159+
alloca.getAlignment().value_or(0));
160+
alloca.replaceAllUsesWith(new_alloca.getResult());
161+
alloca.erase();
162+
}
163+
});
164+
VLOG(3) << "Fixed AMDGPU allocas to use address space 5";
165+
}
143166
}
144167

145168
private:

xla/service/llvm_ir/llvm_loop.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,9 @@ void ForLoop::Emit(llvm::IRBuilderBase* b) {
105105
llvm::Function* func = preheader_bb_->getParent();
106106
b->SetInsertPoint(&func->getEntryBlock(),
107107
func->getEntryBlock().getFirstInsertionPt());
108-
llvm::Value* indvar_address = b->CreateAlloca(
109-
start_index_->getType(), nullptr, GetQualifiedName("invar_address"));
108+
// Use EmitAllocaAtFunctionEntryWithCount which handles AMD GPU address space correctly
109+
llvm::Value* indvar_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
110+
start_index_->getType(), nullptr, GetQualifiedName("invar_address"), b, 0);
110111

111112
// Preheader basic block.
112113
// Initialize induction variable starting index. Create branch to the header.

xla/service/llvm_ir/tuple_ops.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ std::vector<llvm::Value*> EmitTupleAllocasAtFunctionEntry(
8282
CHECK(ShapeUtil::IsScalar(element_shape));
8383
llvm::Type* type = llvm_ir::PrimitiveTypeToIrType(
8484
element_shape.element_type(), b->getContext());
85-
llvm::AllocaInst* alloca = b->CreateAlloca(
86-
type,
87-
/*ArraySize=*/nullptr, AsStringRef(absl::StrCat("tuple_element_", i)));
85+
// Use EmitAllocaAtFunctionEntry which handles AMD GPU address space correctly
86+
llvm::AllocaInst* alloca = llvm_ir::EmitAllocaAtFunctionEntry(
87+
type, absl::StrCat("tuple_element_", i), b);
8888
generated_allocas.push_back(alloca);
8989
}
9090

0 commit comments

Comments
 (0)