-
Notifications
You must be signed in to change notification settings - Fork 260
Expand file tree
/
Copy pathtrainer.py
More file actions
406 lines (340 loc) · 15.7 KB
/
trainer.py
File metadata and controls
406 lines (340 loc) · 15.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
import collections
import copy
import logging
from federatedscope.core.trainers.base_trainer import BaseTrainer
from federatedscope.core.trainers.enums import MODE, LIFECYCLE
from federatedscope.core.auxiliaries.decorators import use_diff
from federatedscope.core.trainers.utils import format_log_hooks, \
filter_by_specified_keywords
from federatedscope.core.trainers.context import Context, CtxVar, lifecycle
logger = logging.getLogger(__name__)
class Trainer(BaseTrainer):
"""
Register, organize and run the train/test/val procedures
"""
HOOK_TRIGGER = [
"on_fit_start", "on_epoch_start", "on_batch_start", "on_batch_forward",
"on_batch_backward", "on_batch_end", "on_epoch_end", "on_fit_end"
]
def __init__(self,
model,
data,
device,
config,
only_for_eval=False,
monitor=None):
self._cfg = config
self.ctx = Context(model, self.cfg, data, device)
# Parse data and setup init vars in ctx
self._setup_data_related_var_in_ctx(self.ctx)
assert monitor is not None, \
f"Monitor not found in trainer with class {type(self)}"
self.ctx.monitor = monitor
# the "model_nums", and "models" are used for multi-model case and
# model size calculation
self.model_nums = 1
self.ctx.models = [model]
# "mirrored_models": whether the internal multi-models adopt the
# same architects and almost the same behaviors,
# which is used to simply the flops, model size calculation
self.ctx.mirrored_models = False
# Atomic operation during training/evaluation
self.hooks_in_train = collections.defaultdict(list)
# By default, use the same trigger keys
self.hooks_in_eval = copy.deepcopy(self.hooks_in_train)
self.hooks_in_ft = copy.deepcopy(self.hooks_in_train)
# register necessary hooks into self.hooks_in_train and
# self.hooks_in_eval
if not only_for_eval:
self.register_default_hooks_train()
if self.cfg.finetune.before_eval:
self.register_default_hooks_ft()
self.register_default_hooks_eval()
if self.cfg.federate.mode == 'distributed':
self.print_trainer_meta_info()
else:
# in standalone mode, by default, we print the trainer info only
# once for better logs readability
pass
@property
def cfg(self):
return self._cfg
@cfg.setter
def cfg(self, new_cfg):
self._cfg = new_cfg
self.ctx.cfg = new_cfg
self._setup_data_related_var_in_ctx(self.ctx)
def parse_data(self, data):
"""
Populate ``${split}_data``, ``${split}_loader`` and \
``num_${split}_data`` for different data splits
"""
raise NotImplementedError
def setup_data(self, ctx):
"""
Initialization data by ``cfg``.
"""
pass
def _setup_data_related_var_in_ctx(self, ctx):
"""
Populate ``${split}_data``, ``${split}_loader`` and \
``num_${split}_data`` for different data splits, and setup init var \
in ctx.
"""
self.setup_data(ctx)
init_dict = self.parse_data(ctx.data)
ctx.merge_from_dict(init_dict)
def register_default_hooks_train(self):
pass
def register_default_hooks_eval(self):
pass
def register_default_hooks_ft(self):
pass
def reset_hook_in_train(self, target_trigger, target_hook_name=None):
hooks_dict = self.hooks_in_train
del_one_hook_idx = self._reset_hook_in_trigger(hooks_dict,
target_hook_name,
target_trigger)
return del_one_hook_idx
def reset_hook_in_eval(self, target_trigger, target_hook_name=None):
hooks_dict = self.hooks_in_eval
del_one_hook_idx = self._reset_hook_in_trigger(hooks_dict,
target_hook_name,
target_trigger)
return del_one_hook_idx
def replace_hook_in_train(self, new_hook, target_trigger,
target_hook_name):
del_one_hook_idx = self.reset_hook_in_train(
target_trigger=target_trigger, target_hook_name=target_hook_name)
self.register_hook_in_train(new_hook=new_hook,
trigger=target_trigger,
insert_pos=del_one_hook_idx)
def replace_hook_in_eval(self, new_hook, target_trigger, target_hook_name):
del_one_hook_idx = self.reset_hook_in_eval(
target_trigger=target_trigger, target_hook_name=target_hook_name)
self.register_hook_in_eval(new_hook=new_hook,
trigger=target_trigger,
insert_pos=del_one_hook_idx)
def _reset_hook_in_trigger(self, hooks_dict, target_hook_name,
target_trigger):
# clean/delete existing hooks for a specific trigger,
# if target_hook_name given, will clean only the specific one;
# otherwise, will clean all hooks for the trigger.
assert target_trigger in self.HOOK_TRIGGER, \
f"Got {target_trigger} as hook trigger, you should specify a " \
f"string within {self.HOOK_TRIGGER}."
del_one_hook_idx = None
if target_hook_name is None:
hooks_dict[target_trigger] = []
del_one_hook_idx = -1 # -1 indicates del the whole list
else:
for hook_idx in range(len(hooks_dict[target_trigger])):
if target_hook_name == hooks_dict[target_trigger][
hook_idx].__name__:
del_one = hooks_dict[target_trigger].pop(hook_idx)
logger.info(f"Remove the hook `{del_one.__name__}` from "
f"hooks_set at trigger `{target_trigger}`")
del_one_hook_idx = hook_idx
break
if del_one_hook_idx is None:
logger.warning(
f"In hook del procedure, can't find the target hook "
f"named {target_hook_name}")
return del_one_hook_idx
def register_hook_in_train(self,
new_hook,
trigger,
insert_pos=None,
base_hook=None,
insert_mode="before"):
hooks_dict = self.hooks_in_train
self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
new_hook, trigger)
def register_hook_in_ft(self,
new_hook,
trigger,
insert_pos=None,
base_hook=None,
insert_mode="before"):
hooks_dict = self.hooks_in_ft
self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
new_hook, trigger)
def register_hook_in_eval(self,
new_hook,
trigger,
insert_pos=None,
base_hook=None,
insert_mode="before"):
hooks_dict = self.hooks_in_eval
self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
new_hook, trigger)
def _register_hook(self, base_hook, hooks_dict, insert_mode, insert_pos,
new_hook, trigger):
assert trigger in self.HOOK_TRIGGER, \
f"Got {trigger} as hook trigger, you should specify a string " \
f"within {self.HOOK_TRIGGER}."
# parse the insertion position
target_hook_set = hooks_dict[trigger]
if insert_pos is not None:
assert (insert_pos == -1) or (insert_pos == len(target_hook_set)
== 0) or \
(0 <= insert_pos <= (len(target_hook_set))), \
f"Got {insert_pos} as insert pos, you should specify a " \
f"integer (1) =-1 " \
f"or (2) =0 for null target_hook_set;" \
f"or (3) within [0, {(len(target_hook_set))}]."
elif base_hook is not None:
base_hook_pos = target_hook_set.index(base_hook)
insert_pos = base_hook_pos - 1 if insert_mode == "before" else \
base_hook_pos + 1
# bounding the insert_pos in rational range
insert_pos = 0 if insert_pos < 0 else insert_pos
insert_pos = -1 if insert_pos > len(
target_hook_set) else insert_pos
else:
insert_pos = -1 # By default, the new hook is called finally
# register the new hook
if insert_pos == -1:
hooks_dict[trigger].append(new_hook)
else:
hooks_dict[trigger].insert(insert_pos, new_hook)
@use_diff
def train(self, target_data_split_name="train", hooks_set=None):
hooks_set = hooks_set or self.hooks_in_train
self.ctx.check_split(target_data_split_name)
num_samples = self._run_routine(MODE.TRAIN, hooks_set,
target_data_split_name)
return num_samples, self.get_model_para(), self.ctx.eval_metrics
def evaluate(self, target_data_split_name="test", hooks_set=None):
hooks_set = hooks_set or self.hooks_in_eval
if self.ctx.check_split(target_data_split_name, skip=True):
self._run_routine(MODE.TEST, hooks_set, target_data_split_name)
else:
self.ctx.eval_metrics = dict()
return self.ctx.eval_metrics
def finetune(self, target_data_split_name="train", hooks_set=None):
hooks_set = hooks_set or self.hooks_in_ft
self.ctx.check_split(target_data_split_name)
self._run_routine(MODE.FINETUNE, hooks_set, target_data_split_name)
@lifecycle(LIFECYCLE.ROUTINE)
def _run_routine(self, mode, hooks_set, dataset_name=None):
"""Run the hooks_set and maintain the mode
Arguments:
mode: running mode of client, chosen from train/val/test
Note:
Considering evaluation could be in ```hooks_set["on_epoch_end"]```,
there could be two data loaders in self.ctx, we must tell the
running hooks which data_loader to call and which
num_samples to count
"""
for hook in hooks_set["on_fit_start"]:
hook(self.ctx)
self._run_epoch(hooks_set)
for hook in hooks_set["on_fit_end"]:
hook(self.ctx)
return self.ctx.num_samples
@lifecycle(LIFECYCLE.EPOCH)
def _run_epoch(self, hooks_set):
for epoch_i in range(
getattr(self.ctx, f"num_{self.ctx.cur_split}_epoch")):
self.ctx.cur_epoch_i = CtxVar(epoch_i, "epoch")
for hook in hooks_set["on_epoch_start"]:
hook(self.ctx)
self._run_batch(hooks_set)
for hook in hooks_set["on_epoch_end"]:
hook(self.ctx)
@lifecycle(LIFECYCLE.BATCH)
def _run_batch(self, hooks_set):
for batch_i in range(
getattr(self.ctx, f"num_{self.ctx.cur_split}_batch")):
self.ctx.cur_batch_i = CtxVar(batch_i, LIFECYCLE.BATCH)
for hook in hooks_set["on_batch_start"]:
hook(self.ctx)
for hook in hooks_set["on_batch_forward"]:
hook(self.ctx)
for hook in hooks_set["on_batch_backward"]:
hook(self.ctx)
for hook in hooks_set["on_batch_end"]:
hook(self.ctx)
# Break in the final epoch
if self.ctx.cur_mode in [
MODE.TRAIN, MODE.FINETUNE
] and self.ctx.cur_epoch_i == self.ctx.num_train_epoch - 1:
if batch_i >= self.ctx.num_train_batch_last_epoch - 1:
break
def update(self, model_parameters, strict=False):
"""
Called by the FL client to update the model parameters
Arguments:
model_parameters (dict): {model_name: model_val}
strict (bool): ensure the k-v paris are strictly same
"""
pass
def get_model_para(self):
"""
:return: model_parameters (dict): {model_name: model_val}
"""
pass
def print_trainer_meta_info(self):
"""
print some meta info for code-users, e.g., model type; the para
names will be filtered out, etc.,
"""
logger.info(f"Model meta-info: {type(self.ctx.model)}.")
logger.debug(f"Model meta-info: {self.ctx.model}.")
# logger.info(f"Data meta-info: {self.ctx['data']}.")
ori_para_names = set(self.ctx.model.state_dict().keys())
preserved_paras = self._param_filter(self.ctx.model.state_dict())
preserved_para_names = set(preserved_paras.keys())
filtered_para_names = ori_para_names - preserved_para_names
logger.info(f"Num of original para names: {len(ori_para_names)}.")
logger.info(f"Num of original trainable para names:"
f" {len(self.ctx['trainable_para_names'])}.")
logger.info(
f"Num of preserved para names in local update:"
f" {len(preserved_para_names)}. \n"
f"Preserved para names in local update: {preserved_para_names}.")
logger.info(
f"Num of filtered para names in local update:"
f" {len(filtered_para_names)}. \n"
f"Filtered para names in local update: {filtered_para_names}.")
logger.info(f"After register default hooks,\n"
f"\tthe hooks_in_train is:\n\t"
f"{format_log_hooks(self.hooks_in_train)};\n"
f"\tthe hooks_in_eval is:\n\
t{format_log_hooks(self.hooks_in_eval)}")
def _param_filter(self, state_dict, filter_keywords=None):
"""
model parameter filter when transmit between local and gloabl,
which is useful in personalization.
e.g., setting cfg.personalization.local_param= ['bn', 'norms']
indicates the implementation of
"FedBN: Federated Learning on Non-IID Features via Local Batch
Normalization, ICML2021", which can be found in
https://openreview.net/forum?id=6YEQUn0QICG
Arguments:
state_dict (dict): PyTorch Module object's state_dict.
Returns:
state_dict (dict): remove the keys that match any of the given
keywords.
"""
if self.cfg.federate.method in ["local", "global"]:
return {}
if filter_keywords is None:
filter_keywords = self.cfg.personalization.local_param
trainable_filter = lambda p: True if \
self.cfg.personalization.share_non_trainable_para else \
lambda p: p in self.ctx.trainable_para_names
keyword_filter = filter_by_specified_keywords
return dict(
filter(
lambda elem: trainable_filter(elem[1]) and keyword_filter(
elem[0], filter_keywords), state_dict.items()))
def save_model(self, path, cur_round=-1):
raise NotImplementedError(
"The function `save_model` should be implemented according to "
"the ML backend (Pytorch, Tensorflow ...).")
def load_model(self, path):
raise NotImplementedError(
"The function `load_model` should be implemented according to "
"the ML backend (Pytorch, Tensorflow ...).")