Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions c/Makefile
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
CFLAGS=-I../src
LDFLAGS=-L../src -lvosk -ldl -lpthread -Wl,-rpath,../src

all: test_vosk test_vosk_speaker
all: test_vosk test_vosk_speaker test_phone_results

test_vosk: test_vosk.o
gcc $^ -o $@ $(LDFLAGS)

test_vosk_speaker: test_vosk_speaker.o
gcc $^ -o $@ $(LDFLAGS)

test_phone_results: test_phone_results.o
g++ $^ -o $@ $(LDFLAGS)
%.o: %.c
gcc $(CFLAGS) -c -o $@ $<

clean:
rm -f *.o *.a test_vosk test_vosk_speaker
rm -f *.o *.a test_vosk test_vosk_speaker test_phone_results
30 changes: 30 additions & 0 deletions c/test_phone_results.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include <vosk_api.h>
#include <stdio.h>

int main() {
FILE *wavin;
char buf[3200];
int nread, final;

VoskModel *model = vosk_model_new("model");
VoskRecognizer *recognizer = vosk_recognizer_new(model, 16000.0);
vosk_recognizer_set_result_options(recognizer, "phones");

wavin = fopen("test.wav", "rb");
fseek(wavin, 44, SEEK_SET);
while (!feof(wavin)) {
nread = fread(buf, 1, sizeof(buf), wavin);
final = vosk_recognizer_accept_waveform(recognizer, buf, nread);
if (final) {
printf("%s\n", vosk_recognizer_result(recognizer));
} else {
printf("%s\n", vosk_recognizer_partial_result(recognizer));
}
}
printf("%s\n", vosk_recognizer_final_result(recognizer));

vosk_recognizer_free(recognizer);
vosk_model_free(model);
fclose(wavin);
return 0;
}
3 changes: 3 additions & 0 deletions python/vosk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ def SetMaxAlternatives(self, max_alternatives):
def SetWords(self, enable_words):
_c.vosk_recognizer_set_words(self._handle, 1 if enable_words else 0)

def SetResultOptions(self, options):
_c.vosk_recognizer_set_result_options(self._handle, options.encode("utf-8"))

def SetPartialWords(self, enable_partial_words):
_c.vosk_recognizer_set_partial_words(self._handle, 1 if enable_partial_words else 0)

Expand Down
12 changes: 12 additions & 0 deletions src/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ void Model::ConfigureV1()
rnnlm_feat_embedding_rxfilename_ = model_path_str_ + "/rnnlm/feat_embedding.final.mat";
rnnlm_config_rxfilename_ = model_path_str_ + "/rnnlm/special_symbol_opts.conf";
rnnlm_lm_rxfilename_ = model_path_str_ + "/rnnlm/final.raw";
phone_syms_rxfilename_ = model_path_str_ + "/graph/phones.txt";
}

void Model::ConfigureV2()
Expand Down Expand Up @@ -204,6 +205,7 @@ void Model::ConfigureV2()
rnnlm_feat_embedding_rxfilename_ = model_path_str_ + "/rnnlm/feat_embedding.final.mat";
rnnlm_config_rxfilename_ = model_path_str_ + "/rnnlm/special_symbol_opts.conf";
rnnlm_lm_rxfilename_ = model_path_str_ + "/rnnlm/final.raw";
phone_syms_rxfilename_ = model_path_str_ + "/graph/phones.txt";
}

void Model::ReadDataFiles()
Expand Down Expand Up @@ -305,6 +307,16 @@ void Model::ReadDataFiles()
winfo_ = new kaldi::WordBoundaryInfo(opts, winfo_rxfilename_);
}

phone_symbol_table_ = NULL;
phone_syms_loaded_ = false;
//Providing phones.txt symbol table is optional and currently not required by Vosk
//If you provide it by default the phone information will be computed
if (stat(phone_syms_rxfilename_.c_str(), &buffer) == 0) {
KALDI_LOG << "Loading phonemes from " << phone_syms_rxfilename_;
phone_symbol_table_ = fst::SymbolTable::ReadText(phone_syms_rxfilename_);
phone_syms_loaded_ = true;
}

if (stat(carpa_rxfilename_.c_str(), &buffer) == 0) {

KALDI_LOG << "Loading subtract G.fst model from " << std_fst_rxfilename_;
Expand Down
3 changes: 3 additions & 0 deletions src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class Model {
string fbank_conf_rxfilename_;
string global_cmvn_stats_rxfilename_;
string pitch_conf_rxfilename_;
string phone_syms_rxfilename_;

string rnnlm_word_feats_rxfilename_;
string rnnlm_feat_embedding_rxfilename_;
Expand All @@ -87,6 +88,8 @@ class Model {
bool word_syms_loaded_ = false;
kaldi::WordBoundaryInfo *winfo_ = nullptr;
vector<int32> disambig_;
const fst::SymbolTable *phone_symbol_table_;
bool phone_syms_loaded_;

fst::Fst<fst::StdArc> *hclg_fst_ = nullptr;
fst::Fst<fst::StdArc> *hcl_fst_ = nullptr;
Expand Down
121 changes: 112 additions & 9 deletions src/recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "lat/sausages.h"
#include "language_model.h"


using namespace fst;
using namespace kaldi::nnet3;

Expand Down Expand Up @@ -241,6 +242,11 @@ void Recognizer::SetMaxAlternatives(int max_alternatives)
max_alternatives_ = max_alternatives;
}

void Recognizer::SetResultOptions(const char *result_opts)
{
result_opts_ = result_opts;
}

void Recognizer::SetWords(bool words)
{
words_ = words;
Expand Down Expand Up @@ -427,6 +433,35 @@ static void CopyLatticeForMbr(CompactLattice &lat, CompactLattice *lat_out)
TopSortCompactLatticeIfNeeded(lat_out);
}

void ComputePhoneInfo(const TransitionModel &tmodel, const CompactLattice &clat, const fst::SymbolTable &word_syms_, const fst::SymbolTable &phone_symbol_table_, std::vector<std::vector<std::string> > *phoneme_labels, std::vector<std::vector<int32> > *phone_lengths)
{
//This function computes the phone information i.e. phone labels and lengths
vector<int32> words_ph_ids, times_lat, lengths;
vector<vector<int32> > prons;

kaldi::CompactLattice best_path;
kaldi::CompactLatticeShortestPath(clat, &best_path);

kaldi::CompactLatticeToWordProns(tmodel, best_path, &words_ph_ids, &times_lat, &lengths,
&prons, phone_lengths);


for (size_t z = 0; z < words_ph_ids.size(); z++) {

//auto word_str = word_syms_.Find(words_ph_ids[z]);
string phone_label_str;
std::vector<std::string> word2phn;

for (size_t j = 0; j < prons[z].size(); j++) {

auto phone_str = phone_symbol_table_.Find(prons[z][j]);
word2phn.push_back(phone_str);
}
phoneme_labels->push_back(word2phn);
}

}

const char *Recognizer::MbrResult(CompactLattice &rlat)
{

Expand All @@ -437,33 +472,101 @@ const char *Recognizer::MbrResult(CompactLattice &rlat)
CopyLatticeForMbr(rlat, &aligned_lat);
}

MinimumBayesRisk mbr(aligned_lat);
kaldi::MinimumBayesRiskOptions mbr_options;

std::vector<std::vector<std::string> > phoneme_labels;
std::vector<std::vector<int32> > phone_lengths;
int phon_vec_size = 1;

if (model_->phone_syms_loaded_ && (result_opts_ == "phones")){
//Compute phone info if phone symbol table is provided

ComputePhoneInfo(*model_->trans_model_, aligned_lat, *model_->word_syms_, *model_->phone_symbol_table_, &phoneme_labels, &phone_lengths);
phon_vec_size = phoneme_labels.size();
mbr_options.print_silence = true; //Print silences in the word-level outputs only if you need phone outputs
mbr_options.decode_mbr = false; // Turn off MBR decoding if you want to print out phone information
}

MinimumBayesRisk mbr(aligned_lat, mbr_options);
const vector<BaseFloat> &conf = mbr.GetOneBestConfidences();
const vector<int32> &words = mbr.GetOneBest();
const vector<int32> &word_ids = mbr.GetOneBest();
const vector<pair<BaseFloat, BaseFloat> > &times =
mbr.GetOneBestTimes();

int size = words.size();
int size = word_ids.size();

json::JSON obj;
stringstream text;

int phone_ptr=0;
// Create JSON object
for (int i = 0; i < size; i++) {
json::JSON word;

if (words_) {
word["word"] = model_->word_syms_->Find(words[i]);
word["word"] = model_->word_syms_->Find(word_ids[i]);
word["start"] = samples_round_start_ / sample_frequency_ + (frame_offset_ + times[i].first) * 0.03;
word["end"] = samples_round_start_ / sample_frequency_ + (frame_offset_ + times[i].second) * 0.03;
word["conf"] = conf[i];
obj["result"].append(word);
}


//When printing silences some extra silence words that we call "gaps" that have length of 0 seconds
//get printed out so we filter them out
//It is possible to have trailing silences in the word output that are not present in the phone output so we
//filter them to generate consistent outputs
if ((samples_round_start_ / sample_frequency_ + (frame_offset_ + (times[i].second-times[i].first)) * 0.03) > 0.0 && phone_ptr < phon_vec_size) {

word["word"] = model_->word_syms_->Find(word_ids[i]);
word["start"] = samples_round_start_ / sample_frequency_ + (frame_offset_ + times[i].first) * 0.03;
word["end"] = samples_round_start_ / sample_frequency_ + (frame_offset_ + times[i].second) * 0.03;
word["conf"] = conf[i];

if (i) {
text << " ";
}
text << model_->word_syms_->Find(words[i]);
//Add phone info to json if phone symbol table is provided
if (model_->phone_syms_loaded_ && (result_opts_ == "phones")){
kaldi::BaseFloat phone_start_time = 0.0;
kaldi::BaseFloat phone_end_time = 0.0;

//If there are silences without phone output (since they are coming from different places) then set the label and timestamps
if (word_ids[i] == 0 && phoneme_labels[phone_ptr][0] != "SIL"){
word["phone_label"].append( "SIL" );
phone_start_time=samples_round_start_ / sample_frequency_ + (frame_offset_ + times[i].first) * 0.03;
phone_end_time = samples_round_start_ / sample_frequency_ + (frame_offset_ + times[i].second) * 0.03;
word["phone_start"].append( phone_start_time );
word["phone_end"].append( phone_end_time );
}

//Else add the information generated from ComputePhoneInfo to results
else {
for ( auto phone: phoneme_labels[phone_ptr]){

word["phone_label"].append( phone );
}
for (int x=0; x < phone_lengths[phone_ptr].size(); x++){
if (x==0){
phone_start_time=samples_round_start_ / sample_frequency_ + (frame_offset_ + times[i].first) * 0.03;
phone_end_time = phone_start_time + (phone_lengths[phone_ptr][x]) * 0.03;
}
else{
phone_start_time = phone_end_time;
phone_end_time = phone_start_time + (phone_lengths[phone_ptr][x]) * 0.03;
}
word["phone_start"].append( phone_start_time );
word["phone_end"].append( phone_end_time );
}
phone_ptr += 1;
}
}

obj["result"].append(word);

if (word_ids[i] != 0){ // Don't print silence symbols
if (i) {
text << " ";
}
text << model_->word_syms_->Find(word_ids[i]);
}
}
}
obj["text"] = text.str();

Expand Down
7 changes: 6 additions & 1 deletion src/recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@
#include "nnet3/am-nnet-simple.h"
#include "nnet3/nnet-am-decodable-simple.h"
#include "nnet3/nnet-utils.h"
#include "lat/lattice-functions-transition-model.h"

#include "model.h"
#include "spk_model.h"

#include <string>

using namespace kaldi;

enum RecognizerState {
Expand All @@ -47,6 +50,7 @@ class Recognizer {
Recognizer(Model *model, float sample_frequency, char const *grammar);
~Recognizer();
void SetMaxAlternatives(int max_alternatives);
void SetResultOptions(const char *result_opts);
void SetSpkModel(SpkModel *spk_model);
void SetWords(bool words);
void SetPartialWords(bool partial_words);
Expand All @@ -58,7 +62,7 @@ class Recognizer {
const char* FinalResult();
const char* PartialResult();
void Reset();

private:
void InitState();
void InitRescoring();
Expand Down Expand Up @@ -96,6 +100,7 @@ class Recognizer {

// Other
int max_alternatives_ = 0; // Disable alternatives by default
std::string result_opts_ = "words"; // By default enable only word-level results
bool words_ = false;
bool partial_words_ = false;
bool nlsml_ = false;
Expand Down
5 changes: 5 additions & 0 deletions src/vosk_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ void vosk_recognizer_set_max_alternatives(VoskRecognizer *recognizer, int max_al
((Recognizer *)recognizer)->SetMaxAlternatives(max_alternatives);
}

void vosk_recognizer_set_result_options(VoskRecognizer *recognizer, const char *result_opts)
{
((Recognizer *)recognizer)->SetResultOptions(result_opts);
}

void vosk_recognizer_set_words(VoskRecognizer *recognizer, int words)
{
((Recognizer *)recognizer)->SetWords((bool)words);
Expand Down
Loading