@@ -107,6 +107,42 @@ class UnfoldOp : public framework::OperatorWithKernel {
107107 " But recieved dims(strides: %u) != dims(dilations: %u)." ,
108108 strides.size (), dilations.size ()));
109109
110+ // check kernel_sizes
111+ PADDLE_ENFORCE_GT (kernel_sizes[0 ], 0 ,
112+ platform::errors::InvalidArgument (
113+ " The `kernel_sizes` should be greater than zero, "
114+ " but recieved kernel_height: %d kernel_width: %d." ,
115+ kernel_sizes[0 ], kernel_sizes[1 ]));
116+ PADDLE_ENFORCE_GT (kernel_sizes[1 ], 0 ,
117+ platform::errors::InvalidArgument (
118+ " The `kernel_sizes` should be greater than zero, "
119+ " but recieved kernel_height: %d kernel_width: %d." ,
120+ kernel_sizes[0 ], kernel_sizes[1 ]));
121+ // check strides
122+ PADDLE_ENFORCE_GT (strides[0 ], 0 ,
123+ platform::errors::InvalidArgument (
124+ " The `strides` should be greater than zero, "
125+ " but recieved strides_height: %d strides_width: %d." ,
126+ strides[0 ], strides[1 ]));
127+ PADDLE_ENFORCE_GT (strides[1 ], 0 ,
128+ platform::errors::InvalidArgument (
129+ " The `strides` should be greater than zero, "
130+ " but recieved strides_height: %d strides_width: %d." ,
131+ strides[0 ], strides[1 ]));
132+ // check dilations
133+ PADDLE_ENFORCE_GT (
134+ dilations[0 ], 0 ,
135+ platform::errors::InvalidArgument (
136+ " The `dilations` should be greater than zero, "
137+ " but recieved dilations_height: %d dilations_width: %d." ,
138+ dilations[0 ], dilations[1 ]));
139+ PADDLE_ENFORCE_GT (
140+ dilations[1 ], 0 ,
141+ platform::errors::InvalidArgument (
142+ " The `dilations` should be greater than zero, "
143+ " but recieved dilations_height: %d dilations_width: %d." ,
144+ dilations[0 ], dilations[1 ]));
145+
110146 std::vector<int > out_dims;
111147 out_dims.push_back (in_dims[0 ]);
112148
0 commit comments