Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,36 @@
#include "paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h"
#include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/tile_tactic.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include "paddle/cinn/ir/op/ir_operators.h"

namespace cinn {
namespace ir {

void DynamicShapeGroupScheduler::Init() {
// Only 1 bucket for test now.
schedule_context_.target = target_;
schedule_context_.output_names = OutputTensorNames();
schedule_context_.global_master = FindGlobalMasterNode();
schedule_context_.iter_space_info =
ConstructIterSpaceInfo(schedule_context_.global_master);
schedule_context_.target = target_;
schedule_context_.bucket_info = {/* sp_lower_bound = */ 1024,
/* sp_upper_bound = */ INT_MAX,
/* rb_lower_bound = */ 64,
/* rb_upper_bound = */ INT_MAX};
tactics_.emplace_back(new AlignIterSpaceTactic());
tactics_.emplace_back(new TileTactic());
tactics_.emplace_back(new ComputeInlineTactic());
tactics_.emplace_back(new BindCudaTactic());
tactics_.emplace_back(new ArrangeStorageTactic());
}

void DynamicShapeGroupScheduler::Schedule() {
// Fake schedule for test
ApplyTactics();
std::vector<Expr> all_blocks = ir_sch_->GetAllBlocks();
auto block0_loops = ir_sch_->GetLoops(all_blocks[0]);
auto splited_loops1 = ir_sch_->Split(block0_loops[0], {1024, -1});
ir_sch_->Bind(splited_loops1[0], "threadIdx.x");

// Fake bucket for test
ir::Expr predicate1 = ir::LE::Make(Expr(1023), Expr(1024));
std::unique_ptr<ir::IRSchedule> new_ir_sch1 =
std::make_unique<ir::IRSchedule>(*ir_sch_);
Expand All @@ -55,12 +60,12 @@ void DynamicShapeGroupScheduler::ApplyTactics() {
VLOG(6) << "before applying [" << tactic->TacticName()
<< "] on ScheduleBlockNode [" << node->id() << "] func body:\n"
<< ir_sch_->GetModule().GetExprs().front();
tactic->Init(&schedule_context_);
tactic->Apply(ir_sch_, node->id());
VLOG(6) << "after applying [" << tactic->TacticName()
<< "] on ScheduleBlockNode [" << node->id() << "] func body:\n"
<< ir_sch_->GetModule().GetExprs().front();
};
tactic->Init(&schedule_context_);
schedule_block_graph_->DFSTopoWalk(ApplyTacticFunc);
schedule_block_graph_->Update(*ir_sch_);
VLOG(5) << "[End " << tactic->TacticName()
Expand Down Expand Up @@ -96,6 +101,7 @@ IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo(
std::unordered_map<ir::Var, ir::Expr> iter_var2value =
analyzer::GetIterVarToValueOfSBlock(block);

// init iter info
if (!reduce_iter_vars.empty()) {
std::set<ir::Expr> reduce_loads = ir::ir_utils::CollectIRNodesWithoutTensor(
block,
Expand Down Expand Up @@ -161,6 +167,20 @@ IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo(
info.rb_last_order.push_back(i);
}
}
// init total extents
ir::Expr sp_extent = ir::Expr(1);
ir::Expr rb_extent = ir::Expr(1);
for (const auto& axis : info.sp_space) {
const ir::Expr& extent = std::get<0>(axis);
sp_extent = sp_extent * extent;
}
for (const auto& axis : info.rb_space) {
const ir::Expr& extent = std::get<0>(axis);
rb_extent = rb_extent * extent;
}
info.total_sp_extent = sp_extent;
info.total_rb_extent = rb_extent;

return info;
}

Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
core_gather_headers()

gather_srcs(cinnapi_src SRCS align_iter_space_tactic.cc)
gather_srcs(cinnapi_src SRCS tile_tactic.cc)
gather_srcs(cinnapi_src SRCS compute_inline_tactic.cc)
gather_srcs(cinnapi_src SRCS bind_cuda_tactic.cc)
gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#pragma once

#include <string>
#include <unordered_set>
#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h"

namespace cinn {
Expand Down
58 changes: 58 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h"
#include <unordered_map>
#include "paddle/cinn/ir/ir.h"

namespace cinn {
namespace ir {

void BindCudaTactic::Init(ScheduleContext* context) { context_ = context; }

const std::unordered_map<IterativeSpaceInfo::AxisType, std::string>
axis_type2bind_info = {
{IterativeSpaceInfo::AxisType::kCudaBlockX, "blockIdx.x"},
{IterativeSpaceInfo::AxisType::kCudaBlockY, "blockIdx.y"},
{IterativeSpaceInfo::AxisType::kCudaBlockZ, "blockIdx.z"},
{IterativeSpaceInfo::AxisType::kCudaThreadX, "threadIdx.x"},
{IterativeSpaceInfo::AxisType::kCudaThreadY, "threadIdx.y"},
{IterativeSpaceInfo::AxisType::kCudaThreadZ, "threadIdx.z"},
};

void BindCudaTactic::Apply(ir::IRSchedule* sch, const std::string& block_id) {
std::vector<ir::Expr> loops = sch->GetLoops(block_id);
int loop_idx = 0;
for (int i = 0;
i < context_->iter_space_info.sp_space.size() && loop_idx < loops.size();
++i, ++loop_idx) {
const auto& axis = context_->iter_space_info.sp_space[i];
const IterativeSpaceInfo::AxisType& axis_type = std::get<1>(axis);
if (axis_type2bind_info.count(axis_type) != 0) {
sch->Bind(loops[loop_idx], axis_type2bind_info.at(axis_type));
}
}
for (int i = 0;
i < context_->iter_space_info.rb_space.size() && loop_idx < loops.size();
++i, ++loop_idx) {
const auto& axis = context_->iter_space_info.rb_space[i];
const IterativeSpaceInfo::AxisType& axis_type = std::get<1>(axis);
if (axis_type2bind_info.count(axis_type) != 0) {
sch->Bind(loops[loop_idx], axis_type2bind_info.at(axis_type));
}
}
}

} // namespace ir
} // namespace cinn
36 changes: 36 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h"

namespace cinn {
namespace ir {

class BindCudaTactic final : public ScheduleTactic {
public:
void Init(ScheduleContext* context) override;

void Apply(ir::IRSchedule* sch, const std::string& block_id) override;

std::string TacticName() const override { return "BindCudaTactic"; }

private:
ScheduleContext* context_;
};

} // namespace ir
} // namespace cinn
12 changes: 12 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ struct IterativeSpaceInfo {
std::vector<std::tuple<ir::Expr, AxisType>> sp_space;
// reduce or broadcast iterative space
std::vector<std::tuple<ir::Expr, AxisType>> rb_space;
// total sp extent
ir::Expr total_sp_extent;
// total rb extent
ir::Expr total_rb_extent;
// original loop order with same iteration order as the memory order
std::vector<ir::Expr> memory_consistent_order_space;
// index that transform from memory consistent order to rb last order
Expand All @@ -45,11 +49,19 @@ struct IterativeSpaceInfo {
std::vector<int> rb_last_order;
};

struct BucketInfo {
int sp_lower_bound = 0;
int sp_upper_bound = UINT_MAX;
int rb_lower_bound = 0;
int rb_upper_bound = UINT_MAX;
};

struct ScheduleContext {
std::unordered_set<std::string> output_names;
ScheduleBlockNode* global_master;
IterativeSpaceInfo iter_space_info;
Target target;
BucketInfo bucket_info;
};

class ScheduleTactic {
Expand Down
84 changes: 84 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/tactic/tile_tactic.h"
#include "paddle/cinn/ir/ir.h"

namespace cinn {
namespace ir {

void TileTactic::Init(ScheduleContext* context) {
context_ = context;
// fake strategy
auto GetFirstFactor = [](int num) {
int factor = 1;
for (int i = num - 1; i >= 1; --i) {
if (num % i == 0) {
return i;
}
}
};

bool has_rb_iter = !context_->iter_space_info.rb_space.empty();
bool has_sp_iter = !context_->iter_space_info.sp_space.empty();
context_->iter_space_info.rb_space.clear();
context_->iter_space_info.sp_space.clear();

if (has_sp_iter) {
int sp_factor = GetFirstFactor(context_->bucket_info.sp_lower_bound);
context_->iter_space_info.sp_space.emplace_back(
ir::Expr(context_->bucket_info.sp_lower_bound / sp_factor),
IterativeSpaceInfo::AxisType::kCudaBlockX);
context_->iter_space_info.sp_space.emplace_back(
ir::Expr(sp_factor),
has_rb_iter ? IterativeSpaceInfo::AxisType::kCudaThreadY
: IterativeSpaceInfo::AxisType::kCudaThreadX);
context_->iter_space_info.sp_space.emplace_back(
ir::Expr(-1), IterativeSpaceInfo::AxisType::kSerial);
}

if (has_rb_iter) {
context_->iter_space_info.rb_space.emplace_back(
ir::Expr(context_->bucket_info.rb_lower_bound),
IterativeSpaceInfo::AxisType::kCudaThreadX);
context_->iter_space_info.rb_space.emplace_back(
ir::Expr(-1), IterativeSpaceInfo::AxisType::kSerial);
}
}

void TileTactic::Apply(ir::IRSchedule* sch, const std::string& block_id) {
std::vector<ir::Expr> loops = sch->GetLoops(block_id);
CHECK(loops.size() == 1 || loops.size() == 2)
<< "All loops must be unified as sp_loop or rb_loop.";
if (loops.size() == 2) {
std::vector<ir::Expr> rb_factors;
for (const auto& axis : context_->iter_space_info.rb_space) {
rb_factors.push_back(std::get<0>(axis));
}
sch->Split(loops[1], rb_factors);
loops = sch->GetLoops(block_id);
VLOG(6) << "after split rb loop of " << block_id << ": "
<< sch->GetModule().GetExprs()[0];
}
std::vector<ir::Expr> sp_factors;
for (const auto& axis : context_->iter_space_info.sp_space) {
sp_factors.push_back(std::get<0>(axis));
}
sch->Split(loops[0], sp_factors);
VLOG(6) << "after split sp loop of " << block_id << ": "
<< sch->GetModule().GetExprs()[0];
}

} // namespace ir
} // namespace cinn
36 changes: 36 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/tile_tactic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h"

namespace cinn {
namespace ir {

class TileTactic final : public ScheduleTactic {
public:
void Init(ScheduleContext* context) override;

void Apply(ir::IRSchedule* sch, const std::string& block_id) override;

std::string TacticName() const override { return "TileTactic"; }

private:
ScheduleContext* context_;
};

} // namespace ir
} // namespace cinn