Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2405,6 +2405,7 @@
infer_meta:
func: SwiGLUGradInferMeta
param: [x, y]
spmd_rule: SwiGLUGradInferSpmd
kernel:
func: swiglu_grad
optional: y
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2766,6 +2766,7 @@
output : Tensor(out)
infer_meta:
func: SwiGLUInferMeta
spmd_rule: SwiGLUInferSpmd
kernel:
func : swiglu
optional : y
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/spmd_rules/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,15 @@ SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& out_grad,
int64_t axis = -1);

SpmdInfo SwiGLUInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y);

SpmdInfo SwiGLUInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& out);

SpmdInfo SwiGLUGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& out_grad);

} // namespace distributed
} // namespace phi
5 changes: 3 additions & 2 deletions paddle/phi/infermeta/spmd_rules/rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,12 +435,13 @@ PD_REGISTER_SPMD_RULE(
logical_xor,
PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd),
PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse));

PD_REGISTER_SPMD_RULE(
not_equal,
PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd),
PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse));

PD_REGISTER_SPMD_RULE(swiglu,
PD_INFER_SPMD(phi::distributed::SwiGLUInferSpmd),
PD_INFER_SPMD(phi::distributed::SwiGLUInferSpmdReverse));
// TODO(pkuzyc): add multiary elementwise rule

// reduction rule
Expand Down
60 changes: 60 additions & 0 deletions paddle/phi/infermeta/spmd_rules/swiglu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/infermeta/spmd_rules/elementwise.h"

#include "glog/logging.h"

#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/infermeta/spmd_rules/utils.h"

namespace phi {
namespace distributed {

SpmdInfo SwiGLUInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y) {
// y.dist_attr() is empty means y is None
if (y.dist_attr() == TensorDistAttr()) {
PADDLE_THROW(
phi::errors::Unimplemented("The input y is not allowed to be None"));
} else {
return ElementwiseBinaryInferSpmd(x, y);
}
}

SpmdInfo SwiGLUInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& out) {
if (y.dist_attr() == TensorDistAttr()) {
PADDLE_THROW(
phi::errors::Unimplemented("The input y is not allowed to be None"));
} else {
return ElementwiseBinaryInferSpmdReverse(x, y, out);
}
}

SpmdInfo SwiGLUGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& out_grad) {
if (y.dist_attr() == TensorDistAttr()) {
PADDLE_THROW(
phi::errors::Unimplemented("The input y is not allowed to be None"));
} else {
return ElementwiseBinaryGradInferSpmd(x, y, out_grad);
}
}

} // namespace distributed
} // namespace phi
77 changes: 77 additions & 0 deletions test/legacy_test/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,24 @@
import unittest

import numpy as np
from op_test import OpTest

import paddle
import paddle.distributed as dist
import paddle.nn.functional as F
from paddle.distributed.auto_parallel.static.dist_attribute import (
DistTensorSpec,
TensorDistAttr,
)
from paddle.incubate.nn.functional import swiglu as fused_swiglu_impl


def swiglu(x, y, out_grad):
if isinstance(x, np.ndarray):
x = paddle.to_tensor(x)
y = paddle.to_tensor(y)
out_grad = paddle.to_tensor(out_grad)

origin_x = x.detach().clone()
origin_x.stop_gradient = False
x = origin_x
Expand Down Expand Up @@ -160,5 +171,71 @@ def test_main(self):
self.check_main([4, 101])


class TestSwigluOp(OpTest):
def config(self):
self.x_shape = (8, 128)
self.check_auto_parallel = True

def setUp(self):
self.config()
self.op_type = "swiglu"
self.python_api = fused_swiglu_impl
x = np.random.uniform(-1, 1, self.x_shape).astype("float64")
y = np.random.uniform(-1, 1, self.x_shape).astype("float64")
out_grad = np.random.uniform(-1, 1, self.x_shape).astype("float64")
res = swiglu(x, y, out_grad)
self.inputs = {'x': x, 'y': y}
self.outputs = {'out': res[0].numpy()}
self.placements = {
'x': [dist.Shard(1)],
'y': [dist.Shard(1)],
'out': [dist.Shard(1)],
}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(
['x', 'y'],
'out',
check_auto_parallel=self.check_auto_parallel,
check_dygraph=1,
)


@unittest.skipIf(
not paddle.base.core.is_compiled_with_dist(),
"The spmd rule is should be tested with distributed=ON",
)
class TestSwigluSpmd(unittest.TestCase):
def setUp(self):
self.kernel = 'swiglu'
self.rule = paddle.base.core.get_phi_spmd_rule(self.kernel)
x_shape = [64, 32]
process_mesh = dist.ProcessMesh(mesh=[0, 1, 2, 3])
x_tensor_dist_attr = TensorDistAttr()
x_tensor_dist_attr.dims_mapping = [-1, 0]
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
self.y_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
self.out_dist_tensor_spec = DistTensorSpec(self.x_dist_tensor_spec)

def test_input_x_y(self):
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.y_dist_tensor_spec
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 2)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0])

def test_input_x(self):
with self.assertRaises(NotImplementedError):
self.rule.infer_forward(self.x_dist_tensor_spec, DistTensorSpec())


if __name__ == "__main__":
unittest.main()