Skip to content
112 changes: 62 additions & 50 deletions python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations

import ast
import re
from collections.abc import Iterable
from dataclasses import dataclass
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(
self.destination_state_shard_info = destination_state_shard_info
self.optim_state_name = [
".w_0",
".moment1_0 ",
".moment1_0",
".moment2_0",
".beta1_pow_acc_0",
".beta2_pow_acc_0",
Expand Down Expand Up @@ -114,11 +115,13 @@ def get_dst_state_shard_num(self, dst_state_key: str) -> int:
raise KeyError(
f"dst_state_key '{dst_state_key}' not in destination_state_shard_info"
)

new_state_key = dst_state_key
for state_name in self.optim_state_name:
if state_name in dst_state_key:
new_state_key = dst_state_key.replace(state_name, "")
break
new_state_key = dst_state_key

shard_infos = self.destination_state_shard_info[new_state_key]
global_offset_set = set()
for shard_info in shard_infos:
Expand Down Expand Up @@ -148,9 +151,7 @@ def __init__(
self.input_vars = self.build_input_vars()
self.output_vars = {}
self.need_remove_input_vars = set()
self.need_remove_output_vars = set()
self.need_transpose_output_vars = set()
self.need_transpose_input_vars = {}
self.need_add_output_vars = set()

self.shape_propagation()

Expand All @@ -176,7 +177,7 @@ def split(
sub_slices = []
for aidx, src_sl, dst_sl, pp_list in tensor.slices:
if pp_list is not None:
src_sl = self.postprocess_transpose(list(src_sl), pp_list)
src_sl = postprocess_transpose(list(src_sl), pp_list)

dst_start = (
dst_sl[axis].start if dst_sl[axis].start is not None else 0
Expand Down Expand Up @@ -206,7 +207,7 @@ def split(
inter_begin - start, inter_begin - start + length
)
if pp_list is not None:
sub_src_sl = self.postprocess_transpose(
sub_src_sl = postprocess_transpose(
list(sub_src_sl), pp_list, reverse=True
)
sub_slices.append(
Expand Down Expand Up @@ -256,17 +257,19 @@ def concat(self, tensors: list[TensorDesc], axis: int) -> TensorDesc:
curr += t.shape[axis]
return TensorDesc(slices, tuple(shape))

def transpose(self, tensor: TensorDesc, transpose: str) -> TensorDesc:
def transpose(self, tensor: TensorDesc, permutation: str) -> TensorDesc:
slices = []
tensor_shape = transpose_list(tensor.shape, eval(transpose))
tensor_shape = transpose_list(
tensor.shape, ast.literal_eval(permutation)
)
for aidx, src_sl, dst_sl, pp_list in tensor.slices:
trans_dst_sl = transpose_list(dst_sl, eval(transpose))
trans_dst_sl = transpose_list(dst_sl, ast.literal_eval(permutation))
if pp_list is not None:
new_pp_list = pp_list.copy()
new_pp_list.append(transpose)
new_pp_list.append(permutation)
slices.append((aidx, src_sl, trans_dst_sl, new_pp_list))
else:
slices.append((aidx, src_sl, trans_dst_sl, [transpose]))
slices.append((aidx, src_sl, trans_dst_sl, [permutation]))
return TensorDesc(slices, tensor_shape)

def cast(self, tensor: TensorDesc, dtype: str) -> TensorDesc:
Expand Down Expand Up @@ -295,7 +298,6 @@ def _get_var_ref(var):
left_vars = stmt.left_vars
right_vars = stmt.right_vars
attrs = stmt.attrs

if len(left_vars) > 1 or len(right_vars) > 1:
if not (len(attrs) == 1 and attrs[0].key == "axis"):
raise ValueError(
Expand Down Expand Up @@ -338,47 +340,49 @@ def _get_var_ref(var):
if rvar.name == "_":
self.need_remove_input_vars.add(lvar.name)
elif lvar.name == "_":
self.need_remove_output_vars.add(rvar.name)
self.need_add_output_vars.add(rvar.name)
else:
if attrs:
if len(attrs) > 0:
for attr in attrs:
in_ref = _get_var_ref(lvar)
if attr.key == "transpose":
if attr.key == "permute":
if attr.value == "[]":
ndim = len(in_ref.shape)
transpose = str(
list(range(ndim - 1, -1, -1))
)
perm = str(list(range(ndim - 1, -1, -1)))
else:
transpose = attr.value
result = self.transpose(in_ref, transpose)
perm = attr.value
result = self.transpose(in_ref, perm)
elif attr.key == "dtype":
result = self.cast(in_ref, attr.value)
elif attr.key == "axis":
pass
else:
raise ValueError(
f"Unsupported attribute: {attr}"
)

out_name = rvar.name
intermediate_vars[out_name] = result
intermediate_vars[rvar.name] = result
if (
out_name
rvar.name
in self.context.get_all_dst_state_keys()
):
self.output_vars[out_name] = result
self.output_vars[rvar.name] = result
else:
intermediate_vars[rvar.name] = _get_var_ref(lvar)
in_ref = _get_var_ref(lvar)
intermediate_vars[rvar.name] = in_ref
if rvar.name in self.context.get_all_dst_state_keys():
self.output_vars[rvar.name] = intermediate_vars[
rvar.name
]
self.output_vars[rvar.name] = in_ref

else:
raise SyntaxError(f'Unexpected statement: {stmt}')

for name in self.destination_state_shard_info.keys():
if name not in self.output_vars:
assert name in self.input_vars
self.output_vars[name] = self.input_vars[name]
if name in self.need_add_output_vars:
self.output_vars[name] = None
else:
assert name in self.input_vars
self.output_vars[name] = self.input_vars[name]

def find_source_slices(
self, key: str, local_slice: tuple[slice, ...]
Expand Down Expand Up @@ -406,7 +410,7 @@ def slice_intersect(a: slice, b: slice):
else:
# Compute corresponding src_slice for the intersection
if pp_list is not None:
sl_src = self.postprocess_transpose(list(sl_src), pp_list)
sl_src = postprocess_transpose(list(sl_src), pp_list)
src_slice = []
for i in range(ndim):
dst = sl_dst[i]
Expand All @@ -424,7 +428,7 @@ def slice_intersect(a: slice, b: slice):
)
src_slice.append(slice(src_inter_start, src_inter_stop, 1))
if pp_list is not None:
src_slice = self.postprocess_transpose(
src_slice = postprocess_transpose(
list(src_slice), pp_list, reverse=True
)
results.append(
Expand Down Expand Up @@ -484,6 +488,14 @@ def find_shard_sources(
tgt_global_offset,
)

if source_sharded_weight.key in self.need_remove_input_vars:
mapping_entry = ShardMappingEntry(
target_sharded_weight,
source_sharded_weight,
[],
)
continue

shard_mappings.append(
ShardMappingEntry(
target_sharded_weight,
Expand All @@ -493,23 +505,23 @@ def find_shard_sources(
)
return shard_mappings

def postprocess_transpose(
self,
li: list[tuple[slice, ...]] | tuple[tuple[slice, ...]],
postprocess_list: list[str],
reverse: bool = False,
) -> list[tuple[slice, ...]] | tuple[tuple[slice, ...]]:
result = li
if reverse:
for pp in list(reversed(postprocess_list)):
if pp.startswith("["):
reversed_transpose = np.argsort(eval(pp)).tolist()
result = transpose_list(result, reversed_transpose)
else:
for pp in postprocess_list:
if pp.startswith("["):
result = transpose_list(result, eval(pp))
return result

def postprocess_transpose(
li: list[tuple[slice, ...]] | tuple[tuple[slice, ...]],
postprocess_list: list[str],
reverse: bool = False,
) -> list[tuple[slice, ...]] | tuple[tuple[slice, ...]]:
result = li
if reverse:
for pp in list(reversed(postprocess_list)):
if pp.startswith("["):
reversed_transpose = np.argsort(ast.literal_eval(pp)).tolist()
result = transpose_list(result, reversed_transpose)
else:
for pp in postprocess_list:
if pp.startswith("["):
result = transpose_list(result, ast.literal_eval(pp))
return result


def transpose_list(
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/distributed/flex_checkpoint/aoa/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Lexer:
('COMMA', r','),
('NUMBER', r'\d+'),
('STRING', r'"[^"]*"|\'[^\']*\''),
('IDENTIFIER', r'[A-Za-z][A-Za-z\.\$\_\*\d\^T]*'),
('IDENTIFIER', r'[A-Za-z_][A-Za-z\.\$\_\*\d\^T]*'),
('SKIP', r'[ \t]+'),
('NEWLINE', r'[\r\n]+'),
('MISMATCH', r'.'),
Expand All @@ -71,7 +71,8 @@ def tokenize(self, text):
pos = 0
mo = self.get_token(text, pos)
tokens = []
text += '\n'
if not text.endswith('\n'):
text += '\n'
while mo is not None:
kind = mo.lastgroup
value = mo.group()
Expand Down
Loading