Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ static int GetSharedSize(const cinn::dialect::ir::OpNode& op_node) {
lane = inshape[idx];
}
// int max_num_threads =
// cinn::common::DefaultNVGPUTarget().max_num_threads();
// cinn::common::DefaultDeviceTarget().max_num_threads();
int max_num_threads = 1000;
if (lane > max_num_threads / 2) {
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ int GetSharedSize(::pir::Operation* op) {
lane = inshape[idx];
}
// int max_num_threads =
// cinn::common::DefaultNVGPUTarget().max_num_threads(); todo(phlrain): get
// gpu max threads
// cinn::common::DefaultDeviceTarget().max_num_threads();
// todo(phlrain): get gpu max threads
int max_num_threads = 2048;
if (lane > max_num_threads / 2) {
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void FusionOpAnalysis::PreCompileGroup() {
}
// Build and trigger compilaion cache.
VLOG(4) << "Parallel Pre-Compile for Group with size: " << groups.size();
PirCompiler pir_compiler(cinn::common::DefaultNVGPUTarget());
PirCompiler pir_compiler(cinn::common::DefaultDeviceTarget());
pir_compiler.Build(groups);
}
} // namespace cinn::dialect::ir::details
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ std::vector<pir::Value> GetBlockOutsideInput(
std::unordered_map<OpLoweringGroupPtr,
std::unordered_map<std::string, pir::Attribute>>
CompileGroupAsOpAttribute(const std::vector<OpLoweringGroupPtr>& group_list) {
PirCompiler pir_compiler(cinn::common::DefaultNVGPUTarget());
PirCompiler pir_compiler(cinn::common::DefaultDeviceTarget());
auto fn_ptr_res = pir_compiler.Build(group_list);

std::unordered_map<OpLoweringGroupPtr,
Expand All @@ -85,7 +85,7 @@ std::unordered_map<std::string, ::pir::Attribute> GetJitKernelAttr(
hlir::framework::pir::FusionInfo fusion_info(*group);
return CompilationCache::Instance().GetKernelInfo(fusion_info);
} else {
PirCompiler pir_compiler(cinn::common::DefaultNVGPUTarget());
PirCompiler pir_compiler(cinn::common::DefaultDeviceTarget());
return pir_compiler.Build({group})[0];
}
};
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/op_lowering_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, // NOLINT
// If the number of current device SM is smaller than the number of SM
// required by Warp Reduce, the performance of Warp Reduce is better.
// Otherwise, use Block Reduce.
auto max_num_threads = cinn::common::DefaultNVGPUTarget().max_num_threads();
auto max_num_threads = cinn::common::DefaultDeviceTarget().max_num_threads();
int need_reduce_last_count = 1;
for (int i = 0; i < inshape.size(); i++) {
if (find(axes.begin(), axes.end(), i) == axes.end()) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/pir/op_lowering_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, // NOLINT
// If the number of current device SM is smaller than the number of SM
// required by Warp Reduce, the performance of Warp Reduce is better.
// Otherwise, use Block Reduce.
auto max_num_threads = cinn::common::DefaultNVGPUTarget().max_num_threads();
auto max_num_threads = cinn::common::DefaultDeviceTarget().max_num_threads();
int need_reduce_last_count = 1;
for (int i = 0; i < inshape.size(); i++) {
if (find(axes.begin(), axes.end(), i) == axes.end()) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/op/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
reduce_tmp_out.as_tensor_ref(),
tmp_out.as_tensor_ref(),
out.as_tensor_ref(),
cinn::common::DefaultNVGPUTarget());
cinn::common::DefaultDeviceTarget());

std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
Expand All @@ -279,7 +279,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
reduce_tmp_out.as_tensor_ref(),
tmp_out.as_tensor_ref(),
out.as_tensor_ref(),
cinn::common::DefaultNVGPUTarget());
cinn::common::DefaultDeviceTarget());

std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
Expand Down
6 changes: 3 additions & 3 deletions paddle/cinn/hlir/pe/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ std::vector<ir::Tensor> TwoStepBlockReduceInternal(
// If the number of current device SM is smaller than the number of SM
// required by Warp Reduce, the performance of Warp Reduce is better.
// Otherwise, use Block Reduce.
auto max_num_threads = cinn::common::DefaultNVGPUTarget().max_num_threads();
auto max_num_threads = cinn::common::DefaultDeviceTarget().max_num_threads();
int need_reduce_last_count = 1;
for (int i = 0; i < A->shape.size(); i++) {
if (find(axes.begin(), axes.end(), i) == axes.end()) {
Expand All @@ -851,9 +851,9 @@ std::vector<ir::Tensor> TwoStepBlockReduceInternal(
int warp_reduce_need_sm_count =
ceil((need_reduce_last_count * 32) /
static_cast<float>(
cinn::common::DefaultNVGPUTarget().get_max_threads_per_sm()));
cinn::common::DefaultDeviceTarget().get_max_threads_per_sm()));
// Set Num_max_threads to 32 is Warp Reduce
if (cinn::common::DefaultNVGPUTarget().get_multi_processor_count() <
if (cinn::common::DefaultDeviceTarget().get_multi_processor_count() <
warp_reduce_need_sm_count) {
max_num_threads = 32;
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void TileTactic::Init(ScheduleContext* context) {
};
auto GetTreeReduceSize = [&](const ir::Expr& total_rb_extent) -> int64_t {
const int64_t max_num_threads =
common::DefaultNVGPUTarget().max_num_threads();
cinn::common::DefaultDeviceTarget().max_num_threads();
int64_t nums_thread_per_block = max_num_threads;
if (total_rb_extent.is_constant()) {
int64_t extent = static_cast<int64_t>(total_rb_extent.get_constant());
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/optim/map_extern_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void DealWithIntrinsicsImpl(common::NVGPUArch, ir::Call *node, Expr *expr) {
}

std::string extern_func =
hlir::GetExternFuncName(cinn::common::DefaultNVGPUTarget(), dtype, name);
hlir::GetExternFuncName(cinn::common::DefaultDeviceTarget(), dtype, name);
*expr = lang::CallExtern(extern_func, node->read_args, node->attrs);
}

Expand Down