@@ -612,15 +612,6 @@ def where(condition, x=None, y=None, name=None):
612612 if x is None or y is None :
613613 raise ValueError ("either both or neither of x and y should be given" )
614614
615- if not paddle .in_dynamic_mode ():
616- check_variable_and_dtype (condition , 'condition' , ['bool' ], 'where' )
617- check_variable_and_dtype (
618- x , 'x' , ['float32' , 'float64' , 'int32' , 'int64' ], 'where'
619- )
620- check_variable_and_dtype (
621- y , 'y' , ['float32' , 'float64' , 'int32' , 'int64' ], 'where'
622- )
623-
624615 condition_shape = list (condition .shape )
625616 x_shape = list (x .shape )
626617 y_shape = list (y .shape )
@@ -646,6 +637,14 @@ def where(condition, x=None, y=None, name=None):
646637 if in_dygraph_mode ():
647638 return _C_ops .where (broadcast_condition , broadcast_x , broadcast_y )
648639 else :
640+ check_variable_and_dtype (condition , 'condition' , ['bool' ], 'where' )
641+ check_variable_and_dtype (
642+ x , 'x' , ['float32' , 'float64' , 'int32' , 'int64' ], 'where'
643+ )
644+ check_variable_and_dtype (
645+ y , 'y' , ['float32' , 'float64' , 'int32' , 'int64' ], 'where'
646+ )
647+
649648 helper = LayerHelper ("where" , ** locals ())
650649 out = helper .create_variable_for_type_inference (dtype = x .dtype )
651650
0 commit comments