diff --git a/paddle/cinn/hlir/pe/broadcast.cc b/paddle/cinn/hlir/pe/broadcast.cc index 9ab00fc8ce5dac..23485461496691 100644 --- a/paddle/cinn/hlir/pe/broadcast.cc +++ b/paddle/cinn/hlir/pe/broadcast.cc @@ -23,6 +23,7 @@ #include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/lang/builtin.h" #include "paddle/cinn/lang/compute.h" +#include "paddle/common/errors.h" PD_DECLARE_bool(cinn_bucket_compile); namespace cinn { @@ -376,16 +377,20 @@ Tensor BroadcastTo(const Tensor& A, const std::vector& out_shape, const std::string& out_name) { auto A_shape = A->shape; - CHECK_EQ(A_shape.size(), out_shape.size()) - << "broadcast_to's out_shape's size should be same with the input " - "shape's size"; + PADDLE_ENFORCE_GE( + out_shape.size(), + A_shape.size(), + ::common::errors::InvalidArgument( + "broadcast_to's out_shape's size should be GreaterEqual " + "with the input shape's size")); return Compute( ToCinnExprs(out_shape), [=](const std::vector& indice) { std::vector broadcast_indice; - for (int idx = 0; idx < out_shape.size(); ++idx) { - ir::Expr a_shape_i = A_shape[idx]; + int out_A_offset = out_shape.size() - A_shape.size(); + for (int idx = out_A_offset; idx < out_shape.size(); ++idx) { + ir::Expr a_shape_i = A_shape[idx - out_A_offset]; if (MathEqual(a_shape_i, ir::Expr(1))) { broadcast_indice.push_back(ir::Expr(0)); } else if (MathEqual(a_shape_i, out_shape[idx])) {