@@ -223,7 +223,7 @@ def flip(x, axis, name=None):
223223 Args:
224224 x (Tensor): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor x
225225 should be float32, float64, int32, int64, bool.
226- axis (list|tuple): The axis(axes) to flip on. Negative indices for indexing from the end are accepted.
226+ axis (list|tuple|int ): The axis(axes) to flip on. Negative indices for indexing from the end are accepted.
227227 name (str, optional): The default value is None. Normally there is no need for user to set this property.
228228 For more information, please refer to :ref:`api_guide_Name` .
229229
@@ -240,10 +240,17 @@ def flip(x, axis, name=None):
240240 x = np.arange(image_shape[0] * image_shape[1] * image_shape[2]).reshape(image_shape)
241241 x = x.astype('float32')
242242 img = paddle.to_tensor(x)
243- out = paddle.flip(img, [0,1])
243+ tmp = paddle.flip(img, [0,1])
244+ print(tmp) # [[[10,11],[8, 9]], [[6, 7],[4, 5]], [[2, 3],[0, 1]]]
244245
245- print(out) # [[[10,11][8, 9]],[[6, 7],[4, 5]] [[2, 3],[0, 1]]]
246+ out = paddle.flip(tmp,-1)
247+ print(out) # [[[11,10],[9, 8]], [[7, 6],[5, 4]], [[3, 2],[1, 0]]]
246248 """
249+ if isinstance (axis , int ):
250+ axis = [axis ]
251+ if in_dygraph_mode ():
252+ return core .ops .flip (x , "axis" , axis )
253+
247254 helper = LayerHelper ("flip" , ** locals ())
248255 check_type (x , 'X' , (Variable ), 'flip' )
249256 dtype = helper .input_dtype ('x' )
0 commit comments