Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,6 @@ bool IsInnerThreadSpatialLoopGT(const ScheduleConfig& config, int num) {
return config.tile_config.spatial_inner_num > num;
}

bool IsPerThreadReduceGELoopExtent(const ScheduleConfig& config,
const ir::Expr& loop) {
if (loop.As<ir::For>()->extent.is_constant()) {
int extent = ir::GetLoopExtent(loop);
return extent <= config.tile_config.tree_reduce_num;
}
return false;
}

bool IsReduceBlock(const ScheduleConfig& config, const std::string& block_id) {
return config.base_info->reduce_tensor_names.count(block_id) > 0;
}
Expand Down Expand Up @@ -184,10 +175,6 @@ void TileFirstGeneralTactic::SplitReduceInner(ir::IRSchedule* sch,
auto loops = sch->GetLoops(block_id);
auto reduce_loop = loops[reduce_current_axis_].As<ir::For>();

if (IsPerThreadReduceGELoopExtent(context_->config, reduce_loop)) {
return;
}

if (FLAGS_support_reduce_stride_read) {
if (context_->config.base_info->reduce_numel <= 256) {
std::vector<int> split_factors{
Expand Down
85 changes: 46 additions & 39 deletions test/ir/pir/cinn/test_cinn_sub_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,53 +158,60 @@ def check_jit_kernel_info(self, static_fn):
# np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8)


# class TestCinnSoftmax(TestCinnSubGraphBase):
# def train(self, use_cinn):
# paddle.seed(2022)
# net = CINNSoftmaxSubGraphNet()
# net = utils.apply_to_static(net, use_cinn)
# out = net(self.x, self.axis)

# loss = out.sum()
# loss.backward()
# print(self.x.gradient())
# return out, self.x.gradient()

# def test_forward(self):
# cinn_out, cinn_grad = self.train(use_cinn=True)
# dy_out, dy_grad = self.train(use_cinn=False)
# np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8)
# np.testing.assert_allclose(cinn_grad, dy_grad, atol=1e-8)


class TestCinnLayerNorm(TestCinnSubGraphBase):
class TestCinnSoftmax(TestCinnSubGraphBase):
def train(self, use_cinn):
paddle.seed(2022)
self.prepare_data()
net = CINNLayerNormSubGraphNet(self.shape[-1])
net = CINNSoftmaxSubGraphNet()
net = utils.apply_to_static(net, use_cinn)
# net.eval()
weight = paddle.ones(shape=[self.shape[-1]], dtype="float64")
weight.stop_gradient = False
bias = paddle.ones(shape=[self.shape[-1]], dtype="float64")
bias.stop_gradient = False
self.x.stop_gradient = False
out = net(self.x, weight, bias)
out = net(self.x, self.axis)

loss = out.sum()
loss.backward()
return out, self.x.gradient()

return out, self.x.gradient(), weight.gradient(), bias.gradient()
def test_forward(self):
cinn_out, cinn_grad = self.train(use_cinn=True)
dy_out, dy_grad = self.train(use_cinn=False)
np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8)
np.testing.assert_allclose(cinn_grad, dy_grad, atol=1e-8)

def test_train(self):
cinn_out, cinn_x_grad, cinn_w_grad, cinn_b_grad = self.train(
use_cinn=True
)

dy_out, dy_x_grad, dy_w_grad, dy_b_grad = self.train(use_cinn=False)
np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8)
np.testing.assert_allclose(cinn_x_grad, dy_x_grad, atol=1e-8)
np.testing.assert_allclose(cinn_w_grad, dy_w_grad, atol=1e-8)
np.testing.assert_allclose(cinn_b_grad, dy_b_grad, atol=1e-8)
class TestCinnSmallSoftmax(TestCinnSoftmax):
def prepare_data(self):
self.shape = [1, 1, 17, 17]
self.axis = -1
self.x = paddle.uniform(self.shape, dtype="float64", min=-0.5, max=0.5)
self.x.stop_gradient = False


# class TestCinnLayerNorm(TestCinnSubGraphBase):
# def train(self, use_cinn):
# paddle.seed(2022)
# self.prepare_data()
# net = CINNLayerNormSubGraphNet(self.shape[-1])
# net = utils.apply_to_static(net, use_cinn)
# # net.eval()
# weight = paddle.ones(shape=[self.shape[-1]], dtype="float64")
# weight.stop_gradient = False
# bias = paddle.ones(shape=[self.shape[-1]], dtype="float64")
# bias.stop_gradient = False
# self.x.stop_gradient = False
# out = net(self.x, weight, bias)
# loss = out.sum()
# loss.backward()

# return out, self.x.gradient(), weight.gradient(), bias.gradient()

# def test_train(self):
# cinn_out, cinn_x_grad, cinn_w_grad, cinn_b_grad = self.train(
# use_cinn=True
# )

# dy_out, dy_x_grad, dy_w_grad, dy_b_grad = self.train(use_cinn=False)
# np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8)
# np.testing.assert_allclose(cinn_x_grad, dy_x_grad, atol=1e-8)
# np.testing.assert_allclose(cinn_w_grad, dy_w_grad, atol=1e-8)
# np.testing.assert_allclose(cinn_b_grad, dy_b_grad, atol=1e-8)


# class TestAddDropoutLayerNorm(TestCinnSubGraphBase):
Expand Down