Skip to content

Commit ff2641c

Browse files
committed
use correct decoder args
1 parent 7154c03 commit ff2641c

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

pix2tex/models/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,6 @@ def get_decoder(args):
6161
dim=args.dim,
6262
depth=args.num_layers,
6363
heads=args.heads,
64-
cross_attend=True
64+
**args.decoder_args
6565
)),
6666
pad_value=args.pad_token)

pix2tex/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def save_models(e, step=0):
7070
max_bleu, max_token_acc = bleu_score, token_accuracy
7171
save_models(e, step=i)
7272
if (e+1) % args.save_freq == 0:
73-
save_models(e)
73+
save_models(e, step=len(dataloader))
7474
if args.wandb:
7575
wandb.log({'train/epoch': e+1})
7676
except KeyboardInterrupt:

0 commit comments

Comments
 (0)