Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
e38da5b
fix
enkilee Jan 29, 2024
36c0063
fix
enkilee Feb 1, 2024
79b74c3
fix
enkilee Feb 1, 2024
80c2c53
Merge branch 'develop' into PIR-optest-fix-27
enkilee Mar 1, 2024
a585498
Merge branch 'develop' into PIR-optest-fix-27
enkilee Mar 6, 2024
1cbd853
fix
enkilee Mar 6, 2024
8ff226d
fix
enkilee Mar 6, 2024
584ee74
fix
enkilee Mar 6, 2024
b3b2586
fix
enkilee Mar 7, 2024
51385d7
fix
enkilee Mar 8, 2024
59d0f3d
fix
enkilee Mar 8, 2024
d62cd21
fix
enkilee Mar 11, 2024
06c367f
fix
enkilee Mar 11, 2024
57a84b9
fix
enkilee Mar 11, 2024
6c0a155
fix
enkilee Mar 11, 2024
ef60899
fix
enkilee Mar 11, 2024
2ada845
fix
enkilee Mar 11, 2024
6127fe5
fix
enkilee Mar 11, 2024
abe0446
fix
enkilee Mar 11, 2024
fbc0884
fix
enkilee Mar 12, 2024
7601680
fix
enkilee Mar 12, 2024
f9c8eb9
fix
enkilee Mar 12, 2024
812afa7
fix
enkilee Mar 13, 2024
4e0ed3f
Merge branch 'develop' into PIR-optest-fix-27
enkilee Mar 13, 2024
daa5211
fix
enkilee Mar 13, 2024
cdc3bf1
Merge branch 'PIR-optest-fix-27' of https://github.com/enkilee/Paddle…
enkilee Mar 13, 2024
782fe3f
fix
enkilee Mar 13, 2024
0935545
fix
enkilee Mar 14, 2024
14d275d
fix
enkilee Mar 15, 2024
1bc577f
Merge branch 'develop' into PIR-optest-fix-27
enkilee Mar 15, 2024
af9be1c
fix
enkilee Mar 18, 2024
4428c2f
fix
enkilee Mar 18, 2024
d66d375
fix
enkilee Mar 18, 2024
1180742
fix
enkilee Mar 20, 2024
00f941c
fix
enkilee Mar 22, 2024
0a167bc
fix
enkilee Mar 23, 2024
f082a23
fix
enkilee Mar 26, 2024
36dde2a
fix
enkilee Mar 27, 2024
33a33c1
fix
enkilee Mar 27, 2024
45bebe0
fix
enkilee Mar 27, 2024
8afa4f2
fix
enkilee Apr 2, 2024
d3c0cae
fix
enkilee Apr 2, 2024
30fa5ea
Merge branch 'develop' into PIR-optest-fix-27
kangguangli Apr 7, 2024
d59ebfb
fix
enkilee Apr 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@
'c_softmax_with_cross_entropy',
'c_split',
'decayed_adagrad',
'distributed_fused_lamb',
'distributed_fused_lamb_',
'distributed_push_sparse',
'distributed_lookup_table',
'dgc_momentum',
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,16 @@
data_type : fpn_rois
optional : rois_num, multi_level_rois_num

- op : distributed_fused_lamb
args : (Tensor[] param, Tensor[] grad, Tensor fp32_fused_param, Tensor fp32_fused_grad, Tensor fp16_fused_param, Tensor fp16_fused_grad, Tensor moment1, Tensor moment2, Tensor beta1pow, Tensor beta2pow, Tensor fused_param_offsets, Tensor fp32_shard_fused_param_offsets, Tensor fp16_shard_fused_param_offsets, Tensor param_info, Tensor param_order, Tensor learning_rate, Tensor global_scale, float beta1, float beta2, float epsilon, float max_global_grad_norm, float weight_decay, bool clip_after_allreduce, int[] ring_ids= {}, int acc_steps = 1, bool use_master_param_norm = true, bool use_master_acc_grad = true, bool is_grad_scaled_by_nranks = true, int64_t nranks = 1, bool use_hierarchical_allreduce = false)
output : Tensor(fp32_fused_param_out), Tensor(fp16_fused_param_out), Tensor(fp32_acc_fused_grad), Tensor(fp16_acc_fused_grad), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1pow_out), Tensor(beta2pow_out), Tensor[](param_out){param.size()}, Tensor(found_inf), Tensor(acc_step), Tensor(stop_update), Tensor(step)
kernel :
func : distributed_fused_lamb
data_type : DataType::FLOAT32
param : [param, grad, fp32_fused_param, fp32_fused_grad, fp16_fused_param, fp16fused_grad, moment1, moment2, beta1pow, beta2pow, fused_param_offsets, fp32_shard_fused_param_offsets, fp16_shard_fused_param_offsets, param_info, param_order, learning_rate, global_scale, acc_steps, beta1, beta2, epsilon, max_global_grad_norm, weight_decay, clip_after_allreduce, use_master_param_norm, use_master_acc_grad, is_grad_scaled_by_nranks, use_hierarchical_allreduce, nranks, ring_ids]
optional : fp32_fused_param, fp32_fused_grad, fp16_fused_param, fp16_fused_grad, fp32_fused_param_out, fp16_fused_param_out, fp32_acc_fused_grad, fp16_acc_fused_grad, acc_step, stop_update
inplace : (fp32_fused_param -> fp32_fused_param_out), (fp16_fused_param -> fp16_fused_param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1pow -> beta1pow_out), (beta2pow -> beta2pow_out), (param -> param_out)

- op : distributed_fused_lamb_init
args : (Tensor[] param, Tensor[] grad, float beta1, float beta2, int[] apply_weight_decay, int alignment, int rank, int nranks)
output : Tensor(fp32_fused_param), Tensor(fp32_fused_grad), Tensor(fp16_fused_param), Tensor(fp16_fused_grad), Tensor(moment1), Tensor(moment2), Tensor(beta1_pow), Tensor(beta2_pow), Tensor(fused_param_offsets), Tensor(fp32_shard_fused_param_offsets), Tensor(fp16_shard_fused_param_offsets), Tensor(param_info), Tensor(param_order), Tensor[](param_out){param.size()}, Tensor[](master_param_out){param.size()}, Tensor[](grad_out){grad.size()}, Tensor(global_scale), Tensor(step)
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3725,6 +3725,12 @@
multi_level_rois_num: MultiLevelRoIsNum
restore_index: RestoreIndex

- op: distributed_fused_lamb
inputs:
{param: Param, grad: Grad, fp32_fused_param: FP32FusedParam, fp32_fused_grad: FP32FusedGrad, fp16_fused_param: FP16FusedParam, fp16_fused_grad: FP16FusedGrad, moment1: Moment1, moment2: Moment2, beta1pow: Beta1Pow, beta2pow: Beta2Pow, fused_param_offsets: FusedParamOffsets, fp32_shard_fused_param_offsets: FP32ShardFusedParamOffsets, fp16_shard_fused_param_offsets: FP16ShardFusedParamOffsets, param_info: ParamInfo, param_order: ParamOrder, learning_rate: LearningRate, global_scale: GlobalScale}
outputs:
{param_out : ParamOut, fp32_fused_param_out: FP32FusedParamOut, fp16_fused_param_out: FP16FusedParamOut, fp32_acc_fused_grad: FP32AccFusedGrad, fp16_acc_fused_grad: FP16AccFusedGrad, moment1_out: Moment1Out, moment2_out: Moment2Out, beta1pow_out: Beta1PowOut, beta2pow_out: Beta2PowOut, found_inf: FoundInf, acc_step: AccStep, stop_update: StopUpdate, step: Step}

- op: distributed_fused_lamb_init
inputs:
{param: Param, grad: Grad}
Expand Down
1 change: 1 addition & 0 deletions test/ir/pir/translator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_reduce_prod_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_scatter_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_split_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_distributed_fused_lamb_init)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_distributed_fused_lamb)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST
test_distributed_lookup_table_translate)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST
Expand Down
218 changes: 218 additions & 0 deletions test/ir/pir/translator/test_distributed_fused_lamb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import test_op_translator

import paddle
from paddle.base import core, unique_name
from paddle.base.layer_helper import LayerHelper


class TestDistributedFusedLambOpTranslator(test_op_translator.TestOpTranslator):
def setUp(self):
super().setUp()
assert (
not paddle.in_dynamic_mode()
), "DistributedFusedLamb does not support dygraph mode"
self._beta1 = 0.9
self._beta2 = 0.999
self._epsilon = 1e-6
self._weight_decay = 0.01
self._max_global_grad_norm = -1.0
self._alignment = 128
self._clip_after_allreduce = True
self._is_grad_scaled_by_nranks = True
self._scale = None
self._use_master_param_norm = True
self._gradient_accumulation_steps = 1
self._use_master_acc_grad = True
self._use_hierarchical_allreduce = False
self.helper = LayerHelper("distributed_fused_lamb")

main_block = self.helper.main_program.global_block()
self._found_inf = main_block.create_var(
name=unique_name.generate("found_inf"),
shape=[1],
dtype=core.VarDesc.VarType.BOOL,
)
self._step = None

self._param_to_master_param = {}

def _create_persistable_var(self, name=None, shape=[-1], dtype="float32"):
startup_block = self.helper.startup_program.global_block()
if name is not None:
name = unique_name.generate(name)
startup_var = startup_block.create_var(
name=name,
shape=shape,
dtype=dtype,
persistable=True,
stop_gradient=True,
)
main_block = self.helper.main_program.global_block()
main_var = main_block.create_var(
name=startup_var.name,
shape=startup_var.shape,
dtype=startup_var.dtype,
persistable=True,
stop_gradient=True,
)
return main_var

def _create_scale_from_constant(self, value):
name = unique_name.generate('global_scale')
return paddle.static.create_global_var(
name=name,
shape=[1],
dtype='float32',
value=float(value),
persistable=True,
)

def append_op(self):
self.op_type = "distributed_fused_lamb"
params = [paddle.ones(shape=(1, 1), dtype="float32")]
grads = [paddle.ones(shape=(1, 1), dtype="float32")]
lr = paddle.to_tensor(0.001, dtype="float32")
rank = paddle.distributed.get_rank()
nranks = paddle.distributed.get_world_size()
fp32_fused_param = self._create_persistable_var("fp32_fused_param")
fp32_fused_grad = self._create_persistable_var("fp32_fused_grad")
fp16_fused_param = self._create_persistable_var(
"fp16_fused_param", dtype="float16"
)
fp16_fused_grad = self._create_persistable_var(
"fp16_fused_grad", dtype="float16"
)

moment1 = self._create_persistable_var("moment1")
moment1.is_distributed = True
moment2 = self._create_persistable_var("moment2")
moment2.is_distributed = True
beta1pow = self._create_persistable_var("beta1pow")
beta2pow = self._create_persistable_var("beta2pow")

param_info = self._create_persistable_var("param_info", dtype="int32")
param_info.is_distributed = True

fused_offsets = self._create_persistable_var(
"fused_offsets", dtype="int32"
)

fp32_partial_fused_offsets = self._create_persistable_var(
"fp32_partial_fused_offsets", dtype="int32"
)
fp32_partial_fused_offsets.is_distributed = True

fp16_partial_fused_offsets = self._create_persistable_var(
"fp16_partial_fused_offsets", dtype="int32"
)
fp16_partial_fused_offsets.is_distributed = True

param_order = self._create_persistable_var("param_order", dtype="int32")
param_order.is_distributed = True

fp32_acc_fused_grad = [
self._create_persistable_var("fp32_acc_fused_grad")
]
fp16_acc_fused_grad = [
self._create_persistable_var("fp16_acc_fused_grad", dtype="float16")
]
acc_step = [self._create_persistable_var("acc_step", dtype="int64")]

scale = self._create_scale_from_constant(1.0)

step = self._create_persistable_var('step', dtype='int64')

ring_ids = []
ring_id = 0
ring_ids.append(ring_id)
main_block = self.helper.main_program.global_block()
_found_inf = main_block.create_var(
name=unique_name.generate("found_inf"),
shape=[1],
dtype=core.VarDesc.VarType.BOOL,
)
_stop_update = main_block.create_var(
name=unique_name.generate("stop_update"),
shape=[1],
dtype=core.VarDesc.VarType.BOOL,
)
attrs = {
"weight_decay": 0.01,
"beta1": 0.9,
"beta2": 0.999,
"epsilon": 1e-6,
"max_global_grad_norm": -1.0,
"clip_after_allreduce": True,
"rank": rank,
"nranks": nranks,
"ring_ids": ring_ids,
"use_master_param_norm": True,
"is_grad_scaled_by_nranks": True,
"acc_steps": 1,
"use_master_acc_grad": True,
"use_hierarchical_allreduce": False,
}

helper = LayerHelper(self.op_type)
helper.append_op(
type=self.op_type,
inputs={
"FP32FusedParam": [fp32_fused_param],
"FP32FusedGrad": [fp32_fused_grad],
"FP16FusedParam": [fp16_fused_param],
"FP16FusedGrad": [fp16_fused_grad],
"LearningRate": [lr],
"Moment1": [moment1],
"Moment2": [moment2],
"Beta1Pow": [beta1pow],
"Beta2Pow": [beta2pow],
"GlobalScale": [scale],
"ParamInfo": [param_info],
"Param": params,
"Grad": grads,
"FusedParamOffsets": [fused_offsets],
"FP32ShardFusedParamOffsets": [fp32_partial_fused_offsets],
"FP16ShardFusedParamOffsets": [fp16_partial_fused_offsets],
"ParamOrder": [param_order],
},
outputs={
"FP32FusedParamOut": [fp32_fused_param],
"FP16FusedParamOut": [fp16_fused_param],
"Moment1Out": [moment1],
"Moment2Out": [moment2],
"Beta1PowOut": [beta1pow],
"Beta2PowOut": [beta2pow],
"ParamOut": params,
"GradOut": grads,
"FoundInf": [_found_inf],
"FP32AccFusedGrad": fp32_acc_fused_grad,
"FP16AccFusedGrad": fp16_acc_fused_grad,
"AccStep": acc_step,
"StopUpdate": _stop_update,
"Step": [step],
},
attrs=attrs,
)

def test_translator(self):
self.check()


if __name__ == "__main__":
unittest.main()