Skip to content

Commit d9fc1b4

Browse files
committed
Custom callback support
1 parent 097e153 commit d9fc1b4

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

aitextgen/aitextgen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ def train(
581581
progress_bar_refresh_rate: int = 20,
582582
freeze_layers: bool = False,
583583
num_layers_freeze: int = None,
584+
custom_callbacks: dict = {},
584585
**kwargs,
585586
) -> None:
586587

@@ -667,6 +668,7 @@ def train(
667668
save_gdrive,
668669
avg_loss_smoothing,
669670
is_gpu_used,
671+
custom_callbacks
670672
)
671673
)
672674

aitextgen/train_callback.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
save_gdrive,
2424
avg_loss_smoothing,
2525
is_gpu_used,
26+
custom_callbacks
2627
):
2728
self.training_bar = None
2829
self.model = model
@@ -39,6 +40,7 @@ def __init__(
3940
self.steps = 0
4041
self.current_loss = None
4142
self.prev_avg_loss = None
43+
self.custom_callbacks = custom_callbacks
4244

4345
@property
4446
def save_every_check(self):
@@ -54,17 +56,22 @@ def on_train_begin(self, args, state, control, **kwargs):
5456
file=sys.stdout,
5557
)
5658

59+
self.custom_callbacks.get('on_train_begin', lambda: None)()
60+
5761
def on_train_end(self, args, state, control, **kwargs):
5862
if state.is_local_process_zero:
5963
self.training_bar.close()
6064
self.training_bar = None
6165

66+
self.custom_callbacks.get('on_train_end', lambda: None)()
67+
6268
def on_evaluate(self, args, state, control, metrics, **kwargs):
69+
print(f'''on_evaluate called: train_loss={metrics.get('train_loss')}''')
6370
if state.is_local_process_zero:
6471
self.current_loss = float(metrics.get("train_loss"))
6572

6673
def on_step_end(self, args, state, control, **kwargs):
67-
74+
print(f'on_step_end called\nargs={args}\nstate={state}\ncontrol={control}\nkwargs={kwargs}\n=====')
6875
if state.is_local_process_zero:
6976
self.steps += 1
7077
avg_loss = 0
@@ -100,6 +107,8 @@ def on_step_end(self, args, state, control, **kwargs):
100107
if self.current_loss:
101108
self.training_bar.set_description(desc)
102109

110+
self.custom_callbacks.get('on_step_end', lambda steps, max, curr, avg, trainer: None)(self.steps, state.max_steps, self.current_loss, avg_loss, self.trainer)
111+
103112
if self.save_every > 0 and self.steps % self.save_every == 0:
104113
self.save_pytorch_model()
105114

@@ -138,6 +147,8 @@ def generate_sample_text(self):
138147

139148
self.training_bar.write("=" * 10)
140149

150+
self.custom_callbacks.get('on_sample_text_generated', lambda texts: None)(gen_texts)
151+
141152
def save_pytorch_model(self):
142153
# only runs on state.is_local_process_zero
143154
self.training_bar.write(
@@ -154,6 +165,8 @@ def save_pytorch_model(self):
154165
os.path.join("/content/drive/My Drive/", self.run_id, pt_file),
155166
)
156167

168+
self.custom_callbacks.get('on_model_saved', lambda curr, out: None)(self.steps, self.output_dir)
169+
157170
def average_loss(self, current_loss, prev_avg_loss, smoothing):
158171
if prev_avg_loss is None:
159172
return current_loss

0 commit comments

Comments
 (0)