diff --git a/runtime/android/.gitignore b/runtime/android/.gitignore new file mode 100644 index 00000000..ce9a2ea2 --- /dev/null +++ b/runtime/android/.gitignore @@ -0,0 +1,8 @@ +.gradle/ +build/ +local.properties +*.iml +.idea/ +.cxx/ +.externalNativeBuild/ +captures/ diff --git a/runtime/android/README.md b/runtime/android/README.md new file mode 100644 index 00000000..f503e467 --- /dev/null +++ b/runtime/android/README.md @@ -0,0 +1,41 @@ +# WeSpeaker Android Speaker Verification Demo + +This app extracts speaker embeddings from two **16 kHz PCM WAV** clips on device, computes cosine similarity (same mapping to 0–1 as desktop `runtime/onnxruntime` `asv_main`), and compares against a threshold to decide same vs different speaker. + +## ONNX model + +Export ONNX on a PC following the repo docs: + +```bash +python wespeaker/bin/export_onnx.py \ + --config $exp/config.yaml \ + --checkpoint $exp/avg_model.pt \ + --output_model final.onnx +``` + +Copy `final.onnx` to `app/src/main/assets/final.onnx` and build (filename must match). + +## Build + +**JDK 17 or newer** is required (Android Gradle Plugin 8.x). If only Java 8 is installed, install JDK 17 or pick the bundled JDK under Android Studio *Settings → Build → Gradle → Gradle JDK*. + +Open `runtime/android` in Android Studio, or use the Gradle wrapper: + +```bash +cd runtime/android +./gradlew :app:assembleDebug +``` + +The app depends on [ONNX Runtime Android](https://github.com/microsoft/onnxruntime) (`onnxruntime-android` AAR). Native integration matches [wekws/runtime/android](https://github.com/wenet-e2e/wekws/tree/main/runtime/android): a resolvable `extractForNativeBuild` configuration unpacks `headers/` and `jni/` from the AAR; CMake uses `include_directories` and links `libonnxruntime.so` (no Prefab / `find_package`). App logic reuses this repo’s `runtime/core` Fbank, `SpeakerEngine`, and ONNX backend. + +## Usage + +1. After installing the APK, pick enroll and test WAV files (16 kHz recommended; other rates are not resampled in-app and may hurt quality). +2. Tune threshold, embedding dim, and chunk samples to match training/export settings. +3. Tap **Compare** to see similarity score and same/different verdict. + +## Notes + +- If CMake reports missing `onnxruntime*.aar` extract dir, **Sync** and run a full **Build** so `extractAARForNativeBuild` runs before `configureCMake` (same idea as wekws). +- First inference copies the model from assets to app-private storage; ensure `assets/final.onnx` exists and is non-empty. +- Default threshold is 0.5; tune on your validation set. diff --git a/runtime/android/app/proguard-rules.pro b/runtime/android/app/proguard-rules.pro new file mode 100644 index 00000000..834134db --- /dev/null +++ b/runtime/android/app/proguard-rules.pro @@ -0,0 +1,4 @@ +# WeSpeaker JNI +-keepclasseswithmembernames class com.wespeaker.app.WespeakerNative { + native ; +} diff --git a/runtime/android/app/src/main/AndroidManifest.xml b/runtime/android/app/src/main/AndroidManifest.xml new file mode 100644 index 00000000..2117714d --- /dev/null +++ b/runtime/android/app/src/main/AndroidManifest.xml @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + diff --git a/runtime/android/app/src/main/assets/README.txt b/runtime/android/app/src/main/assets/README.txt new file mode 100644 index 00000000..88b90d44 --- /dev/null +++ b/runtime/android/app/src/main/assets/README.txt @@ -0,0 +1,2 @@ +Copy the final.onnx produced by export_onnx.py into this directory (filename must be final.onnx). +After rebuild and install, the app copies it from assets to internal storage for native inference. diff --git a/runtime/android/app/src/main/cpp/CMakeLists.txt b/runtime/android/app/src/main/cpp/CMakeLists.txt new file mode 100644 index 00000000..5e07b121 --- /dev/null +++ b/runtime/android/app/src/main/cpp/CMakeLists.txt @@ -0,0 +1,55 @@ +cmake_minimum_required(VERSION 3.22.1) +project(wespeaker_jni) + +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +# runtime/core: from this dir, five levels up to runtime/, then into core/ +set(CORE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../../../core") + +include(deps.cmake) + +# Same as wekws/runtime/android: headers and jni from Gradle-extracted AAR (no Prefab find_package). +# Typical path: app/build/onnxruntime-android-x.y.z.aar/ +set(build_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../build") +file(GLOB ORT_ROOT_LIST "${build_DIR}/onnxruntime*.aar") +list(LENGTH ORT_ROOT_LIST _ort_len) +if(_ort_len EQUAL 0) + message(FATAL_ERROR + "No ${build_DIR}/onnxruntime*.aar found. Run Gradle task extractAARForNativeBuild first (pulled in before configureCMake).") +endif() +list(GET ORT_ROOT_LIST 0 ORT_ROOT) +# onnxruntime-android AAR: headers live under headers/ (not headers/include/). +include_directories("${ORT_ROOT}/headers") +link_directories("${ORT_ROOT}/jni/${ANDROID_ABI}") + +add_definitions(-DUSE_ONNX) + +add_library(utils STATIC "${CORE_DIR}/utils/utils.cc") +target_include_directories(utils PUBLIC "${CORE_DIR}") +target_link_libraries(utils PUBLIC glog gflags) + +add_library(frontend STATIC + "${CORE_DIR}/frontend/feature_pipeline.cc" + "${CORE_DIR}/frontend/fft.cc" +) +target_include_directories(frontend PUBLIC "${CORE_DIR}") +target_link_libraries(frontend PUBLIC utils) + +add_library(speaker STATIC + "${CORE_DIR}/speaker/speaker_engine.cc" + "${CORE_DIR}/speaker/onnx_speaker_model.cc" +) +target_include_directories(speaker PUBLIC "${CORE_DIR}") +target_link_libraries(speaker PUBLIC frontend onnxruntime) + +add_library(wespeaker_jni SHARED wespeaker_jni.cpp) +target_include_directories(wespeaker_jni PRIVATE "${CORE_DIR}") +target_link_libraries(wespeaker_jni + PRIVATE + speaker + onnxruntime + glog + gflags + log +) diff --git a/runtime/android/app/src/main/cpp/deps.cmake b/runtime/android/app/src/main/cpp/deps.cmake new file mode 100644 index 00000000..7a1dea99 --- /dev/null +++ b/runtime/android/app/src/main/cpp/deps.cmake @@ -0,0 +1,34 @@ +# Android NDK: FetchContent for gflags / glog. +include(FetchContent) +set(FETCHCONTENT_QUIET ON) + +FetchContent_Declare(gflags + URL https://github.com/gflags/gflags/archive/refs/tags/v2.3.0.zip + URL_HASH SHA256=ca732b5fd17bf3a27a01a6784b947cbe6323644ecc9e26bbe2117ec43bf7e13b) +FetchContent_MakeAvailable(gflags) + +set(BUILD_TESTING OFF CACHE BOOL "" FORCE) +set(WITH_GFLAGS ON CACHE BOOL "" FORCE) + +FetchContent_Declare(glog + URL https://github.com/google/glog/archive/v0.4.0.zip + URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc) +FetchContent_GetProperties(glog) +if(NOT glog_POPULATED) + FetchContent_Populate(glog) + file(READ ${glog_SOURCE_DIR}/CMakeLists.txt _glog_cm) + # glog 0.4.0: bump cmake_minimum for CMake 4+; on Android, execinfo probe can pass but link fails. + string(REGEX REPLACE + "cmake_minimum_required[ ]*\\([ ]*VERSION[ ]+[^)]+\\)" + "cmake_minimum_required(VERSION 3.10)" _glog_cm "${_glog_cm}") + if(ANDROID) + string(REPLACE + "check_include_file (execinfo.h HAVE_EXECINFO_H)" + "if(ANDROID)\n set(HAVE_EXECINFO_H 0)\nelse()\n check_include_file (execinfo.h HAVE_EXECINFO_H)\nendif()" + _glog_cm "${_glog_cm}") + endif() + file(WRITE ${glog_SOURCE_DIR}/CMakeLists.txt "${_glog_cm}") + add_subdirectory(${glog_SOURCE_DIR} ${glog_BINARY_DIR}) +endif() + +include_directories(${gflags_BINARY_DIR}/include ${glog_SOURCE_DIR}/src ${glog_BINARY_DIR}) diff --git a/runtime/android/app/src/main/cpp/wespeaker_jni.cpp b/runtime/android/app/src/main/cpp/wespeaker_jni.cpp new file mode 100644 index 00000000..dcb380a2 --- /dev/null +++ b/runtime/android/app/src/main/cpp/wespeaker_jni.cpp @@ -0,0 +1,110 @@ +// Copyright 2023 Chengdong Liang (WeSpeaker runtime) +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include +#include + +#include "frontend/wav.h" +#include "glog/logging.h" +#include "speaker/speaker_engine.h" + +namespace { + +std::once_flag g_glog_init; + +void EnsureGlog() { + std::call_once(g_glog_init, []() { + google::InitGoogleLogging("wespeaker"); + FLAGS_logtostderr = 1; + FLAGS_minloglevel = 2; + }); +} + +jfloatArray MakeFloatArray(JNIEnv* env, float a, float b) { + jfloatArray out = env->NewFloatArray(2); + if (!out) return nullptr; + jfloat buf[2] = {a, b}; + env->SetFloatArrayRegion(out, 0, 2, buf); + return out; +} + +void ThrowIo(JNIEnv* env, const char* msg) { + jclass ex = env->FindClass("java/io/IOException"); + if (ex) env->ThrowNew(ex, msg); +} + +std::string JStringToUtf8(JNIEnv* env, jstring s) { + if (!s) return {}; + const char* p = env->GetStringUTFChars(s, nullptr); + std::string out(p ? p : ""); + if (p) env->ReleaseStringUTFChars(s, p); + return out; +} + +} // namespace + +extern "C" JNIEXPORT jfloatArray JNICALL +Java_com_wespeaker_app_WespeakerNative_compare(JNIEnv* env, jclass /* clazz */, + jstring j_enroll, jstring j_test, + jstring j_model, + jdouble j_threshold, + jint j_fbank_dim, + jint j_sample_rate) { + EnsureGlog(); + + const std::string enroll_path = JStringToUtf8(env, j_enroll); + const std::string test_path = JStringToUtf8(env, j_test); + const std::string model_path = JStringToUtf8(env, j_model); + const float threshold = static_cast(j_threshold); + + if (enroll_path.empty() || test_path.empty() || model_path.empty()) { + ThrowIo(env, "路径不能为空"); + return nullptr; + } + + try { + wenet::WavReader enroll_reader; + if (!enroll_reader.Open(enroll_path)) { + ThrowIo(env, "无法打开注册音频(需有效 WAV)"); + return nullptr; + } + wenet::WavReader test_reader; + if (!test_reader.Open(test_path)) { + ThrowIo(env, "无法打开测试音频(需有效 WAV)"); + return nullptr; + } + if (enroll_reader.num_sample() <= 0 || test_reader.num_sample() <= 0) { + ThrowIo(env, "音频长度无效"); + return nullptr; + } + + auto speaker_engine = std::make_shared( + model_path, j_fbank_dim, j_sample_rate, + 0 /* embedding size: infer from ONNX output shape */, + -1 /* one embedding for full audio; same as per_chunk_samples_ <= 0 */); + const int embedding_size = speaker_engine->EmbeddingSize(); + + int16_t* enroll_data = const_cast(enroll_reader.data()); + const int enroll_samples = enroll_reader.num_sample(); + int16_t* test_data = const_cast(test_reader.data()); + const int test_samples = test_reader.num_sample(); + + std::vector enroll_emb(embedding_size, 0.f); + std::vector test_emb(embedding_size, 0.f); + speaker_engine->ExtractEmbedding(enroll_data, enroll_samples, &enroll_emb); + speaker_engine->ExtractEmbedding(test_data, test_samples, &test_emb); + + const float score = speaker_engine->CosineSimilarity(enroll_emb, test_emb); + const float same = (score >= threshold) ? 1.f : 0.f; + return MakeFloatArray(env, score, same); + } catch (const std::exception& e) { + ThrowIo(env, e.what()); + return nullptr; + } catch (...) { + ThrowIo(env, "native 推理异常"); + return nullptr; + } +} diff --git a/runtime/android/app/src/main/java/com/wespeaker/app/MainActivity.kt b/runtime/android/app/src/main/java/com/wespeaker/app/MainActivity.kt new file mode 100644 index 00000000..7d9b4d21 --- /dev/null +++ b/runtime/android/app/src/main/java/com/wespeaker/app/MainActivity.kt @@ -0,0 +1,284 @@ +package com.wespeaker.app + +import android.Manifest +import android.content.pm.PackageManager +import android.os.Bundle +import android.widget.Toast +import androidx.activity.result.contract.ActivityResultContracts +import androidx.appcompat.app.AppCompatActivity +import androidx.core.content.ContextCompat +import androidx.lifecycle.lifecycleScope +import com.google.android.material.dialog.MaterialAlertDialogBuilder +import com.wespeaker.app.audio.MicRecorder +import com.wespeaker.app.audio.WavNormalize +import com.wespeaker.app.databinding.ActivityMainBinding +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import java.io.File + +class MainActivity : AppCompatActivity() { + + private lateinit var binding: ActivityMainBinding + private var enrollFile: File? = null + private var testFile: File? = null + + private val micRecorder = MicRecorder() + /** 0=enroll, 1=test, null=not recording */ + private var recordingSlot: Int? = null + private var captureSampleRate: Int = WavNormalize.TARGET_SAMPLE_RATE + + private var pendingAfterMicPermission: (() -> Unit)? = null + + private val requestMicPermission = + registerForActivityResult(ActivityResultContracts.RequestPermission()) { granted -> + val run = pendingAfterMicPermission + pendingAfterMicPermission = null + if (granted) { + run?.invoke() + } else { + Toast.makeText(this, "需要麦克风权限才能录制", Toast.LENGTH_SHORT).show() + } + } + + private val pickEnroll = + registerForActivityResult(ActivityResultContracts.OpenDocument()) { uri -> + if (uri != null) { + lifecycleScope.launch { + prepareWavFromUri(uri, isEnroll = true) + } + } + } + + private val pickTest = + registerForActivityResult(ActivityResultContracts.OpenDocument()) { uri -> + if (uri != null) { + lifecycleScope.launch { + prepareWavFromUri(uri, isEnroll = false) + } + } + } + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + binding = ActivityMainBinding.inflate(layoutInflater) + setContentView(binding.root) + + binding.btnEnroll.setOnClickListener { + pickEnroll.launch(arrayOf("audio/*", "audio/wav", "audio/x-wav")) + } + binding.btnTest.setOnClickListener { + pickTest.launch(arrayOf("audio/*", "audio/wav", "audio/x-wav")) + } + binding.btnEnrollRecord.setOnClickListener { toggleMicRecording(isEnroll = true) } + binding.btnTestRecord.setOnClickListener { toggleMicRecording(isEnroll = false) } + binding.btnCompare.setOnClickListener { runCompare() } + } + + private suspend fun prepareWavFromUri(uri: android.net.Uri, isEnroll: Boolean) { + val tv = if (isEnroll) binding.tvEnroll else binding.tvTest + try { + tv.text = getString(R.string.processing_audio) + val bytes = withContext(Dispatchers.IO) { + contentResolver.openInputStream(uri)!!.use { it.readBytes() } + } + val pcm = WavNormalize.wavBytesToMono16k(bytes) + val out = File(cacheDir, if (isEnroll) "enroll.wav" else "test.wav") + withContext(Dispatchers.IO) { + WavNormalize.writeMono16Wav(out, pcm.samples) + } + if (isEnroll) { + enrollFile = out + } else { + testFile = out + } + tv.text = getString( + R.string.label_audio_ready, + out.name, + pcm.samples.size, + ) + } catch (e: Exception) { + if (isEnroll) enrollFile = null else testFile = null + tv.text = getString(R.string.audio_prepare_failed, e.message ?: e.javaClass.simpleName) + } + } + + private fun toggleMicRecording(isEnroll: Boolean) { + val wantSlot = if (isEnroll) 0 else 1 + if (recordingSlot != null && recordingSlot != wantSlot) { + Toast.makeText(this, R.string.stop_other_recording_first, Toast.LENGTH_SHORT).show() + return + } + if (recordingSlot == wantSlot) { + stopMicAndSave(isEnroll) + return + } + ensureMicPermissionThen { + val sr = micRecorder.preferredSampleRate() + captureSampleRate = sr + if (!micRecorder.start(sr)) { + Toast.makeText(this, R.string.mic_open_failed, Toast.LENGTH_SHORT).show() + return@ensureMicPermissionThen + } + recordingSlot = wantSlot + refreshRecordingUi() + } + } + + private fun ensureMicPermissionThen(block: () -> Unit) { + when { + ContextCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO) == + PackageManager.PERMISSION_GRANTED -> block() + else -> { + pendingAfterMicPermission = block + requestMicPermission.launch(Manifest.permission.RECORD_AUDIO) + } + } + } + + private fun stopMicAndSave(isEnroll: Boolean) { + lifecycleScope.launch(Dispatchers.Default) { + val sr = captureSampleRate + val samples = micRecorder.stop() + recordingSlot = null + withContext(Dispatchers.Main) { refreshRecordingUi() } + if (samples.isEmpty()) { + withContext(Dispatchers.Main) { + Toast.makeText(this@MainActivity, "未采集到有效音频", Toast.LENGTH_SHORT).show() + } + return@launch + } + val mono16k = WavNormalize.monoPcmTo16k(samples, sr) + val out = File(cacheDir, if (isEnroll) "enroll.wav" else "test.wav") + withContext(Dispatchers.IO) { + WavNormalize.writeMono16Wav(out, mono16k) + } + if (isEnroll) { + enrollFile = out + } else { + testFile = out + } + val tv = if (isEnroll) binding.tvEnroll else binding.tvTest + withContext(Dispatchers.Main) { + tv.text = getString( + R.string.label_audio_ready, + out.name, + mono16k.size, + ) + } + } + } + + private fun refreshRecordingUi() { + val enrollRec = recordingSlot == 0 + val testRec = recordingSlot == 1 + binding.btnEnrollRecord.text = if (enrollRec) { + getString(R.string.stop_recording) + } else { + getString(R.string.record_enroll) + } + binding.btnTestRecord.text = if (testRec) { + getString(R.string.stop_recording) + } else { + getString(R.string.record_test) + } + binding.btnEnroll.isEnabled = recordingSlot == null + binding.btnTest.isEnabled = recordingSlot == null + binding.btnEnrollRecord.isEnabled = !testRec + binding.btnTestRecord.isEnabled = !enrollRec + binding.btnCompare.isEnabled = recordingSlot == null + if (enrollRec) { + binding.tvEnroll.text = getString(R.string.recording_hint_enroll) + } + if (testRec) { + binding.tvTest.text = getString(R.string.recording_hint_test) + } + } + + private fun runCompare() { + val e = enrollFile + val t = testFile + if (e == null || !e.exists() || t == null || !t.exists()) { + Toast.makeText(this, "请先选择或录制两段音频", Toast.LENGTH_SHORT).show() + return + } + val threshold = binding.etThreshold.text?.toString()?.toDoubleOrNull() ?: 0.5 + + binding.btnCompare.isEnabled = false + + lifecycleScope.launch { + try { + val modelFile = withContext(Dispatchers.IO) { copyModelFromAssetsIfPresent() } + if (modelFile == null) { + showCompareDialog( + getString(R.string.dialog_title_tip), + getString(R.string.model_missing), + ) + return@launch + } + val out = withContext(Dispatchers.Default) { + WespeakerNative.compare( + e.absolutePath, + t.absolutePath, + modelFile.absolutePath, + threshold, + 80, + 16000, + ) + } + if (out == null || out.size < 2) { + showCompareDialog( + getString(R.string.dialog_title_tip), + "推理失败(返回为空)", + ) + return@launch + } + val score = out[0] + val same = out[1] >= 0.5f + val verdict = + getString(if (same) R.string.verdict_same else R.string.verdict_diff) + showCompareDialog( + getString(R.string.dialog_title_compare_result), + getString( + R.string.dialog_msg_score_format, + "%.4f".format(score), + verdict, + ), + ) + } catch (ex: Exception) { + showCompareDialog( + getString(R.string.dialog_title_error), + ex.message ?: ex.javaClass.simpleName, + ) + } finally { + binding.btnCompare.isEnabled = true + } + } + } + + private fun showCompareDialog(title: String, message: String) { + MaterialAlertDialogBuilder(this) + .setTitle(title) + .setMessage(message) + .setPositiveButton(R.string.dialog_positive_ok, null) + .show() + } + + /** If assets contain final.onnx, copy to app files dir for native loading */ + private fun copyModelFromAssetsIfPresent(): File? { + val out = File(filesDir, "final.onnx") + if (out.exists() && out.length() > 0) return out + return try { + assets.open(ASSET_ONNX).use { input -> + java.io.FileOutputStream(out).use { output -> input.copyTo(output) } + } + out + } catch (_: Exception) { + null + } + } + + companion object { + private const val ASSET_ONNX = "final.onnx" + } +} diff --git a/runtime/android/app/src/main/java/com/wespeaker/app/WespeakerNative.kt b/runtime/android/app/src/main/java/com/wespeaker/app/WespeakerNative.kt new file mode 100644 index 00000000..923f77ba --- /dev/null +++ b/runtime/android/app/src/main/java/com/wespeaker/app/WespeakerNative.kt @@ -0,0 +1,21 @@ +package com.wespeaker.app + +object WespeakerNative { + init { + System.loadLibrary("wespeaker_jni") + } + + /** + * @return floatArrayOf(score, sameFlag) — sameFlag 1f means same speaker (score >= threshold). + * Embedding dim is inferred from ONNX output shape; features use the **full** utterance (no chunking). + */ + @JvmStatic + external fun compare( + enrollPath: String, + testPath: String, + modelPath: String, + threshold: Double, + fbankDim: Int, + sampleRate: Int, + ): FloatArray +} diff --git a/runtime/android/app/src/main/java/com/wespeaker/app/audio/MicRecorder.kt b/runtime/android/app/src/main/java/com/wespeaker/app/audio/MicRecorder.kt new file mode 100644 index 00000000..96167569 --- /dev/null +++ b/runtime/android/app/src/main/java/com/wespeaker/app/audio/MicRecorder.kt @@ -0,0 +1,103 @@ +package com.wespeaker.app.audio + +import android.media.AudioFormat +import android.media.AudioRecord +import android.media.MediaRecorder +import java.util.concurrent.atomic.AtomicBoolean + +/** + * Captures PCM via [AudioRecord]; prefers 16 kHz mono, falls back to device rates, then [WavNormalize] to 16 kHz. + */ +class MicRecorder { + + private var audioRecord: AudioRecord? = null + private var thread: Thread? = null + private val running = AtomicBoolean(false) + private val chunks = mutableListOf() + + fun preferredSampleRate(): Int { + val tryRates = intArrayOf(16000, 48000, 44100, 22050) + for (rate in tryRates) { + val min = AudioRecord.getMinBufferSize( + rate, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT, + ) + if (min > 0) return rate + } + return 44100 + } + + fun start(sampleRate: Int): Boolean { + if (running.get()) return false + chunks.clear() + val minBuf = AudioRecord.getMinBufferSize( + sampleRate, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT, + ) + if (minBuf <= 0) return false + val record = AudioRecord( + MediaRecorder.AudioSource.VOICE_RECOGNITION, + sampleRate, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT, + minBuf * 2, + ) + if (record.state != AudioRecord.STATE_INITIALIZED) { + record.release() + return false + } + audioRecord = record + running.set(true) + record.startRecording() + val bufSize = minBuf / 2 + val readBuf = ShortArray(bufSize.coerceAtLeast(256)) + thread = Thread({ + while (running.get()) { + val n = record.read(readBuf, 0, readBuf.size) + if (n > 0) { + synchronized(chunks) { + chunks.add(readBuf.copyOf(n)) + } + } else if (n < 0) { + break + } + } + }, "wespeaker-mic").also { it.start() } + return true + } + + fun stop(): ShortArray { + running.set(false) + val rec = audioRecord + if (rec != null) { + try { + rec.stop() + } catch (_: Exception) { + } + } + try { + thread?.join(5000) + } catch (_: InterruptedException) { + Thread.currentThread().interrupt() + } + thread = null + rec?.release() + audioRecord = null + val merged: List + synchronized(chunks) { + merged = chunks.toList() + chunks.clear() + } + val total = merged.sumOf { it.size } + if (total == 0) return ShortArray(0) + val out = ShortArray(total) + var o = 0 + for (c in merged) { + c.copyInto(out, o) + o += c.size + } + return out + } +} diff --git a/runtime/android/app/src/main/java/com/wespeaker/app/audio/WavNormalize.kt b/runtime/android/app/src/main/java/com/wespeaker/app/audio/WavNormalize.kt new file mode 100644 index 00000000..2a86c786 --- /dev/null +++ b/runtime/android/app/src/main/java/com/wespeaker/app/audio/WavNormalize.kt @@ -0,0 +1,196 @@ +package com.wespeaker.app.audio + +import java.io.File +import java.io.FileOutputStream +import java.nio.ByteBuffer +import java.nio.ByteOrder + +/** + * Converts common PCM / float WAV to model input: mono, 16 kHz, 16-bit PCM, and writes a standard WAV. + */ +object WavNormalize { + + const val TARGET_SAMPLE_RATE = 16000 + const val TARGET_CHANNELS = 1 + + data class PcmMono16(val samples: ShortArray, val sampleRate: Int) + + /** + * Parses WAV bytes to mono int16 (average channels if multi-channel; linear resample if not 16 kHz). + */ + fun wavBytesToMono16k(bytes: ByteArray): PcmMono16 { + val parsed = parseWav(bytes) + var mono = toMono(parsed.samples, parsed.channels) + mono = resampleLinear(mono, parsed.sampleRate, TARGET_SAMPLE_RATE) + return PcmMono16(mono, TARGET_SAMPLE_RATE) + } + + /** When input is already mono PCM, only resample to 16 kHz (mic capture path). */ + fun monoPcmTo16k(samples: ShortArray, sampleRate: Int): ShortArray { + if (sampleRate == TARGET_SAMPLE_RATE) return samples + return resampleLinear(samples, sampleRate, TARGET_SAMPLE_RATE) + } + + fun writeMono16Wav(file: File, samples: ShortArray, sampleRate: Int = TARGET_SAMPLE_RATE) { + val dataSize = samples.size * 2 + val riffSize = 36 + dataSize + FileOutputStream(file).use { os -> + val hdr = ByteBuffer.allocate(44).order(ByteOrder.LITTLE_ENDIAN) + hdr.put("RIFF".toByteArray()) + hdr.putInt(riffSize) + hdr.put("WAVE".toByteArray()) + hdr.put("fmt ".toByteArray()) + hdr.putInt(16) + hdr.putShort(1) // PCM + hdr.putShort(1) // mono + hdr.putInt(sampleRate) + hdr.putInt(sampleRate * 2) + hdr.putShort(2) + hdr.putShort(16) + hdr.put("data".toByteArray()) + hdr.putInt(dataSize) + os.write(hdr.array()) + val buf = ByteBuffer.allocate(samples.size * 2).order(ByteOrder.LITTLE_ENDIAN) + for (s in samples) buf.putShort(s) + os.write(buf.array()) + } + } + + private data class ParsedWav( + val samples: ShortArray, + val channels: Int, + val sampleRate: Int, + ) + + private fun parseWav(bytes: ByteArray): ParsedWav { + require(bytes.size >= 44) { "WAV 过短" } + require(String(bytes, 0, 4) == "RIFF") { "非 RIFF" } + require(String(bytes, 8, 4) == "WAVE") { "非 WAVE" } + + var pos = 12 + var audioFormat = 0 + var numChannels = 0 + var sampleRate = 0 + var bitsPerSample = 0 + var dataOffset = -1 + var dataSize = 0 + + while (pos + 8 <= bytes.size) { + val id = String(bytes, pos, 4) + val size = ByteBuffer.wrap(bytes, pos + 4, 4).order(ByteOrder.LITTLE_ENDIAN).int + val contentStart = pos + 8 + if (id == "fmt ") { + require(size >= 16) { "fmt 块无效" } + val bb = ByteBuffer.wrap(bytes, contentStart, size).order(ByteOrder.LITTLE_ENDIAN) + audioFormat = bb.short.toInt() and 0xffff + numChannels = bb.short.toInt() and 0xffff + sampleRate = bb.int + bb.int // byte rate + bb.short // block align + bitsPerSample = bb.short.toInt() and 0xffff + } else if (id == "data") { + dataOffset = contentStart + dataSize = size + break + } + pos = contentStart + size + (size and 1) + } + require(dataOffset >= 0 && dataSize > 0) { "缺少 data 块" } + require(numChannels in 1..16) { "声道数异常: $numChannels" } + + val samples: ShortArray = when (audioFormat) { + 1 -> when (bitsPerSample) { + 16 -> decodePcm16Interleaved(bytes, dataOffset, dataSize, numChannels) + 8 -> decodePcm8Interleaved(bytes, dataOffset, dataSize, numChannels) + else -> throw IllegalArgumentException("不支持的 PCM 位深: $bitsPerSample") + } + 3 -> { + require(bitsPerSample == 32) { "float WAV 需 32-bit" } + decodeFloat32Interleaved(bytes, dataOffset, dataSize, numChannels) + } + else -> throw IllegalArgumentException("不支持的 WAV 格式码: $audioFormat") + } + return ParsedWav(samples, numChannels, sampleRate) + } + + private fun decodePcm16Interleaved( + bytes: ByteArray, + offset: Int, + dataSize: Int, + channels: Int, + ): ShortArray { + val total = dataSize / 2 + val out = ShortArray(total) + val bb = ByteBuffer.wrap(bytes, offset, dataSize).order(ByteOrder.LITTLE_ENDIAN) + for (i in 0 until total) { + out[i] = bb.short + } + return out + } + + private fun decodePcm8Interleaved( + bytes: ByteArray, + offset: Int, + dataSize: Int, + channels: Int, + ): ShortArray { + val total = dataSize + val out = ShortArray(total) + for (i in 0 until total) { + val u = bytes[offset + i].toInt() and 0xff + out[i] = ((u - 128) * 256).coerceIn(Short.MIN_VALUE.toInt(), Short.MAX_VALUE.toInt()).toShort() + } + return out + } + + private fun decodeFloat32Interleaved( + bytes: ByteArray, + offset: Int, + dataSize: Int, + channels: Int, + ): ShortArray { + val floatCount = dataSize / 4 + val out = ShortArray(floatCount) + val bb = ByteBuffer.wrap(bytes, offset, dataSize).order(ByteOrder.LITTLE_ENDIAN) + for (i in 0 until floatCount) { + val f = bb.float + val s = (f * 32767.0f).toInt().coerceIn(Short.MIN_VALUE.toInt(), Short.MAX_VALUE.toInt()) + out[i] = s.toShort() + } + return out + } + + private fun toMono(interleaved: ShortArray, channels: Int): ShortArray { + if (channels == 1) return interleaved + val frames = interleaved.size / channels + val out = ShortArray(frames) + for (i in 0 until frames) { + var sum = 0L + for (c in 0 until channels) { + sum += interleaved[i * channels + c].toInt() + } + out[i] = (sum / channels).toInt().toShort() + } + return out + } + + private fun resampleLinear(input: ShortArray, srcRate: Int, dstRate: Int): ShortArray { + if (srcRate == dstRate || input.isEmpty()) return input + require(srcRate > 0 && dstRate > 0) + val outLen = ((input.size.toLong() * dstRate + srcRate / 2) / srcRate).toInt().coerceAtLeast(1) + val out = ShortArray(outLen) + for (i in 0 until outLen) { + val srcPos = (i.toDouble() * srcRate) / dstRate + val idx = srcPos.toInt().coerceIn(0, input.size - 1) + val frac = srcPos - idx + val s0 = input[idx].toDouble() + val s1 = if (idx + 1 < input.size) input[idx + 1].toDouble() else s0 + val v = (s0 + (s1 - s0) * frac).toInt().coerceIn( + Short.MIN_VALUE.toInt(), + Short.MAX_VALUE.toInt(), + ) + out[i] = v.toShort() + } + return out + } +} diff --git a/runtime/android/app/src/main/res/drawable/wenet_gradient_header.xml b/runtime/android/app/src/main/res/drawable/wenet_gradient_header.xml new file mode 100644 index 00000000..12a87fdb --- /dev/null +++ b/runtime/android/app/src/main/res/drawable/wenet_gradient_header.xml @@ -0,0 +1,8 @@ + + + + diff --git a/runtime/android/app/src/main/res/drawable/wenet_section_indicator.xml b/runtime/android/app/src/main/res/drawable/wenet_section_indicator.xml new file mode 100644 index 00000000..91e266e1 --- /dev/null +++ b/runtime/android/app/src/main/res/drawable/wenet_section_indicator.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/runtime/android/app/src/main/res/layout/activity_main.xml b/runtime/android/app/src/main/res/layout/activity_main.xml new file mode 100644 index 00000000..819e0032 --- /dev/null +++ b/runtime/android/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,263 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/runtime/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml b/runtime/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml new file mode 100644 index 00000000..9892f09e --- /dev/null +++ b/runtime/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/runtime/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml b/runtime/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml new file mode 100644 index 00000000..9892f09e --- /dev/null +++ b/runtime/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/runtime/android/app/src/main/res/values-night/colors.xml b/runtime/android/app/src/main/res/values-night/colors.xml new file mode 100644 index 00000000..db1a52a1 --- /dev/null +++ b/runtime/android/app/src/main/res/values-night/colors.xml @@ -0,0 +1,14 @@ + + + #101418 + #1A2329 + #2A3440 + #1E3A5F + #BBDEFB + #1B3D36 + #B2DFDB + #ECEFF1 + #90A4AE + #3D4A54 + #152028 + diff --git a/runtime/android/app/src/main/res/values/colors.xml b/runtime/android/app/src/main/res/values/colors.xml new file mode 100644 index 00000000..aa2fac71 --- /dev/null +++ b/runtime/android/app/src/main/res/values/colors.xml @@ -0,0 +1,41 @@ + + + + #FF000000 + #FFFFFFFF + + + #1565C0 + #0D47A1 + #42A5F5 + #FFFFFF + #E3F2FD + #0D47A1 + + + #00897B + #00695C + #FFFFFF + #E0F2F1 + #004D40 + + + #F0F4F8 + #FFFFFF + #ECEFF1 + #1C2833 + #1C2833 + #546E7A + #CFD8DC + + + #2E7D32 + #FAFCFE + + + #FFBB86FC + @color/wenet_primary + @color/wenet_primary_dark + @color/wenet_secondary_container + @color/wenet_secondary + diff --git a/runtime/android/app/src/main/res/values/dimens.xml b/runtime/android/app/src/main/res/values/dimens.xml new file mode 100644 index 00000000..8c3f7a6d --- /dev/null +++ b/runtime/android/app/src/main/res/values/dimens.xml @@ -0,0 +1,11 @@ + + + 20dp + 16dp + 16dp + 18dp + 12dp + 10dp + 10dp + 52dp + diff --git a/runtime/android/app/src/main/res/values/strings.xml b/runtime/android/app/src/main/res/values/strings.xml new file mode 100644 index 00000000..130a6081 --- /dev/null +++ b/runtime/android/app/src/main/res/values/strings.xml @@ -0,0 +1,30 @@ + + WeSpeaker + 说话人验证 · WeNet 社区 + 支持 WAV 文件或麦克风录制;送入模型前会统一为单声道 16 kHz PCM,与桌面端流程一致。 + 注册音频 + 测试音频 + 参数 + 选择 WAV 文件 + 选择 WAV 文件 + 麦克风录制 + 麦克风录制 + 停止录制 + 正在转换为单声道 16kHz… + 已就绪: %1$s(%2$d 采样点 @16kHz) + 音频处理失败: %1$s + 请先停止另一路录音 + 无法打开麦克风 + 正在录制注册音频… + 正在录制测试音频… + 比对 + 判定阈值 (0–1) + 请先将 export 的 final.onnx 放入 app/src/main/assets/ 并重新编译 + 比对结果 + 提示 + 错误 + 相似度得分:%1$s\n判定:%2$s + 同一人 + 不同人 + 确定 + diff --git a/runtime/android/app/src/main/res/values/styles.xml b/runtime/android/app/src/main/res/values/styles.xml new file mode 100644 index 00000000..547611b6 --- /dev/null +++ b/runtime/android/app/src/main/res/values/styles.xml @@ -0,0 +1,32 @@ + + + + + + + + + + + + diff --git a/runtime/android/app/src/main/res/values/themes.xml b/runtime/android/app/src/main/res/values/themes.xml new file mode 100644 index 00000000..98caca4d --- /dev/null +++ b/runtime/android/app/src/main/res/values/themes.xml @@ -0,0 +1,26 @@ + + + + diff --git a/runtime/android/gradle.properties b/runtime/android/gradle.properties new file mode 100644 index 00000000..526ab824 --- /dev/null +++ b/runtime/android/gradle.properties @@ -0,0 +1,7 @@ +org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8 +# Connection timeout (ms) when downloading Gradle or dependencies +systemProp.org.gradle.internal.http.connectionTimeout=120000 +systemProp.org.gradle.internal.http.socketTimeout=120000 +android.useAndroidX=true +android.nonTransitiveRClass=true +kotlin.code.style=official diff --git a/runtime/android/gradle/wrapper/gradle-wrapper.jar b/runtime/android/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 00000000..e6441136 Binary files /dev/null and b/runtime/android/gradle/wrapper/gradle-wrapper.jar differ diff --git a/runtime/android/gradle/wrapper/gradle-wrapper.properties b/runtime/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..586c672b --- /dev/null +++ b/runtime/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,8 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +# If services.gradle.org is slow, use a mirror (e.g. Tencent Cloud) with the same distribution URL hash. +distributionUrl=https\://mirrors.cloud.tencent.com/gradle/gradle-8.7-bin.zip +networkTimeout=120000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/runtime/android/gradlew b/runtime/android/gradlew new file mode 100755 index 00000000..20beb7f4 --- /dev/null +++ b/runtime/android/gradlew @@ -0,0 +1,120 @@ +#!/usr/bin/env sh + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +PRG="$0" +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +DEFAULT_JVM_OPTS="" + +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +if $cygwin ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` +fi + +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=$(save "$@") + +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then + cd "$(dirname "$0")" +fi + +exec "$JAVACMD" "$@" diff --git a/runtime/android/settings.gradle.kts b/runtime/android/settings.gradle.kts new file mode 100644 index 00000000..5573550e --- /dev/null +++ b/runtime/android/settings.gradle.kts @@ -0,0 +1,18 @@ +pluginManagement { + repositories { + google() + mavenCentral() + gradlePluginPortal() + } +} + +dependencyResolutionManagement { + repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) + repositories { + google() + mavenCentral() + } +} + +rootProject.name = "wespeaker-android" +include(":app") diff --git a/runtime/core/bin/asv_main.cc b/runtime/core/bin/asv_main.cc index f61a6ddb..0639ac02 100644 --- a/runtime/core/bin/asv_main.cc +++ b/runtime/core/bin/asv_main.cc @@ -27,8 +27,11 @@ DEFINE_double(threshold, 0.5, "Threshold"); DEFINE_string(speaker_model_path, "", "path of speaker model"); DEFINE_int32(fbank_dim, 80, "fbank feature dimension"); DEFINE_int32(sample_rate, 16000, "sample rate"); -DEFINE_int32(embedding_size, 256, "embedding size"); -DEFINE_int32(SamplesPerChunk, 32000, "samples of one chunk"); +DEFINE_int32(SamplesPerChunk, -1, + "<=0 whole utterance; >0 chunk size in samples (then average)"); +DEFINE_int32(embedding_size, 0, + "ONNX: <=0 infer from model; >0 must match output dim. Other " + "backends: >0."); int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); diff --git a/runtime/core/bin/extract_emb_main.cc b/runtime/core/bin/extract_emb_main.cc index 62bc4a5e..318b23a7 100644 --- a/runtime/core/bin/extract_emb_main.cc +++ b/runtime/core/bin/extract_emb_main.cc @@ -31,8 +31,11 @@ DEFINE_string(result, "", "output embedding file"); DEFINE_string(speaker_model_path, "", "path of speaker model"); DEFINE_int32(fbank_dim, 80, "fbank feature dimension"); DEFINE_int32(sample_rate, 16000, "sample rate"); -DEFINE_int32(embedding_size, 256, "embedding size"); -DEFINE_int32(samples_per_chunk, 32000, "samples of one chunk"); +DEFINE_int32(samples_per_chunk, -1, + "<=0 whole utterance; >0 chunk size in samples (then average)"); +DEFINE_int32(embedding_size, 0, + "ONNX: <=0 infer from model; >0 must match output dim. Other " + "backends: >0."); DEFINE_int32(thread_num, 1, "num of extract_emb thread"); std::ofstream g_result; @@ -53,7 +56,7 @@ void extract_emb(std::pair wav) { int16_t* data = const_cast(wav_reader.data()); int samples = wav_reader.num_sample(); // NOTE(cdliang): memory allocation - std::vector embs(FLAGS_embedding_size, 0); + std::vector embs(embedding_size, 0); int wave_dur = static_cast(static_cast(samples) / wav_reader.sample_rate() * 1000); diff --git a/runtime/core/cmake/gflags.cmake b/runtime/core/cmake/gflags.cmake index 53ae5763..86bed768 100644 --- a/runtime/core/cmake/gflags.cmake +++ b/runtime/core/cmake/gflags.cmake @@ -1,6 +1,6 @@ FetchContent_Declare(gflags - URL https://github.com/gflags/gflags/archive/v2.2.2.zip - URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5 + URL https://github.com/gflags/gflags/archive/refs/tags/v2.3.0.zip + URL_HASH SHA256=ca732b5fd17bf3a27a01a6784b947cbe6323644ecc9e26bbe2117ec43bf7e13b ) FetchContent_MakeAvailable(gflags) include_directories(${gflags_BINARY_DIR}/include) \ No newline at end of file diff --git a/runtime/core/cmake/glog.cmake b/runtime/core/cmake/glog.cmake index 447ab413..c3017354 100644 --- a/runtime/core/cmake/glog.cmake +++ b/runtime/core/cmake/glog.cmake @@ -2,5 +2,16 @@ FetchContent_Declare(glog URL https://github.com/google/glog/archive/v0.4.0.zip URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc ) -FetchContent_MakeAvailable(glog) -include_directories(${glog_SOURCE_DIR}/src ${glog_BINARY_DIR}) \ No newline at end of file +FetchContent_GetProperties(glog) +if(NOT glog_POPULATED) + FetchContent_Populate(glog) + file(READ ${glog_SOURCE_DIR}/CMakeLists.txt _glog_cmake) + # glog 0.4.0 uses cmake_minimum_required(VERSION 3.0); CMake 4+ rejects <3.5. + string(REGEX REPLACE + "cmake_minimum_required[ ]*\\([ ]*VERSION[ ]+[^)]+\\)" + "cmake_minimum_required(VERSION 3.10)" + _glog_cmake "${_glog_cmake}") + file(WRITE ${glog_SOURCE_DIR}/CMakeLists.txt "${_glog_cmake}") + add_subdirectory(${glog_SOURCE_DIR} ${glog_BINARY_DIR}) +endif() +include_directories(${glog_SOURCE_DIR}/src ${glog_BINARY_DIR}) diff --git a/runtime/core/cmake/onnx.cmake b/runtime/core/cmake/onnx.cmake index ec2ebac7..ba39836b 100644 --- a/runtime/core/cmake/onnx.cmake +++ b/runtime/core/cmake/onnx.cmake @@ -1,28 +1,28 @@ if(ONNX) - set(ONNX_VERSION "1.12.0") + set(ONNX_VERSION "1.16.1") if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows") set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-win-x64-${ONNX_VERSION}.zip") - set(URL_HASH "SHA256=8b5d61204989350b7904ac277f5fbccd3e6736ddbb6ec001e412723d71c9c176") + set(URL_HASH "SHA256=05a972384c73c05bce51ffd3e15b1e78325ea9fa652573113159b5cac547ecce") elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux") if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-aarch64-${ONNX_VERSION}.tgz") - set(URL_HASH "SHA256=5820d9f343df73c63b6b2b174a1ff62575032e171c9564bcf92060f46827d0ac") + set(URL_HASH "SHA256=f10851b62eb44f9e811134737e7c6edd15733d2c1549cb6ce403808e9c047385") else() if(GPU) set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-x64-gpu-${ONNX_VERSION}.tgz") - set(URL_HASH "SHA256=bc2e615314df0a871c560b7af6d4ce5896f351d23cad476562d2715208c9c7f7") + set(URL_HASH "SHA256=474d5d74b588d54aa3e167f38acc9b1b8d20c292d0db92299bdc33a81eb4492d") else() set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-x64-${ONNX_VERSION}.tgz") - set(URL_HASH "SHA256=5d503ce8540358b59be26c675e42081be14a3e833a5301926f555451046929c5") + set(URL_HASH "SHA256=53a0f03f71587ed602e99e82773132fc634b74c2d227316fbfd4bf67181e72ed") endif() endif() elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin") if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64") set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-osx-arm64-${ONNX_VERSION}.tgz") - set(URL_HASH "SHA256=23117b6f5d7324d4a7c51184e5f808dd952aec411a6b99a1b6fd1011de06e300") + set(URL_HASH "SHA256=56ca6b8de3a220ea606c2067ba65d11dfa6e4f722e01ac7dc75f7152b81445e0") else() set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-osx-x86_64-${ONNX_VERSION}.tgz") - set(URL_HASH "SHA256=09b17f712f8c6f19bb63da35d508815b443cbb473e16c6192abfaa297c02f600") + set(URL_HASH "SHA256=0b8ae24401a8f75e1c4f75257d4eaeb1b6d44055e027df4aa4a84e67e0f9b9e3") endif() else() message(FATAL_ERROR "Unsupported CMake System Name '${CMAKE_SYSTEM_NAME}' (expected 'Windows', 'Linux' or 'Darwin')") diff --git a/runtime/core/speaker/onnx_speaker_model.cc b/runtime/core/speaker/onnx_speaker_model.cc index 3e7c6eb0..00bfeba2 100644 --- a/runtime/core/speaker/onnx_speaker_model.cc +++ b/runtime/core/speaker/onnx_speaker_model.cc @@ -14,6 +14,8 @@ #ifdef USE_ONNX +#include +#include #include #include "glog/logging.h" @@ -22,6 +24,33 @@ namespace wespeaker { +namespace { + +int InferEmbeddingSizeFromOutputShape(const std::vector& shape) { + if (shape.empty()) { + return -1; + } + int64_t prod = 1; + int num_positive = 0; + for (int64_t d : shape) { + if (d > 0) { + prod *= d; + ++num_positive; + } + } + if (num_positive == static_cast(shape.size())) { + return static_cast(prod); + } + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + if (shape[i] > 0) { + return static_cast(shape[i]); + } + } + return -1; +} + +} // namespace + Ort::Env OnnxSpeakerModel::env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "OnnxModel"); Ort::SessionOptions OnnxSpeakerModel::session_options_ = Ort::SessionOptions(); @@ -48,24 +77,67 @@ OnnxSpeakerModel::OnnxSpeakerModel(const std::string& model_path) { speaker_session_ = std::make_shared(env_, model_path.c_str(), session_options_); #endif - // 2. Model info + // 2. Model info: ORT 1.13+ removed GetInputName; only GetInputNameAllocated + // remains. Ort::AllocatorWithDefaultOptions allocator; // 2.1. input info - int num_nodes = speaker_session_->GetInputCount(); + int num_nodes = static_cast(speaker_session_->GetInputCount()); // NOTE(cdliang): for speaker model, num_nodes is 1. CHECK_EQ(num_nodes, 1); - input_names_.resize(num_nodes); - char* name = speaker_session_->GetInputName(0, allocator); - input_names_[0] = name; - LOG(INFO) << "Ouput name: " << name; +#if ORT_API_VERSION >= 13 + { + auto name_ptr = speaker_session_->GetInputNameAllocated(0, allocator); + input_name_strs_.emplace_back(name_ptr.get()); + input_names_.push_back(input_name_strs_.back().c_str()); + } +#else + { + char* name = speaker_session_->GetInputName(0, allocator); + input_name_strs_.emplace_back(name != nullptr ? name : ""); + input_names_.push_back(input_name_strs_.back().c_str()); + } +#endif + LOG(INFO) << "Input name: " << input_name_strs_[0]; // 2.2. output info - num_nodes = speaker_session_->GetOutputCount(); + num_nodes = static_cast(speaker_session_->GetOutputCount()); CHECK_EQ(num_nodes, 1); - output_names_.resize(num_nodes); - name = speaker_session_->GetOutputName(0, allocator); - output_names_[0] = name; - LOG(INFO) << "Output name: " << name; +#if ORT_API_VERSION >= 13 + { + auto name_ptr = speaker_session_->GetOutputNameAllocated(0, allocator); + output_name_strs_.emplace_back(name_ptr.get()); + output_names_.push_back(output_name_strs_.back().c_str()); + } +#else + { + char* name = speaker_session_->GetOutputName(0, allocator); + output_name_strs_.emplace_back(name != nullptr ? name : ""); + output_names_.push_back(output_name_strs_.back().c_str()); + } +#endif + LOG(INFO) << "Output name: " << output_name_strs_[0]; + + Ort::TypeInfo output_type = speaker_session_->GetOutputTypeInfo(0); + auto tensor_shape = output_type.GetTensorTypeAndShapeInfo(); + std::vector out_shape = tensor_shape.GetShape(); + std::ostringstream oss; + for (size_t i = 0; i < out_shape.size(); ++i) { + if (i) oss << ","; + oss << out_shape[i]; + } + LOG(INFO) << "ONNX output shape: [" << oss.str() << "]"; + + int64_t elem_count = tensor_shape.GetElementCount(); + if (elem_count > 0) { + embedding_size_ = static_cast(elem_count); + } else { + embedding_size_ = InferEmbeddingSizeFromOutputShape(out_shape); + } + CHECK_GT(embedding_size_, 0) + << "Cannot infer embedding size from ONNX output (shape may be fully " + "dynamic); try exporting with static output shape or set embedding " + "size explicitly."; + LOG(INFO) << "Inferred embedding size from ONNX: " << embedding_size_; } void OnnxSpeakerModel::ExtractEmbedding( diff --git a/runtime/core/speaker/onnx_speaker_model.h b/runtime/core/speaker/onnx_speaker_model.h index 7ba9565b..2415d846 100644 --- a/runtime/core/speaker/onnx_speaker_model.h +++ b/runtime/core/speaker/onnx_speaker_model.h @@ -33,18 +33,25 @@ class OnnxSpeakerModel : public SpeakerModel { #ifdef USE_GPU static void SetGpuDeviceId(int gpu_id = 0); #endif - public: + explicit OnnxSpeakerModel(const std::string& model_path); void ExtractEmbedding(const std::vector>& feats, std::vector* embed) override; + /** Embedding length from ONNX output shape (matches Run() output element + * count). */ + int EmbeddingSize() const { return embedding_size_; } + private: // session static Ort::Env env_; static Ort::SessionOptions session_options_; std::shared_ptr speaker_session_ = nullptr; - // node names + // Name strings must outlive the session; Run() uses c_str() from + // input_names_/output_names_. + std::vector input_name_strs_; + std::vector output_name_strs_; std::vector input_names_; std::vector output_names_; int embedding_size_ = 0; diff --git a/runtime/core/speaker/speaker_engine.cc b/runtime/core/speaker/speaker_engine.cc index cd669f14..e43318e6 100644 --- a/runtime/core/speaker/speaker_engine.cc +++ b/runtime/core/speaker/speaker_engine.cc @@ -33,8 +33,6 @@ SpeakerEngine::SpeakerEngine(const std::string& model_path, const int feat_dim, // NOTE(cdliang): default num_threads = 1 const int kNumGemmThreads = 1; LOG(INFO) << "Reading model " << model_path; - embedding_size_ = embedding_size; - LOG(INFO) << "Embedding size: " << embedding_size_; per_chunk_samples_ = SamplesPerChunk; LOG(INFO) << "per_chunk_samples: " << per_chunk_samples_; sample_rate_ = sample_rate; @@ -51,10 +49,29 @@ SpeakerEngine::SpeakerEngine(const std::string& model_path, const int feat_dim, OnnxSpeakerModel::SetGpuDeviceId(0); #endif model_ = std::make_shared(model_path); + { + auto onnx_model = std::static_pointer_cast(model_); + const int inferred = onnx_model->EmbeddingSize(); + if (embedding_size <= 0) { + embedding_size_ = inferred; + } else { + CHECK_EQ(inferred, embedding_size) + << "ONNX output embedding dim " << inferred + << " != provided embedding_size " << embedding_size; + embedding_size_ = embedding_size; + } + } + LOG(INFO) << "Embedding size: " << embedding_size_; #elif USE_MNN + CHECK_GT(embedding_size, 0) << "embedding_size must be > 0 for MNN backend"; + embedding_size_ = embedding_size; model_ = std::make_shared(model_path, kNumGemmThreads); + LOG(INFO) << "Embedding size: " << embedding_size_; #elif USE_BPU + CHECK_GT(embedding_size, 0) << "embedding_size must be > 0 for BPU backend"; + embedding_size_ = embedding_size; model_ = std::make_shared(model_path); + LOG(INFO) << "Embedding size: " << embedding_size_; #endif } @@ -123,8 +140,8 @@ void SpeakerEngine::ExtractFeature( chunk_feat.begin() + (num_chunk_frames_ - chunk_feat.size())); } else { chunk_feat.insert(chunk_feat.end(), (*chunks_feat)[0].begin(), - (*chunks_feat)[0].begin() + (num_chunk_frames_ - - chunk_feat.size())); + (*chunks_feat)[0].begin() + + (num_chunk_frames_ - chunk_feat.size())); } CHECK_EQ(chunk_feat.size(), num_chunk_frames_); chunks_feat->emplace_back(chunk_feat); diff --git a/runtime/core/speaker/speaker_engine.h b/runtime/core/speaker/speaker_engine.h index 219bbb3f..5712218f 100644 --- a/runtime/core/speaker/speaker_engine.h +++ b/runtime/core/speaker/speaker_engine.h @@ -48,7 +48,7 @@ class SpeakerEngine { std::shared_ptr feature_config_ = nullptr; std::shared_ptr feature_pipeline_ = nullptr; int embedding_size_ = 0; - int per_chunk_samples_ = 32000; + int per_chunk_samples_ = 0; // <=0: whole utterance (default) int sample_rate_ = 16000; };