|
15 | 15 | import paddle |
16 | 16 | from paddle.fluid.dygraph.layers import Layer |
17 | 17 | from ...utils.log_util import logger, layer_to_str |
| 18 | +from functools import partial |
18 | 19 |
|
19 | 20 | __all__ = [] |
20 | 21 |
|
@@ -58,6 +59,20 @@ def __repr__(self): |
58 | 59 | **self.kwargs) |
59 | 60 |
|
60 | 61 |
|
| 62 | +class SharedLayerDesc(LayerDesc): |
| 63 | + def __init__(self, |
| 64 | + key, |
| 65 | + layer_func, |
| 66 | + forward_func=None, |
| 67 | + shared_weight_attr='weight', |
| 68 | + *inputs, |
| 69 | + **kwargs): |
| 70 | + super(SharedLayerDesc, self).__init__(layer_func, *inputs, **kwargs) |
| 71 | + self.layer_name = key |
| 72 | + self.forward_func = forward_func |
| 73 | + self.shared_weight_attr = shared_weight_attr |
| 74 | + |
| 75 | + |
61 | 76 | class PipelineLayer(Layer): |
62 | 77 | def __init__(self, |
63 | 78 | layers, |
@@ -104,11 +119,86 @@ def __init__(self, |
104 | 119 | self._start_pos = 0 |
105 | 120 | self._end_pos = self._num_layers - 1 |
106 | 121 | self._segment_network(seg_method) |
| 122 | + self.shared_layers = paddle.nn.LayerDict() |
| 123 | + self.shared_weight_attrs = {} |
107 | 124 |
|
108 | 125 | # construct layer |
109 | 126 | self.run_function = [] |
110 | 127 | self._build_layer() |
111 | 128 |
|
| 129 | + self.shared_comm = self._construct_shared_comm() |
| 130 | + self._synchronize_shared_weights() |
| 131 | + |
| 132 | + def get_stage_from_index(self, layer_idx): |
| 133 | + assert 0 <= layer_idx < self._num_layers, "layer_idx is out of bound" |
| 134 | + for stage in range(self._topo.get_dim('pipe')): |
| 135 | + if self.segment_parts[stage] <= layer_idx < self.segment_parts[stage |
| 136 | + + 1]: |
| 137 | + return stage |
| 138 | + |
| 139 | + def _construct_shared_comm(self): |
| 140 | + shared_comm = {} |
| 141 | + if self._topo.get_dim("pipe") == 1: |
| 142 | + return |
| 143 | + |
| 144 | + layers_desc = self._layers_desc |
| 145 | + shared_layer_names = set( |
| 146 | + s.layer_name for s in layers_desc if isinstance(s, SharedLayerDesc)) |
| 147 | + for key in shared_layer_names: |
| 148 | + shared_layers = [] |
| 149 | + for idx, layer in enumerate(layers_desc): |
| 150 | + if isinstance(layer, |
| 151 | + SharedLayerDesc) and layer.layer_name == key: |
| 152 | + shared_layers.append(idx) |
| 153 | + |
| 154 | + shared_stages = set( |
| 155 | + self.get_stage_from_index(idx) for idx in shared_layers) |
| 156 | + self._dp_degree = self._topo.get_dim('data') |
| 157 | + self._mp_degree = self._topo.get_dim('model') |
| 158 | + |
| 159 | + shared_ranks = [] |
| 160 | + for dp in range(self._dp_degree): |
| 161 | + for mp in range(self._mp_degree): |
| 162 | + shared_ranks = [] |
| 163 | + for s in sorted(shared_stages): |
| 164 | + shared_ranks.append( |
| 165 | + self._topo.get_rank_from_stage( |
| 166 | + self.global_rank, pipe=s, data=dp, model=mp)) |
| 167 | + |
| 168 | + group = paddle.distributed.new_group(ranks=shared_ranks) |
| 169 | + if self.global_rank in shared_ranks: |
| 170 | + assert key in self.shared_layers |
| 171 | + if key in self.shared_layers: |
| 172 | + shared_comm[key] = { |
| 173 | + 'ranks': shared_ranks, |
| 174 | + 'group': group, |
| 175 | + 'weight_attr': self.shared_weight_attrs[key], |
| 176 | + 'layer': self.shared_layers[key], |
| 177 | + } |
| 178 | + return shared_comm |
| 179 | + |
| 180 | + def _synchronize_shared_weights(self): |
| 181 | + for key, comm in self.shared_comm.items(): |
| 182 | + with paddle.framework.no_grad(): |
| 183 | + paddle.distributed.broadcast( |
| 184 | + getattr(comm['layer'], comm['weight_attr']), |
| 185 | + src=min(comm['ranks']), |
| 186 | + group=comm['group']) |
| 187 | + |
| 188 | + def allreduce_shared_weight_gradients(self): |
| 189 | + for key, comm in self.shared_comm.items(): |
| 190 | + param = getattr(self.shared_layers[key], comm['weight_attr']) |
| 191 | + # need use trace_op to allreduce weight |
| 192 | + with paddle.framework.no_grad(): |
| 193 | + paddle.fluid.framework._dygraph_tracer().trace_op( |
| 194 | + type="c_allreduce_sum", |
| 195 | + inputs={'X': param._grad_ivar()}, |
| 196 | + outputs={'Out': param._grad_ivar()}, |
| 197 | + attrs={ |
| 198 | + 'ring_id': comm['group'].id, |
| 199 | + 'use_calc_stream': True |
| 200 | + }) |
| 201 | + |
112 | 202 | def _segment_network(self, seg_method): |
113 | 203 | logger.info("start segment network..") |
114 | 204 | seg = SegmentLayers( |
@@ -142,6 +232,21 @@ def _build_layer(self): |
142 | 232 | if isinstance(layer, Layer): |
143 | 233 | self.run_function.append(layer) |
144 | 234 | self.add_sublayer(str(layer_index), layer) |
| 235 | + elif isinstance(layer, SharedLayerDesc): |
| 236 | + if layer.layer_name not in self.shared_layers: |
| 237 | + self.shared_layers[layer.layer_name] = layer.build_layer() |
| 238 | + self.shared_weight_attrs[ |
| 239 | + layer.layer_name] = layer.shared_weight_attr |
| 240 | + |
| 241 | + if layer.forward_func is None: |
| 242 | + self.run_function.append(self.shared_layers[ |
| 243 | + layer.layer_name]) |
| 244 | + |
| 245 | + else: |
| 246 | + self.run_function.append( |
| 247 | + partial(layer.forward_func, self.shared_layers[ |
| 248 | + layer.layer_name])) |
| 249 | + |
145 | 250 | elif isinstance(layer, LayerDesc): |
146 | 251 | model = layer.build_layer() |
147 | 252 | self.run_function.append(model) |
|
0 commit comments