Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/unittests/test_bmm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def test_api_error(self):
y_data = np.arange(16, dtype='float32').reshape((2, 4, 2))
y_data_wrong1 = np.arange(16, dtype='float32').reshape((2, 2, 4))
y_data_wrong2 = np.arange(16, dtype='float32').reshape((2, 2, 2, 2))
y_data_wrong3 = np.arange(24, dtype='float32').reshape((3, 2, 4))
self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong1)
self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong2)
self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong3)


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,10 @@ def bmm(x, y, name=None):
raise ValueError(
"x's width must be equal with y's height. But received x's shape: {}, y's shape: {}".
format(x_shape, y_shape))
if x_shape[0] != y_shape[0]:
raise ValueError(
"x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}".
format(x_shape, y_shape))
helper = LayerHelper('bmm', **locals())
if in_dygraph_mode():
return core.ops.bmm(x, y)
Expand Down