|
12 | 12 | import torch.utils.data |
13 | 13 | import numpy as np |
14 | 14 |
|
15 | | -from utils import CTCLabelConverter, AttnLabelConverter, Averager |
| 15 | +from utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, Averager |
16 | 16 | from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset |
17 | 17 | from model import Model |
18 | 18 | from test import validation |
@@ -45,7 +45,10 @@ def train(opt): |
45 | 45 |
|
46 | 46 | """ model configuration """ |
47 | 47 | if 'CTC' in opt.Prediction: |
48 | | - converter = CTCLabelConverter(opt.character) |
| 48 | + if opt.baiduCTC: |
| 49 | + converter = CTCLabelConverterForBaiduWarpctc(opt.character) |
| 50 | + else: |
| 51 | + converter = CTCLabelConverter(opt.character) |
49 | 52 | else: |
50 | 53 | converter = AttnLabelConverter(opt.character) |
51 | 54 | opt.num_class = len(converter.character) |
@@ -86,7 +89,12 @@ def train(opt): |
86 | 89 |
|
87 | 90 | """ setup loss """ |
88 | 91 | if 'CTC' in opt.Prediction: |
89 | | - criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) |
| 92 | + if opt.baiduCTC: |
| 93 | + # need to install warpctc. see our guideline. |
| 94 | + from warpctc_pytorch import CTCLoss |
| 95 | + criterion = CTCLoss() |
| 96 | + else: |
| 97 | + criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) |
90 | 98 | else: |
91 | 99 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 |
92 | 100 | # loss averager |
@@ -144,8 +152,12 @@ def train(opt): |
144 | 152 | if 'CTC' in opt.Prediction: |
145 | 153 | preds = model(image, text) |
146 | 154 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) |
147 | | - preds = preds.log_softmax(2).permute(1, 0, 2) |
148 | | - cost = criterion(preds, text, preds_size, length) |
| 155 | + if opt.baiduCTC: |
| 156 | + preds = preds.permute(1, 0, 2) # to use CTCLoss format |
| 157 | + cost = criterion(preds, text, preds_size, length) / batch_size |
| 158 | + else: |
| 159 | + preds = preds.log_softmax(2).permute(1, 0, 2) |
| 160 | + cost = criterion(preds, text, preds_size, length) |
149 | 161 |
|
150 | 162 | else: |
151 | 163 | preds = model(image, text[:, :-1]) # align with Attention.forward |
@@ -232,6 +244,7 @@ def train(opt): |
232 | 244 | parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95') |
233 | 245 | parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8') |
234 | 246 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5') |
| 247 | + parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode') |
235 | 248 | """ Data processing """ |
236 | 249 | parser.add_argument('--select_data', type=str, default='MJ-ST', |
237 | 250 | help='select training data (default is MJ-ST, which means MJ and ST used as training data)') |
|
0 commit comments