Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/paddle/fluid/tests/unittests/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def setUpClass(cls):
np.random.seed(12345)

cls.TEST_SAMPLES = {
"a": np.random.rand(1, 1),
"b": np.random.rand(1),
"x": np.random.rand(5),
"y": np.random.rand(7),
"A": np.random.rand(4, 5),
Expand Down Expand Up @@ -179,6 +181,11 @@ def setUp(self):
self.sample = {"paradigm": "ij,ij->ij", "data": ["A", "A"]}


class TestEinsumDegenerateMatrixVecMul(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,j", "data": ["a", "b"]}


class TestEinsumMatrixVecMul(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,j->i", "data": ["A", "x"]}
Expand Down
61 changes: 36 additions & 25 deletions python/paddle/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,11 +394,12 @@ def plan_reduce(plan, op, reduce_dims, keepdim):
def plan_scalar_prod(plan, op1, op2):
varnames = [f'op{op1}', f'op{op2}']
f = lambda var1, var2: paddle_sum(var1) * var2
# f = lambda var1, var2: var1 * var2
step = f, varnames, varnames[1]
plan.add_step(step)


def plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K):
def plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K):
'''
plan matmul
'''
Expand All @@ -416,7 +417,7 @@ def plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K):
K1_dims = [op1_view[ax] for ax in K]
K2_dims = [op2_view[ax] for ax in K]

op1_mask, op2_mask = [g_op_masks[op] for op in (op1, op2)]
op1_mask, op2_mask = [g_supports[op] for op in (op1, op2)]
op1_vshape = [s if m else 1 for s, m in zip(g_shape, op1_mask)]
op2_vshape = [s if m else 1 for s, m in zip(g_shape, op2_mask)]

Expand Down Expand Up @@ -515,13 +516,13 @@ def plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K):
op2_view[ax], dim = dim, dim + 1


def plan_summation(plan, g_view, op1, op2, g_op_masks, g_shape, g_count,
def plan_summation(plan, g_view, op1, op2, g_supports, g_shape, g_count,
n_bcast):
'''
Plan various kinds of summation
'''
op1_view, op2_view = g_view[op1], g_view[op2]
op1_mask, op2_mask = g_op_masks[op1], g_op_masks[op2]
op1_mask, op2_mask = g_supports[op1], g_supports[op2]

ndim = len(op1_view)
nout = ndim - len(g_count)
Expand Down Expand Up @@ -553,7 +554,7 @@ def plan_summation(plan, g_view, op1, op2, g_op_masks, g_shape, g_count,

# Now it's OK to merge the K dims as the same shape holds
# print(f'I: {I} J1: {J1} J2: {J2} K: {K}')
plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K)
plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K)


def rearrange(axes):
Expand Down Expand Up @@ -625,7 +626,7 @@ def execute(self):
return res


def plan_einsum(operands, g_view, g_shape, g_op_masks, g_count, n_bcast):
def plan_einsum(operands, g_view, g_shape, g_supports, g_count, n_bcast):
'''
Plans the actual execution steps.
Results
Expand All @@ -646,17 +647,18 @@ def plan_einsum(operands, g_view, g_shape, g_op_masks, g_count, n_bcast):
plan_broadcast(plan, operands, g_view)
return plan

# Down count axis >= nout and degenerate dimensions (masked is not set)
for view, mask in zip(g_view, g_op_masks):
# Down count degenerate contraction dimensions.
for view, support in zip(g_view, g_supports):
# To collect the down count number, we use a type casting trick
down_count = [
1 if (dim > -1 and not masked) else 0
for dim, masked in zip(view[nout:], mask[nout:])
int((d + 1) and (not s))
for d, s in zip(view[nout:], support[nout:])
]
for i, d in enumerate(down_count):
g_count[i] -= d
for i, count in enumerate(down_count):
g_count[i] -= count

# Reduce any dimension for which g_mask is set and g_count == 1
for i, view, mask in zip(range(nop), g_view, g_op_masks):
# Reduce any dimension for which g_support is set and g_count == 1
for i, view, mask in zip(range(nop), g_view, g_supports):
to_reduce = []
for dim, masked, count in zip(view[nout:], mask[nout:], g_count):
to_reduce.append(dim if (masked and count == 1) else -1)
Expand Down Expand Up @@ -695,27 +697,36 @@ def plan_einsum(operands, g_view, g_shape, g_op_masks, g_count, n_bcast):
# (4) Elsewise, either I... or J... not empty, and K... not empty, use a general matmul

# Resolve the summation kind: dot, matmul or *
if not any(g_op_masks[i - 1]):
# op1 is a scalar
if not any(g_supports[i - 1]):
# op1 is a one element tensor.
plan_scalar_prod(plan, i - 1, i)
else:
plan_summation(plan, g_view, i - 1, i, g_op_masks, g_shape, g_count,
plan_summation(plan, g_view, i - 1, i, g_supports, g_shape, g_count,
n_bcast)

# for ax, dim in enumerate(g_view[nop-1][:nout]):
# assert dim == ax
assert all(not masked for masked in g_op_masks[nop - 1][nout:])
assert all(not masked for masked in g_supports[nop - 1][nout:])

view = g_view[-1]
if any(ax != dim for ax, dim in enumerate(view[:nout])):
perm = [dim for dim in view if dim >= 0]
varname = f'op{nop-1}'
step = transpose, [varname], varname, perm
plan.add_step(step)
if sorted(perm) != perm:
varname = f'op{nop-1}'
step = transpose, [varname], varname, perm
plan.add_step(step)
dim = 0
unsqueeze_dims = []
for ax, d in enumerate(view):
if d != -1:
view[ax], dim = dim, dim + 1
for ax, d in enumerate(view[:nout]):
if d == -1:
unsqueeze_dims.append(ax)
if unsqueeze_dims:
varname = f'op{nop-1}'
step = unsqueeze, [varname], varname, unsqueeze_dims
plan.add_step(step)

squeeze_dims = [dim for dim in view[nout:] if dim != -1]
if squeeze_dims:
Expand Down Expand Up @@ -922,18 +933,18 @@ def einsum(equation, *operands):
# should broadcast to
# g_nout:
# Number of output axes
# g_op_masks
# A list of masks that specify each operand's non-trivial dimensions
# g_supports
# Booleans indicating each operand's non-trivial dimensions
# g_count
# Counting how many non-trivial dimensions remain for each ax

g_labels, g_view, g_nout, g_count = build_global_view(nop_labels, rhs,
n_bcast_dims)
g_shape, g_op_masks = build_global_shape(g_view, g_labels,
g_shape, g_supports = build_global_shape(g_view, g_labels,
[op.shape for op in operands])

# Now we're ready to build up an execution plan
args = operands, g_view, g_shape, g_op_masks, g_count, n_bcast_dims
args = operands, g_view, g_shape, g_supports, g_count, n_bcast_dims
plan = plan_einsum(*args)
result = plan.execute()

Expand Down