diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index a898965f1f7025..040fbb28377115 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -287,20 +287,30 @@ std::vector> IfOp::Vjp( void WhileOp::Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT pir::Value cond, - const std::vector &inputs) { + const std::vector &inputs, + bool construct_body) { argument.AddInput(cond); argument.AddInputs(inputs); - auto &body = argument.AddRegion().emplace_back(); std::vector outs_stop_gradient; - for (auto val : inputs) { - argument.AddOutput(val.type()); - auto arg = body.AddArgument(val.type()); - - auto bool_attr = val.attribute(kStopGradientAttrName); - arg.set_attribute(kStopGradientAttrName, - bool_attr ? bool_attr : builder.bool_attr(false)); - outs_stop_gradient.push_back(bool_attr ? bool_attr - : builder.bool_attr(false)); + if (construct_body) { + auto &body = argument.AddRegion().emplace_back(); + for (auto val : inputs) { + argument.AddOutput(val.type()); + auto arg = body.AddArgument(val.type()); + auto bool_attr = val.attribute(kStopGradientAttrName); + outs_stop_gradient.push_back(bool_attr ? bool_attr + : builder.bool_attr(false)); + arg.set_attribute(kStopGradientAttrName, + bool_attr ? bool_attr : builder.bool_attr(false)); + } + } else { + argument.AddRegion(nullptr); + for (auto val : inputs) { + argument.AddOutput(val.type()); + auto bool_attr = val.attribute(kStopGradientAttrName); + outs_stop_gradient.push_back(bool_attr ? bool_attr + : builder.bool_attr(false)); + } } argument.AddAttribute( @@ -343,6 +353,96 @@ void WhileOp::Print(pir::IrPrinter &printer) { os << "\n }"; } +void WhileOp::VerifySig() { + VLOG(4) << "Start Verifying inputs, outputs and attributes for: WhileOp."; + auto input_size = num_operands(); + PADDLE_ENFORCE_GE( + input_size, + 1u, + phi::errors::PreconditionNotMet( + "The size %d of inputs must be greater or equal to 1.", input_size)); + + if (auto cond_type = operand_type(0).dyn_cast()) { + PADDLE_ENFORCE_EQ( + cond_type.dtype().isa(), + true, + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th input, it should be a " + "bool DenseTensorType.")); + } else if (auto cond_type = + operand_type(0).dyn_cast()) { + PADDLE_ENFORCE_EQ( + cond_type.dtype().isa(), + true, + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th input, it should be a " + "bool DenseTensorType.")); + } else { + PADDLE_THROW(phi::errors::PreconditionNotMet( + "Currently, the while op cond input only support bool dense_tensor " + "and bool allocated_dense_tensor.")); + } + PADDLE_ENFORCE_EQ((*this)->num_regions(), + 1u, + phi::errors::PreconditionNotMet( + "The size %d of regions must be equal to 1.", + (*this)->num_regions())); + auto output_size = num_results(); + PADDLE_ENFORCE_EQ(output_size + 1, + input_size, + phi::errors::PreconditionNotMet( + "The result size (%d) not equal to input size(%d) + 1.", + num_results(), + input_size)); + for (size_t index = 0; index < output_size; ++index) { + PADDLE_ENFORCE_EQ( + operand_type(index + 1), + result_type(index), + phi::errors::PreconditionNotMet( + "The (%d) result and operand type is not equal.", index)); + } +} + +void WhileOp::VerifyRegion() { + VLOG(4) << "Start verifying sub regions for: WhileOp."; + PADDLE_ENFORCE_EQ( + (*this)->region(0).size(), + 1u, + phi::errors::PreconditionNotMet("The size %d of body_region must be 1.", + (*this)->region(0).size())); + auto &body_block = body(); + auto output_size = num_results(); + PADDLE_ENFORCE_EQ( + body_block.args_size(), + output_size, + phi::errors::PreconditionNotMet( + "The result size (%d) not equal to block args size(%d) + 1.", + output_size, + body_block.args_size())); + + PADDLE_ENFORCE_EQ( + body_block.empty(), + false, + phi::errors::PreconditionNotMet("The body block is empty.")); + + auto yield_op = body_block.back().dyn_cast(); + auto input_size = num_operands(); + PADDLE_ENFORCE_EQ( + yield_op && yield_op.num_operands() == input_size, + true, + phi::errors::PreconditionNotMet( + "The body block yield size not equal to operands size.")); + // Todo: fix other bugs and make the following code work. + // for (size_t index = 0; index < input_size; ++index) { + // PADDLE_ENFORCE_EQ( + // operand_type(index), + // yield_op.operand_type(index), + // phi::errors::PreconditionNotMet( + // "The (%d) operand and block yield type is not equal.", index)); + // } + VLOG(4) << "Successful end verifying sub regions for: WhileOp."; +} + std::vector> WhileOp::Vjp( pir::Operation *op, const std::vector> &inputs, diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h index baffcadc127184..3c86d56d116165 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h @@ -77,13 +77,14 @@ class WhileOp : public pir::Op { static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT pir::Value cond, - const std::vector &inputs); + const std::vector &inputs, + bool construct_body = true); TEST_API pir::Block &body(); pir::Value cond(); const pir::Block::ArgListType &block_args() { return body().args(); } void Print(pir::IrPrinter &printer); // NOLINT - void VerifySig() {} - void VerifyRegion() {} + void VerifySig(); + void VerifyRegion(); static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 8cd6375dbe7b64..7b5959a542e7af 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -35,8 +35,7 @@ OperatorDialect::OperatorDialect(pir::IrContext *ctx) ctx->GetOrRegisterDialect<::pir::ControlFlowDialect>(); auto info = ctx->GetRegisteredOpInfo(pir::TuplePushOp::name()); info.AttachInterface(std::move( - pir::InterfaceValue:: - Get())); + pir::InterfaceValue::Get())); } void OperatorDialect::initialize() { diff --git a/paddle/fluid/pybind/control_flow_api.cc b/paddle/fluid/pybind/control_flow_api.cc index 2979d944e0bbf4..f66c3d3ccfddbb 100644 --- a/paddle/fluid/pybind/control_flow_api.cc +++ b/paddle/fluid/pybind/control_flow_api.cc @@ -39,6 +39,8 @@ using paddle::dialect::AssertOp; using paddle::dialect::HasElementsOp; using paddle::dialect::IfOp; using paddle::dialect::WhileOp; +using paddle::pybind::PyIfOp; +using paddle::pybind::PyWhileOp; using pir::Block; using pir::Builder; using pir::Operation; @@ -50,8 +52,6 @@ using pir::Type; using pir::Value; using pir::YieldOp; using pybind11::return_value_policy; - -using paddle::pybind::PyIfOp; namespace { void BindIfOp(py::module* m) { @@ -78,22 +78,24 @@ void BindIfOp(py::module* m) { } void BindWhileOp(py::module* m) { - m->def("build_while_op", [](Value cond, py::list loop_vars) { + m->def("build_while_op", [](Value cond, py::list loop_vars) -> PyWhileOp { std::vector loop_values; for (auto var : loop_vars) { loop_values.push_back(var.cast()); } - return ApiBuilder::Instance().GetBuilder()->Build(cond, - loop_values); + return PyWhileOp( + ApiBuilder::Instance().GetBuilder()->Build(cond, loop_values)); }); - py::class_ while_op(*m, "WhileOp", R"DOC( + py::class_ while_op(*m, "WhileOp", R"DOC( WhileOp in python api. )DOC"); - while_op.def("body", &WhileOp::body, return_value_policy::reference) - .def("as_operation", &WhileOp::operation, return_value_policy::reference) + while_op.def("body", &PyWhileOp::body, return_value_policy::reference) + .def( + "as_operation", &PyWhileOp::operation, return_value_policy::reference) .def("block_arguments", &WhileOp::block_args, - return_value_policy::reference); + return_value_policy::reference) + .def("optimize_update", &PyWhileOp::OptimizeUpdate); } void BindAssertOp(py::module* m) { @@ -227,7 +229,7 @@ PyIfOp::PyIfOp(IfOp if_op) : IfOp(if_op) { void PyIfOp::UpdateOutput() { PADDLE_ENFORCE_NOT_NULL( - *this, + operation_, paddle::platform::errors::InvalidArgument( "The if_op in PyIfOp used to update output can't be nullptr")); auto block = parent(); @@ -241,7 +243,68 @@ void PyIfOp::UpdateOutput() { cond(), true_region().TakeBack(), false_region().TakeBack()); block->Assign(iter, new_if_op); IfOp::operator=(new_if_op); - VerifyRegion(); + operation_->Verify(); +} + +PyWhileOp::PyWhileOp(WhileOp while_op) : WhileOp(while_op) { + PADDLE_ENFORCE_NOT_NULL( + operation_, + paddle::platform::errors::InvalidArgument( + "The while_op used to construct PyWhileOp can't be nullptr")); +} + +std::vector PyWhileOp::OptimizeUpdate() { + PADDLE_ENFORCE_NOT_NULL(operation_, + paddle::platform::errors::InvalidArgument( + "The while_op in PyWhileOp used to remove unused " + "loop vars can't be nullptr")); + auto parent_block = parent(); + PADDLE_ENFORCE_NOT_NULL( + parent_block, + paddle::platform::errors::InvalidArgument( + "The parent block of while_op which used to remove " + "unused loop vars can't be nullptr")); + + operation_->Verify(); + auto& body_block = body(); + auto yield_op = body_block.back().dyn_cast(); + auto operand_num = operation_->num_operands(); + bool no_change = true; + std::vector index_vec; + std::vector res, new_input, new_yield_val{yield_op.operand_source(0)}; + for (uint32_t i = 0; i < num_results(); ++i) { + res.push_back(result(i)); + } + for (size_t operand_index = 1u, arg_index = 0u; operand_index < operand_num; + ++operand_index) { + if (yield_op.operand_source(operand_index) == body_block.arg(arg_index)) { + body_block.arg(arg_index).ReplaceAllUsesWith( + operand_source(operand_index)); + body_block.EraseArgument(arg_index); + no_change = false; + res[operand_index - 1u] = operand_source(operand_index); + } else { + new_input.push_back(operand_source(operand_index)); + index_vec.push_back(operand_index - 1u); + new_yield_val.push_back(yield_op.operand_source(operand_index)); + ++arg_index; + } + } + if (no_change) return res; + Block::Iterator iter = **this; + Builder builder(ir_context(), false); + auto new_while_op = builder.Build(cond(), new_input, false); + new_while_op->region(0).swap(std::move(operation_->region(0))); + parent_block->Assign(iter, new_while_op); + WhileOp::operator=(new_while_op); + body_block.pop_back(); + builder.SetInsertionPointToBlockEnd(&body_block); + builder.Build(new_yield_val); + operation_->Verify(); + for (size_t result_index = 0; result_index < num_results(); ++result_index) { + res[index_vec[result_index]] = result(result_index); + } + return res; } void BindControlFlowApi(py::module* m) { diff --git a/paddle/fluid/pybind/control_flow_api.h b/paddle/fluid/pybind/control_flow_api.h index 18905bdc096787..020904a6d999dc 100644 --- a/paddle/fluid/pybind/control_flow_api.h +++ b/paddle/fluid/pybind/control_flow_api.h @@ -25,6 +25,22 @@ class PyIfOp : public dialect::IfOp { void UpdateOutput(); }; +class PyWhileOp : public dialect::WhileOp { + public: + explicit PyWhileOp(dialect::WhileOp while_op); + + /// + /// \brief Construct a new while_op to replace the original while_op. The + /// input, output, and parameters of the new while_op no longer contain the + /// variables that have not been modified in the loop. The size of the return + /// value is equal to the output size of the original while_op, where the + /// value of the read-only loop variable is the corresponding operand of the + /// original while_op, and the value of the non-read-only loop variable is the + /// corresponding output of the new while_op, + /// + std::vector OptimizeUpdate(); +}; + void BindControlFlowApi(pybind11::module *m); } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index bbd389c4886a3f..08b266364dfcb3 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -527,14 +527,8 @@ void BindOperation(py::module *m) { }) .def("as_if_op", [](Operation &self) { return PyIfOp(self.dyn_cast()); }) - .def("as_while_op", [](Operation &self) -> WhileOp { - auto while_op = self.dyn_cast(); - if (!while_op) { - PADDLE_THROW(phi::errors::InvalidArgument( - "Can't cast non-while type Operation to WhileOp.")); - } - return while_op; - }); + .def("as_while_op", + [](Operation &self) { return PyWhileOp(self.dyn_cast()); }); py::class_ block_container( *m, "Operation_BlockContainer", R"DOC( The Operation_BlockContainer only use to walk all blocks in the operation. diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 16da7fbc021280..e0c06f0f40e0a5 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1859,6 +1859,7 @@ void IncrementInferMeta(const MetaTensor& x, float value, MetaTensor* out) { product(x.dims()))); out->set_dims(x.dims()); out->share_lod(x); + out->set_layout(x.layout()); out->set_dtype(x.dtype()); } diff --git a/paddle/pir/core/block.cc b/paddle/pir/core/block.cc index 73902960c95ab7..49389454545d10 100644 --- a/paddle/pir/core/block.cc +++ b/paddle/pir/core/block.cc @@ -32,6 +32,12 @@ void Block::push_back(Operation *op) { insert(ops_.end(), op); } void Block::push_front(Operation *op) { insert(ops_.begin(), op); } +void Block::pop_back() { + IR_ENFORCE(!ops_.empty(), "can't pop back from empty block."); + ops_.back()->Destroy(); + ops_.pop_back(); +} + Operation *Block::GetParentOp() const { return parent_ ? parent_->GetParent() : nullptr; } @@ -50,8 +56,7 @@ Block::Iterator Block::erase(ConstIterator position) { void Block::clear() { while (!empty()) { - ops_.back()->Destroy(); - ops_.pop_back(); + pop_back(); } } @@ -103,6 +108,13 @@ Value Block::AddArgument(Type type) { return argument; } +void Block::EraseArgument(uint32_t index) { + auto argument = arg(index); + IR_ENFORCE(argument.use_empty(), + "Erase a block argument that is still in use."); + argument.dyn_cast().Destroy(); + arguments_.erase(arguments_.begin() + index); +} bool Block::TopoOrderCheck(const OpListType &op_list) { std::unordered_set visited_values; for (Operation *op : op_list) { diff --git a/paddle/pir/core/block.h b/paddle/pir/core/block.h index a912676f7fb684..373f97e12c51ef 100644 --- a/paddle/pir/core/block.h +++ b/paddle/pir/core/block.h @@ -69,6 +69,7 @@ class IR_API Block { void push_back(Operation *op); void push_front(Operation *op); + void pop_back(); Iterator insert(ConstIterator iterator, Operation *op); Iterator erase(ConstIterator position); void clear(); @@ -111,6 +112,7 @@ class IR_API Block { Type arg_type(uint32_t index) const { return arguments_[index].type(); } void ClearArguments(); Value AddArgument(Type type); + void EraseArgument(uint32_t index); template void AddArguments(TypeIter first, TypeIter last); template diff --git a/paddle/pir/core/interface_support.h b/paddle/pir/core/interface_support.h index f8fc83efa31720..60211a9437d7bb 100644 --- a/paddle/pir/core/interface_support.h +++ b/paddle/pir/core/interface_support.h @@ -39,8 +39,8 @@ class ConstructInterfacesOrTraits { /// Placement new interface. template static void ConstrctInterface(InterfaceSet &interface_set) { // NOLINT - InterfaceValue val = InterfaceValue:: - Get>(); + InterfaceValue val = + InterfaceValue::Get>(); auto suceess = interface_set.insert(std::move(val)).second; IR_ENFORCE(suceess, "Interface: id[%u] is already registered. inset failed", diff --git a/paddle/pir/core/interface_value.h b/paddle/pir/core/interface_value.h index 3115dc47a365e1..4c28e35c72ca22 100644 --- a/paddle/pir/core/interface_value.h +++ b/paddle/pir/core/interface_value.h @@ -22,7 +22,7 @@ namespace pir { class IR_API InterfaceValue { public: - template + template static InterfaceValue Get(); TypeId type_id() const { return type_id_; } void *model() const { return model_; } @@ -52,7 +52,7 @@ class IR_API InterfaceValue { void *model_{nullptr}; }; -template +template InterfaceValue InterfaceValue::Get() { InterfaceValue val; val.type_id_ = TypeId::get(); diff --git a/paddle/pir/core/region.cc b/paddle/pir/core/region.cc index 66e2e9d407f755..21a09198f1d791 100644 --- a/paddle/pir/core/region.cc +++ b/paddle/pir/core/region.cc @@ -70,6 +70,16 @@ void Region::clear() { } } +void Region::swap(Region &&other) { + blocks_.swap(other.blocks_); + for (auto iter = begin(); iter != end(); ++iter) { + iter->SetParent(this, iter); + } + for (auto iter = other.begin(); iter != other.end(); ++iter) { + iter->SetParent(&other, iter); + } +} + template void Region::Walk(FuncT &&callback) { for (auto &block : *this) { diff --git a/paddle/pir/core/region.h b/paddle/pir/core/region.h index 9a4675990c8156..c8d4daadaa74ca 100644 --- a/paddle/pir/core/region.h +++ b/paddle/pir/core/region.h @@ -55,7 +55,6 @@ class IR_API Region { Block &front() { return *blocks_.front(); } Block &back() { return *blocks_.back(); } - const Block &front() const { return *blocks_.front(); } const Block &back() const { return *blocks_.back(); } @@ -65,6 +64,7 @@ class IR_API Region { Iterator insert(ConstIterator position, Block *block); Iterator erase(ConstIterator position); void clear(); + void swap(Region &&other); /// Operation Walkers, walk the operations in this region. The callback method /// is called for each nested region, block or operation, @@ -77,7 +77,6 @@ class IR_API Region { void TakeBody(Region &&other); Operation *GetParent() const { return parent_; } - void set_parent(Operation *parent) { parent_ = parent; } // return the program which contains this region. // if region is not in a program, return nullptr. Program *parent_program() const; @@ -85,7 +84,7 @@ class IR_API Region { IrContext *ir_context() const; private: - Operation *parent_{nullptr}; // not owned - std::list blocks_; // owned + Operation *const parent_{nullptr}; // not owned + std::list blocks_; // owned }; } // namespace pir diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index 5ba3a14469d8ec..3d2f9858a1feb5 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -687,21 +687,23 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): if in_pir_mode(): while_op = build_while_op(pre_cond, flatten(loop_vars)) with while_op.body() as cur_block: - args = cur_block.args() - next_var = body(*args) + args = pack_sequence_as(loop_vars, cur_block.args()) + next_vars = body(*args) try: assert_same_structure( - flatten(next_var), flatten(loop_vars), check_types=False + flatten(next_vars), flatten(loop_vars), check_types=False ) except ValueError as e: raise ValueError( "body in while_loop should return the same arity " f"(length and structure) as loop_vars: {e}" ) - next_cond = cond(*next_var) + if not isinstance(next_vars, (list, tuple)): + next_vars = [next_vars] + next_cond = cond(*next_vars) next_cond.stop_gradient = True - cf_yield([next_cond, *next_var]) - return while_op.as_operation().results() + cf_yield([next_cond, *flatten(next_vars)]) + return pack_sequence_as(loop_vars, while_op.optimize_update()) if in_dygraph_mode(): now_cond = pre_cond.item() diff --git a/test/ir/pir/test_ir_pybind.py b/test/ir/pir/test_ir_pybind.py index fda8236020b4df..9ae4a3ebbf633e 100644 --- a/test/ir/pir/test_ir_pybind.py +++ b/test/ir/pir/test_ir_pybind.py @@ -42,7 +42,6 @@ def get_ir_program(): class TestPybind(unittest.TestCase): def test_program(self): pir_program = get_ir_program() - print(pir_program) block = pir_program.global_block() program = block.program @@ -152,7 +151,6 @@ def test_type(self): pir_program = get_ir_program() matmul_op = pir_program.global_block().ops[1] add_op = pir_program.global_block().ops[2] - print(matmul_op.result(0).type()) self.assertEqual( matmul_op.result(0).type() == add_op.result(0).type(), True ) @@ -184,7 +182,6 @@ def test_attr(self): ) pir_program = pir.translate_to_pir(main_program.desc) - print(pir_program) conv_attr = pir_program.global_block().ops[3].attrs() full_attr = pir_program.global_block().ops[8].attrs() self.assertEqual(conv_attr["stop_gradient"], [False]) diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index cc07cdbb58ad66..1a5ee3186d692a 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -57,7 +57,7 @@ def test_while_base(self): out = last_op.results() self.assertEqual(out[0].stop_gradient, False) self.assertEqual(last_op.name(), "pd_op.while") - self.assertEqual(len(out), 2) + self.assertEqual(len(out), 1) def test_get_used_external_value(self): main_program = paddle.static.Program() @@ -177,20 +177,20 @@ def test_backward(self): ) self.assertEqual( main_program.global_block() - .ops[-1] + .ops[-3] .as_while_op() .body() - .ops[-2] + .ops[-4] .name(), "cf.has_elements", ) self.assertEqual( main_program.global_block() - .ops[-1] + .ops[-3] .as_while_op() .body() - .ops[-3] + .ops[-5] .name(), "pd_op.add_grad", ) diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index 534d5fa42e7e37..6a8adb425a775c 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -22,7 +22,6 @@ from paddle import base from paddle.base import core from paddle.base.backward import append_backward -from paddle.base.framework import program_guard from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -98,6 +97,7 @@ def body(i, mem): np.testing.assert_allclose(np.asarray(res[1]), data, rtol=1e-05) @compare_legacy_with_pt + @test_with_pir_api def test_var_dict(self): def cond(i, ten, test_dict, test_list, test_list_dict): return paddle.less_than(i, ten) @@ -118,7 +118,7 @@ def body(i, ten, test_dict, test_list, test_list_dict): main_program = paddle.static.Program() startup_program = paddle.static.Program() - with program_guard(main_program, startup_program): + with paddle.static.program_guard(main_program, startup_program): i = paddle.zeros(shape=[1], dtype='int64') ten = paddle.tensor.fill_constant( shape=[1], dtype='int64', value=10 @@ -130,7 +130,7 @@ def body(i, ten, test_dict, test_list, test_list_dict): test_dict = {"test_key": test_data} test_list = [ paddle.tensor.fill_constant( - shape=[1, 2], dtype='int64', value=0 + shape=[2, 1], dtype='int64', value=0 ) ] test_list_dict = [