Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,5 @@ pix2tex/model/checkpoints/**
!**/.gitkeep
.vscode
.DS_Store
test/*

10 changes: 3 additions & 7 deletions pix2tex/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, arguments=None):
download_checkpoints()
self.model = get_model(self.args)
self.model.load_state_dict(torch.load(self.args.checkpoint, map_location=self.args.device))
self.model.eval()

if 'image_resizer.pth' in os.listdir(os.path.dirname(self.args.checkpoint)) and not arguments.no_resize:
self.image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max(self.args.max_dimensions)//32, global_pool='avg', in_chans=1, drop_rate=.05,
Expand Down Expand Up @@ -123,13 +124,8 @@ def __call__(self, img=None, resize=True) -> str:
t = test_transform(image=img)['image'][:1].unsqueeze(0)
im = t.to(self.args.device)

with torch.no_grad():
self.model.eval()
device = self.args.device
encoded = self.model.encoder(im.to(device))
dec = self.model.decoder.generate(torch.LongTensor([self.args.bos_token])[:, None].to(device), self.args.max_seq_len,
eos_token=self.args.eos_token, context=encoded.detach(), temperature=self.args.get('temperature', .25))
pred = post_process(token2str(dec, self.tokenizer)[0])
dec = self.model.generate(im.to(self.args.device), temperature=self.args.get('temperature', .25))
pred = post_process(token2str(dec, self.tokenizer)[0])
try:
clipboard.copy(pred)
except:
Expand Down
4 changes: 1 addition & 3 deletions pix2tex/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i
for i, (seq, im) in pbar:
if seq is None or im is None:
continue
encoded = model.encoder(im.to(device))
#loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)
dec = model.decoder.generate(torch.LongTensor([args.bos_token]*len(encoded))[:, None].to(device), args.max_seq_len,
eos_token=args.pad_token, context=encoded, temperature=args.get('temperature', .2))
dec = model.generate(im.to(device), temperature=args.get('temperature', .2))
pred = detokenize(dec, dataset.tokenizer)
truth = detokenize(seq['input_ids'], dataset.tokenizer)
bleus.append(metrics.bleu_score(pred, [alternatives(x) for x in truth]))
Expand Down
52 changes: 52 additions & 0 deletions pix2tex/model/settings/config-vit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
gpu_devices: null #[0,1,2,3,4,5,6,7]
betas:
- 0.9
- 0.999
batchsize: 64
bos_token: 1
channels: 1
data: dataset/data/train.pkl
debug: false
decoder_args:
attn_on_attn: true
cross_attend: true
ff_glu: true
rel_pos_bias: false
use_scalenorm: false
dim: 256
emb_dropout: 0
encoder_depth: 4
eos_token: 2
epochs: 10
gamma: 0.9995
heads: 8
id: null
load_chkpt: null
lr: 0.0005
lr_step: 30
max_height: 192
max_seq_len: 512
max_width: 672
min_height: 32
min_width: 32
micro_batchsize: 64
model_path: checkpoints_add
name: pix2tex-vit
num_layers: 4
num_tokens: 8000
optimizer: Adam
output_path: outputs
pad: false
pad_token: 0
patch_size: 16
sample_freq: 1000
save_freq: 5
scheduler: StepLR
seed: 42
encoder_structure: vit
temperature: 0.2
test_samples: 5
testbatchsize: 20
tokenizer: dataset/tokenizer.json
valbatches: 100
valdata: dataset/data/val.pkl
2 changes: 2 additions & 0 deletions pix2tex/model/settings/config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
gpu_devices: null #[0,1,2,3,4,5,6,7]
backbone_layers:
- 2
- 3
Expand Down Expand Up @@ -45,6 +46,7 @@ sample_freq: 3000
save_freq: 5
scheduler: StepLR
seed: 42
encoder_structure: hybrid
temperature: 0.2
test_samples: 5
testbatchsize: 20
Expand Down
1 change: 1 addition & 0 deletions pix2tex/model/settings/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ decoder_args:
heads: 8
num_tokens: 8000
max_seq_len: 1024
encoder_structure: hybrid

# Other
seed: 42
Expand Down
160 changes: 0 additions & 160 deletions pix2tex/models.py

This file was deleted.

1 change: 1 addition & 0 deletions pix2tex/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .utils import *
56 changes: 56 additions & 0 deletions pix2tex/models/hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import torch.nn as nn

from timm.models.vision_transformer import VisionTransformer
from timm.models.vision_transformer_hybrid import HybridEmbed
from timm.models.resnetv2 import ResNetV2
from timm.models.layers import StdConv2dSame
from einops import repeat

class CustomVisionTransformer(VisionTransformer):
def __init__(self, img_size=224, patch_size=16, *args, **kwargs):
super(CustomVisionTransformer, self).__init__(img_size=img_size, patch_size=patch_size, *args, **kwargs)
self.height, self.width = img_size
self.patch_size = patch_size

def forward_features(self, x):
B, c, h, w = x.shape
x = self.patch_embed(x)

cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
h, w = h//self.patch_size, w//self.patch_size
pos_emb_ind = repeat(torch.arange(h)*(self.width//self.patch_size-w), 'h -> (h w)', w=w)+torch.arange(h*w)
pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()
x += self.pos_embed[:, pos_emb_ind]
#x = x + self.pos_embed
x = self.pos_drop(x)

for blk in self.blocks:
x = blk(x)

x = self.norm(x)
return x


def get_encoder(args):
backbone = ResNetV2(
layers=args.backbone_layers, num_classes=0, global_pool='', in_chans=args.channels,
preact=False, stem_type='same', conv_layer=StdConv2dSame)
min_patch_size = 2**(len(args.backbone_layers)+1)

def embed_layer(**x):
ps = x.pop('patch_size', min_patch_size)
assert ps % min_patch_size == 0 and ps >= min_patch_size, 'patch_size needs to be multiple of %i with current backbone configuration' % min_patch_size
return HybridEmbed(**x, patch_size=ps//min_patch_size, backbone=backbone)

encoder = CustomVisionTransformer(img_size=(args.max_height, args.max_width),
patch_size=args.patch_size,
in_chans=args.channels,
num_classes=0,
embed_dim=args.dim,
depth=args.encoder_depth,
num_heads=args.heads,
embed_layer=embed_layer
)
return encoder
Loading