diff --git a/paddle2onnx/op_mapper/__init__.py b/paddle2onnx/op_mapper/__init__.py index 5c219ec94..fdad53f29 100644 --- a/paddle2onnx/op_mapper/__init__.py +++ b/paddle2onnx/op_mapper/__init__.py @@ -36,4 +36,3 @@ from .custom_paddle_op import collect_fpn_proposals from .custom_paddle_op import distribute_fpn_proposals from .custom_paddle_op import box_clip -from .custom_paddle_op import fill_constant_batch_size_like diff --git a/paddle2onnx/op_mapper/activation.py b/paddle2onnx/op_mapper/activation.py index 6f590e5db..65c369d48 100644 --- a/paddle2onnx/op_mapper/activation.py +++ b/paddle2onnx/op_mapper/activation.py @@ -52,6 +52,25 @@ def opset_1(cls, graph, node, **kw): alpha=node.attr('alpha')) +@op_mapper('softplus') +class Softplus(): + support_opset_verison_range = (1, 12) + + @classmethod + def opset_1(cls, graph, node, **kw): + beta = node.attr('beta') + threshold = node.attr('threshold') + if np.isclose(beta, 1.0, 1e-06, 1e-06) and \ + np.isclose(threshold, 20.0, 1e-06, 1e-06): + onnx_node = graph.make_node( + 'Softplus', + inputs=[node.input('X')[0]], + outputs=node.output('Out')) + else: + raise Exception("[ERROR] Operator softplus " \ + "only supported while beta==1.0 and threshold==20.0") + + @op_mapper('prelu') class PRelu(): support_opset_verison_range = (9, 13) diff --git a/paddle2onnx/op_mapper/custom_paddle_op/fill_constant_batch_size_like.py b/paddle2onnx/op_mapper/custom_paddle_op/fill_constant_batch_size_like.py deleted file mode 100644 index eaa5eab09..000000000 --- a/paddle2onnx/op_mapper/custom_paddle_op/fill_constant_batch_size_like.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import - -import numpy as np -import paddle -from paddle.fluid import layers -from paddle2onnx.constant import dtypes -from paddle2onnx.op_mapper import CustomPaddleOp, register_custom_paddle_op - - -class FillConstantBatchSizeLike(CustomPaddleOp): - def __init__(self, node, **kw): - super(FillConstantBatchSizeLike, self).__init__(node) - - def forward(self): - input = self.input('Input', 0) - input_shape = paddle.shape(input) - updates = input_shape[self.node.attr('input_dim_idx')] - shape = paddle.assign(np.array(self.node.attr('shape')).astype('int32')) - dims = len(self.node.attr('shape')) - new_shape = paddle.concat([ - shape[:self.node.attr('output_dim_idx')], updates, - shape[self.node.attr('output_dim_idx') + 1:dims] - ]) - dtype = dtypes.DTYPE_PADDLE_STR_MAP[self.node.attr('dtype')] - out = paddle.full(new_shape, self.node.attr('value'), dtype) - return {'Out': [out]} - - -register_custom_paddle_op('fill_constant_batch_size_like', - FillConstantBatchSizeLike) diff --git a/paddle2onnx/op_mapper/detection/multiclass_nms.py b/paddle2onnx/op_mapper/detection/multiclass_nms.py index ea6fad3d4..86e4cc260 100644 --- a/paddle2onnx/op_mapper/detection/multiclass_nms.py +++ b/paddle2onnx/op_mapper/detection/multiclass_nms.py @@ -77,7 +77,7 @@ def nms(cls, graph, node, scores, bboxes, class_id=None): iou_threshold = 0.5 logging.warning( "Operator:{} is not supported completely, so we use traditional" - " NMS (iou_theshold={}) to instead it, which introduce some difference.". + " NMS (nms_theshold={}) to instead it, which introduce some difference.". format(node.type, str(iou_threshold))) else: iou_threshold = node.attr('nms_threshold') @@ -162,14 +162,12 @@ def keep_top_k(cls, if background == 0: nonzero = graph.make_node('NonZero', inputs=[squeezed_class_id]) else: - thresh = graph.make_node( - 'Constant', inputs=[], dtype=dtypes.ONNX.INT32, value=[-1]) - - cast = graph.make_node('Cast', inputs=[squeezed_class_id], to=6) - - greater = graph.make_node('Greater', inputs=[cast, thresh]) - - nonzero = graph.make_node('NonZero', inputs=[greater]) + filter_cls_id = graph.make_node( + 'Constant', dtype=dtypes.ONNX.INT32, value=[background]) + cast = graph.make_node( + 'Cast', inputs=[squeezed_class_id], to=dtypes.ONNX.INT32) + filter_index = graph.make_node('Sub', inputs=[cast, filter_cls_id]) + nonzero = graph.make_node('NonZero', inputs=[filter_index]) class_id = graph.make_node('Gather', inputs=[class_id, nonzero], axis=0) @@ -295,7 +293,9 @@ def keep_top_k(cls, axes=[0]) if node.type in ['matrix_nms', 'multiclass_nms3']: select_bboxes_shape = graph.make_node( - 'Shape', inputs=[final_indices]) + 'Shape', inputs=[concat_final_results]) + select_bboxes_shape1 = graph.make_node( + 'Cast', inputs=[select_bboxes_shape], to=dtypes.ONNX.INT32) indices = graph.make_node( 'Constant', dtype=dtypes.ONNX.INT64, value=[0]) rois_num = None @@ -306,5 +306,5 @@ def keep_top_k(cls, if rois_num is not None: graph.make_node( "Gather", - inputs=[select_bboxes_shape, indices], + inputs=[select_bboxes_shape1, indices], outputs=rois_num) diff --git a/paddle2onnx/op_mapper/logic.py b/paddle2onnx/op_mapper/logic.py index b86898b16..25a0d9a6a 100644 --- a/paddle2onnx/op_mapper/logic.py +++ b/paddle2onnx/op_mapper/logic.py @@ -31,6 +31,30 @@ def opset_12(cls, graph, node, **kw): outputs=node.output('Out')) +@op_mapper('equal') +class Equal(): + support_opset_verison_range = (12, ) + + @classmethod + def opset_1(cls, graph, node, **kw): + onnx_node = graph.make_node( + 'Equal', + inputs=[node.input('X', 0), node.input('Y', 0)], + outputs=node.output('Out')) + + +@op_mapper('greater_than') +class GreaterThan(): + support_opset_verison_range = (1, ) + + @classmethod + def opset_1(cls, graph, node, **kw): + onnx_node = graph.make_node( + 'Greater', + inputs=[node.input('X', 0), node.input('Y', 0)], + outputs=node.output('Out')) + + @op_mapper('logical_and') class LogicalAnd(): support_opset_verison_range = (1, ) diff --git a/paddle2onnx/op_mapper/math.py b/paddle2onnx/op_mapper/math.py index b3123f520..3693aa267 100644 --- a/paddle2onnx/op_mapper/math.py +++ b/paddle2onnx/op_mapper/math.py @@ -88,9 +88,14 @@ def opset_1(cls, graph, node, **kw): @op_mapper( [ - 'elementwise_add', 'elementwise_sub', 'elementwise_div', - 'elementwise_mul', 'elementwise_min', 'elementwise_max', - 'elementwise_pow' + 'elementwise_add', + 'elementwise_sub', + 'elementwise_div', + 'elementwise_mul', + 'elementwise_min', + 'elementwise_max', + 'elementwise_pow', + 'elementwise_mod', ], mapper_dict={ 'elementwise_add': 'Add', @@ -100,12 +105,13 @@ def opset_1(cls, graph, node, **kw): 'elementwise_min': 'Min', 'elementwise_max': 'Max', 'elementwise_pow': 'Pow', + 'elementwise_mod': 'Mod', }) class ElementwiseOps(): support_opset_version_range = (7, 12) @classmethod - def opset_7(cls, graph, node, **kw): + def opset_9(cls, graph, node, **kw): op_type = kw['mapper_dict'][node.type] axis = node.attr('axis') x = node.input('X', 0) @@ -130,6 +136,51 @@ def opset_7(cls, graph, node, **kw): op_type, inputs=[x, y_node], outputs=node.output('Out')) +@op_mapper('elementwise_floordiv') +class ElementWiseFloorDiv(): + support_opset_version_range = (11, 12) + + @classmethod + def opset_7(cls, graph, node, **kw): + x = node.input('X', 0) + y = node.input('Y', 0) + axis = node.attr('axis') + x_shape = node.input_shape('X', 0) + y_shape = node.input_shape('Y', 0) + x_dtype = node.input_dtype('X', 0) + y_dtype = node.input_dtype('Y', 0) + x_dtype = dtypes.DTYPE_PADDLE_STR_MAP[x_dtype] + y_dtype = dtypes.DTYPE_PADDLE_STR_MAP[y_dtype] + is_int = False + if x_dtype.count('int') > 0 and y_dtype.count('int') > 0: + is_int = True + if axis == -1 or axis == (len(x_shape) - 1 + ) or len(x_shape) == len(y_shape): + if is_int: + graph.make_node( + 'Div', inputs=[x, y], outputs=node.output('Out')) + else: + div_node = graph.make_node('Div', inputs=[x, y]) + graph.make_node( + 'Floor', inputs=[div_node], outputs=node.output('Out')) + else: + broadcast_shape = [1] * len(x_shape) + broadcast_shape[axis:axis + len(y_shape)] = y_shape + broadcast_shape_node = graph.make_node( + 'Constant', + dtype=dtypes.ONNX.INT64, + value=list(broadcast_shape)) + y_node = graph.make_node( + 'Reshape', inputs=[y, broadcast_shape_node]) + if is_int: + div_node = graph.make_node( + 'Div', inputs=[x, y_node], outputs=node.output('Out')) + else: + div_node = graph.make_node('Div', inputs=[x, y_node]) + graph.make_node( + 'Floor', inputs=[div_node], outputs=node.output('Out')) + + @op_mapper('pow') class Pow(): support_opset_version_range = (8, 12) @@ -233,6 +284,65 @@ def opset_1(cls, graph, node, **kw): 'MatMul', inputs=[x, y], outputs=node.output('Out')) +@op_mapper('p_norm') +class PNorm(): + support_opset_version_range = (1, 12) + + @classmethod + def opset_1(cls, graph, node, **kw): + x = node.input('X', 0) + axis = node.attr('axis') + p = node.attr('porder') + keepdim = node.attr('keepdim') + epsilon = node.attr('epsilon') + assert axis == 1, "Only axis == 1 is supported for p_norm" + if p == 1 or p == 2 and not keepdim: + graph.make_node( + 'LpNormalization', + inputs=[x], + outputs=node.output('Out'), + axis=1, + p=p) + else: + pnode = graph.make_node( + 'Constant', dtype=dtypes.ONNX.FLOAT, value=[p]) + mul = graph.make_node('Pow', inputs=[x, pnode]) + reduce_sum = graph.make_node( + 'ReduceSum', inputs=[mul], axes=[1], keepdims=keepdim) + pnode1 = graph.make_node( + 'Constant', dtype=dtypes.ONNX.FLOAT, value=[1.0 / p]) + graph.make_node( + 'Pow', inputs=[reduce_sum, pnode1], outputs=node.output('Out')) + + @classmethod + def opset_13(cls, graph, node, **kw): + x = node.input('X', 0) + axis = node.attr('axis') + p = node.attr('porder') + keepdim = node.attr('keepdim') + epsilon = node.attr('epsilon') + assert axis == 1, "Only axis == 1 is supported for p_norm" + if (p == 1 or p == 2) and not keepdim: + graph.make_node( + 'LpNormalization', + inputs=[x], + outputs=node.output('Out'), + axis=1, + p=p) + else: + pnode = graph.make_node( + 'Constant', dtype=dtypes.ONNX.FLOAT, value=[p]) + mul = graph.make_node('Pow', inputs=[x, pnode]) + axes = graph.make_node( + 'Constant', dtype=dtypes.ONNX.INT64, value=[1]) + reduce_sum = graph.make_node( + 'ReduceSum', inputs=[mul, axes], keepdims=keepdim) + pnode1 = graph.make_node( + 'Constant', dtype=dtypes.ONNX.FLOAT, value=[1.0 / p]) + graph.make_node( + 'Pow', inputs=[reduce_sum, pnode1], outputs=node.output('Out')) + + @op_mapper('sum') class Sum(): support_opset_version_range = (1, 12) @@ -325,6 +435,66 @@ def opset_1(cls, graph, node, **kw): 'keepdims': 0}) +# +#@op_mapper('scale') +#class Scale(): +# support_opset_version_range = (1, 12) +# +# @classmethod +# def opset_1(cls, graph, node, **kw): +# scale = node.attr('scale') +# bias = node.attr('bias') +# if np.fabs(scale - 1.0) < 1e-06 and np.fabs(bias - 0.0) < 1e-06: +# graph.make_node( +# 'Identity', inputs=node.input('X'), outputs=node.output('Out')) +# else: +# raise Exception( +# "please try to convert OP:scale with opset_version >= 7.") +# +# @classmethod +# def opset_7(cls, graph, node, **kw): +# scale = node.attr('scale') +# bias = node.attr('bias') +# if np.fabs(scale - 1.0) < 1e-06 and np.fabs(bias - 0.0) < 1e-06: +# graph.make_node( +# 'Identity', inputs=node.input('X'), outputs=node.output('Out')) +# else: +# cast_node = graph.make_node( +# 'Cast', inputs=node.input('X'), +# attrs={'to': dtypes.ONNX.FLOAT}) +# if np.fabs(scale - 1.0) < 1e-06: +# bias_node = graph.make_node( +# 'Constant', +# attrs={'dtype': dtypes.ONNX.FLOAT, +# 'value': [bias]}) +# graph.make_node('Add', inputs=[cast_node, bias_node], outputs=node.output('Out')) +# elif np.fabs(bias - 1.0) < 1e-06: +# scale_node = graph.make_node( +# 'Constant', +# attrs={'dtype': dtypes.ONNX.FLOAT, +# 'value': [scale]}) +# graph.make_node('Mul', inputs=[cast_node, scale_node], outputs=node.output('Out')) +# else: +# scale_node = graph.make_node( +# 'Constant', +# attrs={'dtype': dtypes.ONNX.FLOAT, +# 'value': [scale]}) +# bias_node = graph.make_node( +# 'Constant', +# attrs={'dtype': dtypes.ONNX.FLOAT, +# 'value': [bias]}) +# if node.attr('bias_after_scale'): +# node1 = graph.make_node('Mul', inputs=[cast_node, scale_node]) +# node2 = graph.make_node( +# 'Add', +# inputs=[node1, bias_node], +# outputs=node.output('Out')) +# else: +# node1 = graph.make_node('Add', inputs=[cast_node, bias_node]) +# node2 = graph.make_node( +# 'Mul', +# inputs=[node1, scale_node], +# outputs=[node.output('Out', 0)]) @op_mapper('scale') class Scale(): support_opset_version_range = (1, 12) diff --git a/paddle2onnx/op_mapper/nn.py b/paddle2onnx/op_mapper/nn.py index 0269ed13f..b2604d016 100644 --- a/paddle2onnx/op_mapper/nn.py +++ b/paddle2onnx/op_mapper/nn.py @@ -57,7 +57,7 @@ def opset_1(cls, graph, node, **kw): attrs=attrs) -@op_mapper('conv2d_transpose') +@op_mapper(['conv2d_transpose', 'depthwise_conv2d_transpose']) class ConvTranspose(): support_opset_verison_range = (1, 12) @@ -71,7 +71,7 @@ def opset_1(cls, graph, node, **kw): dilations=node.attr('dilations'), kernel_shape=kernel_shape[-2:], strides=node.attr('strides'), - group=1, + group=node.attr('groups'), pads=node.attr('paddings') + node.attr('paddings')) @@ -291,6 +291,79 @@ def opset_1(cls, graph, node, **kw): **onnx_attr) +@op_mapper('group_norm') +class GroupNorm(): + support_opset_verison_range = (1, 12) + + @classmethod + def opset_13(cls, graph, node, **kw): + num_groups = node.attr('groups') + epsilon = node.attr('epsilon') + ipt = node.input('X')[0] + + ipt_shape = node.input_shape('X', 0) + assert len( + ipt_shape) == 4, "Only support 4D-Tensor as input for GroupNorm" + + scale = node.input('Scale')[0] + bias = node.input('Bias')[0] + shape = graph.make_node( + 'Constant', dtype=dtypes.ONNX.INT64, value=[0, num_groups, -1]) + reshape_input = graph.make_node('Reshape', inputs=[ipt, shape]) + scale_ = graph.make_node( + 'Constant', dtype=dtypes.ONNX.FLOAT, value=[1.0] * num_groups) + bias_ = graph.make_node( + 'Constant', dtype=dtypes.ONNX.FLOAT, value=[0.0] * num_groups) + reshaped_output = graph.make_node( + 'InstanceNormalization', + inputs=[reshape_input, scale_, bias_], + epsilon=epsilon) + origin_shape = graph.make_node('Shape', inputs=[ipt]) + output = graph.make_node( + 'Reshape', inputs=[reshaped_output, origin_shape]) + axes = graph.make_node( + 'Constant', dtype=dtypes.ONNX.INT64, value=[1, 2]) + unsqueezed_scale = graph.make_node('Unsqueeze', inputs=[scale, axes]) + unsqueezed_bias = graph.make_node('Unsqueeze', inputs=[bias, axes]) + part0 = graph.make_node('Mul', inputs=[output, unsqueezed_scale]) + graph.make_node( + 'Add', inputs=[part0, unsqueezed_bias], outputs=node.output('Y')) + + @classmethod + def opset_11(cls, graph, node, **kw): + num_groups = node.attr('groups') + epsilon = node.attr('epsilon') + ipt = node.input('X')[0] + + ipt_shape = node.input_shape('X', 0) + assert len( + ipt_shape) == 4, "Only support 4D-Tensor as input for GroupNorm" + + scale = node.input('Scale')[0] + bias = node.input('Bias')[0] + shape = graph.make_node( + 'Constant', dtype=dtypes.ONNX.INT64, value=[0, num_groups, -1]) + reshape_input = graph.make_node('Reshape', inputs=[ipt, shape]) + scale_ = graph.make_node( + 'Constant', dtype=dtypes.ONNX.FLOAT, value=[1.0] * num_groups) + bias_ = graph.make_node( + 'Constant', dtype=dtypes.ONNX.FLOAT, value=[0.0] * num_groups) + reshaped_output = graph.make_node( + 'InstanceNormalization', + inputs=[reshape_input, scale_, bias_], + epsilon=epsilon) + origin_shape = graph.make_node('Shape', inputs=[ipt]) + output = graph.make_node( + 'Reshape', inputs=[reshaped_output, origin_shape]) + unsqueezed_scale = graph.make_node( + 'Unsqueeze', inputs=[scale], axes=[1, 2]) + unsqueezed_bias = graph.make_node( + 'Unsqueeze', inputs=[bias], axes=[1, 2]) + part0 = graph.make_node('Mul', inputs=[output, unsqueezed_scale]) + graph.make_node( + 'Add', inputs=[part0, unsqueezed_bias], outputs=node.output('Y')) + + @op_mapper('instance_norm') class InstanceNorm(): support_opset_verison_range = (1, 12) diff --git a/paddle2onnx/op_mapper/search.py b/paddle2onnx/op_mapper/search.py index dc3c92998..15d8ed03b 100644 --- a/paddle2onnx/op_mapper/search.py +++ b/paddle2onnx/op_mapper/search.py @@ -57,7 +57,7 @@ def opset_11(cls, graph, node, **kw): k = node.attr('k') k_node = graph.make_node( 'Constant', attrs={'dtype': dtypes.ONNX.INT64, - 'value': k}) + 'value': [k]}) graph.make_node( 'TopK', inputs=[node.input('X', 0), k_node], diff --git a/paddle2onnx/op_mapper/tensor.py b/paddle2onnx/op_mapper/tensor.py index 6545a7780..6db5cc8f5 100644 --- a/paddle2onnx/op_mapper/tensor.py +++ b/paddle2onnx/op_mapper/tensor.py @@ -111,17 +111,25 @@ class ExpandV2(): @classmethod def opset_8(cls, graph, node, **kw): - if len(node.input('Shape')) > 0: - shape = mapper_helper.cast(graph, - node.input('Shape', 0), - node.input_dtype('Shape', 0), 'int64') - elif len(node.attr('shape')) > 0: - shape = node.attr('shape') - for idx in range(len(shape)): - if shape[idx] == -1: - shape[idx] = 1 + if node.input('expand_shapes_tensor') is not None and len( + node.input('expand_shapes_tensor')) > 0: shape = graph.make_node( - 'Constant', dtype=dtypes.ONNX.INT64, value=shape) + 'Concat', inputs=node.input('expand_shapes_tensor'), axis=-1) + shape = graph.make_node( + 'Cast', inputs=[shape], to=dtypes.ONNX.INT64) + else: + if len(node.input('Shape')) > 0: + shape = mapper_helper.cast(graph, + node.input('Shape', 0), + node.input_dtype('Shape', 0), + 'int64') + elif len(node.attr('shape')) > 0: + shape = node.attr('shape') + for idx in range(len(shape)): + if shape[idx] == -1: + shape[idx] = 1 + shape = graph.make_node( + 'Constant', dtype=dtypes.ONNX.INT64, value=shape) node = graph.make_node( 'Expand', inputs=[node.input('X', 0), shape], @@ -377,10 +385,11 @@ def opset_1(cls, graph, node, **kw): value = np.ones(shape) * value value = value.astype(dtypes.DTYPE_PADDLE_NUMPY_MAP[dtype]) value = value.flatten().tolist() - if len(shape) ==0 and len(node.input('ShapeTensor')) > 0: - shape_tensor = mapper_helper.cast(graph, - node.input('ShapeTensor', 0), - node.input_dtype('ShapeTensor', 0), 'int64') + if len(shape) == 0 and len(node.input('ShapeTensor')) > 0: + shape_tensor = mapper_helper.cast( + graph, + node.input('ShapeTensor', 0), + node.input_dtype('ShapeTensor', 0), 'int64') graph.make_node( 'ConstantOfShape', inputs=shape_tensor, @@ -423,33 +432,78 @@ class FillConstantBatchSizeLike(): support_opset_verison_range = (9, 12) @classmethod - def opset_11(cls, graph, node, **kw): - input_dim_idx = tensor_shape = graph.make_node( + def opset_10(cls, graph, node, **kw): + out_shape = node.attr('shape') + input_dim_idx = node.attr('input_dim_idx') + output_dim_idx = node.attr('output_dim_idx') + + del out_shape[output_dim_idx] + out_shape.insert(0, 1) + + dtype = dtypes.DTYPE_PADDLE_ONNX_MAP[node.attr('dtype')] + value = node.attr('value') + input_shape = node.input_shape('Input', 0) + + constant = graph.make_node( 'Constant', - dtype=dtypes.ONNX.INT64, - dims=[1], - value=node.attr('input_dim_idx')) - output_dim_idx = tensor_shape = graph.make_node( + dtype=dtype, + dims=out_shape, + value=[value] * np.prod(out_shape)) + + shape = graph.make_node('Shape', inputs=node.input('Input')) + start = graph.make_node( + 'Constant', dtype=dtypes.ONNX.INT64, value=[input_dim_idx]) + end = graph.make_node( + 'Constant', dtype=dtypes.ONNX.INT64, value=[input_dim_idx + 1]) + batch = graph.make_node('Slice', inputs=[shape, start, end]) + repeat = graph.make_node( 'Constant', dtype=dtypes.ONNX.INT64, - dims=[1], - value=node.attr('output_dim_idx')) - input_shape = graph.make_node('Shape', inputs=node.input('Input')) - updates = graph.make_node('Gather', inputs=[input_shape, input_dim_idx]) - tensor_shape = tensor_shape = graph.make_node( - 'Constant', - attrs={'dtype': dtypes.ONNX.INT64, - 'value': node.attr('shape')}) - tensor_shape = graph.make_node( - 'ScatterND', inputs=[tensor_shape, output_dim_idx, updates]) - dtype = dtypes.DTYPE_PADDLE_ONNX_MAP[node.attr('dtype')] - graph.make_node( - 'ConstantOfShape', - inputs=[tensor_shape], - outputs=node.output('Out'), - dims=[1], - dtype=dtype, - value=node.attr('value')) + value=[1] * (len(out_shape) - 1)) + repeat = graph.make_node('Concat', inputs=[batch, repeat], axis=-1) + if output_dim_idx == 0: + graph.make_node( + 'Tile', inputs=[constant, repeat], outputs=node.output('Out')) + else: + out = graph.make_node('Tile', inputs=[constant, repeat]) + perm = list(range(len(out_shape))) + perm[0] = output_dim_idx + perm[output_dim_idx] = 0 + graph.make_node( + 'Transpose', + inputs=[out], + perm=perm, + outputs=node.output('Out')) + + +# @classmethod +# def opset_11(cls, graph, node, **kw): +# input_dim_idx = tensor_shape = graph.make_node( +# 'Constant', +# dtype=dtypes.ONNX.INT64, +# dims=[1], +# value=node.attr('input_dim_idx')) +# output_dim_idx = tensor_shape = graph.make_node( +# 'Constant', +# dtype=dtypes.ONNX.INT64, +# dims=[1], +# value=node.attr('output_dim_idx')) +# input_shape = graph.make_node('Shape', inputs=node.input('Input')) +# updates = graph.make_node('Gather', inputs=[input_shape, input_dim_idx]) +# tensor_shape = tensor_shape = graph.make_node( +# 'Constant', +# attrs={'dtype': dtypes.ONNX.INT64, +# 'value': node.attr('shape')}) +# tensor_shape = graph.make_node( +# 'ScatterND', inputs=[tensor_shape, output_dim_idx, updates]) +# dtype = dtypes.DTYPE_PADDLE_ONNX_MAP[node.attr('dtype')] +# graph.make_node( +# 'ConstantOfShape', +# inputs=[tensor_shape], +# outputs=node.output('Out'), +# dims=[1], +# dtype=dtype, +# value=node.attr('value')) @op_mapper('fill_any_like') @@ -540,7 +594,7 @@ def opset_1(cls, graph, node, **kw): if value is None or value.size < 1: value = np.array(node.attr('int32_values')) if value is None or value.size < 1: - value = np.array(node.attr('int64_value')) + value = np.array(node.attr('int64_values')) parameter = { 'data': value, 'dtype': node.attr('dtype'),