diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 84ab245..421bcf8 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -38,7 +38,7 @@ jobs: - name: Install kaldialign shell: bash run: | - python3 -m pip install --verbose . + python3 -m pip install --verbose '.[test]' - name: Test shell: bash diff --git a/CMakeLists.txt b/CMakeLists.txt index d0417d3..cc59d5d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,7 +14,7 @@ cmake_minimum_required(VERSION 3.8 FATAL_ERROR) project(kaldialign CXX) # Please remember to also change line 3 of ./scripts/conda/kaldialign/meta.yaml -set(KALDIALIGN_VERSION "0.8.1") +set(KALDIALIGN_VERSION "0.9") if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) diff --git a/extensions/kaldi_align.cpp b/extensions/kaldi_align.cpp index 2007f60..7336b7f 100644 --- a/extensions/kaldi_align.cpp +++ b/extensions/kaldi_align.cpp @@ -1,3 +1,4 @@ +#include #include "kaldi_align.h" int LevenshteinEditDistance(const std::vector &ref, @@ -14,7 +15,7 @@ int LevenshteinEditDistance(const std::vector &ref, del_cost = DEL_COST; sub_cost = SUB_COST; } - + // temp sequence to remember error type and stats. std::vector e(ref.size()+1); std::vector cur_e(ref.size()+1); @@ -63,8 +64,15 @@ int LevenshteinEditDistance(const std::vector &ref, e = cur_e; // alternate for the next recursion. } size_t ref_index = e.size()-1; - *ins = e[ref_index].ins_num, *del = - e[ref_index].del_num, *sub = e[ref_index].sub_num; + if (ins != nullptr) { + *ins = e[ref_index].ins_num; + } + if (del != nullptr) { + *del = e[ref_index].del_num; + } + if (sub != nullptr) { + *sub = e[ref_index].sub_num; + } return e[ref_index].total_num; } @@ -148,3 +156,75 @@ int LevenshteinAlignment(const std::vector &a, ReverseVector(output); return e[M][N]; } + +namespace internal { + + std::vector> GetEdits( + const std::vector> &refs, + const std::vector> &hyps + ) { + std::vector> ans; + for (int i = 0; i != refs.size(); ++i) { + const auto &ref = refs[i]; + const auto dist = LevenshteinEditDistance(ref, hyps[i], false, nullptr, nullptr, nullptr); + ans.emplace_back(dist, ref.size()); + } + return ans; + } + + std::pair GetBootstrapWerInterval( + const std::vector> &edit_sym_per_hyp, + const int replications, + const unsigned int seed) + { + std::mt19937 rng{seed}; + std::uniform_int_distribution<> dist{0, static_cast(edit_sym_per_hyp.size()) - 1}; + + double wer_accum = 0.0, wer_mult_accum = 0.0; + for (int i = 0; i != replications; ++i) { + int num_sym = 0, num_errs = 0; + for (int j = 0; j != edit_sym_per_hyp.size(); ++j) { + const auto selected = dist(rng); + const auto &nerr_nsym = edit_sym_per_hyp[selected]; + num_errs += nerr_nsym.first; + num_sym += nerr_nsym.second; + } + const double wer_rep = static_cast(num_errs) / num_sym; + wer_accum += wer_rep; + wer_mult_accum += std::pow(wer_rep, 2); + } + + const double mean = wer_accum / replications; + const double _tmp = wer_mult_accum / replications - std::pow(mean, 2); + double interval = 0.0; + if (_tmp > 0) { + interval = 1.96 * std::sqrt(_tmp); + } + return std::make_pair(mean, interval); + } + + double GetPImprov( + const std::vector> &edit_sym_per_hyp, + const std::vector> &edit_sym_per_hyp2, + const int replications, + const unsigned int seed + ) { + std::mt19937 rng{seed}; + std::uniform_int_distribution<> dist{0, static_cast(edit_sym_per_hyp.size()) - 1}; + + double improv_accum = 0.0; + for (int i = 0; i != replications; ++i) { + int num_errs = 0; + for (int j = 0; j != edit_sym_per_hyp.size(); ++j) { + const auto selected = dist(rng); + num_errs += edit_sym_per_hyp[selected].first - edit_sym_per_hyp2[selected].first; + } + if (num_errs > 0) { + improv_accum += 1; + } + } + + return improv_accum / replications; + } + +} diff --git a/extensions/kaldi_align.h b/extensions/kaldi_align.h index 9fb1b43..32239cb 100644 --- a/extensions/kaldi_align.h +++ b/extensions/kaldi_align.h @@ -42,3 +42,24 @@ int LevenshteinAlignment(const std::vector &a, int eps_symbol, const bool sclite_mode, std::vector > *output); + + +namespace internal{ + std::vector> GetEdits( + const std::vector> &refs, + const std::vector> &hyps + ); + + std::pair GetBootstrapWerInterval( + const std::vector> &edit_sym_per_hyp, + const int replications, + const unsigned int seed + ); + + double GetPImprov( + const std::vector> &edit_sym_per_hyp, + const std::vector> &edit_sym_per_hyp2, + const int replications, + const unsigned int seed + ); +} diff --git a/extensions/kaldialign.cpp b/extensions/kaldialign.cpp index 9a018de..156ea37 100644 --- a/extensions/kaldialign.cpp +++ b/extensions/kaldialign.cpp @@ -26,8 +26,36 @@ Align(const std::vector &a, const std::vector &b, int eps_symbol, cons return ans; } +static std::vector> GetEdits( + const std::vector> &refs, + const std::vector> &hyps +) { + return internal::GetEdits(refs, hyps); +} + +static py::tuple GetBootstrapWerInterval( + const std::vector> &edit_sym_per_hyp, + const int replications, + const unsigned int seed +) { + const auto ans = internal::GetBootstrapWerInterval(edit_sym_per_hyp, replications, seed); + return py::make_tuple(ans.first, ans.second); +} + +static double GetPImprov( + const std::vector> &edit_sym_per_hyp, + const std::vector> &edit_sym_per_hyp2, + const int replications, + const unsigned int seed +) { + return internal::GetPImprov(edit_sym_per_hyp, edit_sym_per_hyp2, replications, seed); +} + PYBIND11_MODULE(_kaldialign, m) { m.doc() = "Python wrapper for kaldialign"; m.def("edit_distance", &EditDistance, py::arg("a"), py::arg("b"), py::arg("sclite_mode") = false); m.def("align", &Align, py::arg("a"), py::arg("b"), py::arg("eps_symbol"), py::arg("sclite_mode") = false); + m.def("_get_edits", &GetEdits, py::arg("refs"), py::arg("hyps")); + m.def("_get_boostrap_wer_interval", &GetBootstrapWerInterval, py::arg("edit_sym_per_hyp"), py::arg("replications") = 10000, py::arg("seed") = 0); + m.def("_get_p_improv", &GetPImprov, py::arg("edit_sym_per_hyp"), py::arg("edit_sym_per_hyp2"), py::arg("replications") = 10000, py::arg("seed") = 0); } diff --git a/kaldialign/__init__.py b/kaldialign/__init__.py index fb83d5c..3d725cb 100644 --- a/kaldialign/__init__.py +++ b/kaldialign/__init__.py @@ -84,9 +84,9 @@ def align( def bootstrap_wer_ci( - ref_seqs: Sequence[Sequence[Symbol]], - hyp_seqs: Sequence[Sequence[Symbol]], - hyp2_seqs: Optional[Sequence[Sequence[Symbol]]] = None, + refs: Sequence[Sequence[Symbol]], + hyps: Sequence[Sequence[Symbol]], + hyps2: Optional[Sequence[Sequence[Symbol]]] = None, replications: int = 10000, seed: int = 0, ) -> Dict: @@ -96,9 +96,9 @@ def bootstrap_wer_ci( The implementation is based on Kaldi's ``compute-wer-bootci`` script [2]. Args: - ref_seqs: A list of reference sequences (str, list[str], list[int]) - hyp_seqs: A list of hypothesis sequences from system1 (str, list[str], list[int]) - hyp2_seqs: A list of hypothesis sequences from system2 (str, list[str], list[int]). + refs: A list of reference sequences (str, list[str], list[list[[int]]) + hyps: A list of hypothesis sequences from system1 (str, list[str], list[list[int]]) + hyps2: A list of hypothesis sequences from system2 (str, list[str], list[list[int]]). When provided, we'll compute CI for both systems as well as the probability of system2 improving over system1. replications: The number of replications to use for bootstrapping. @@ -119,21 +119,31 @@ def bootstrap_wer_ci( [2] https://github.com/kaldi-asr/kaldi/blob/master/src/bin/compute-wer-bootci.cc """ - assert len(hyp_seqs) == len( - ref_seqs - ), f"Inconsistent number of reference ({len(ref_seqs)}) and hypothesis ({len(hyp_seqs)}) sequences." - edit_sym_per_hyp = _get_edits(ref_seqs, hyp_seqs) + from _kaldialign import _get_boostrap_wer_interval, _get_edits, _get_p_improv + + assert len(hyps) == len( + refs + ), f"Inconsistent number of reference ({len(refs)}) and hypothesis ({len(hyps)}) sequences." + assert replications > 0, "The number of replications must be greater than 0." + assert seed >= 0, "The seed must be 0 or greater." + assert not isinstance(refs, str) and not isinstance( + hyps, str + ), "The input must be a list of strings or list of lists of ints." + + refs, hyps, hyps2 = _convert_to_int(refs, hyps, hyps2) + + edit_sym_per_hyp = _get_edits(refs, hyps) mean, interval = _get_boostrap_wer_interval( edit_sym_per_hyp, replications=replications, seed=seed ) ans1 = _build_results(mean, interval) - if hyp2_seqs is None: + if hyps2 is None: return ans1 - assert len(hyp2_seqs) == len( - ref_seqs - ), f"Inconsistent number of reference ({len(ref_seqs)}) and hypothesis ({len(hyp2_seqs)}) sequences for the second system (hyp2_seqs)." - edit_sym_per_hyp2 = _get_edits(ref_seqs, hyp2_seqs) + assert len(hyps2) == len( + refs + ), f"Inconsistent number of reference ({len(refs)}) and hypothesis ({len(hyps2)}) sequences for the second system (hyp2_seqs)." + edit_sym_per_hyp2 = _get_edits(refs, hyps2) mean2, interval2 = _get_boostrap_wer_interval( edit_sym_per_hyp2, replications=replications, seed=seed ) @@ -147,57 +157,31 @@ def bootstrap_wer_ci( } -def _build_results(mean, interval): +def _build_results(mean: float, interval: float) -> Dict[str, float]: return { - "wer": round(mean, ndigits=4), - "ci95": round(interval, ndigits=4), - "ci95min": round(mean - interval, ndigits=4), - "ci95max": round(mean + interval, ndigits=4), + "wer": mean, + "ci95": interval, + "ci95min": mean - interval, + "ci95max": mean + interval, } -def _get_edits(ref_seqs, hyp_seqs): - edit_sym_per_hyp = [] - for ref, hyp in zip(ref_seqs, hyp_seqs): - dist = edit_distance(ref, hyp) - edit_sym_per_hyp.append((dist["total"], len(ref))) - return edit_sym_per_hyp - - -def _get_boostrap_wer_interval(edit_sym_per_hyp, replications, seed): - rng = random.Random(seed) - - wer_accum, wer_mult_accum = 0.0, 0.0 - for i in range(replications): - num_sym, num_errs = 0, 0 - for j in range(len(edit_sym_per_hyp)): - nerr, nsym = rng.choice(edit_sym_per_hyp) - num_sym += nsym - num_errs += nerr - wer_rep = num_errs / num_sym - wer_accum += wer_rep - wer_mult_accum += wer_rep**2 +def _convert_to_int( + ref: Sequence[Sequence[Symbol]], + hyp: Sequence[Sequence[Symbol]], + hyp2: Sequence[Sequence[Symbol]] = None, +) -> Tuple[List[List[Symbol]], ...]: + sources = [ref, hyp] + if hyp2 is not None: + sources.append(hyp2) - mean = wer_accum / replications - _tmp = wer_mult_accum / replications - mean**2 - if _tmp < 0: - interval = 0 - else: - interval = 1.96 * math.sqrt(_tmp) - - return mean, interval - - -def _get_p_improv(edit_sym_per_hyp, edit_sym_per_hyp2, replications, seed): - rng = random.Random(seed) - - improv_accum = 0 - for i in range(replications): - num_errs = 0 - for j in range(len(edit_sym_per_hyp)): - pos = rng.randint(0, len(edit_sym_per_hyp) - 1) - num_errs += edit_sym_per_hyp[pos][0] - edit_sym_per_hyp2[pos][0] - if num_errs > 0: - improv_accum += 1 + symbols = sorted( + set(symbol for source in sources for seq in source for symbol in seq) + ) + int2sym = dict(enumerate(symbols)) + sym2int = {v: k for k, v in int2sym.items()} - return improv_accum / replications + ints = [[[sym2int[item] for item in seq] for seq in source] for source in sources] + if hyp2 is None: + ints.append(None) + return tuple(ints) diff --git a/scripts/conda/kaldialign/meta.yaml b/scripts/conda/kaldialign/meta.yaml index 8adfe21..33b04e0 100644 --- a/scripts/conda/kaldialign/meta.yaml +++ b/scripts/conda/kaldialign/meta.yaml @@ -1,6 +1,6 @@ package: name: kaldialign - version: "0.8.1" + version: "0.9" source: path: "{{ environ.get('KALDIALIGN_ROOT_DIR') }}" diff --git a/setup.py b/setup.py index c687fd4..85533e2 100644 --- a/setup.py +++ b/setup.py @@ -126,6 +126,7 @@ def get_package_version(): long_description_content_type="text/markdown", ext_modules=[cmake_extension("_kaldialign")], cmdclass={"build_ext": BuildExtension}, + extras_require={"test": ["pytest"]}, keywords=[ "natural language processing", "speech recognition", diff --git a/tests/test_align.py b/tests/test_align.py index b2548c4..1832874 100644 --- a/tests/test_align.py +++ b/tests/test_align.py @@ -1,3 +1,7 @@ +from functools import partial + +import pytest + from kaldialign import align, bootstrap_wer_ci, edit_distance EPS = "*" @@ -96,6 +100,9 @@ def test_edit_distance_sclite(): } +approx = partial(pytest.approx, abs=3e-3) + + def test_bootstrap_wer_ci_1system(): ref = [ ("a", "b", "c"), @@ -108,11 +115,12 @@ def test_bootstrap_wer_ci_1system(): ] ans = bootstrap_wer_ci(ref, hyp) + print(ans) - assert ans["wer"] == 0.4989 - assert ans["ci95"] == 0.2312 - assert ans["ci95min"] == 0.2678 - assert ans["ci95max"] == 0.7301 + assert ans["wer"] == approx(0.50) + assert ans["ci95"] == approx(0.23) + assert ans["ci95min"] == approx(0.269) + assert ans["ci95max"] == approx(0.731) def test_bootstrap_wer_ci_2system(): @@ -132,18 +140,19 @@ def test_bootstrap_wer_ci_2system(): ] ans = bootstrap_wer_ci(ref, hyp, hyp2) + print(ans) s = ans["system1"] - assert s["wer"] == 0.4989 - assert s["ci95"] == 0.2312 - assert s["ci95min"] == 0.2678 - assert s["ci95max"] == 0.7301 + assert s["wer"] == approx(0.50) + assert s["ci95"] == approx(0.23) + assert s["ci95min"] == approx(0.269) + assert s["ci95max"] == approx(0.731) s = ans["system2"] - assert s["wer"] == 0.1656 - assert s["ci95"] == 0.2312 - assert s["ci95min"] == -0.0656 - assert s["ci95max"] == 0.3968 + assert s["wer"] == approx(0.166) + assert s["ci95"] == approx(0.231) + assert s["ci95min"] == approx(-0.064) + assert s["ci95max"] == approx(0.397) assert ans["p_s2_improv_over_s1"] == 1.0