@@ -471,7 +471,7 @@ def infer_shape(self, in_shape):
471471 List of aux shapes calculated from in_shape,
472472 in the same order as declared in list_auxiliary_states.
473473 """
474- return in_shape , [ in_shape [0 ]], []
474+ return in_shape , ( in_shape [0 ],) * len ( self . list_outputs ()), ()
475475
476476 def infer_type (self , in_type ):
477477 """infer_type interface. override to create new operators
@@ -753,9 +753,7 @@ def forward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _):
753753 NDArrayHandle ),
754754 writable = False ))
755755 reqs = [req_enum [reqs [i ]] for i in range (len (tensors [1 ]))]
756- op .forward (is_train = is_train , req = reqs ,
757- in_data = tensors [0 ], out_data = tensors [1 ],
758- aux = tensors [4 ])
756+ op .forward (is_train , reqs , tensors [0 ], tensors [1 ], tensors [4 ])
759757 except Exception :
760758 print ('Error in CustomOp.forward: %s' % traceback .format_exc ())
761759 return False
@@ -776,10 +774,8 @@ def backward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _):
776774 NDArrayHandle ),
777775 writable = False ))
778776 reqs = [req_enum [reqs [i ]] for i in range (len (tensors [2 ]))]
779- op .backward (req = reqs ,
780- in_data = tensors [0 ], out_data = tensors [1 ],
781- in_grad = tensors [2 ], out_grad = tensors [3 ],
782- aux = tensors [4 ])
777+ op .backward (reqs , tensors [0 ], tensors [1 ], tensors [2 ],
778+ tensors [3 ], tensors [4 ])
783779 except Exception :
784780 print ('Error in CustomOp.backward: %s' % traceback .format_exc ())
785781 return False
0 commit comments