We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5fca76e commit 0aefdbfCopy full SHA for 0aefdbf
pix2tex/models/utils.py
@@ -46,7 +46,9 @@ def get_model(args, training=False):
46
decoder.to(args.device)
47
if args.wandb:
48
import wandb
49
- en_attn_layers = encoder.module.attn_layers if available_gpus > 1 else encoder.attn_layers
+ en_attn_layers = encoder
50
+ if args.encoder_structure.lower() == 'vit':
51
+ en_attn_layers = encoder.module.attn_layers if available_gpus > 1 else encoder.attn_layers
52
de_attn_layers = decoder.module.net.attn_layers if available_gpus > 1 else decoder.net.attn_layers
53
wandb.watch((en_attn_layers, de_attn_layers))
54
model = Model(encoder, decoder, args)
0 commit comments