@@ -173,43 +173,19 @@ def infer_loss_batch(self, infer_data):
173173 # run inference
174174 return self ._loss_inferer .infer (input = infer_data )
175175
176- def infer_batch (self , infer_data , decoding_method , beam_alpha , beam_beta ,
177- beam_size , cutoff_prob , cutoff_top_n , vocab_list ,
178- language_model_path , num_processes , feeding_dict ):
179- """Model inference. Infer the transcription for a batch of speech
180- utterances.
176+ def infer_batch_probs (self , infer_data , feeding_dict ):
177+ """Infer the prob matrices for a batch of speech utterances.
181178
182179 :param infer_data: List of utterances to infer, with each utterance
183180 consisting of a tuple of audio features and
184181 transcription text (empty string).
185182 :type infer_data: list
186- :param decoding_method: Decoding method name, 'ctc_greedy' or
187- 'ctc_beam_search'.
188- :param decoding_method: string
189- :param beam_alpha: Parameter associated with language model.
190- :type beam_alpha: float
191- :param beam_beta: Parameter associated with word count.
192- :type beam_beta: float
193- :param beam_size: Width for Beam search.
194- :type beam_size: int
195- :param cutoff_prob: Cutoff probability in pruning,
196- default 1.0, no pruning.
197- :type cutoff_prob: float
198- :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
199- characters with highest probs in vocabulary will be
200- used in beam search, default 40.
201- :type cutoff_top_n: int
202- :param vocab_list: List of tokens in the vocabulary, for decoding.
203- :type vocab_list: list
204- :param language_model_path: Filepath for language model.
205- :type language_model_path: basestring|None
206- :param num_processes: Number of processes (CPU) for decoder.
207- :type num_processes: int
208183 :param feeding_dict: Feeding is a map of field name and tuple index
209184 of the data that reader returns.
210185 :type feeding_dict: dict|list
211- :return: List of transcription texts.
212- :rtype: List of basestring
186+ :return: List of 2-D probability matrix, and each consists of prob
187+ vectors for one speech utterancce.
188+ :rtype: List of matrix
213189 """
214190 # define inferer
215191 if self ._inferer == None :
@@ -227,49 +203,102 @@ def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta,
227203 infer_results [start_pos [i ]:start_pos [i + 1 ]]
228204 for i in xrange (0 , len (adapted_infer_data ))
229205 ]
230- # run decoder
206+ return probs_split
207+
208+ def decode_batch_greedy (self , probs_split , vocab_list ):
209+ """Decode by best path for a batch of probs matrix input.
210+
211+ :param probs_split: List of 2-D probability matrix, and each consists
212+ of prob vectors for one speech utterancce.
213+ :param probs_split: List of matrix
214+ :param vocab_list: List of tokens in the vocabulary, for decoding.
215+ :type vocab_list: list
216+ :return: List of transcription texts.
217+ :rtype: List of basestring
218+ """
231219 results = []
232- if decoding_method == "ctc_greedy" :
233- # best path decode
234- for i , probs in enumerate (probs_split ):
235- output_transcription = ctc_greedy_decoder (
236- probs_seq = probs , vocabulary = vocab_list )
237- results .append (output_transcription )
238- elif decoding_method == "ctc_beam_search" :
239- # initialize external scorer
240- if self ._ext_scorer == None :
241- self ._loaded_lm_path = language_model_path
242- self .logger .info ("begin to initialize the external scorer "
243- "for decoding" )
244- self ._ext_scorer = Scorer (beam_alpha , beam_beta ,
245- language_model_path , vocab_list )
246-
247- lm_char_based = self ._ext_scorer .is_character_based ()
248- lm_max_order = self ._ext_scorer .get_max_order ()
249- lm_dict_size = self ._ext_scorer .get_dict_size ()
250- self .logger .info ("language model: "
251- "is_character_based = %d," % lm_char_based +
252- " max_order = %d," % lm_max_order +
253- " dict_size = %d" % lm_dict_size )
254- self .logger .info ("end initializing scorer. Start decoding ..." )
255- else :
256- self ._ext_scorer .reset_params (beam_alpha , beam_beta )
257- assert self ._loaded_lm_path == language_model_path
258- # beam search decode
259- num_processes = min (num_processes , len (probs_split ))
260- beam_search_results = ctc_beam_search_decoder_batch (
261- probs_split = probs_split ,
262- vocabulary = vocab_list ,
263- beam_size = beam_size ,
264- num_processes = num_processes ,
265- ext_scoring_func = self ._ext_scorer ,
266- cutoff_prob = cutoff_prob ,
267- cutoff_top_n = cutoff_top_n )
268-
269- results = [result [0 ][1 ] for result in beam_search_results ]
220+ for i , probs in enumerate (probs_split ):
221+ output_transcription = ctc_greedy_decoder (
222+ probs_seq = probs , vocabulary = vocab_list )
223+ results .append (output_transcription )
224+ return results
225+
226+ def init_ext_scorer (self , beam_alpha , beam_beta , language_model_path ,
227+ vocab_list ):
228+ """Initialize the external scorer.
229+
230+ :param beam_alpha: Parameter associated with language model.
231+ :type beam_alpha: float
232+ :param beam_beta: Parameter associated with word count.
233+ :type beam_beta: float
234+ :param language_model_path: Filepath for language model. If it is
235+ empty, the external scorer will be set to
236+ None, and the decoding method will be pure
237+ beam search without scorer.
238+ :type language_model_path: basestring|None
239+ :param vocab_list: List of tokens in the vocabulary, for decoding.
240+ :type vocab_list: list
241+ """
242+ if language_model_path != '' :
243+ self .logger .info ("begin to initialize the external scorer "
244+ "for decoding" )
245+ self ._ext_scorer = Scorer (beam_alpha , beam_beta ,
246+ language_model_path , vocab_list )
247+ lm_char_based = self ._ext_scorer .is_character_based ()
248+ lm_max_order = self ._ext_scorer .get_max_order ()
249+ lm_dict_size = self ._ext_scorer .get_dict_size ()
250+ self .logger .info ("language model: "
251+ "is_character_based = %d," % lm_char_based +
252+ " max_order = %d," % lm_max_order +
253+ " dict_size = %d" % lm_dict_size )
254+ self .logger .info ("end initializing scorer" )
270255 else :
271- raise ValueError ("Decoding method [%s] is not supported." %
272- decoding_method )
256+ self ._ext_scorer = None
257+ self .logger .info ("no language model provided, "
258+ "decoding by pure beam search without scorer." )
259+
260+ def decode_batch_beam_search (self , probs_split , beam_alpha , beam_beta ,
261+ beam_size , cutoff_prob , cutoff_top_n ,
262+ vocab_list , num_processes ):
263+ """Decode by beam search for a batch of probs matrix input.
264+
265+ :param probs_split: List of 2-D probability matrix, and each consists
266+ of prob vectors for one speech utterancce.
267+ :param probs_split: List of matrix
268+ :param beam_alpha: Parameter associated with language model.
269+ :type beam_alpha: float
270+ :param beam_beta: Parameter associated with word count.
271+ :type beam_beta: float
272+ :param beam_size: Width for Beam search.
273+ :type beam_size: int
274+ :param cutoff_prob: Cutoff probability in pruning,
275+ default 1.0, no pruning.
276+ :type cutoff_prob: float
277+ :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
278+ characters with highest probs in vocabulary will be
279+ used in beam search, default 40.
280+ :type cutoff_top_n: int
281+ :param vocab_list: List of tokens in the vocabulary, for decoding.
282+ :type vocab_list: list
283+ :param num_processes: Number of processes (CPU) for decoder.
284+ :type num_processes: int
285+ :return: List of transcription texts.
286+ :rtype: List of basestring
287+ """
288+ if self ._ext_scorer != None :
289+ self ._ext_scorer .reset_params (beam_alpha , beam_beta )
290+ # beam search decode
291+ num_processes = min (num_processes , len (probs_split ))
292+ beam_search_results = ctc_beam_search_decoder_batch (
293+ probs_split = probs_split ,
294+ vocabulary = vocab_list ,
295+ beam_size = beam_size ,
296+ num_processes = num_processes ,
297+ ext_scoring_func = self ._ext_scorer ,
298+ cutoff_prob = cutoff_prob ,
299+ cutoff_top_n = cutoff_top_n )
300+
301+ results = [result [0 ][1 ] for result in beam_search_results ]
273302 return results
274303
275304 def _adapt_feeding_dict (self , feeding_dict ):
0 commit comments