diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index c580b2a..cc0d1df 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -662,7 +662,9 @@ def train( progress_bar_refresh_rate, save_every, generate_every, + n_generate, output_dir, + save_gdrive, avg_loss_smoothing, is_gpu_used, ) diff --git a/aitextgen/train_callback.py b/aitextgen/train_callback.py index a282efe..56a1ce2 100644 --- a/aitextgen/train_callback.py +++ b/aitextgen/train_callback.py @@ -18,7 +18,9 @@ def __init__( refresh_rate, save_every, generate_every, + n_generate, output_dir, + save_gdrive, avg_loss_smoothing, is_gpu_used, ): @@ -29,7 +31,9 @@ def __init__( self.refresh_rate = refresh_rate self.save_every = save_every self.generate_every = generate_every + self.n_generate = n_generate self.output_dir = output_dir + self.save_gdrive = save_gdrive self.smoothing = avg_loss_smoothing self.gpu = is_gpu_used self.steps = 0