Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions flang/include/flang/Optimizer/Builder/IntrinsicCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,10 @@ struct IntrinsicLibrary {
fir::ExtendedValue genSum(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
void genSignalSubroutine(llvm::ArrayRef<fir::ExtendedValue>);
void genSleep(llvm::ArrayRef<fir::ExtendedValue>);
void genSyncThreads(llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genSyncThreadsAnd(mlir::Type,llvm::ArrayRef<mlir::Value>);
mlir::Value genSyncThreadsCount(mlir::Type,llvm::ArrayRef<mlir::Value>);
mlir::Value genSyncThreadsOr(mlir::Type,llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genSystem(std::optional<mlir::Type>,
mlir::ArrayRef<fir::ExtendedValue> args);
void genSystemClock(llvm::ArrayRef<fir::ExtendedValue>);
Expand All @@ -401,6 +405,9 @@ struct IntrinsicLibrary {
llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genTranspose(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
void genThreadFence(llvm::ArrayRef<fir::ExtendedValue>);
void genThreadFenceBlock(llvm::ArrayRef<fir::ExtendedValue>);
void genThreadFenceSystem(llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genTrim(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genUbound(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genUnpack(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
Expand Down
85 changes: 85 additions & 0 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,10 @@ static constexpr IntrinsicHandler handlers[]{
{"dim", asValue},
{"mask", asBox, handleDynamicOptional}}},
/*isElemental=*/false},
{"syncthreads", &I::genSyncThreads, {}, /*isElemental=*/false},
{"syncthreads_and", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
{"syncthreads_count", &I::genSyncThreadsCount, {}, /*isElemental=*/false},
{"syncthreads_or", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
{"system",
&I::genSystem,
{{{"command", asBox}, {"exitstat", asBox, handleDynamicOptional}}},
Expand All @@ -660,6 +664,9 @@ static constexpr IntrinsicHandler handlers[]{
&I::genTranspose,
{{{"matrix", asAddr}}},
/*isElemental=*/false},
{"threadfence", &I::genThreadFence, {}, /*isElemental=*/false},
{"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false},
{"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false},
{"trim", &I::genTrim, {{{"string", asAddr}}}, /*isElemental=*/false},
{"ubound",
&I::genUbound,
Expand Down Expand Up @@ -7290,6 +7297,52 @@ IntrinsicLibrary::genSum(mlir::Type resultType,
resultType, args);
}

// SYNCTHREADS
void IntrinsicLibrary::genSyncThreads(llvm::ArrayRef<fir::ExtendedValue> args) {
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0";
mlir::FunctionType funcType =
mlir::FunctionType::get(builder.getContext(), {}, {});
auto funcOp = builder.createFunction(loc, funcName, funcType);
llvm::SmallVector<mlir::Value> noArgs;
builder.create<fir::CallOp>(loc, funcOp, noArgs);
}

// SYNCTHREADS_AND
mlir::Value
IntrinsicLibrary::genSyncThreadsAnd(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.and";
mlir::MLIRContext *context = builder.getContext();
mlir::FunctionType ftype =
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
auto funcOp = builder.createFunction(loc, funcName, ftype);
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
}

// SYNCTHREADS_COUNT
mlir::Value
IntrinsicLibrary::genSyncThreadsCount(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.popc";
mlir::MLIRContext *context = builder.getContext();
mlir::FunctionType ftype =
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
auto funcOp = builder.createFunction(loc, funcName, ftype);
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
}

// SYNCTHREADS_OR
mlir::Value
IntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.or";
mlir::MLIRContext *context = builder.getContext();
mlir::FunctionType ftype =
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
auto funcOp = builder.createFunction(loc, funcName, ftype);
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
}

// SYSTEM
fir::ExtendedValue
IntrinsicLibrary::genSystem(std::optional<mlir::Type> resultType,
Expand Down Expand Up @@ -7420,6 +7473,38 @@ IntrinsicLibrary::genTranspose(mlir::Type resultType,
return readAndAddCleanUp(resultMutableBox, resultType, "TRANSPOSE");
}

// THREADFENCE
void IntrinsicLibrary::genThreadFence(llvm::ArrayRef<fir::ExtendedValue> args) {
constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.gl";
mlir::FunctionType funcType =
mlir::FunctionType::get(builder.getContext(), {}, {});
auto funcOp = builder.createFunction(loc, funcName, funcType);
llvm::SmallVector<mlir::Value> noArgs;
builder.create<fir::CallOp>(loc, funcOp, noArgs);
}

// THREADFENCE_BLOCK
void IntrinsicLibrary::genThreadFenceBlock(
llvm::ArrayRef<fir::ExtendedValue> args) {
constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.cta";
mlir::FunctionType funcType =
mlir::FunctionType::get(builder.getContext(), {}, {});
auto funcOp = builder.createFunction(loc, funcName, funcType);
llvm::SmallVector<mlir::Value> noArgs;
builder.create<fir::CallOp>(loc, funcOp, noArgs);
}

// THREADFENCE_SYSTEM
void IntrinsicLibrary::genThreadFenceSystem(
llvm::ArrayRef<fir::ExtendedValue> args) {
constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.sys";
mlir::FunctionType funcType =
mlir::FunctionType::get(builder.getContext(), {}, {});
auto funcOp = builder.createFunction(loc, funcName, funcType);
llvm::SmallVector<mlir::Value> noArgs;
builder.create<fir::CallOp>(loc, funcOp, noArgs);
}

// TRIM
fir::ExtendedValue
IntrinsicLibrary::genTrim(mlir::Type resultType,
Expand Down
14 changes: 7 additions & 7 deletions flang/module/cudadevice.f90
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,27 @@ module cudadevice
! Synchronization Functions

interface
attributes(device) subroutine syncthreads() bind(c, name='__syncthreads')
attributes(device) subroutine syncthreads()
end subroutine
end interface
public :: syncthreads

interface
attributes(device) integer function syncthreads_and(value) bind(c, name='__syncthreads_and')
attributes(device) integer function syncthreads_and(value)
integer :: value
end function
end interface
public :: syncthreads_and

interface
attributes(device) integer function syncthreads_count(value) bind(c, name='__syncthreads_count')
attributes(device) integer function syncthreads_count(value)
integer :: value
end function
end interface
public :: syncthreads_count

interface
attributes(device) integer function syncthreads_or(value) bind(c, name='__syncthreads_or')
attributes(device) integer function syncthreads_or(value)
integer :: value
end function
end interface
Expand All @@ -54,19 +54,19 @@ attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
! Memory Fences

interface
attributes(device) subroutine threadfence() bind(c, name='__threadfence')
attributes(device) subroutine threadfence()
end subroutine
end interface
public :: threadfence

interface
attributes(device) subroutine threadfence_block() bind(c, name='__threadfence_block')
attributes(device) subroutine threadfence_block()
end subroutine
end interface
public :: threadfence_block

interface
attributes(device) subroutine threadfence_system() bind(c, name='__threadfence_system')
attributes(device) subroutine threadfence_system()
end subroutine
end interface
public :: threadfence_system
Expand Down
8 changes: 4 additions & 4 deletions flang/test/Lower/CUDA/cuda-device-proc.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ attributes(global) subroutine devsub()
end

! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
! CHECK: fir.call @__syncthreads()
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> ()
! CHECK: fir.call @__threadfence()
! CHECK: fir.call @__threadfence_block()
! CHECK: fir.call @__threadfence_system()
! CHECK: %{{.*}} = fir.call @__syncthreads_and(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
! CHECK: %{{.*}} = fir.call @__syncthreads_count(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
! CHECK: %{{.*}} = fir.call @__syncthreads_or(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%c1_i32_0) fastmath<contract> : (i32) -> i32
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%c1_i32_1) fastmath<contract> : (i32) -> i32
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%c1_i32_2) fastmath<contract> : (i32) -> i32

! CHECK: func.func private @__syncthreads() attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncthreads", fir.proc_attrs = #fir.proc_attrs<bind_c>}
! CHECK: func.func private @__syncwarp(!fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>}) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
Expand Down
Loading