-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathopt_utils.py
More file actions
145 lines (121 loc) · 4.71 KB
/
opt_utils.py
File metadata and controls
145 lines (121 loc) · 4.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
from collections import defaultdict
from torch import Tensor
from torch.distributed.tensor import DTensor
from typing import Generator, List, Optional, Union
def to_local(tensor: Union[Tensor, List[Tensor]]) -> Union[Tensor, List[Tensor]]:
"""
Convert a single DTensor or list of DTensors to local tensors.
This is a no-op for regular tensors.
"""
if isinstance(tensor, Tensor):
return tensor.to_local() if isinstance(tensor, DTensor) else tensor
return [t.to_local() if isinstance(t, DTensor) else t for t in tensor]
def dtensor_from_local(
tensor: Union[Tensor, List[Tensor]], ref: Tensor
) -> Union[DTensor, List[DTensor]]:
"""
Convert a single local Tensor or list of local Tensors to DTensor.
The reference tensor's device mesh and placements are used to create the DTensor.
if the reference tensor is not a DTensor, we return the input unmodified.
"""
if not isinstance(ref, DTensor):
assert isinstance(ref, Tensor)
return tensor
device_mesh = ref.device_mesh
placements = ref.placements
# If we have a single tensor
if isinstance(tensor, Tensor):
assert not isinstance(tensor, DTensor)
return DTensor.from_local(
tensor, device_mesh=device_mesh, placements=placements
)
# We have a list of tensors
assert not isinstance(tensor[0], DTensor)
return [
DTensor.from_local(t, device_mesh=device_mesh, placements=placements)
for t in tensor
]
def create_param_batches(
params: List[Tensor], batch_size: int
) -> Generator[List[Tensor], None, None]:
"""
Batch parameters into groups of size `batch_size`.
Tensors in each batch will have identical shape, sharding, and dtype.
"""
# Group parameters by shape, sharding, and dtype
groups = defaultdict(list)
for p in params:
sharding = p.placements if isinstance(p, DTensor) else None
groups[(p.shape, sharding, p.dtype)].append(p)
# Create batches from grouped parameters
for group in groups.values():
for i in range(0, len(group), batch_size):
batch = group[i : i + batch_size]
yield batch
def pad_batch(batch: List[Tensor], batch_size: int) -> List[Tensor]:
"""
Insert dummy tensors so the batch has exactly `batch_size` elements.
"""
assert len(batch) > 0
assert len(batch) <= batch_size
while len(batch) < batch_size:
batch.append(torch.empty_like(batch[0]))
return batch
class AsyncTask:
"""
AsyncTask wraps a Python generator to run until the next yield statement.
This is used to allow other tasks to run while waiting for distributed operations.
"""
def __init__(self, generator: Generator[None, None, None]):
self._generator = generator
self.run() # Start running the generator
def run(self) -> bool:
# Run the next step of the async task.
# Returns True if the task is still running and False if completed.
try:
next(self._generator)
return True
except StopIteration:
pass
return False
class AsyncRuntime:
"""
Event loop for running multiple async tasks concurrently.
"""
def __init__(
self, task_gen: Generator["AsyncTask", None, None], max_concurrent_tasks: int
):
# Initialize runtime with a generator that produces AsyncTask objects
if max_concurrent_tasks <= 0:
raise ValueError(f"{max_concurrent_tasks=} cannot be <= 0")
self._task_gen = task_gen
self._max_concurrent_tasks = max_concurrent_tasks
def _get_next_task(self) -> Optional["AsyncTask"]:
try:
task = next(self._task_gen)
return task
except StopIteration:
return None
def run(self):
# Run the event loop until all tasks are completed
have_new_tasks = True
previous_tasks: List["AsyncTask"] = []
while have_new_tasks or previous_tasks:
# See if we can add another task
running_tasks = []
if have_new_tasks and len(previous_tasks) < self._max_concurrent_tasks:
new_task = self._get_next_task()
if new_task is not None:
# Add new task to the queue
running_tasks.append(new_task)
else:
# No more tasks left
have_new_tasks = False
# Run all previous tasks for one step
for task in previous_tasks:
still_running = task.run()
if still_running:
running_tasks.append(task)
# Update task list for next iteration
previous_tasks = running_tasks