Skip to content

Commit 6ebb20b

Browse files
committed
[CINN][New Hardware Update] replace DefaultNVGPUTarget
* replace DefaultNVGPUTarget with CurrentTarget
1 parent 3a4b1b7 commit 6ebb20b

File tree

11 files changed

+30
-20
lines changed

11 files changed

+30
-20
lines changed

paddle/cinn/backends/codegen_device_util.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "paddle/cinn/ir/ir.h"
2626
#include "paddle/cinn/ir/ir_mutator.h"
2727
#include "paddle/cinn/ir/utils/ir_copy.h"
28+
#include "paddle/cinn/runtime/flags.h"
2829

2930
namespace cinn {
3031
namespace backends {
@@ -51,8 +52,9 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
5152
explicit CollectHostFunctionVisitor(const std::string& module_name)
5253
: host_module_builder(module_name + "_host",
5354
cinn::common::DefaultHostTarget()),
54-
device_module_builder(module_name + "_gpu_device",
55-
cinn::common::DefaultNVGPUTarget()) {}
55+
device_module_builder(
56+
module_name + "_gpu_device",
57+
cinn::runtime::CurrentTarget::GetCurrentTarget()) {}
5658

5759
std::tuple<ir::Module, ir::Module> operator()(Expr* expr) {
5860
ir::IRMutator<>::Visit(expr, expr);

paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass_utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_group.h"
1818
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
19+
#include "paddle/cinn/runtime/flags.h"
1920

2021
namespace cinn {
2122
namespace dialect {
@@ -140,7 +141,7 @@ static int GetSharedSize(const cinn::dialect::ir::OpNode& op_node) {
140141
lane = inshape[idx];
141142
}
142143
// int max_num_threads =
143-
// cinn::common::DefaultNVGPUTarget().max_num_threads();
144+
// cinn::runtime::CurrentTarget::GetCurrentTarget().max_num_threads();
144145
int max_num_threads = 1000;
145146
if (lane > max_num_threads / 2) {
146147
return 0;

paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <unordered_set>
2323
#include <vector>
2424

25+
#include "paddle/cinn/runtime/flags.h"
2526
#include "paddle/phi/core/enforce.h"
2627
#include "paddle/pir/include/core/builtin_attribute.h"
2728
#include "paddle/pir/include/core/ir_printer.h"
@@ -197,8 +198,8 @@ int GetSharedSize(::pir::Operation* op) {
197198
lane = inshape[idx];
198199
}
199200
// int max_num_threads =
200-
// cinn::common::DefaultNVGPUTarget().max_num_threads(); todo(phlrain): get
201-
// gpu max threads
201+
// cinn::runtime::CurrentTarget::GetCurrentTarget().max_num_threads();
202+
// todo(phlrain): get gpu max threads
202203
int max_num_threads = 2048;
203204
if (lane > max_num_threads / 2) {
204205
return 0;

paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/pre_analysis.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
1717
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h"
1818
#include "paddle/cinn/hlir/framework/pir_compiler.h"
19+
#include "paddle/cinn/runtime/flags.h"
1920
#include "paddle/common/flags.h"
2021

2122
PD_DECLARE_bool(enable_cinn_compile_cache);
@@ -56,7 +57,7 @@ void FusionOpAnalysis::PreCompileGroup() {
5657
}
5758
// Build and trigger compilaion cache.
5859
VLOG(4) << "Parallel Pre-Compile for Group with size: " << groups.size();
59-
PirCompiler pir_compiler(cinn::common::DefaultNVGPUTarget());
60+
PirCompiler pir_compiler(cinn::runtime::CurrentTarget::GetCurrentTarget());
6061
pir_compiler.Build(groups);
6162
}
6263
} // namespace cinn::dialect::ir::details

paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ std::vector<pir::Value> GetBlockOutsideInput(
6161
std::unordered_map<OpLoweringGroupPtr,
6262
std::unordered_map<std::string, pir::Attribute>>
6363
CompileGroupAsOpAttribute(const std::vector<OpLoweringGroupPtr>& group_list) {
64-
PirCompiler pir_compiler(cinn::common::DefaultNVGPUTarget());
64+
PirCompiler pir_compiler(cinn::runtime::CurrentTarget::GetCurrentTarget());
6565
auto fn_ptr_res = pir_compiler.Build(group_list);
6666

6767
std::unordered_map<OpLoweringGroupPtr,
@@ -85,7 +85,8 @@ std::unordered_map<std::string, ::pir::Attribute> GetJitKernelAttr(
8585
hlir::framework::pir::FusionInfo fusion_info(*group);
8686
return CompilationCache::Instance().GetKernelInfo(fusion_info);
8787
} else {
88-
PirCompiler pir_compiler(cinn::common::DefaultNVGPUTarget());
88+
PirCompiler pir_compiler(
89+
cinn::runtime::CurrentTarget::GetCurrentTarget());
8990
return pir_compiler.Build({group})[0];
9091
}
9192
};

paddle/cinn/hlir/framework/op_lowering_util.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, // NOLINT
717717
// If the number of current device SM is smaller than the number of SM
718718
// required by Warp Reduce, the performance of Warp Reduce is better.
719719
// Otherwise, use Block Reduce.
720-
auto max_num_threads = cinn::common::DefaultNVGPUTarget().max_num_threads();
720+
auto max_num_threads = target.max_num_threads();
721721
int need_reduce_last_count = 1;
722722
for (int i = 0; i < inshape.size(); i++) {
723723
if (find(axes.begin(), axes.end(), i) == axes.end()) {

paddle/cinn/hlir/framework/pir/op_lowering_util.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, // NOLINT
577577
// If the number of current device SM is smaller than the number of SM
578578
// required by Warp Reduce, the performance of Warp Reduce is better.
579579
// Otherwise, use Block Reduce.
580-
auto max_num_threads = cinn::common::DefaultNVGPUTarget().max_num_threads();
580+
auto max_num_threads = target.max_num_threads();
581581
int need_reduce_last_count = 1;
582582
for (int i = 0; i < inshape.size(); i++) {
583583
if (find(axes.begin(), axes.end(), i) == axes.end()) {

paddle/cinn/hlir/op/reduction.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
263263
reduce_tmp_out.as_tensor_ref(),
264264
tmp_out.as_tensor_ref(),
265265
out.as_tensor_ref(),
266-
cinn::common::DefaultNVGPUTarget());
266+
cinn::runtime::CurrentTarget::GetCurrentTarget());
267267

268268
std::vector<CINNValue> res{
269269
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
@@ -279,7 +279,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
279279
reduce_tmp_out.as_tensor_ref(),
280280
tmp_out.as_tensor_ref(),
281281
out.as_tensor_ref(),
282-
cinn::common::DefaultNVGPUTarget());
282+
cinn::runtime::CurrentTarget::GetCurrentTarget());
283283

284284
std::vector<CINNValue> res{
285285
CINNValue(ir_sch.GetModule().GetExprs().at(0))};

paddle/cinn/hlir/pe/reduction.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "paddle/cinn/ir/tensor.h"
2828
#include "paddle/cinn/lang/builtin.h"
2929
#include "paddle/cinn/lang/compute.h"
30+
#include "paddle/cinn/runtime/flags.h"
3031
#include "paddle/cinn/utils/string.h"
3132

3233
namespace cinn {
@@ -841,7 +842,8 @@ std::vector<ir::Tensor> TwoStepBlockReduceInternal(
841842
// If the number of current device SM is smaller than the number of SM
842843
// required by Warp Reduce, the performance of Warp Reduce is better.
843844
// Otherwise, use Block Reduce.
844-
auto max_num_threads = cinn::common::DefaultNVGPUTarget().max_num_threads();
845+
auto max_num_threads =
846+
cinn::runtime::CurrentTarget::GetCurrentTarget().max_num_threads();
845847
int need_reduce_last_count = 1;
846848
for (int i = 0; i < A->shape.size(); i++) {
847849
if (find(axes.begin(), axes.end(), i) == axes.end()) {
@@ -850,11 +852,11 @@ std::vector<ir::Tensor> TwoStepBlockReduceInternal(
850852
}
851853
int warp_reduce_need_sm_count =
852854
ceil((need_reduce_last_count * 32) /
853-
static_cast<float>(
854-
cinn::common::DefaultNVGPUTarget().get_max_threads_per_sm()));
855+
static_cast<float>(cinn::runtime::CurrentTarget::GetCurrentTarget()
856+
.get_max_threads_per_sm()));
855857
// Set Num_max_threads to 32 is Warp Reduce
856-
if (cinn::common::DefaultNVGPUTarget().get_multi_processor_count() <
857-
warp_reduce_need_sm_count) {
858+
if (cinn::runtime::CurrentTarget::GetCurrentTarget()
859+
.get_multi_processor_count() < warp_reduce_need_sm_count) {
858860
max_num_threads = 32;
859861
}
860862

paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "paddle/cinn/ir/group_schedule/tactic/tile_tactic.h"
1616
#include "paddle/cinn/common/target.h"
1717
#include "paddle/cinn/ir/ir.h"
18+
#include "paddle/cinn/runtime/flags.h"
1819

1920
namespace cinn {
2021
namespace ir {
@@ -46,7 +47,7 @@ void TileTactic::Init(ScheduleContext* context) {
4647
};
4748
auto GetTreeReduceSize = [&](const ir::Expr& total_rb_extent) -> int64_t {
4849
const int64_t max_num_threads =
49-
common::DefaultNVGPUTarget().max_num_threads();
50+
cinn::runtime::CurrentTarget::GetCurrentTarget().max_num_threads();
5051
int64_t nums_thread_per_block = max_num_threads;
5152
if (total_rb_extent.is_constant()) {
5253
int64_t extent = static_cast<int64_t>(total_rb_extent.get_constant());

0 commit comments

Comments
 (0)