We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7154c03 commit ff2641cCopy full SHA for ff2641c
pix2tex/models/transformer.py
@@ -61,6 +61,6 @@ def get_decoder(args):
61
dim=args.dim,
62
depth=args.num_layers,
63
heads=args.heads,
64
- cross_attend=True
+ **args.decoder_args
65
)),
66
pad_value=args.pad_token)
pix2tex/train.py
@@ -70,7 +70,7 @@ def save_models(e, step=0):
70
max_bleu, max_token_acc = bleu_score, token_accuracy
71
save_models(e, step=i)
72
if (e+1) % args.save_freq == 0:
73
- save_models(e)
+ save_models(e, step=len(dataloader))
74
if args.wandb:
75
wandb.log({'train/epoch': e+1})
76
except KeyboardInterrupt:
0 commit comments