Skip to content

Commit 422f55a

Browse files
author
Yibing Liu
authored
Merge pull request #122 from kuke/fix_tune
Decouple ext scorer init & inference & decoding for the convenience o…
2 parents f896f0e + 7c6fa64 commit 422f55a

6 files changed

Lines changed: 185 additions & 188 deletions

File tree

deploy/demo_server.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,30 @@ def start_server():
160160

161161
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
162162

163+
if args.decoding_method == "ctc_beam_search":
164+
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
165+
vocab_list)
163166
# prepare ASR inference handler
164167
def file_to_transcript(filename):
165168
feature = data_generator.process_utterance(filename, "")
166-
167-
result_transcript = ds2_model.infer_batch(
169+
probs_split = ds2_model.infer_batch_probs(
168170
infer_data=[feature],
169-
decoding_method=args.decoding_method,
170-
beam_alpha=args.alpha,
171-
beam_beta=args.beta,
172-
beam_size=args.beam_size,
173-
cutoff_prob=args.cutoff_prob,
174-
cutoff_top_n=args.cutoff_top_n,
175-
vocab_list=vocab_list,
176-
language_model_path=args.lang_model_path,
177-
num_processes=1,
178171
feeding_dict=data_generator.feeding)
172+
173+
if args.decoding_method == "ctc_greedy":
174+
result_transcript = ds2_model.decode_batch_greedy(
175+
probs_split=probs_split,
176+
vocab_list=vocab_list)
177+
else:
178+
result_transcript = ds2_model.decode_batch_beam_search(
179+
probs_split=probs_split,
180+
beam_alpha=args.alpha,
181+
beam_beta=args.beta,
182+
beam_size=args.beam_size,
183+
cutoff_prob=args.cutoff_prob,
184+
cutoff_top_n=args.cutoff_top_n,
185+
vocab_list=vocab_list,
186+
num_processes=1)
179187
return result_transcript[0]
180188

181189
# warming up with utterrances sampled from Librispeech

examples/librispeech/run_tune.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 \
77
python -u tools/tune.py \
88
--num_batches=-1 \
99
--batch_size=128 \
10-
--trainer_count=8 \
10+
--trainer_count=4 \
1111
--beam_size=500 \
1212
--num_proc_bsearch=12 \
1313
--num_conv_layers=2 \

infer.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,28 @@ def infer():
9090
# decoders only accept string encoded in utf-8
9191
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
9292

93-
result_transcripts = ds2_model.infer_batch(
94-
infer_data=infer_data,
95-
decoding_method=args.decoding_method,
96-
beam_alpha=args.alpha,
97-
beam_beta=args.beta,
98-
beam_size=args.beam_size,
99-
cutoff_prob=args.cutoff_prob,
100-
cutoff_top_n=args.cutoff_top_n,
101-
vocab_list=vocab_list,
102-
language_model_path=args.lang_model_path,
103-
num_processes=args.num_proc_bsearch,
104-
feeding_dict=data_generator.feeding)
93+
if args.decoding_method == "ctc_greedy":
94+
ds2_model.logger.info("start inference ...")
95+
probs_split = ds2_model.infer_batch_probs(infer_data=infer_data,
96+
feeding_dict=data_generator.feeding)
97+
result_transcripts = ds2_model.decode_batch_greedy(
98+
probs_split=probs_split,
99+
vocab_list=vocab_list)
100+
else:
101+
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
102+
vocab_list)
103+
ds2_model.logger.info("start inference ...")
104+
probs_split = ds2_model.infer_batch_probs(infer_data=infer_data,
105+
feeding_dict=data_generator.feeding)
106+
result_transcripts = ds2_model.decode_batch_beam_search(
107+
probs_split=probs_split,
108+
beam_alpha=args.alpha,
109+
beam_beta=args.beta,
110+
beam_size=args.beam_size,
111+
cutoff_prob=args.cutoff_prob,
112+
cutoff_top_n=args.cutoff_top_n,
113+
vocab_list=vocab_list,
114+
num_processes=args.num_proc_bsearch)
105115

106116
error_rate_func = cer if args.error_rate_type == 'cer' else wer
107117
target_transcripts = [data[1] for data in infer_data]

model_utils/model.py

Lines changed: 99 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -173,43 +173,19 @@ def infer_loss_batch(self, infer_data):
173173
# run inference
174174
return self._loss_inferer.infer(input=infer_data)
175175

176-
def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta,
177-
beam_size, cutoff_prob, cutoff_top_n, vocab_list,
178-
language_model_path, num_processes, feeding_dict):
179-
"""Model inference. Infer the transcription for a batch of speech
180-
utterances.
176+
def infer_batch_probs(self, infer_data, feeding_dict):
177+
"""Infer the prob matrices for a batch of speech utterances.
181178
182179
:param infer_data: List of utterances to infer, with each utterance
183180
consisting of a tuple of audio features and
184181
transcription text (empty string).
185182
:type infer_data: list
186-
:param decoding_method: Decoding method name, 'ctc_greedy' or
187-
'ctc_beam_search'.
188-
:param decoding_method: string
189-
:param beam_alpha: Parameter associated with language model.
190-
:type beam_alpha: float
191-
:param beam_beta: Parameter associated with word count.
192-
:type beam_beta: float
193-
:param beam_size: Width for Beam search.
194-
:type beam_size: int
195-
:param cutoff_prob: Cutoff probability in pruning,
196-
default 1.0, no pruning.
197-
:type cutoff_prob: float
198-
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
199-
characters with highest probs in vocabulary will be
200-
used in beam search, default 40.
201-
:type cutoff_top_n: int
202-
:param vocab_list: List of tokens in the vocabulary, for decoding.
203-
:type vocab_list: list
204-
:param language_model_path: Filepath for language model.
205-
:type language_model_path: basestring|None
206-
:param num_processes: Number of processes (CPU) for decoder.
207-
:type num_processes: int
208183
:param feeding_dict: Feeding is a map of field name and tuple index
209184
of the data that reader returns.
210185
:type feeding_dict: dict|list
211-
:return: List of transcription texts.
212-
:rtype: List of basestring
186+
:return: List of 2-D probability matrix, and each consists of prob
187+
vectors for one speech utterancce.
188+
:rtype: List of matrix
213189
"""
214190
# define inferer
215191
if self._inferer == None:
@@ -227,49 +203,102 @@ def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta,
227203
infer_results[start_pos[i]:start_pos[i + 1]]
228204
for i in xrange(0, len(adapted_infer_data))
229205
]
230-
# run decoder
206+
return probs_split
207+
208+
def decode_batch_greedy(self, probs_split, vocab_list):
209+
"""Decode by best path for a batch of probs matrix input.
210+
211+
:param probs_split: List of 2-D probability matrix, and each consists
212+
of prob vectors for one speech utterancce.
213+
:param probs_split: List of matrix
214+
:param vocab_list: List of tokens in the vocabulary, for decoding.
215+
:type vocab_list: list
216+
:return: List of transcription texts.
217+
:rtype: List of basestring
218+
"""
231219
results = []
232-
if decoding_method == "ctc_greedy":
233-
# best path decode
234-
for i, probs in enumerate(probs_split):
235-
output_transcription = ctc_greedy_decoder(
236-
probs_seq=probs, vocabulary=vocab_list)
237-
results.append(output_transcription)
238-
elif decoding_method == "ctc_beam_search":
239-
# initialize external scorer
240-
if self._ext_scorer == None:
241-
self._loaded_lm_path = language_model_path
242-
self.logger.info("begin to initialize the external scorer "
243-
"for decoding")
244-
self._ext_scorer = Scorer(beam_alpha, beam_beta,
245-
language_model_path, vocab_list)
246-
247-
lm_char_based = self._ext_scorer.is_character_based()
248-
lm_max_order = self._ext_scorer.get_max_order()
249-
lm_dict_size = self._ext_scorer.get_dict_size()
250-
self.logger.info("language model: "
251-
"is_character_based = %d," % lm_char_based +
252-
" max_order = %d," % lm_max_order +
253-
" dict_size = %d" % lm_dict_size)
254-
self.logger.info("end initializing scorer. Start decoding ...")
255-
else:
256-
self._ext_scorer.reset_params(beam_alpha, beam_beta)
257-
assert self._loaded_lm_path == language_model_path
258-
# beam search decode
259-
num_processes = min(num_processes, len(probs_split))
260-
beam_search_results = ctc_beam_search_decoder_batch(
261-
probs_split=probs_split,
262-
vocabulary=vocab_list,
263-
beam_size=beam_size,
264-
num_processes=num_processes,
265-
ext_scoring_func=self._ext_scorer,
266-
cutoff_prob=cutoff_prob,
267-
cutoff_top_n=cutoff_top_n)
268-
269-
results = [result[0][1] for result in beam_search_results]
220+
for i, probs in enumerate(probs_split):
221+
output_transcription = ctc_greedy_decoder(
222+
probs_seq=probs, vocabulary=vocab_list)
223+
results.append(output_transcription)
224+
return results
225+
226+
def init_ext_scorer(self, beam_alpha, beam_beta, language_model_path,
227+
vocab_list):
228+
"""Initialize the external scorer.
229+
230+
:param beam_alpha: Parameter associated with language model.
231+
:type beam_alpha: float
232+
:param beam_beta: Parameter associated with word count.
233+
:type beam_beta: float
234+
:param language_model_path: Filepath for language model. If it is
235+
empty, the external scorer will be set to
236+
None, and the decoding method will be pure
237+
beam search without scorer.
238+
:type language_model_path: basestring|None
239+
:param vocab_list: List of tokens in the vocabulary, for decoding.
240+
:type vocab_list: list
241+
"""
242+
if language_model_path != '':
243+
self.logger.info("begin to initialize the external scorer "
244+
"for decoding")
245+
self._ext_scorer = Scorer(beam_alpha, beam_beta,
246+
language_model_path, vocab_list)
247+
lm_char_based = self._ext_scorer.is_character_based()
248+
lm_max_order = self._ext_scorer.get_max_order()
249+
lm_dict_size = self._ext_scorer.get_dict_size()
250+
self.logger.info("language model: "
251+
"is_character_based = %d," % lm_char_based +
252+
" max_order = %d," % lm_max_order +
253+
" dict_size = %d" % lm_dict_size)
254+
self.logger.info("end initializing scorer")
270255
else:
271-
raise ValueError("Decoding method [%s] is not supported." %
272-
decoding_method)
256+
self._ext_scorer = None
257+
self.logger.info("no language model provided, "
258+
"decoding by pure beam search without scorer.")
259+
260+
def decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
261+
beam_size, cutoff_prob, cutoff_top_n,
262+
vocab_list, num_processes):
263+
"""Decode by beam search for a batch of probs matrix input.
264+
265+
:param probs_split: List of 2-D probability matrix, and each consists
266+
of prob vectors for one speech utterancce.
267+
:param probs_split: List of matrix
268+
:param beam_alpha: Parameter associated with language model.
269+
:type beam_alpha: float
270+
:param beam_beta: Parameter associated with word count.
271+
:type beam_beta: float
272+
:param beam_size: Width for Beam search.
273+
:type beam_size: int
274+
:param cutoff_prob: Cutoff probability in pruning,
275+
default 1.0, no pruning.
276+
:type cutoff_prob: float
277+
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
278+
characters with highest probs in vocabulary will be
279+
used in beam search, default 40.
280+
:type cutoff_top_n: int
281+
:param vocab_list: List of tokens in the vocabulary, for decoding.
282+
:type vocab_list: list
283+
:param num_processes: Number of processes (CPU) for decoder.
284+
:type num_processes: int
285+
:return: List of transcription texts.
286+
:rtype: List of basestring
287+
"""
288+
if self._ext_scorer != None:
289+
self._ext_scorer.reset_params(beam_alpha, beam_beta)
290+
# beam search decode
291+
num_processes = min(num_processes, len(probs_split))
292+
beam_search_results = ctc_beam_search_decoder_batch(
293+
probs_split=probs_split,
294+
vocabulary=vocab_list,
295+
beam_size=beam_size,
296+
num_processes=num_processes,
297+
ext_scoring_func=self._ext_scorer,
298+
cutoff_prob=cutoff_prob,
299+
cutoff_top_n=cutoff_top_n)
300+
301+
results = [result[0][1] for result in beam_search_results]
273302
return results
274303

275304
def _adapt_feeding_dict(self, feeding_dict):

test.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,22 +90,33 @@ def evaluate():
9090
# decoders only accept string encoded in utf-8
9191
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
9292

93+
if args.decoding_method == "ctc_beam_search":
94+
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
95+
vocab_list)
9396
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
9497
errors_sum, len_refs, num_ins = 0.0, 0, 0
98+
ds2_model.logger.info("start evaluation ...")
9599
for infer_data in batch_reader():
96-
result_transcripts = ds2_model.infer_batch(
100+
probs_split = ds2_model.infer_batch_probs(
97101
infer_data=infer_data,
98-
decoding_method=args.decoding_method,
99-
beam_alpha=args.alpha,
100-
beam_beta=args.beta,
101-
beam_size=args.beam_size,
102-
cutoff_prob=args.cutoff_prob,
103-
cutoff_top_n=args.cutoff_top_n,
104-
vocab_list=vocab_list,
105-
language_model_path=args.lang_model_path,
106-
num_processes=args.num_proc_bsearch,
107102
feeding_dict=data_generator.feeding)
103+
104+
if args.decoding_method == "ctc_greedy":
105+
result_transcripts = ds2_model.decode_batch_greedy(
106+
probs_split=probs_split,
107+
vocab_list=vocab_list)
108+
else:
109+
result_transcripts = ds2_model.decode_batch_beam_search(
110+
probs_split=probs_split,
111+
beam_alpha=args.alpha,
112+
beam_beta=args.beta,
113+
beam_size=args.beam_size,
114+
cutoff_prob=args.cutoff_prob,
115+
cutoff_top_n=args.cutoff_top_n,
116+
vocab_list=vocab_list,
117+
num_processes=args.num_proc_bsearch)
108118
target_transcripts = [data[1] for data in infer_data]
119+
109120
for target, result in zip(target_transcripts, result_transcripts):
110121
errors, len_ref = errors_func(target, result)
111122
errors_sum += errors

0 commit comments

Comments
 (0)