Skip to content

Commit 93c51b4

Browse files
authored
[METRICS] Fix tokenization issue of CJK languages for evaluation (#20)
For CJK languages, we need to tokenize them with `CJSegmenter` before sending them to `mweralign.align_texts`. This PR makes the following modifications: 1. Apply `CJSegmenter` before calling `mweralign.align_texts`. This is done for both latency scorer and quality scorer. 2. Add `latency_unit` argument to the quality scorer and use this argument to trigger `CJSegmenter` in the quality scorer.
1 parent 123054a commit 93c51b4

File tree

5 files changed

+295
-5
lines changed

5 files changed

+295
-5
lines changed

simulstream/metrics/score_quality.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ def cli_main():
151151
parser.add_argument(
152152
"--audio-definition", "-a", type=str, default=None,
153153
help="Path to the yaml file containing the segment-level audio information.")
154+
parser.add_argument(
155+
"--latency-unit", choices=["char", "word"], default="word",
156+
help="Whether to computed stats based on words or characters. Default: word.")
154157
parser.add_argument("--scorer", choices=QUALITY_SCORER_REGISTRY.keys(), required=True)
155158
args, _ = parser.parse_known_args()
156159

simulstream/metrics/scorers/latency/mwersegmenter.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import List
1818

1919
from mweralign import mweralign
20+
from mweralign.segmenter import CJSegmenter
2021

2122
from simulstream.metrics.readers import ReferenceSentenceDefinition, OutputWithDelays, text_items
2223
from simulstream.metrics.scorers.latency import LatencyScorer, LatencyScoringSample, LatencyScores
@@ -58,6 +59,7 @@ class MWERSegmenterBasedLatencyScorer(LatencyScorer):
5859
def __init__(self, args):
5960
super().__init__(args)
6061
self.latency_unit = args.latency_unit
62+
self.segmenter = CJSegmenter() if args.latency_unit == "char" else None
6163

6264
def requires_reference(self) -> bool:
6365
return True
@@ -101,19 +103,50 @@ def _split_delays_by_segmented_text(
101103
f"Index {index} should have reached end of delays ({len(delays)})"
102104
return segmented_delays
103105

106+
def _tokenize(self, text: List[str]) -> List[str]:
107+
"""
108+
Tokenize text using the segmenter.
109+
110+
Borrowed from
111+
https://github.com/mjpost/mweralign/blob/d23a5479/mweralign/mweralign.py#L147
112+
"""
113+
if self.segmenter is not None:
114+
tokenized_text = []
115+
for i in range(len(text)):
116+
if " ### " in text[i]:
117+
pieces = text[i].strip().split(" ### ")
118+
encoded = [" ".join(self.segmenter.encode(p)) for p in pieces]
119+
tokenized_text.append(" ### ".join(encoded))
120+
elif "\t" in text[i]:
121+
pieces = text[i].strip().split("\t")
122+
# underlying C++ binary still uses ###
123+
encoded = [" ".join(self.segmenter.encode(p)) for p in pieces]
124+
tokenized_text.append(" ### ".join(encoded))
125+
else:
126+
tokenized_text.append(" ".join(self.segmenter.encode(text[i].strip())))
127+
return "\n".join(tokenized_text)
128+
else:
129+
return "\n".join(text)
130+
104131
def score(self, samples: List[LatencyScoringSample]) -> LatencyScores:
105132
resegmented_samples = []
106133
for sample in samples:
107134
assert sample.reference is not None, "Cannot realign hypothesis to missing reference"
108135

109-
resegmented_hypos = mweralign.align_texts(
110-
"\n".join([sentence_def.content for sentence_def in sample.reference]),
111-
sample.hypothesis.final_text).split("\n")
136+
hypo = self._tokenize([sample.hypothesis.final_text])
137+
refs = self._tokenize(
138+
[sentence_def.content for sentence_def in sample.reference])
139+
resegmented_hypos = mweralign.align_texts(refs, hypo).split("\n")
112140

113141
assert len(resegmented_hypos) == len(sample.reference), \
114142
f"Reference ({sample.audio_name}) has mismatched number of target " \
115143
f"({len(sample.reference)}) and resegmented lines ({len(resegmented_hypos)})"
116144

145+
if self.segmenter is not None:
146+
# segmenter.decode will strip() the spaces, but we need them to align with delays
147+
resegmented_hypos = [
148+
hypo.replace(" ", "").replace("_", " ") for hypo in resegmented_hypos]
149+
117150
ideal_delays_splits = self._split_delays_by_segmented_text(
118151
sample.hypothesis.ideal_delays,
119152
resegmented_hypos)

simulstream/metrics/scorers/quality/mwersegmenter.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import List, Optional
1818

1919
from mweralign import mweralign
20+
from mweralign.segmenter import CJSegmenter
2021

2122
from simulstream.metrics.scorers.quality import QualityScorer, QualityScoringSample
2223

@@ -56,6 +57,11 @@ class MWERSegmenterBasedQualityScorer(QualityScorer):
5657
... # Compute a custom quality score
5758
... return ...
5859
"""
60+
61+
def __init__(self, args):
62+
super().__init__(args)
63+
self.segmenter = CJSegmenter() if args.latency_unit == "char" else None
64+
5965
def requires_reference(self) -> bool:
6066
return True
6167

@@ -75,15 +81,48 @@ def _do_score(self, samples: List[ResegmentedQualityScoringSample]) -> float:
7581
"""
7682
...
7783

84+
def _tokenize(self, text: List[str]) -> List[str]:
85+
"""
86+
Tokenize text using the segmenter.
87+
88+
Borrowed from
89+
https://github.com/mjpost/mweralign/blob/d23a5479/mweralign/mweralign.py#L147
90+
"""
91+
if self.segmenter is not None:
92+
tokenized_text = []
93+
for i in range(len(text)):
94+
if " ### " in text[i]:
95+
pieces = text[i].strip().split(" ### ")
96+
encoded = [" ".join(self.segmenter.encode(p)) for p in pieces]
97+
tokenized_text.append(" ### ".join(encoded))
98+
elif "\t" in text[i]:
99+
pieces = text[i].strip().split("\t")
100+
# underlying C++ binary still uses ###
101+
encoded = [" ".join(self.segmenter.encode(p)) for p in pieces]
102+
tokenized_text.append(" ### ".join(encoded))
103+
else:
104+
tokenized_text.append(" ".join(self.segmenter.encode(text[i].strip())))
105+
return "\n".join(tokenized_text)
106+
else:
107+
return "\n".join(text)
108+
78109
def score(self, samples: List[QualityScoringSample]) -> float:
79110
resegmented_samples = []
80111
for sample in samples:
81112
assert sample.reference is not None, "Cannot realign hypothesis to missing reference"
82-
resegmented_hypos = mweralign.align_texts(
83-
"\n".join(sample.reference), sample.hypothesis).split("\n")
113+
hypo = self._tokenize([sample.hypothesis])
114+
refs = self._tokenize(sample.reference)
115+
resegmented_hypos = mweralign.align_texts(refs, hypo).split("\n")
116+
84117
assert len(sample.reference) == len(resegmented_hypos), \
85118
f"Reference ({sample.audio_name}) has mismatched number of target " \
86119
f"({len(sample.reference)}) and resegmented lines ({len(resegmented_hypos)})"
120+
121+
if self.segmenter is not None:
122+
# segmenter.decode will strip() the spaces, but we need them to align with delays
123+
resegmented_hypos = [
124+
hypo.replace(" ", "").replace("_", " ") for hypo in resegmented_hypos]
125+
87126
resegmented_samples.append(ResegmentedQualityScoringSample(
88127
sample.audio_name,
89128
resegmented_hypos,

uts/metrics/test_stream_laal.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2026 FBK
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License
14+
15+
import unittest
16+
from argparse import Namespace
17+
18+
from simulstream.metrics.readers import OutputWithDelays, ReferenceSentenceDefinition
19+
from simulstream.metrics.scorers.latency import LatencyScoringSample
20+
from simulstream.metrics.scorers.latency.stream_laal import StreamLaal
21+
22+
23+
class StreamLaalTestCase(unittest.TestCase):
24+
def test_basic(self):
25+
reference = [
26+
ReferenceSentenceDefinition(
27+
"A New York, sono a capo di un'associazione no profit, chiamata Robin Hood.",
28+
12.61,
29+
4.07,
30+
),
31+
ReferenceSentenceDefinition(
32+
"Quando non combatto la povertà, combatto gli incendi come assistente capitano di "
33+
"una brigata di pompieri volontari.",
34+
16.9,
35+
5.14,
36+
)
37+
]
38+
hypothesis = OutputWithDelays(
39+
"Tornando a New York, sono il capo dello sviluppo per un non-profit chiamato Robin "
40+
"Hood. Quando non sto combattendo la povertà, sto combattendo i fuochi.",
41+
[14.0, 14.0, 14.0, 14.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 18.0,
42+
18.0, 18.0, 18.0, 18.0, 18.0, 18.0, 18.0, 20.0, 20.0, 20.0, 20.0],
43+
[18.22, 18.22, 18.22, 18.22, 19.93, 19.93, 19.93, 19.93, 19.93, 19.93, 19.93, 19.93,
44+
19.93, 23.01, 23.01, 23.01, 23.01, 23.01, 23.01, 23.01, 23.01, 27.30, 27.30, 27.30,
45+
27.30,]
46+
)
47+
scorer = StreamLaal(Namespace(latency_unit="word"))
48+
score = scorer.score([LatencyScoringSample("a", hypothesis, reference)])
49+
self.assertAlmostEqual(score.ideal_latency, 0.868587, 4)
50+
self.assertAlmostEqual(score.computational_aware_latency, 5.86, 4)
51+
52+
def test_with_characters(self):
53+
reference = [
54+
ReferenceSentenceDefinition(
55+
"今天她看起很好,",
56+
12.61,
57+
3.07,
58+
),
59+
ReferenceSentenceDefinition(
60+
"我们一起去公园散步吧。",
61+
16.9,
62+
3.14,
63+
),
64+
ReferenceSentenceDefinition(
65+
"Amy",
66+
21.0,
67+
0.5,
68+
),
69+
ReferenceSentenceDefinition(
70+
"今天心情很好",
71+
21.5,
72+
2.0,
73+
),
74+
]
75+
hypothesis = OutputWithDelays(
76+
"今天她很漂亮,我们一起去花园跑步吧。Amy 今天心情很好",
77+
[14.0, 14.0, 14.0, 15.0, 15.0, 16.0, 17.0,
78+
17.0, 17.0, 18.0, 18.0, 19.0, 19.0, 20.0, 20.0, 21.0, 21.0, 21.0,
79+
22.0, 22.0, 22.0, 22.0, 24.0, 24.0, 24.0, 24.0, 24.0, 24.0],
80+
[14.5, 14.5, 14.5, 15.2, 15.2, 16.8, 17.5,
81+
18.0, 18.5, 18.5, 18.5, 20.1, 20.1, 21.3, 21.3, 22.0, 22.0, 22.0,
82+
23.0, 23.0, 23.0, 23.0, 25.0, 25.0, 25.0, 25.0, 25.0, 25.0],
83+
)
84+
scorer = StreamLaal(Namespace(latency_unit="char"))
85+
score = scorer.score([LatencyScoringSample("a", hypothesis, reference)])
86+
self.assertAlmostEqual(score.ideal_latency, 1.333312, 4)
87+
self.assertAlmostEqual(score.computational_aware_latency, 2.074095, 4)
88+
89+
90+
if __name__ == '__main__':
91+
unittest.main()
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2026 FBK
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License
14+
15+
import copy
16+
import unittest
17+
from argparse import Namespace
18+
19+
from simulstream.metrics.scorers.quality.mwersegmenter import (
20+
MWERSegmenterBasedQualityScorer,
21+
)
22+
from simulstream.metrics.scorers.latency.mwersegmenter import (
23+
MWERSegmenterBasedLatencyScorer,
24+
)
25+
from simulstream.metrics.scorers.latency import LatencyScores
26+
27+
28+
class TokenizeNoInplaceModificationTestCase(unittest.TestCase):
29+
"""
30+
Ensures that _tokenize does not alter the references.
31+
See https://github.com/hlt-mt/simulstream/pull/20#issuecomment-3960951980
32+
"""
33+
34+
def _make_quality_scorer(self, latency_unit="char"):
35+
"""Create a concrete subclass of the abstract quality scorer."""
36+
class _Scorer(MWERSegmenterBasedQualityScorer):
37+
def _do_score(self, samples):
38+
return 0.0
39+
40+
@classmethod
41+
def add_arguments(cls, parser):
42+
pass
43+
44+
def requires_source(self):
45+
return False
46+
47+
args = Namespace(latency_unit=latency_unit)
48+
return _Scorer(args)
49+
50+
def _make_latency_scorer(self, latency_unit="char"):
51+
"""Create a concrete subclass of the abstract latency scorer."""
52+
class _Scorer(MWERSegmenterBasedLatencyScorer):
53+
def _do_score(self, samples):
54+
return LatencyScores(0.0, [])
55+
56+
@classmethod
57+
def add_arguments(cls, parser):
58+
pass
59+
60+
def requires_source(self):
61+
return False
62+
63+
args = Namespace(latency_unit=latency_unit)
64+
return _Scorer(args)
65+
66+
def test_quality_tokenize_does_not_modify_input(self):
67+
scorer = self._make_quality_scorer(latency_unit="char")
68+
text = ["你好世界", "这是测试"]
69+
original = copy.deepcopy(text)
70+
scorer._tokenize(text)
71+
self.assertEqual(text, original)
72+
73+
def test_latency_tokenize_does_not_modify_input(self):
74+
scorer = self._make_latency_scorer(latency_unit="char")
75+
text = ["你好世界", "这是测试"]
76+
original = copy.deepcopy(text)
77+
scorer._tokenize(text)
78+
self.assertEqual(text, original)
79+
80+
def test_quality_tokenize_no_modify_with_separator(self):
81+
scorer = self._make_quality_scorer(latency_unit="char")
82+
text = ["你好 ### 世界"]
83+
original = copy.deepcopy(text)
84+
scorer._tokenize(text)
85+
self.assertEqual(text, original)
86+
87+
def test_quality_tokenize_no_modify_with_tab(self):
88+
scorer = self._make_quality_scorer(latency_unit="char")
89+
text = ["你好\t世界"]
90+
original = copy.deepcopy(text)
91+
scorer._tokenize(text)
92+
self.assertEqual(text, original)
93+
94+
def test_quality_tokenize_does_not_modify_input_english(self):
95+
scorer = self._make_quality_scorer(latency_unit="word")
96+
text = ["hello world", "this is a test"]
97+
original = copy.deepcopy(text)
98+
scorer._tokenize(text)
99+
self.assertEqual(text, original)
100+
101+
def test_latency_tokenize_does_not_modify_input_english(self):
102+
scorer = self._make_latency_scorer(latency_unit="word")
103+
text = ["hello world", "this is a test"]
104+
original = copy.deepcopy(text)
105+
scorer._tokenize(text)
106+
self.assertEqual(text, original)
107+
108+
def test_quality_tokenize_no_modify_with_separator_english(self):
109+
scorer = self._make_quality_scorer(latency_unit="word")
110+
text = ["hello ### world"]
111+
original = copy.deepcopy(text)
112+
scorer._tokenize(text)
113+
self.assertEqual(text, original)
114+
115+
def test_quality_tokenize_no_modify_with_tab_english(self):
116+
scorer = self._make_quality_scorer(latency_unit="word")
117+
text = ["hello\tworld"]
118+
original = copy.deepcopy(text)
119+
scorer._tokenize(text)
120+
self.assertEqual(text, original)
121+
122+
123+
if __name__ == '__main__':
124+
unittest.main()

0 commit comments

Comments
 (0)