Skip to content

Commit 51dde0e

Browse files
authored
Merge pull request #261 from tarun-menta/ocr-error-model
Add OCR Error Detection Model
2 parents a3fde2f + f63a4bf commit 51dde0e

File tree

7 files changed

+1599
-1
lines changed

7 files changed

+1599
-1
lines changed

signatures/version1/cla.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@
6363
"created_at": "2024-10-30T17:55:23Z",
6464
"repoId": 741297064,
6565
"pullRequestNo": 235
66+
},
67+
{
68+
"name": "ArthurMor4is",
69+
"id": 42987302,
70+
"comment_id": 2515315717,
71+
"created_at": "2024-12-03T18:37:45Z",
72+
"repoId": 741297064,
73+
"pullRequestNo": 255
6674
}
6775
]
6876
}

surya/model/ocr_error/config.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from collections import OrderedDict
2+
from typing import Mapping
3+
4+
from transformers.configuration_utils import PretrainedConfig
5+
from transformers.onnx import OnnxConfig
6+
7+
ID2LABEL = {
8+
0: 'good',
9+
1: 'bad'
10+
}
11+
12+
class DistilBertConfig(PretrainedConfig):
13+
model_type = "distilbert"
14+
attribute_map = {
15+
"hidden_size": "dim",
16+
"num_attention_heads": "n_heads",
17+
"num_hidden_layers": "n_layers",
18+
}
19+
20+
def __init__(
21+
self,
22+
vocab_size=30522,
23+
max_position_embeddings=512,
24+
sinusoidal_pos_embds=False,
25+
n_layers=6,
26+
n_heads=12,
27+
dim=768,
28+
hidden_dim=4 * 768,
29+
dropout=0.1,
30+
attention_dropout=0.1,
31+
activation="gelu",
32+
initializer_range=0.02,
33+
qa_dropout=0.1,
34+
seq_classif_dropout=0.2,
35+
pad_token_id=0,
36+
**kwargs,
37+
):
38+
self.vocab_size = vocab_size
39+
self.max_position_embeddings = max_position_embeddings
40+
self.sinusoidal_pos_embds = sinusoidal_pos_embds
41+
self.n_layers = n_layers
42+
self.n_heads = n_heads
43+
self.dim = dim
44+
self.hidden_dim = hidden_dim
45+
self.dropout = dropout
46+
self.attention_dropout = attention_dropout
47+
self.activation = activation
48+
self.initializer_range = initializer_range
49+
self.qa_dropout = qa_dropout
50+
self.seq_classif_dropout = seq_classif_dropout
51+
super().__init__(**kwargs, pad_token_id=pad_token_id)
52+
53+
54+
class DistilBertOnnxConfig(OnnxConfig):
55+
@property
56+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
57+
if self.task == "multiple-choice":
58+
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
59+
else:
60+
dynamic_axis = {0: "batch", 1: "sequence"}
61+
return OrderedDict(
62+
[
63+
("input_ids", dynamic_axis),
64+
("attention_mask", dynamic_axis),
65+
]
66+
)

0 commit comments

Comments
 (0)