Skip to content

Commit 0d7640f

Browse files
author
yangyaming
committed
Follow comments.
1 parent 8e3c26f commit 0d7640f

File tree

1 file changed

+45
-49
lines changed

1 file changed

+45
-49
lines changed

deep_speech_2/error_rate.py

Lines changed: 45 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
# -- * -- coding: utf-8 -- * --
1+
# -*- coding: utf-8 -*-
2+
"""
3+
This module provides functions to calculate error rate in different level.
4+
e.g. wer for word-level, cer for char-level.
5+
"""
6+
27
import numpy as np
38

49

@@ -14,9 +19,9 @@ def levenshtein_distance(ref, hyp):
1419
if hyp_len == 0:
1520
return ref_len
1621

17-
distance = np.zeros((ref_len + 1, hyp_len + 1), dtype=np.int64)
22+
distance = np.zeros((ref_len + 1, hyp_len + 1), dtype=np.int32)
1823

19-
# initialization distance matrix
24+
# initialize distance matrix
2025
for j in xrange(hyp_len + 1):
2126
distance[0][j] = j
2227
for i in xrange(ref_len + 1):
@@ -36,11 +41,10 @@ def levenshtein_distance(ref, hyp):
3641
return distance[ref_len][hyp_len]
3742

3843

39-
def wer(reference, hypophysis, delimiter=' ', filter_none=True):
44+
def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
4045
"""
41-
Calculate word error rate (WER). WER is a popular evaluation metric used
42-
in speech recognition. It compares a reference with an hypophysis and
43-
is defined like this:
46+
Calculate word error rate (WER). WER compares reference text and
47+
hypothesis text in word-level. WER is defined as:
4448
4549
.. math::
4650
WER = (Sw + Dw + Iw) / Nw
@@ -54,41 +58,39 @@ def wer(reference, hypophysis, delimiter=' ', filter_none=True):
5458
Iw is the number of words inserted,
5559
Nw is the number of words in the reference
5660
57-
We can use levenshtein distance to calculate WER. Please draw an attention
58-
that this function will truncate the beginning and ending delimiter for
59-
reference and hypophysis sentences before calculating WER.
61+
We can use levenshtein distance to calculate WER. Please draw an attention that
62+
empty items will be removed when splitting sentences by delimiter.
6063
6164
:param reference: The reference sentence.
62-
:type reference: str
63-
:param hypophysis: The hypophysis sentence.
64-
:type reference: str
65+
:type reference: basestring
66+
:param hypothesis: The hypothesis sentence.
67+
:type hypothesis: basestring
68+
:param ignore_case: Whether case-sensitive or not.
69+
:type ignore_case: bool
6570
:param delimiter: Delimiter of input sentences.
6671
:type delimiter: char
67-
:param filter_none: Whether to remove None value when splitting sentence.
68-
:type filter_none: bool
69-
:return: WER
72+
:return: Word error rate.
7073
:rtype: float
7174
"""
75+
if ignore_case == True:
76+
reference = reference.lower()
77+
hypothesis = hypothesis.lower()
7278

73-
if len(reference.strip(delimiter)) == 0:
74-
raise ValueError("Reference's word number should be greater than 0.")
79+
ref_words = filter(None, reference.split(delimiter))
80+
hyp_words = filter(None, hypothesis.split(delimiter))
7581

76-
if filter_none == True:
77-
ref_words = filter(None, reference.strip(delimiter).split(delimiter))
78-
hyp_words = filter(None, hypophysis.strip(delimiter).split(delimiter))
79-
else:
80-
ref_words = reference.strip(delimiter).split(delimiter)
81-
hyp_words = reference.strip(delimiter).split(delimiter)
82+
if len(ref_words) == 0:
83+
raise ValueError("Reference's word number should be greater than 0.")
8284

8385
edit_distance = levenshtein_distance(ref_words, hyp_words)
8486
wer = float(edit_distance) / len(ref_words)
8587
return wer
8688

8789

88-
def cer(reference, hypophysis, squeeze=True, ignore_case=False, strip_char=''):
90+
def cer(reference, hypothesis, ignore_case=False):
8991
"""
90-
Calculate charactor error rate (CER). CER will compare reference text and
91-
hypophysis text in char-level. CER is defined as:
92+
Calculate charactor error rate (CER). CER compares reference text and
93+
hypothesis text in char-level. CER is defined as:
9294
9395
.. math::
9496
CER = (Sc + Dc + Ic) / Nc
@@ -97,41 +99,35 @@ def cer(reference, hypophysis, squeeze=True, ignore_case=False, strip_char=''):
9799
98100
.. code-block:: text
99101
100-
Sc is the number of character substituted,
101-
Dc is the number of deleted,
102-
Ic is the number of inserted
102+
Sc is the number of characters substituted,
103+
Dc is the number of characters deleted,
104+
Ic is the number of characters inserted
103105
Nc is the number of characters in the reference
104106
105107
We can use levenshtein distance to calculate CER. Chinese input should be
106-
encoded to unicode.
108+
encoded to unicode. Please draw an attention that the leading and tailing
109+
white space characters will be truncated and multiple consecutive white
110+
space characters in a sentence will be replaced by one white space character.
107111
108112
:param reference: The reference sentence.
109-
:type reference: str
110-
:param hypophysis: The hypophysis sentence.
111-
:type reference: str
112-
:param squeeze: If set true, consecutive space character
113-
will be squeezed to one
114-
:type squeeze: bool
113+
:type reference: basestring
114+
:param hypothesis: The hypothesis sentence.
115+
:type hypothesis: basestring
115116
:param ignore_case: Whether case-sensitive or not.
116117
:type ignore_case: bool
117-
:param strip_char: If not set to '', strip_char in beginning and ending of
118-
sentence will be truncated.
119-
:type strip_char: char
120-
:return: CER
118+
:return: Character error rate.
121119
:rtype: float
122120
"""
123121
if ignore_case == True:
124122
reference = reference.lower()
125-
hypophysis = hypophysis.lower()
126-
if strip_char != '':
127-
reference = reference.strip(strip_char)
128-
hypophysis = hypophysis.strip(strip_char)
129-
if squeeze == True:
130-
reference = ' '.join(filter(None, reference.split(' ')))
131-
hypophysis = ' '.join(filter(None, hypophysis.split(' ')))
123+
hypothesis = hypothesis.lower()
124+
125+
reference = ' '.join(filter(None, reference.split(' ')))
126+
hypothesis = ' '.join(filter(None, hypothesis.split(' ')))
132127

133128
if len(reference) == 0:
134129
raise ValueError("Length of reference should be greater than 0.")
135-
edit_distance = levenshtein_distance(reference, hypophysis)
130+
131+
edit_distance = levenshtein_distance(reference, hypothesis)
136132
cer = float(edit_distance) / len(reference)
137133
return cer

0 commit comments

Comments
 (0)