Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
- name: Install kaldialign
shell: bash
run: |
python3 -m pip install --verbose .
python3 -m pip install --verbose '.[test]'

- name: Test
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ cmake_minimum_required(VERSION 3.8 FATAL_ERROR)
project(kaldialign CXX)

# Please remember to also change line 3 of ./scripts/conda/kaldialign/meta.yaml
set(KALDIALIGN_VERSION "0.8.1")
set(KALDIALIGN_VERSION "0.9")

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
Expand Down
86 changes: 83 additions & 3 deletions extensions/kaldi_align.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <random>
#include "kaldi_align.h"

int LevenshteinEditDistance(const std::vector<int> &ref,
Expand All @@ -14,7 +15,7 @@ int LevenshteinEditDistance(const std::vector<int> &ref,
del_cost = DEL_COST;
sub_cost = SUB_COST;
}

// temp sequence to remember error type and stats.
std::vector<error_stats> e(ref.size()+1);
std::vector<error_stats> cur_e(ref.size()+1);
Expand Down Expand Up @@ -63,8 +64,15 @@ int LevenshteinEditDistance(const std::vector<int> &ref,
e = cur_e; // alternate for the next recursion.
}
size_t ref_index = e.size()-1;
*ins = e[ref_index].ins_num, *del =
e[ref_index].del_num, *sub = e[ref_index].sub_num;
if (ins != nullptr) {
*ins = e[ref_index].ins_num;
}
if (del != nullptr) {
*del = e[ref_index].del_num;
}
if (sub != nullptr) {
*sub = e[ref_index].sub_num;
}
return e[ref_index].total_num;
}

Expand Down Expand Up @@ -148,3 +156,75 @@ int LevenshteinAlignment(const std::vector<int> &a,
ReverseVector(output);
return e[M][N];
}

namespace internal {

std::vector<std::pair<int, int>> GetEdits(
const std::vector<std::vector<int>> &refs,
const std::vector<std::vector<int>> &hyps
) {
std::vector<std::pair<int, int>> ans;
for (int i = 0; i != refs.size(); ++i) {
const auto &ref = refs[i];
const auto dist = LevenshteinEditDistance(ref, hyps[i], false, nullptr, nullptr, nullptr);
ans.emplace_back(dist, ref.size());
}
return ans;
}

std::pair<double, double> GetBootstrapWerInterval(
const std::vector<std::pair<int, int>> &edit_sym_per_hyp,
const int replications,
const unsigned int seed)
{
std::mt19937 rng{seed};
std::uniform_int_distribution<> dist{0, static_cast<int>(edit_sym_per_hyp.size()) - 1};

double wer_accum = 0.0, wer_mult_accum = 0.0;
for (int i = 0; i != replications; ++i) {
int num_sym = 0, num_errs = 0;
for (int j = 0; j != edit_sym_per_hyp.size(); ++j) {
const auto selected = dist(rng);
const auto &nerr_nsym = edit_sym_per_hyp[selected];
num_errs += nerr_nsym.first;
num_sym += nerr_nsym.second;
}
const double wer_rep = static_cast<double>(num_errs) / num_sym;
wer_accum += wer_rep;
wer_mult_accum += std::pow(wer_rep, 2);
}

const double mean = wer_accum / replications;
const double _tmp = wer_mult_accum / replications - std::pow(mean, 2);
double interval = 0.0;
if (_tmp > 0) {
interval = 1.96 * std::sqrt(_tmp);
}
return std::make_pair(mean, interval);
}

double GetPImprov(
const std::vector<std::pair<int, int>> &edit_sym_per_hyp,
const std::vector<std::pair<int, int>> &edit_sym_per_hyp2,
const int replications,
const unsigned int seed
) {
std::mt19937 rng{seed};
std::uniform_int_distribution<> dist{0, static_cast<int>(edit_sym_per_hyp.size()) - 1};

double improv_accum = 0.0;
for (int i = 0; i != replications; ++i) {
int num_errs = 0;
for (int j = 0; j != edit_sym_per_hyp.size(); ++j) {
const auto selected = dist(rng);
num_errs += edit_sym_per_hyp[selected].first - edit_sym_per_hyp2[selected].first;
}
if (num_errs > 0) {
improv_accum += 1;
}
}

return improv_accum / replications;
}

}
21 changes: 21 additions & 0 deletions extensions/kaldi_align.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,24 @@ int LevenshteinAlignment(const std::vector<int> &a,
int eps_symbol,
const bool sclite_mode,
std::vector<std::pair<int, int> > *output);


namespace internal{
std::vector<std::pair<int, int>> GetEdits(
const std::vector<std::vector<int>> &refs,
const std::vector<std::vector<int>> &hyps
);

std::pair<double, double> GetBootstrapWerInterval(
const std::vector<std::pair<int, int>> &edit_sym_per_hyp,
const int replications,
const unsigned int seed
);

double GetPImprov(
const std::vector<std::pair<int, int>> &edit_sym_per_hyp,
const std::vector<std::pair<int, int>> &edit_sym_per_hyp2,
const int replications,
const unsigned int seed
);
}
28 changes: 28 additions & 0 deletions extensions/kaldialign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,36 @@ Align(const std::vector<int> &a, const std::vector<int> &b, int eps_symbol, cons
return ans;
}

static std::vector<std::pair<int, int>> GetEdits(
const std::vector<std::vector<int>> &refs,
const std::vector<std::vector<int>> &hyps
) {
return internal::GetEdits(refs, hyps);
}

static py::tuple GetBootstrapWerInterval(
const std::vector<std::pair<int, int>> &edit_sym_per_hyp,
const int replications,
const unsigned int seed
) {
const auto ans = internal::GetBootstrapWerInterval(edit_sym_per_hyp, replications, seed);
return py::make_tuple(ans.first, ans.second);
}

static double GetPImprov(
const std::vector<std::pair<int, int>> &edit_sym_per_hyp,
const std::vector<std::pair<int, int>> &edit_sym_per_hyp2,
const int replications,
const unsigned int seed
) {
return internal::GetPImprov(edit_sym_per_hyp, edit_sym_per_hyp2, replications, seed);
}

PYBIND11_MODULE(_kaldialign, m) {
m.doc() = "Python wrapper for kaldialign";
m.def("edit_distance", &EditDistance, py::arg("a"), py::arg("b"), py::arg("sclite_mode") = false);
m.def("align", &Align, py::arg("a"), py::arg("b"), py::arg("eps_symbol"), py::arg("sclite_mode") = false);
m.def("_get_edits", &GetEdits, py::arg("refs"), py::arg("hyps"));
m.def("_get_boostrap_wer_interval", &GetBootstrapWerInterval, py::arg("edit_sym_per_hyp"), py::arg("replications") = 10000, py::arg("seed") = 0);
m.def("_get_p_improv", &GetPImprov, py::arg("edit_sym_per_hyp"), py::arg("edit_sym_per_hyp2"), py::arg("replications") = 10000, py::arg("seed") = 0);
}
110 changes: 47 additions & 63 deletions kaldialign/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def align(


def bootstrap_wer_ci(
ref_seqs: Sequence[Sequence[Symbol]],
hyp_seqs: Sequence[Sequence[Symbol]],
hyp2_seqs: Optional[Sequence[Sequence[Symbol]]] = None,
refs: Sequence[Sequence[Symbol]],
hyps: Sequence[Sequence[Symbol]],
hyps2: Optional[Sequence[Sequence[Symbol]]] = None,
replications: int = 10000,
seed: int = 0,
) -> Dict:
Expand All @@ -96,9 +96,9 @@ def bootstrap_wer_ci(
The implementation is based on Kaldi's ``compute-wer-bootci`` script [2].

Args:
ref_seqs: A list of reference sequences (str, list[str], list[int])
hyp_seqs: A list of hypothesis sequences from system1 (str, list[str], list[int])
hyp2_seqs: A list of hypothesis sequences from system2 (str, list[str], list[int]).
refs: A list of reference sequences (str, list[str], list[list[[int]])
hyps: A list of hypothesis sequences from system1 (str, list[str], list[list[int]])
hyps2: A list of hypothesis sequences from system2 (str, list[str], list[list[int]]).
When provided, we'll compute CI for both systems as well as the probability
of system2 improving over system1.
replications: The number of replications to use for bootstrapping.
Expand All @@ -119,21 +119,31 @@ def bootstrap_wer_ci(

[2] https://github.com/kaldi-asr/kaldi/blob/master/src/bin/compute-wer-bootci.cc
"""
assert len(hyp_seqs) == len(
ref_seqs
), f"Inconsistent number of reference ({len(ref_seqs)}) and hypothesis ({len(hyp_seqs)}) sequences."
edit_sym_per_hyp = _get_edits(ref_seqs, hyp_seqs)
from _kaldialign import _get_boostrap_wer_interval, _get_edits, _get_p_improv

assert len(hyps) == len(
refs
), f"Inconsistent number of reference ({len(refs)}) and hypothesis ({len(hyps)}) sequences."
assert replications > 0, "The number of replications must be greater than 0."
assert seed >= 0, "The seed must be 0 or greater."
assert not isinstance(refs, str) and not isinstance(
hyps, str
), "The input must be a list of strings or list of lists of ints."

refs, hyps, hyps2 = _convert_to_int(refs, hyps, hyps2)

edit_sym_per_hyp = _get_edits(refs, hyps)
mean, interval = _get_boostrap_wer_interval(
edit_sym_per_hyp, replications=replications, seed=seed
)
ans1 = _build_results(mean, interval)
if hyp2_seqs is None:
if hyps2 is None:
return ans1

assert len(hyp2_seqs) == len(
ref_seqs
), f"Inconsistent number of reference ({len(ref_seqs)}) and hypothesis ({len(hyp2_seqs)}) sequences for the second system (hyp2_seqs)."
edit_sym_per_hyp2 = _get_edits(ref_seqs, hyp2_seqs)
assert len(hyps2) == len(
refs
), f"Inconsistent number of reference ({len(refs)}) and hypothesis ({len(hyps2)}) sequences for the second system (hyp2_seqs)."
edit_sym_per_hyp2 = _get_edits(refs, hyps2)
mean2, interval2 = _get_boostrap_wer_interval(
edit_sym_per_hyp2, replications=replications, seed=seed
)
Expand All @@ -147,57 +157,31 @@ def bootstrap_wer_ci(
}


def _build_results(mean, interval):
def _build_results(mean: float, interval: float) -> Dict[str, float]:
return {
"wer": round(mean, ndigits=4),
"ci95": round(interval, ndigits=4),
"ci95min": round(mean - interval, ndigits=4),
"ci95max": round(mean + interval, ndigits=4),
"wer": mean,
"ci95": interval,
"ci95min": mean - interval,
"ci95max": mean + interval,
}


def _get_edits(ref_seqs, hyp_seqs):
edit_sym_per_hyp = []
for ref, hyp in zip(ref_seqs, hyp_seqs):
dist = edit_distance(ref, hyp)
edit_sym_per_hyp.append((dist["total"], len(ref)))
return edit_sym_per_hyp


def _get_boostrap_wer_interval(edit_sym_per_hyp, replications, seed):
rng = random.Random(seed)

wer_accum, wer_mult_accum = 0.0, 0.0
for i in range(replications):
num_sym, num_errs = 0, 0
for j in range(len(edit_sym_per_hyp)):
nerr, nsym = rng.choice(edit_sym_per_hyp)
num_sym += nsym
num_errs += nerr
wer_rep = num_errs / num_sym
wer_accum += wer_rep
wer_mult_accum += wer_rep**2
def _convert_to_int(
ref: Sequence[Sequence[Symbol]],
hyp: Sequence[Sequence[Symbol]],
hyp2: Sequence[Sequence[Symbol]] = None,
) -> Tuple[List[List[Symbol]], ...]:
sources = [ref, hyp]
if hyp2 is not None:
sources.append(hyp2)

mean = wer_accum / replications
_tmp = wer_mult_accum / replications - mean**2
if _tmp < 0:
interval = 0
else:
interval = 1.96 * math.sqrt(_tmp)

return mean, interval


def _get_p_improv(edit_sym_per_hyp, edit_sym_per_hyp2, replications, seed):
rng = random.Random(seed)

improv_accum = 0
for i in range(replications):
num_errs = 0
for j in range(len(edit_sym_per_hyp)):
pos = rng.randint(0, len(edit_sym_per_hyp) - 1)
num_errs += edit_sym_per_hyp[pos][0] - edit_sym_per_hyp2[pos][0]
if num_errs > 0:
improv_accum += 1
symbols = sorted(
set(symbol for source in sources for seq in source for symbol in seq)
)
int2sym = dict(enumerate(symbols))
sym2int = {v: k for k, v in int2sym.items()}

return improv_accum / replications
ints = [[[sym2int[item] for item in seq] for seq in source] for source in sources]
if hyp2 is None:
ints.append(None)
return tuple(ints)
2 changes: 1 addition & 1 deletion scripts/conda/kaldialign/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package:
name: kaldialign
version: "0.8.1"
version: "0.9"

source:
path: "{{ environ.get('KALDIALIGN_ROOT_DIR') }}"
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def get_package_version():
long_description_content_type="text/markdown",
ext_modules=[cmake_extension("_kaldialign")],
cmdclass={"build_ext": BuildExtension},
extras_require={"test": ["pytest"]},
keywords=[
"natural language processing",
"speech recognition",
Expand Down
Loading