diff --git a/Cargo.lock b/Cargo.lock index 3d46e6c7e..f229644e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -264,6 +264,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "ansi_colours" version = "1.2.3" @@ -666,7 +672,7 @@ dependencies = [ "enumflags2", "futures-channel", "futures-util", - "rand 0.9.1", + "rand 0.9.2", "raw-window-handle 0.6.2", "serde", "serde_repr", @@ -713,6 +719,29 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "assert2" +version = "0.3.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d6c710e60d14b07d8f42d0e702b16120865eea39edb751e75cd6bf401d18f14" +dependencies = [ + "assert2-macros", + "diff", + "yansi", +] + +[[package]] +name = "assert2-macros" +version = "0.3.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9008cbbba9e1d655538870b91fd93814bd82e6968f27788fc734375120ac6f57" +dependencies = [ + "proc-macro2", + "quote", + "rustc_version", + "syn 2.0.101", +] + [[package]] name = "assert_matches" version = "1.5.0" @@ -1179,14 +1208,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", - "axum-core", + "axum-core 0.4.5", "bytes", "futures-util", "http 1.3.1", "http-body 1.0.1", "http-body-util", "itoa", - "matchit", + "matchit 0.7.3", "memchr", "mime", "percent-encoding", @@ -1199,6 +1228,40 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" +dependencies = [ + "axum-core 0.5.2", + "bytes", + "form_urlencoded", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "itoa", + "matchit 0.8.4", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower 0.5.2", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-core" version = "0.4.5" @@ -1219,6 +1282,26 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-core" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6" +dependencies = [ + "bytes", + "futures-core", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "az" version = "1.2.1" @@ -1240,6 +1323,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "base" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a956d500c2380c818e09d3d7c79ba4a1d7fc6354464f1fceaa5705483a29930" + [[package]] name = "base64" version = "0.13.1" @@ -1640,7 +1729,7 @@ dependencies = [ "metal 0.27.0", "num-traits", "num_cpus", - "rand 0.9.1", + "rand 0.9.2", "rand_distr", "rayon", "safetensors", @@ -1732,6 +1821,12 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.2.17" @@ -1867,6 +1962,33 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e4de3bc4ea267985becf712dc6d9eed8b04c953b3fcfb339ebc87acd9804901" +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "cipher" version = "0.4.4" @@ -2303,6 +2425,42 @@ dependencies = [ "cfg-if 1.0.0", ] +[[package]] +name = "criterion" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c76e09c1aae2bc52b3d2f29e13c6572553b30c4aa1b8a49fd70de6412654cb" +dependencies = [ + "anes", + "atty", + "cast", + "ciborium", + "clap 3.2.25", + "criterion-plot", + "itertools 0.10.5", + "lazy_static", + "num-traits", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam" version = "0.8.4" @@ -2809,6 +2967,12 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab03c107fafeb3ee9f5925686dbb7a73bc76e3932abb0d2b365cb64b169cf04c" +[[package]] +name = "diff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" + [[package]] name = "digest" version = "0.10.7" @@ -2947,7 +3111,6 @@ dependencies = [ "dora-coordinator", "dora-core", "dora-daemon", - "dora-download", "dora-message", "dora-node-api-c", "dora-operator-api-c", @@ -3004,6 +3167,7 @@ dependencies = [ name = "dora-core" version = "0.3.12" dependencies = [ + "dora-download", "dora-message", "dunce", "eyre", @@ -3266,6 +3430,34 @@ dependencies = [ "uuid 1.16.0", ] +[[package]] +name = "dora-openai-websocket" +version = "0.1.0" +dependencies = [ + "anyhow", + "assert2", + "axum 0.8.4", + "base", + "base64 0.22.1", + "bytes", + "criterion", + "dora-node-api", + "fastwebsockets", + "futures-concurrency", + "futures-util", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "rand 0.9.2", + "rustls-pemfile 1.0.4", + "serde", + "serde_json", + "tokio", + "tokio-rustls 0.24.1", + "trybuild", + "webpki-roots 0.23.1", +] + [[package]] name = "dora-operator-api" version = "0.3.12" @@ -3362,7 +3554,7 @@ dependencies = [ "ndarray 0.15.6", "pinyin", "pyo3", - "rand 0.9.1", + "rand 0.9.2", "rerun", "tokio", ] @@ -4161,6 +4353,26 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "fastwebsockets" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "305d3ba574508e27190906d11707dad683e0494e6b85eae9b044cb2734a5e422" +dependencies = [ + "base64 0.21.7", + "bytes", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "pin-project", + "rand 0.8.5", + "sha1", + "simdutf8", + "thiserror 1.0.69", + "tokio", + "utf-8", +] + [[package]] name = "fdeflate" version = "0.3.7" @@ -4272,7 +4484,7 @@ dependencies = [ "cudarc", "half", "num-traits", - "rand 0.9.1", + "rand 0.9.2", "rand_distr", ] @@ -5052,7 +5264,7 @@ dependencies = [ "cfg-if 1.0.0", "crunchy", "num-traits", - "rand 0.9.1", + "rand 0.9.2", "rand_distr", ] @@ -5392,7 +5604,7 @@ dependencies = [ "rustls 0.23.25", "rustls-pki-types", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.2", "tower-service", "webpki-roots 0.26.8", ] @@ -6575,6 +6787,12 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "matrixmultiply" version = "0.3.9" @@ -6858,7 +7076,7 @@ dependencies = [ "image", "indexmap 2.8.0", "mistralrs-core", - "rand 0.9.1", + "rand 0.9.2", "reqwest", "serde", "serde_json", @@ -6910,7 +7128,7 @@ dependencies = [ "objc", "once_cell", "radix_trie", - "rand 0.9.1", + "rand 0.9.2", "rand_isaac", "rayon", "regex", @@ -6931,7 +7149,7 @@ dependencies = [ "tokio", "tokio-rayon", "toktrie_hf_tokenizers", - "toml", + "toml 0.8.20", "tqdm", "tracing", "tracing-subscriber", @@ -7893,6 +8111,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "openssl-probe" version = "0.1.6" @@ -8081,7 +8305,7 @@ dependencies = [ "glob", "opentelemetry 0.29.1", "percent-encoding", - "rand 0.9.1", + "rand 0.9.2", "serde_json", "thiserror 2.0.12", "tokio", @@ -8531,6 +8755,34 @@ dependencies = [ "time", ] +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "ply-rs" version = "0.1.3" @@ -9161,7 +9413,7 @@ checksum = "b820744eb4dc9b57a3398183639c511b5a26d2ed702cedd3febaa1393caa22cc" dependencies = [ "bytes", "getrandom 0.3.2", - "rand 0.9.1", + "rand 0.9.2", "ring 0.17.14", "rustc-hash 2.1.1", "rustls 0.23.25", @@ -9261,9 +9513,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", @@ -9314,7 +9566,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" dependencies = [ "num-traits", - "rand 0.9.1", + "rand 0.9.2", ] [[package]] @@ -9370,7 +9622,7 @@ dependencies = [ "simd_helpers", "system-deps", "thiserror 1.0.69", - "toml", + "toml 0.8.20", "v_frame", "y4m", ] @@ -10508,7 +10760,7 @@ dependencies = [ "serde", "syn 2.0.101", "tempfile", - "toml", + "toml 0.8.20", "unindent", "xshell", ] @@ -11260,7 +11512,7 @@ dependencies = [ "sync_wrapper", "system-configuration", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.2", "tokio-util", "tower 0.5.2", "tower-service", @@ -11660,7 +11912,7 @@ dependencies = [ "num-derive", "num-traits", "paste", - "rand 0.9.1", + "rand 0.9.2", "serde", "serde_repr", "socket2 0.5.8", @@ -11731,6 +11983,18 @@ dependencies = [ "webpki", ] +[[package]] +name = "rustls" +version = "0.21.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" +dependencies = [ + "log", + "ring 0.17.14", + "rustls-webpki 0.101.7", + "sct", +] + [[package]] name = "rustls" version = "0.23.25" @@ -11824,6 +12088,26 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" +[[package]] +name = "rustls-webpki" +version = "0.100.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6a5fc258f1c1276dfe3016516945546e2d5383911efc0fc4f1cdc5df3a4ae3" +dependencies = [ + "ring 0.16.20", + "untrusted 0.7.1", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring 0.17.14", + "untrusted 0.9.0", +] + [[package]] name = "rustls-webpki" version = "0.102.8" @@ -12228,9 +12512,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "30b9eff21ebe718216c6ec64e1d9ac57087aad11efc64e32002bce4a0d4c03d3" dependencies = [ "indexmap 2.8.0", "itoa", @@ -12239,6 +12523,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_plain" version = "1.0.2" @@ -12268,6 +12562,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40734c41988f7306bb04f0ecf60ec0f3f1caa34290e4e8ea471dcd3346483b83" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -13221,7 +13524,7 @@ dependencies = [ "cfg-expr", "heck 0.5.0", "pkg-config", - "toml", + "toml 0.8.20", "version-compare", ] @@ -13257,6 +13560,12 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "target-triple" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ac9aa371f599d22256307c24a9d748c041e548cbf599f35d890f9d365361790" + [[package]] name = "tempfile" version = "3.19.1" @@ -13515,6 +13824,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.9.0" @@ -13540,7 +13859,7 @@ dependencies = [ "pin-project-lite", "thiserror 2.0.12", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.2", ] [[package]] @@ -13638,6 +13957,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls 0.21.12", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.2" @@ -13720,11 +14049,26 @@ checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148" dependencies = [ "indexmap 2.8.0", "serde", - "serde_spanned", - "toml_datetime", + "serde_spanned 0.6.8", + "toml_datetime 0.6.8", "toml_edit", ] +[[package]] +name = "toml" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41ae868b5a0f67631c14589f7e250c1ea2c574ee5ba21c6c8dd4b1485705a5a1" +dependencies = [ + "indexmap 2.8.0", + "serde", + "serde_spanned 1.0.0", + "toml_datetime 0.7.0", + "toml_parser", + "toml_writer", + "winnow", +] + [[package]] name = "toml_datetime" version = "0.6.8" @@ -13734,6 +14078,15 @@ dependencies = [ "serde", ] +[[package]] +name = "toml_datetime" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bade1c3e902f58d73d3f294cd7f20391c1cb2fbcb643b73566bc773971df91e3" +dependencies = [ + "serde", +] + [[package]] name = "toml_edit" version = "0.22.24" @@ -13742,11 +14095,26 @@ checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" dependencies = [ "indexmap 2.8.0", "serde", - "serde_spanned", - "toml_datetime", + "serde_spanned 0.6.8", + "toml_datetime 0.6.8", + "winnow", +] + +[[package]] +name = "toml_parser" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97200572db069e74c512a14117b296ba0a80a30123fbbb5aa1f4a348f639ca30" +dependencies = [ "winnow", ] +[[package]] +name = "toml_writer" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc842091f2def52017664b53082ecbbeb5c7731092bad69d2c63050401dfd64" + [[package]] name = "tonic" version = "0.12.3" @@ -13755,7 +14123,7 @@ checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" dependencies = [ "async-stream", "async-trait", - "axum", + "axum 0.7.9", "base64 0.22.1", "bytes", "h2 0.4.8", @@ -13772,7 +14140,7 @@ dependencies = [ "rustls-pemfile 2.2.0", "socket2 0.5.8", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.2", "tokio-stream", "tower 0.4.13", "tower-layer", @@ -13858,6 +14226,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -14005,6 +14374,21 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "trybuild" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65af40ad689f2527aebbd37a0a816aea88ff5f774ceabe99de5be02f2f91dae2" +dependencies = [ + "glob", + "serde", + "serde_derive", + "serde_json", + "target-triple", + "termcolor", + "toml 0.9.4", +] + [[package]] name = "ttf-parser" version = "0.25.1" @@ -14352,7 +14736,7 @@ checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" dependencies = [ "getrandom 0.3.2", "js-sys", - "rand 0.9.1", + "rand 0.9.2", "serde", "uuid-macro-internal", "wasm-bindgen", @@ -14810,6 +15194,15 @@ dependencies = [ "webpki", ] +[[package]] +name = "webpki-roots" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b03058f88386e5ff5310d9111d53f48b17d732b401aeb83a8d5190f2ac459338" +dependencies = [ + "rustls-webpki 0.100.3", +] + [[package]] name = "webpki-roots" version = "0.26.8" @@ -15676,9 +16069,9 @@ dependencies = [ [[package]] name = "winnow" -version = "0.7.4" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e97b544156e9bebe1a0ffbc03484fc1ffe3100cbce3ffb17eac35f7cdd7ab36" +checksum = "f3edebf492c8125044983378ecb5766203ad3b4c2f7a922bd7dd207f6d443e95" dependencies = [ "memchr", ] @@ -15888,6 +16281,12 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "yoke" version = "0.7.5" @@ -16549,7 +16948,7 @@ dependencies = [ "time", "tls-listener", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.2", "tokio-util", "tracing", "webpki-roots 0.26.8", diff --git a/Cargo.toml b/Cargo.toml index 9d705ba23..aebb01581 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ members = [ "node-hub/dora-rerun", "node-hub/terminal-print", "node-hub/openai-proxy-server", + "node-hub/dora-openai-websocket", "node-hub/dora-kit-car", "node-hub/dora-object-to-pose", "node-hub/dora-mistral-rs", diff --git a/apis/python/node/src/lib.rs b/apis/python/node/src/lib.rs index abe5527c6..edf8d8055 100644 --- a/apis/python/node/src/lib.rs +++ b/apis/python/node/src/lib.rs @@ -1,14 +1,10 @@ #![allow(clippy::borrow_deref_ref)] // clippy warns about code generated by #[pymethods] -use std::env::current_dir; -use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; -use dora_download::download_file; use dora_node_api::dora_core::config::NodeId; -use dora_node_api::dora_core::descriptor::source_is_url; use dora_node_api::merged::{MergeExternalSend, MergedEvent}; use dora_node_api::{DataflowId, DoraNode, EventStream}; use dora_operator_api_python::{pydict_to_metadata, DelayedCleanup, NodeCleanupHandle, PyEvent}; @@ -360,22 +356,6 @@ pub fn start_runtime() -> eyre::Result<()> { dora_runtime::main().wrap_err("Dora Runtime raised an error.") } -pub fn resolve_dataflow(dataflow: String) -> eyre::Result { - let dataflow = if source_is_url(&dataflow) { - // try to download the shared library - let target_path = current_dir().context("Could not access the current dir")?; - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .context("tokio runtime failed")?; - rt.block_on(async { download_file(&dataflow, &target_path).await }) - .wrap_err("failed to download dataflow yaml file")? - } else { - PathBuf::from(dataflow) - }; - Ok(dataflow) -} - /// Run a Dataflow /// /// :rtype: None diff --git a/apis/rust/node/src/lib.rs b/apis/rust/node/src/lib.rs index 90f266217..cfee194e3 100644 --- a/apis/rust/node/src/lib.rs +++ b/apis/rust/node/src/lib.rs @@ -95,3 +95,4 @@ pub use node::{arrow_utils, DataSample, DoraNode, ZERO_COPY_THRESHOLD}; mod daemon_connection; mod event_stream; mod node; +pub mod requests; diff --git a/apis/rust/node/src/requests.rs b/apis/rust/node/src/requests.rs new file mode 100644 index 000000000..da53b8509 --- /dev/null +++ b/apis/rust/node/src/requests.rs @@ -0,0 +1,42 @@ +use std::sync::Arc; + +use dora_core::{ + topics::{DORA_DAEMON_LOCAL_LISTEN_PORT_DEFAULT, LOCALHOST}, + uhlc, +}; +use dora_message::{ + common::Timestamped, daemon_to_node::DaemonReply, node_to_daemon::DaemonRequest, DataflowId, +}; +use eyre::{bail, Context}; + +use crate::daemon_connection::DaemonChannel; + +pub fn start_dataflow( + dataflow: String, + name: Option, + uv: bool, +) -> eyre::Result { + let mut channel = init_daemon_channel()?; + let clock = Arc::new(uhlc::HLC::default()); + + let request = DaemonRequest::StartDataflow { dataflow, name, uv }; + let reply = channel + .request(&Timestamped { + inner: request, + timestamp: clock.new_timestamp(), + }) + .wrap_err("failed to trigger dataflow start through daemon")?; + match reply { + DaemonReply::StartDataflowResult(Ok(dataflow_id)) => Ok(dataflow_id), + DaemonReply::StartDataflowResult(Err(err)) => bail!("failed to start dataflow: {err}"), + other => bail!("unexpected StartDataflow reply from daemon: {other:?}"), + } +} + +fn init_daemon_channel() -> eyre::Result { + let daemon_address = (LOCALHOST, DORA_DAEMON_LOCAL_LISTEN_PORT_DEFAULT).into(); + + let channel = + DaemonChannel::new_tcp(daemon_address).context("Could not connect to the daemon")?; + Ok(channel) +} diff --git a/binaries/cli/Cargo.toml b/binaries/cli/Cargo.toml index 797fb9e04..e364dbb46 100644 --- a/binaries/cli/Cargo.toml +++ b/binaries/cli/Cargo.toml @@ -26,7 +26,6 @@ dora-core = { workspace = true } dora-message = { workspace = true } dora-node-api-c = { workspace = true } dora-operator-api-c = { workspace = true } -dora-download = { workspace = true } serde = { version = "1.0.136", features = ["derive"] } serde_yaml = { workspace = true } webbrowser = "0.8.3" diff --git a/binaries/cli/src/command/build/distributed.rs b/binaries/cli/src/command/build/distributed.rs index 1fd1ed914..453e4c4d4 100644 --- a/binaries/cli/src/command/build/distributed.rs +++ b/binaries/cli/src/command/build/distributed.rs @@ -1,5 +1,5 @@ use communication_layer_request_reply::{TcpConnection, TcpRequestReplyConnection}; -use dora_core::descriptor::Descriptor; +use dora_core::{descriptor::Descriptor, session::DataflowSession}; use dora_message::{ cli_to_coordinator::ControlRequest, common::{GitSource, LogMessage}, @@ -13,7 +13,7 @@ use std::{ net::{SocketAddr, TcpStream}, }; -use crate::{output::print_log_message, session::DataflowSession}; +use crate::output::print_log_message; pub fn build_distributed_dataflow( session: &mut TcpRequestReplyConnection, diff --git a/binaries/cli/src/command/build/local.rs b/binaries/cli/src/command/build/local.rs index 7f6c25579..ad81bb87f 100644 --- a/binaries/cli/src/command/build/local.rs +++ b/binaries/cli/src/command/build/local.rs @@ -4,12 +4,11 @@ use colored::Colorize; use dora_core::{ build::{BuildInfo, BuildLogger, Builder, GitManager, LogLevelOrStdout, PrevGitSource}, descriptor::{Descriptor, DescriptorExt}, + session::DataflowSession, }; use dora_message::{common::GitSource, id::NodeId}; use eyre::Context; -use crate::session::DataflowSession; - pub fn build_dataflow_locally( dataflow: Descriptor, git_sources: &BTreeMap, diff --git a/binaries/cli/src/command/build/mod.rs b/binaries/cli/src/command/build/mod.rs index 04f16f55f..92a830ec7 100644 --- a/binaries/cli/src/command/build/mod.rs +++ b/binaries/cli/src/command/build/mod.rs @@ -50,6 +50,8 @@ use communication_layer_request_reply::TcpRequestReplyConnection; use dora_core::{ descriptor::{CoreNodeKind, CustomNode, Descriptor, DescriptorExt}, + resolve_dataflow, + session::DataflowSession, topics::{DORA_COORDINATOR_PORT_CONTROL_DEFAULT, LOCALHOST}, }; use dora_message::{descriptor::NodeSource, BuildId}; @@ -57,10 +59,7 @@ use eyre::Context; use std::{collections::BTreeMap, net::IpAddr}; use super::{default_tracing, Executable}; -use crate::{ - common::{connect_to_coordinator, local_working_dir, resolve_dataflow}, - session::DataflowSession, -}; +use crate::common::{connect_to_coordinator, local_working_dir}; use distributed::{build_distributed_dataflow, wait_until_dataflow_built}; use local::build_dataflow_locally; diff --git a/binaries/cli/src/command/coordinator.rs b/binaries/cli/src/command/coordinator.rs index da48d1b33..c6f69c1e7 100644 --- a/binaries/cli/src/command/coordinator.rs +++ b/binaries/cli/src/command/coordinator.rs @@ -9,6 +9,7 @@ use dora_tracing::TracingBuilder; use eyre::Context; use std::net::{IpAddr, SocketAddr}; use tokio::runtime::Builder; +#[cfg(feature = "tracing")] use tracing::level_filters::LevelFilter; #[derive(Debug, clap::Args)] diff --git a/binaries/cli/src/command/daemon.rs b/binaries/cli/src/command/daemon.rs index c4aa6ca72..9f94f9452 100644 --- a/binaries/cli/src/command/daemon.rs +++ b/binaries/cli/src/command/daemon.rs @@ -1,7 +1,8 @@ use super::Executable; -use crate::{common::handle_dataflow_result, session::DataflowSession}; -use dora_core::topics::{ - DORA_COORDINATOR_PORT_DEFAULT, DORA_DAEMON_LOCAL_LISTEN_PORT_DEFAULT, LOCALHOST, +use crate::common::handle_dataflow_result; +use dora_core::{ + session::DataflowSession, + topics::{DORA_COORDINATOR_PORT_DEFAULT, DORA_DAEMON_LOCAL_LISTEN_PORT_DEFAULT, LOCALHOST}, }; use dora_daemon::LogDestination; @@ -14,6 +15,7 @@ use std::{ path::PathBuf, }; use tokio::runtime::Builder; +#[cfg(feature = "tracing")] use tracing::level_filters::LevelFilter; #[derive(Debug, clap::Args)] diff --git a/binaries/cli/src/command/run.rs b/binaries/cli/src/command/run.rs index feb5947c0..55ef3d33c 100644 --- a/binaries/cli/src/command/run.rs +++ b/binaries/cli/src/command/run.rs @@ -6,12 +6,10 @@ //! Use `dora build --local` or manual build commands to build your nodes. use super::Executable; -use crate::{ - common::{handle_dataflow_result, resolve_dataflow}, - output::print_log_message, - session::DataflowSession, -}; +use crate::{common::handle_dataflow_result, output::print_log_message}; +use dora_core::{resolve_dataflow, session::DataflowSession}; use dora_daemon::{flume, Daemon, LogDestination}; +#[cfg(feature = "tracing")] use dora_tracing::TracingBuilder; use eyre::Context; use tokio::runtime::Builder; diff --git a/binaries/cli/src/command/start/mod.rs b/binaries/cli/src/command/start/mod.rs index 077a67b4b..c87e73876 100644 --- a/binaries/cli/src/command/start/mod.rs +++ b/binaries/cli/src/command/start/mod.rs @@ -5,17 +5,20 @@ use super::{default_tracing, Executable}; use crate::{ command::start::attach::attach_dataflow, - common::{connect_to_coordinator, local_working_dir, resolve_dataflow}, + common::{connect_to_coordinator, local_working_dir}, output::print_log_message, - session::DataflowSession, }; use communication_layer_request_reply::{TcpConnection, TcpRequestReplyConnection}; use dora_core::{ descriptor::{Descriptor, DescriptorExt}, + resolve_dataflow, + session::DataflowSession, topics::{DORA_COORDINATOR_PORT_CONTROL_DEFAULT, LOCALHOST}, }; use dora_message::{ - cli_to_coordinator::ControlRequest, common::LogMessage, coordinator_to_cli::ControlRequestReply, + cli_to_coordinator::{ControlRequest, StartRequest}, + common::LogMessage, + coordinator_to_cli::ControlRequestReply, }; use eyre::{bail, Context}; use std::{ @@ -31,28 +34,28 @@ mod attach; pub struct Start { /// Path to the dataflow descriptor file #[clap(value_name = "PATH")] - dataflow: String, + pub dataflow: String, /// Assign a name to the dataflow #[clap(long)] - name: Option, + pub name: Option, /// Address of the dora coordinator #[clap(long, value_name = "IP", default_value_t = LOCALHOST)] - coordinator_addr: IpAddr, + pub coordinator_addr: IpAddr, /// Port number of the coordinator control server #[clap(long, value_name = "PORT", default_value_t = DORA_COORDINATOR_PORT_CONTROL_DEFAULT)] - coordinator_port: u16, + pub coordinator_port: u16, /// Attach to the dataflow and wait for its completion #[clap(long, action)] - attach: bool, + pub attach: bool, /// Run the dataflow in background #[clap(long, action)] - detach: bool, + pub detach: bool, /// Enable hot reloading (Python only) #[clap(long, action)] - hot_reload: bool, + pub hot_reload: bool, // Use UV to run nodes. #[clap(long, action)] - uv: bool, + pub uv: bool, } impl Executable for Start { @@ -125,14 +128,14 @@ fn start_dataflow( let session: &mut TcpRequestReplyConnection = &mut *session; let reply_raw = session .request( - &serde_json::to_vec(&ControlRequest::Start { + &serde_json::to_vec(&ControlRequest::Start(StartRequest { build_id: dataflow_session.build_id, session_id: dataflow_session.session_id, dataflow, name, local_working_dir, uv, - }) + })) .unwrap(), ) .wrap_err("failed to send start dataflow message")?; diff --git a/binaries/cli/src/common.rs b/binaries/cli/src/common.rs index e02a16731..91321a7b7 100644 --- a/binaries/cli/src/common.rs +++ b/binaries/cli/src/common.rs @@ -1,18 +1,15 @@ use crate::formatting::FormatDataflowError; use communication_layer_request_reply::{RequestReplyLayer, TcpLayer, TcpRequestReplyConnection}; -use dora_core::descriptor::{source_is_url, Descriptor}; -use dora_download::download_file; +use dora_core::descriptor::Descriptor; use dora_message::{ cli_to_coordinator::ControlRequest, coordinator_to_cli::{ControlRequestReply, DataflowList, DataflowResult}, }; use eyre::{bail, Context, ContextCompat}; use std::{ - env::current_dir, net::SocketAddr, path::{Path, PathBuf}, }; -use tokio::runtime::Builder; use uuid::Uuid; pub(crate) fn handle_dataflow_result( @@ -56,22 +53,6 @@ pub(crate) fn connect_to_coordinator( TcpLayer::new().connect(coordinator_addr) } -pub(crate) fn resolve_dataflow(dataflow: String) -> eyre::Result { - let dataflow = if source_is_url(&dataflow) { - // try to download the shared library - let target_path = current_dir().context("Could not access the current dir")?; - let rt = Builder::new_current_thread() - .enable_all() - .build() - .context("tokio runtime failed")?; - rt.block_on(async { download_file(&dataflow, &target_path).await }) - .wrap_err("failed to download dataflow yaml file")? - } else { - PathBuf::from(dataflow) - }; - Ok(dataflow) -} - pub(crate) fn local_working_dir( dataflow_path: &Path, dataflow_descriptor: &Descriptor, diff --git a/binaries/cli/src/lib.rs b/binaries/cli/src/lib.rs index 868d7a5a3..f20484e3e 100644 --- a/binaries/cli/src/lib.rs +++ b/binaries/cli/src/lib.rs @@ -9,7 +9,7 @@ mod command; mod common; mod formatting; pub mod output; -pub mod session; + mod template; pub use command::run_func; diff --git a/binaries/coordinator/src/lib.rs b/binaries/coordinator/src/lib.rs index 67ee69a4a..ce94612d1 100644 --- a/binaries/coordinator/src/lib.rs +++ b/binaries/coordinator/src/lib.rs @@ -9,7 +9,7 @@ use dora_core::{ uhlc::{self, HLC}, }; use dora_message::{ - cli_to_coordinator::ControlRequest, + cli_to_coordinator::{ControlRequest, StartRequest}, common::{DaemonId, GitSource}, coordinator_to_cli::{ ControlRequestReply, DataflowIdAndName, DataflowList, DataflowListEntry, DataflowResult, @@ -457,51 +457,15 @@ async fn start_inner( reply_sender.send(Err(eyre!("unknown build id {build_id}"))); } } - ControlRequest::Start { - build_id, - session_id, - dataflow, - name, - local_working_dir, - uv, - } => { - let name = name.or_else(|| names::Generator::default().next()); - - let inner = async { - if let Some(name) = name.as_deref() { - // check that name is unique - if running_dataflows - .values() - .any(|d: &RunningDataflow| d.name.as_deref() == Some(name)) - { - bail!("there is already a running dataflow with name `{name}`"); - } - } - let dataflow = start_dataflow( - build_id, - session_id, - dataflow, - local_working_dir, - name, - &mut daemon_connections, - &clock, - uv, - ) - .await?; - Ok(dataflow) - }; - match inner.await { - Ok(dataflow) => { - let uuid = dataflow.uuid; - running_dataflows.insert(uuid, dataflow); - let _ = reply_sender.send(Ok( - ControlRequestReply::DataflowStartTriggered { uuid }, - )); - } - Err(err) => { - let _ = reply_sender.send(Err(err)); - } - } + ControlRequest::Start(start_request) => { + handle_start_request( + start_request, + &mut running_dataflows, + &mut daemon_connections, + &clock, + reply_sender, + ) + .await } ControlRequest::WaitForSpawn { dataflow_id } => { if let Some(dataflow) = running_dataflows.get_mut(&dataflow_id) { @@ -921,6 +885,58 @@ async fn start_inner( Ok(()) } +async fn handle_start_request( + start_request: StartRequest, + running_dataflows: &mut HashMap, + daemon_connections: &mut DaemonConnections, + clock: &HLC, + reply_sender: oneshot::Sender>, +) { + let StartRequest { + build_id, + session_id, + dataflow, + name, + local_working_dir, + uv, + } = start_request; + + let name = name.or_else(|| names::Generator::default().next()); + let inner = async { + if let Some(name) = name.as_deref() { + // check that name is unique + if running_dataflows + .values() + .any(|d: &RunningDataflow| d.name.as_deref() == Some(name)) + { + bail!("there is already a running dataflow with name `{name}`"); + } + } + let dataflow = start_dataflow( + build_id, + session_id, + dataflow, + local_working_dir, + name, + daemon_connections, + clock, + uv, + ) + .await?; + Ok(dataflow) + }; + match inner.await { + Ok(dataflow) => { + let uuid = dataflow.uuid; + running_dataflows.insert(uuid, dataflow); + let _ = reply_sender.send(Ok(ControlRequestReply::DataflowStartTriggered { uuid })); + } + Err(err) => { + let _ = reply_sender.send(Err(err)); + } + } +} + async fn send_log_message(log_subscribers: &mut Vec, message: &LogMessage) { for subscriber in log_subscribers.iter_mut() { let send_result = diff --git a/binaries/coordinator/src/listener.rs b/binaries/coordinator/src/listener.rs index ab7e3b9db..e19217fe6 100644 --- a/binaries/coordinator/src/listener.rs +++ b/binaries/coordinator/src/listener.rs @@ -1,11 +1,17 @@ -use crate::{tcp_utils::tcp_receive, DaemonRequest, DataflowEvent, Event}; +use crate::{ + tcp_utils::{tcp_receive, tcp_send}, + ControlEvent, DaemonRequest, DataflowEvent, Event, +}; use dora_core::uhlc::HLC; -use dora_message::daemon_to_coordinator::{CoordinatorRequest, DaemonEvent, Timestamped}; +use dora_message::{ + cli_to_coordinator::ControlRequest, + daemon_to_coordinator::{CoordinatorRequest, DaemonEvent, Timestamped}, +}; use eyre::Context; use std::{io::ErrorKind, net::SocketAddr, sync::Arc}; use tokio::{ net::{TcpListener, TcpStream}, - sync::mpsc, + sync::{mpsc, oneshot}, }; pub async fn create_listener(bind: SocketAddr) -> eyre::Result { @@ -136,6 +142,21 @@ pub async fn handle_connection( } } }, + CoordinatorRequest::StartDataflow(start_request) => { + let (reply_sender, reply_rx) = oneshot::channel(); + let event = Event::Control(ControlEvent::IncomingRequest { + request: ControlRequest::Start(start_request), + reply_sender, + }); + if events_tx.send(event).await.is_err() { + break; + } + let Ok(reply) = reply_rx.await else { break }; + let message = serde_json::to_vec(&reply.map_err(|e| format!("{e:?}"))).unwrap(); + if let Err(err) = tcp_send(&mut connection, &message).await { + tracing::warn!("failed to send StartDataflow reply to node: {}", err.kind()); + } + } }; } } diff --git a/binaries/daemon/src/lib.rs b/binaries/daemon/src/lib.rs index 89642c7b7..bf3619ad8 100644 --- a/binaries/daemon/src/lib.rs +++ b/binaries/daemon/src/lib.rs @@ -8,15 +8,18 @@ use dora_core::{ read_as_descriptor, CoreNodeKind, Descriptor, DescriptorExt, ResolvedNode, RuntimeNode, DYNAMIC_SOURCE, }, + resolve_dataflow, + session::DataflowSession, topics::LOCALHOST, uhlc::{self, HLC}, }; use dora_message::{ + cli_to_coordinator::StartRequest, common::{ DaemonId, DataMessage, DropToken, GitSource, LogLevel, NodeError, NodeErrorCause, NodeExitStatus, }, - coordinator_to_cli::DataflowResult, + coordinator_to_cli::{ControlRequestReply, DataflowResult}, coordinator_to_daemon::{BuildDataflowNodes, DaemonCoordinatorEvent, SpawnDataflowNodes}, daemon_to_coordinator::{ CoordinatorRequest, DaemonCoordinatorReply, DaemonEvent, DataflowDaemonResult, @@ -79,7 +82,7 @@ use dora_tracing::telemetry::serialize_context; #[cfg(feature = "telemetry")] use tracing_opentelemetry::OpenTelemetrySpanExt; -use crate::pending::DataflowStatus; +use crate::{pending::DataflowStatus, socket_stream_utils::socket_stream_receive}; const STDERR_LOG_LINES: usize = 10; @@ -1611,8 +1614,8 @@ impl Daemon { dataflow.check_drop_token(token, &self.clock).await?; } else { tracing::warn!( - "node `{node_id}` is not pending for drop token `{token:?}`" - ); + "node `{node_id}` is not pending for drop token `{token:?}`" + ); } } None => tracing::warn!("unknown drop token `{token:?}`"), @@ -1635,6 +1638,51 @@ impl Daemon { let reply = inner.await.map_err(|err| format!("{err:?}")); let _ = reply_sender.send(DaemonReply::Result(reply)); } + DaemonNodeEvent::StartDataflow { + dataflow, + name, + uv, + reply_sender, + } => { + let inner = async { + let Some(connection) = &mut self.coordinator_connection else { + bail!("no coordinator connection to send StartDataflow"); + }; + let dataflow = + resolve_dataflow(dataflow).context("could not resolve dataflow")?; + let dataflow_descriptor = Descriptor::blocking_read(&dataflow) + .wrap_err("Failed to read yaml dataflow")?; + let dataflow_session = DataflowSession::read_session(&dataflow) + .context("failed to read DataflowSession")?; + + let msg = serde_json::to_vec(&Timestamped { + inner: CoordinatorRequest::StartDataflow(StartRequest { + build_id: dataflow_session.build_id, + session_id: dataflow_session.session_id, + dataflow: dataflow_descriptor, + name, + local_working_dir: None, + uv, + }), + timestamp: self.clock.new_timestamp(), + })?; + socket_stream_send(connection, &msg) + .await + .wrap_err("failed to send StartDataflow message to dora-coordinator")?; + let reply_raw = socket_stream_receive(connection) + .await + .wrap_err("failed to receive StartDataflow reply")?; + let result: Timestamped> = + serde_json::from_slice(&reply_raw) + .wrap_err("failed to deserialize StartDataflow reply")?; + match result.inner.map_err(|e| eyre::eyre!(e))? { + ControlRequestReply::DataflowStartTriggered { uuid } => Ok(uuid), + other => bail!("unexpected StartDataflow reply: {other:?}"), + } + }; + let reply = inner.await.map_err(|err| format!("{err:?}")); + let _ = reply_sender.send(DaemonReply::StartDataflowResult(reply)); + } } Ok(()) } @@ -2723,6 +2771,12 @@ pub enum DaemonNodeEvent { EventStreamDropped { reply_sender: oneshot::Sender, }, + StartDataflow { + dataflow: String, + name: Option, + uv: bool, + reply_sender: oneshot::Sender, + }, } #[derive(Debug)] diff --git a/binaries/daemon/src/node_communication/mod.rs b/binaries/daemon/src/node_communication/mod.rs index 8ed9af0b6..96a84e256 100644 --- a/binaries/daemon/src/node_communication/mod.rs +++ b/binaries/daemon/src/node_communication/mod.rs @@ -457,6 +457,20 @@ impl Listener { ) .await?; } + DaemonRequest::StartDataflow { dataflow, name, uv } => { + let (reply_sender, reply) = oneshot::channel(); + self.process_daemon_event( + DaemonNodeEvent::StartDataflow { + dataflow, + name, + uv, + reply_sender, + }, + Some(reply), + connection, + ) + .await? + } } Ok(()) } diff --git a/examples/multiple-daemons/run.rs b/examples/multiple-daemons/run.rs index cb558af38..1cf7cd998 100644 --- a/examples/multiple-daemons/run.rs +++ b/examples/multiple-daemons/run.rs @@ -1,11 +1,11 @@ -use dora_cli::session::DataflowSession; use dora_coordinator::{ControlEvent, Event}; use dora_core::{ descriptor::{read_as_descriptor, DescriptorExt}, + session::DataflowSession, topics::{DORA_COORDINATOR_PORT_CONTROL_DEFAULT, DORA_COORDINATOR_PORT_DEFAULT}, }; use dora_message::{ - cli_to_coordinator::ControlRequest, + cli_to_coordinator::{ControlRequest, StartRequest}, common::DaemonId, coordinator_to_cli::{ControlRequestReply, DataflowIdAndName}, }; @@ -147,14 +147,14 @@ async fn start_dataflow( let (reply_sender, reply) = oneshot::channel(); coordinator_events_tx .send(Event::Control(ControlEvent::IncomingRequest { - request: ControlRequest::Start { + request: ControlRequest::Start(StartRequest { build_id: dataflow_session.build_id, session_id: dataflow_session.session_id, dataflow: dataflow_descriptor, local_working_dir: Some(working_dir), name: None, uv: false, - }, + }), reply_sender, })) .await?; diff --git a/examples/openai-realtime/README.md b/examples/openai-realtime/README.md new file mode 100644 index 000000000..e4d872064 --- /dev/null +++ b/examples/openai-realtime/README.md @@ -0,0 +1,78 @@ +# Dora-OpenAI-Realtime (ROOT Repo) + +## Front End + +### Build Client + +```bash +git clone git@github.com:haixuanTao/moly.git --branch dora-backend-support +cd moly +cargo build --release +``` + +### Run Client + +```bash +cd moly +cargo run -r +``` + +## Server + +### Build server + +```bash +uv venv --seed -p 3.11 +source .venv/bin/activate +uv pip install dora-rs-cli dora-rs +dora build whisper-template-metal.yml --uv ## very long process +``` + +### Run server + +```bash +source .venv/bin/activate +dora up +cargo run --release -p dora-openai-websocket +``` + +## On finish + +```bash +dora destroy +``` + +## GUI + +- Go to MolyServer Tab +- Add a custom Provider +- In API Host, use: + + - Name: dora-websocket + - API Host: ws://127.0.0.1:8123 + - Type: OpenAI Realtime + +- Then go to Chat Tab +- New Chat +- ( Make sure the servver is running with: `cargo run --release -p dora-openai-websocket`) +- On bottom right, click on 🎧 icon. + > If nothing happen is that the server is not found. +- Click on start +- Wait for the first AI greeting +- Start speaking! +- You should get AI response! + +### WIP: Moyoyo + +## {Recommended} Install git-lfs + +```bash +brew install git-lfs # MacOS +``` + +## Clone Moxin Voice Chat + +```bash +git lfs install +git clone https://github.com/moxin-org/moxin-voice-chat.git +``` diff --git a/examples/openai-realtime/whisper-template-metal.yml b/examples/openai-realtime/whisper-template-metal.yml new file mode 100644 index 000000000..6a438a615 --- /dev/null +++ b/examples/openai-realtime/whisper-template-metal.yml @@ -0,0 +1,54 @@ +nodes: + - id: NODE_ID + path: dynamic + inputs: + audio: tts/audio + text: stt/text + speech_started: stt/speech_started + outputs: + - audio + - text + + - id: dora-vad + build: pip install -e ../../node-hub/dora-vad + path: dora-vad + inputs: + audio: + source: NODE_ID/audio + queue_size: 1000000 + outputs: + - audio + env: + MIN_SPEECH_DURATION_MS: 1000 + MIN_SILENCE_DURATION_MS: 1000 + THRESHOLD: 0.5 + + - id: stt + build: pip install -e ../../node-hub/dora-distil-whisper + path: dora-distil-whisper + inputs: + audio: dora-vad/audio + outputs: + - text + - word + - speech_started + + - id: llm + build: pip install -e ../../node-hub/dora-qwen + path: dora-qwen + inputs: + text: stt/text + text_to_audio: NODE_ID/text + outputs: + - text + env: + MODEL_NAME_OR_PATH: LLM_ID + MODEL_FILE_PATTERN: "*[qQ]6_[kK].[gG][gG][uU][fF]" + + - id: tts + build: pip install -e ../../node-hub/dora-kokoro-tts + path: dora-kokoro-tts + inputs: + text: llm/text + outputs: + - audio diff --git a/libraries/core/Cargo.toml b/libraries/core/Cargo.toml index fca59d4cf..6ae5e073e 100644 --- a/libraries/core/Cargo.toml +++ b/libraries/core/Cargo.toml @@ -32,3 +32,4 @@ itertools = "0.14" url = { version = "2.5.4", optional = true } git2 = { workspace = true, optional = true } fs_extra = "1.3.0" +dora-download = { workspace = true } diff --git a/libraries/core/src/lib.rs b/libraries/core/src/lib.rs index c45ec6137..d6be8feca 100644 --- a/libraries/core/src/lib.rs +++ b/libraries/core/src/lib.rs @@ -1,16 +1,24 @@ +use dora_download::download_file; use eyre::{bail, eyre, Context}; use std::{ - env::consts::{DLL_PREFIX, DLL_SUFFIX}, + env::{ + consts::{DLL_PREFIX, DLL_SUFFIX}, + current_dir, + }, ffi::OsStr, - path::Path, + path::{Path, PathBuf}, }; pub use dora_message::{config, uhlc}; +use crate::descriptor::source_is_url; + #[cfg(feature = "build")] pub mod build; pub mod descriptor; pub mod metadata; +#[cfg(feature = "build")] +pub mod session; pub mod topics; pub fn adjust_shared_library_path(path: &Path) -> Result { @@ -80,3 +88,19 @@ where }; Ok(()) } + +pub fn resolve_dataflow(dataflow: String) -> eyre::Result { + let dataflow = if source_is_url(&dataflow) { + // try to download the shared library + let target_path = current_dir().context("Could not access the current dir")?; + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .context("tokio runtime failed")?; + rt.block_on(async { download_file(&dataflow, &target_path).await }) + .wrap_err("failed to download dataflow yaml file")? + } else { + PathBuf::from(dataflow) + }; + Ok(dataflow) +} diff --git a/binaries/cli/src/session.rs b/libraries/core/src/session.rs similarity index 99% rename from binaries/cli/src/session.rs rename to libraries/core/src/session.rs index 9a8ac5b8f..03ab580eb 100644 --- a/binaries/cli/src/session.rs +++ b/libraries/core/src/session.rs @@ -3,7 +3,7 @@ use std::{ path::{Path, PathBuf}, }; -use dora_core::build::BuildInfo; +use crate::build::BuildInfo; use dora_message::{common::GitSource, id::NodeId, BuildId, SessionId}; use eyre::{Context, ContextCompat}; diff --git a/libraries/message/src/cli_to_coordinator.rs b/libraries/message/src/cli_to_coordinator.rs index bf3d3a039..aaa74f6c1 100644 --- a/libraries/message/src/cli_to_coordinator.rs +++ b/libraries/message/src/cli_to_coordinator.rs @@ -29,21 +29,7 @@ pub enum ControlRequest { WaitForBuild { build_id: BuildId, }, - Start { - build_id: Option, - session_id: SessionId, - dataflow: Descriptor, - name: Option, - /// Allows overwriting the base working dir when CLI and daemon are - /// running on the same machine. - /// - /// Must not be used for multi-machine dataflows. - /// - /// Note that nodes with git sources still use a subdirectory of - /// the base working dir. - local_working_dir: Option, - uv: bool, - }, + Start(StartRequest), WaitForSpawn { dataflow_id: Uuid, }, @@ -82,3 +68,20 @@ pub enum ControlRequest { }, CliAndDefaultDaemonOnSameMachine, } + +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] +pub struct StartRequest { + pub build_id: Option, + pub session_id: SessionId, + pub dataflow: Descriptor, + pub name: Option, + /// Allows overwriting the base working dir when CLI and daemon are + /// running on the same machine. + /// + /// Must not be used for multi-machine dataflows. + /// + /// Note that nodes with git sources still use a subdirectory of + /// the base working dir. + pub local_working_dir: Option, + pub uv: bool, +} diff --git a/libraries/message/src/daemon_to_coordinator.rs b/libraries/message/src/daemon_to_coordinator.rs index ccafb0a5f..00bf0005c 100644 --- a/libraries/message/src/daemon_to_coordinator.rs +++ b/libraries/message/src/daemon_to_coordinator.rs @@ -4,7 +4,8 @@ pub use crate::common::{ DataMessage, LogLevel, LogMessage, NodeError, NodeErrorCause, NodeExitStatus, Timestamped, }; use crate::{ - common::DaemonId, current_crate_version, id::NodeId, versions_compatible, BuildId, DataflowId, + cli_to_coordinator::StartRequest, common::DaemonId, current_crate_version, id::NodeId, + versions_compatible, BuildId, DataflowId, }; #[derive(Debug, serde::Serialize, serde::Deserialize)] @@ -14,6 +15,7 @@ pub enum CoordinatorRequest { daemon_id: DaemonId, event: DaemonEvent, }, + StartDataflow(StartRequest), } #[derive(Debug, serde::Serialize, serde::Deserialize)] diff --git a/libraries/message/src/daemon_to_node.rs b/libraries/message/src/daemon_to_node.rs index 7d520f427..0d837a966 100644 --- a/libraries/message/src/daemon_to_node.rs +++ b/libraries/message/src/daemon_to_node.rs @@ -54,6 +54,7 @@ pub enum DaemonReply { NextDropEvents(Vec>), NodeConfig { result: Result }, Empty, + StartDataflowResult(Result), } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] diff --git a/libraries/message/src/node_to_daemon.rs b/libraries/message/src/node_to_daemon.rs index bb5a0850c..79f887bbe 100644 --- a/libraries/message/src/node_to_daemon.rs +++ b/libraries/message/src/node_to_daemon.rs @@ -1,3 +1,5 @@ +use std::path::PathBuf; + pub use crate::common::{ DataMessage, DropToken, LogLevel, LogMessage, SharedMemoryId, Timestamped, }; @@ -33,6 +35,11 @@ pub enum DaemonRequest { NodeConfig { node_id: NodeId, }, + StartDataflow { + dataflow: String, + name: Option, + uv: bool, + }, } impl DaemonRequest { @@ -41,7 +48,8 @@ impl DaemonRequest { match self { DaemonRequest::SendMessage { .. } | DaemonRequest::NodeConfig { .. } - | DaemonRequest::ReportDropTokens { .. } => false, + | DaemonRequest::ReportDropTokens { .. } + | DaemonRequest::StartDataflow { .. } => false, DaemonRequest::Register(NodeRegisterRequest { .. }) | DaemonRequest::Subscribe | DaemonRequest::CloseOutputs(_) @@ -56,7 +64,7 @@ impl DaemonRequest { pub fn expects_tcp_json_reply(&self) -> bool { #[allow(clippy::match_like_matches_macro)] match self { - DaemonRequest::NodeConfig { .. } => true, + DaemonRequest::NodeConfig { .. } | Self::StartDataflow { .. } => true, DaemonRequest::Register(NodeRegisterRequest { .. }) | DaemonRequest::Subscribe | DaemonRequest::CloseOutputs(_) diff --git a/node-hub/dora-distil-whisper/dora_distil_whisper/main.py b/node-hub/dora-distil-whisper/dora_distil_whisper/main.py index 007b3b43d..3c31e291c 100644 --- a/node-hub/dora-distil-whisper/dora_distil_whisper/main.py +++ b/node-hub/dora-distil-whisper/dora_distil_whisper/main.py @@ -6,6 +6,7 @@ import time from pathlib import Path +import numpy as np import pyarrow as pa import torch from dora import Node @@ -47,35 +48,6 @@ def normalize(s): text_words = normalized_text.split() noise_words = normalized_noise.split() - # Function to find and remove noise sequence flexibly - def remove_flexible(text_list, noise_list): - i = 0 - while i <= len(text_list) - len(noise_list): - match = True - extra_words = 0 - for j, noise_word in enumerate(noise_list): - if i + j + extra_words >= len(text_list): - match = False - break - # Allow skipping extra words in text_list - while ( - i + j + extra_words < len(text_list) - and text_list[i + j + extra_words] != noise_word - ): - extra_words += 1 - if i + j + extra_words >= len(text_list): - match = False - break - if not match: - break - if match: - # Remove matched part - del text_list[i : i + len(noise_list) + extra_words] - i = max(0, i - len(noise_list)) # Adjust index after removal - else: - i += 1 - return text_list - # Only remove parts of text_noise that are found in text cleaned_words = text_words[:] for noise_word in noise_words: @@ -124,7 +96,30 @@ def load_model(): BAD_SENTENCES = [ "", " so", + " So.", + " So, let's go.", " so so", + " What?", + " We'll see you next time.", + " I'll see you next time.", + " We're going to come back.", + " let's move on.", + " Here we go.", + " my", + " All right. Thank you.", + " That's what we're doing.", + " That's what I wanted to do.", + " I'll be back.", + " Hold this. Hold this.", + " Hold this one. Hold this one.", + " And we'll see you next time.", + " strength.", + " Length.", + " Let's go.", + " Let's do it.", + "You", + "You ", + " You", "字幕", "字幕志愿", "中文字幕", @@ -181,13 +176,29 @@ def cut_repetition(text, min_repeat_length=4, max_repeat_length=50): def main(): """TODO: Add docstring.""" - node = Node() text_noise = "" - noise_timestamp = time.time() # For macos use mlx: if sys.platform != "darwin": pipe = load_model() + else: + import mlx_whisper + result = mlx_whisper.transcribe( + np.array([]), + path_or_hf_repo="mlx-community/whisper-large-v3-turbo", + append_punctuations=".", + language=TARGET_LANGUAGE, + ) + result = mlx_whisper.transcribe( + np.array([]), + path_or_hf_repo="mlx-community/whisper-large-v3-turbo", + append_punctuations=".", + language=TARGET_LANGUAGE, + ) + + node = Node() + noise_timestamp = time.time() + cache_audio = None for event in node: if event["type"] == "INPUT": if "text_noise" in event["id"]: @@ -200,7 +211,12 @@ def main(): ) noise_timestamp = time.time() else: - audio = event["value"].to_numpy() + audio_input = event["value"].to_numpy() + if cache_audio is not None: + audio = np.concatenate([cache_audio, audio_input]) + else: + audio = audio_input + confg = ( {"language": TARGET_LANGUAGE, "task": "translate"} if TRANSLATE @@ -215,6 +231,7 @@ def main(): audio, path_or_hf_repo="mlx-community/whisper-large-v3-turbo", append_punctuations=".", + language=TARGET_LANGUAGE, ) else: @@ -223,6 +240,8 @@ def main(): generate_kwargs=confg, ) if result["text"] in BAD_SENTENCES: + print("Discarded text: ", result["text"]) + # cache_audio = None continue text = cut_repetition(result["text"]) @@ -235,6 +254,29 @@ def main(): if text.strip() == "" or text.strip() == ".": continue - node.send_output( - "text", pa.array([text]), {"language": TARGET_LANGUAGE}, - ) + + if ( + text.endswith(".") + or text.endswith("!") + or text.endswith("?") + or text.endswith('."') + or text.endswith('!"') + or text.endswith('?"') + ) and not text.endswith("..."): + node.send_output( + "text", + pa.array([text]), + ) + node.send_output( + "speech_started", + pa.array([text]), + ) + cache_audio = None + audio = None + print("Text:", text) + elif text.endswith("..."): + print( + "Keeping audio in cache for next text output with punctuation" + ) + print("Discarded text", text) + cache_audio = audio diff --git a/node-hub/dora-kokoro-tts/dora_kokoro_tts/main.py b/node-hub/dora-kokoro-tts/dora_kokoro_tts/main.py index 7762cfca1..d1c28a4f1 100644 --- a/node-hub/dora-kokoro-tts/dora_kokoro_tts/main.py +++ b/node-hub/dora-kokoro-tts/dora_kokoro_tts/main.py @@ -1,4 +1,5 @@ """TODO: Add docstring.""" + import os import re @@ -8,11 +9,12 @@ LANGUAGE = os.getenv("LANGUAGE", "en") + def main(): """TODO: Add docstring.""" if LANGUAGE in ["en", "english"]: pipeline = KPipeline(lang_code="a") - elif LANGUAGE in ["zh","ch","chinese"]: + elif LANGUAGE in ["zh", "ch", "chinese"]: pipeline = KPipeline(lang_code="z") else: print("warning: Defaulting to english speaker as language not found") @@ -22,22 +24,23 @@ def main(): for event in node: if event["type"] == "INPUT": - if event["id"] == "text": - text = event["value"][0].as_py() - if re.findall(r'[\u4e00-\u9fff]+', text): - pipeline = KPipeline(lang_code="z") - elif pipeline.lang_code != "a": - pipeline = KPipeline(lang_code="a") # <= make sure lang_code matches voice - - generator = pipeline( - text, - voice="af_heart", # <= change voice here - speed=1.2, - split_pattern=r"\n+", - ) - for _, (_, _, audio) in enumerate(generator): - audio = audio.numpy() - node.send_output("audio", pa.array(audio), {"sample_rate": 24000}) + text = event["value"][0].as_py() + if re.findall(r"[\u4e00-\u9fff]+", text): + pipeline = KPipeline(lang_code="z") + elif pipeline.lang_code != "a": + pipeline = KPipeline( + lang_code="a" + ) # <= make sure lang_code matches voice + + generator = pipeline( + text, + voice="af_heart", # <= change voice here + speed=1.2, + split_pattern=r"\n+", + ) + for _, (_, _, audio) in enumerate(generator): + audio = audio.numpy() + node.send_output("audio", pa.array(audio), {"sample_rate": 24000}) if __name__ == "__main__": diff --git a/node-hub/dora-kokoro-tts/pyproject.toml b/node-hub/dora-kokoro-tts/pyproject.toml index 14abf6e79..24813a406 100644 --- a/node-hub/dora-kokoro-tts/pyproject.toml +++ b/node-hub/dora-kokoro-tts/pyproject.toml @@ -5,13 +5,14 @@ authors = [{ name = "Your Name", email = "email@email.com" }] description = "dora-kokoro-tts" license = { text = "MIT" } readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.10" dependencies = [ "dora-rs >= 0.3.9", "kokoro>=0.2.2", "soundfile>=0.13.1", "misaki[zh]", + "en-core-web-sm", ] [dependency-groups] @@ -31,3 +32,6 @@ extend-select = [ "N", # Ruff's N rule "I", # Ruff's I rule ] + +[tool.uv.sources] +en-core-web-sm = { url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl" } diff --git a/node-hub/dora-openai-websocket/Cargo.toml b/node-hub/dora-openai-websocket/Cargo.toml new file mode 100644 index 000000000..d106bd3f9 --- /dev/null +++ b/node-hub/dora-openai-websocket/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "dora-openai-websocket" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +dora-node-api = { workspace = true } +tokio = { version = "1.25.0", features = ["full", "macros"] } +tokio-rustls = "0.24.0" +rustls-pemfile = "1.0" +hyper-util = { version = "0.1.0", features = ["tokio"] } +http-body-util = { version = "0.1.0" } +hyper = { version = "1", features = ["http1", "server", "client"] } +assert2 = "0.3.4" +trybuild = "1.0.106" +criterion = "0.4.0" +anyhow = "1.0.71" +webpki-roots = "0.23.0" +bytes = "1.4.0" +axum = "0.8.1" +fastwebsockets = { version = "0.10.0", features = ["upgrade"] } +serde_json = "1.0.141" +serde = "1.0.219" +base = "0.1.0" +base64 = "0.22.1" +rand = "0.9.2" +futures-util = "0.3.31" +futures-concurrency = "7.6.3" diff --git a/node-hub/dora-openai-websocket/src/main.rs b/node-hub/dora-openai-websocket/src/main.rs new file mode 100644 index 000000000..47cf64001 --- /dev/null +++ b/node-hub/dora-openai-websocket/src/main.rs @@ -0,0 +1,498 @@ +// Copyright 2023 Divy Srivastava +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use base64::engine::general_purpose; +use base64::Engine; +use dora_node_api::arrow::array::AsArray; +use dora_node_api::arrow::datatypes::DataType; +use dora_node_api::dora_core::config::DataId; +use dora_node_api::dora_core::config::NodeId; +use dora_node_api::into_vec; +use dora_node_api::requests; +use dora_node_api::DoraNode; +use dora_node_api::IntoArrow; +use dora_node_api::MetadataParameters; +use fastwebsockets::upgrade; +use fastwebsockets::Frame; +use fastwebsockets::OpCode; +use fastwebsockets::Payload; +use fastwebsockets::WebSocketError; +use futures_concurrency::future::Race; +use futures_util::future; +use futures_util::future::Either; +use futures_util::FutureExt; +use http_body_util::Empty; +use hyper::body::Bytes; +use hyper::body::Incoming; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::Request; +use hyper::Response; +use rand::random; +use serde; +use serde::Deserialize; +use serde::Serialize; +use std::collections::HashMap; +use std::fs; +use std::io::{self, Write}; +use tokio::net::TcpListener; +#[derive(Serialize, Deserialize, Debug)] +pub struct ErrorDetails { + pub code: Option, + pub message: String, + pub param: Option, + #[serde(rename = "type")] + pub error_type: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "type")] +pub enum OpenAIRealtimeMessage { + #[serde(rename = "session.update")] + SessionUpdate { session: SessionConfig }, + #[serde(rename = "input_audio_buffer.append")] + InputAudioBufferAppend { + audio: String, // base64 encoded audio + }, + #[serde(rename = "input_audio_buffer.commit")] + InputAudioBufferCommit, + #[serde(rename = "response.create")] + ResponseCreate { response: ResponseConfig }, + #[serde(rename = "conversation.item.create")] + ConversationItemCreate { item: ConversationItem }, + #[serde(rename = "conversation.item.truncate")] + ConversationItemTruncate { + item_id: String, + content_index: u32, + audio_end_ms: u32, + #[serde(skip_serializing_if = "Option::is_none")] + event_id: Option, + }, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SessionConfig { + pub modalities: Vec, + pub instructions: String, + pub voice: String, + pub model: String, + pub input_audio_format: String, + pub output_audio_format: String, + pub input_audio_transcription: Option, + pub turn_detection: Option, + pub tools: Vec, + pub tool_choice: String, + pub temperature: f32, + pub max_response_output_tokens: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct TranscriptionConfig { + pub model: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct TurnDetectionConfig { + #[serde(rename = "type")] + pub detection_type: String, + pub threshold: f32, + pub prefix_padding_ms: u32, + pub silence_duration_ms: u32, + pub interrupt_response: bool, + pub create_response: bool, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ResponseConfig { + pub modalities: Vec, + pub instructions: Option, + pub voice: Option, + pub output_audio_format: Option, + pub tools: Option>, + pub tool_choice: Option, + pub temperature: Option, + pub max_output_tokens: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ConversationItem { + pub id: Option, + #[serde(rename = "type")] + pub item_type: String, + pub status: Option, + pub role: String, + pub content: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "type")] +pub enum ContentPart { + #[serde(rename = "input_text")] + InputText { text: String }, + #[serde(rename = "input_audio")] + InputAudio { + audio: String, + transcript: Option, + }, + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "audio")] + Audio { + audio: String, + transcript: Option, + }, +} + +// Incoming message types from OpenAI +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "type")] +pub enum OpenAIRealtimeResponse { + #[serde(rename = "error")] + Error { error: ErrorDetails }, + #[serde(rename = "session.created")] + SessionCreated { session: serde_json::Value }, + #[serde(rename = "session.updated")] + SessionUpdated { session: serde_json::Value }, + #[serde(rename = "conversation.item.created")] + ConversationItemCreated { item: serde_json::Value }, + #[serde(rename = "conversation.item.truncated")] + ConversationItemTruncated { item: serde_json::Value }, + #[serde(rename = "response.audio.delta")] + ResponseAudioDelta { + response_id: String, + item_id: String, + output_index: u32, + content_index: u32, + delta: String, // base64 encoded audio + }, + #[serde(rename = "response.audio.done")] + ResponseAudioDone { + response_id: String, + item_id: String, + output_index: u32, + content_index: u32, + }, + #[serde(rename = "response.text.delta")] + ResponseTextDelta { + response_id: String, + item_id: String, + output_index: u32, + content_index: u32, + delta: String, + }, + #[serde(rename = "response.audio_transcript.delta")] + ResponseAudioTranscriptDelta { + response_id: String, + item_id: String, + output_index: u32, + content_index: u32, + delta: String, + }, + #[serde(rename = "response.done")] + ResponseDone { response: serde_json::Value }, + #[serde(rename = "input_audio_buffer.speech_started")] + InputAudioBufferSpeechStarted { + audio_start_ms: u32, + item_id: String, + }, + #[serde(rename = "input_audio_buffer.speech_stopped")] + InputAudioBufferSpeechStopped { audio_end_ms: u32, item_id: String }, + #[serde(other)] + Other, +} + +fn convert_pcm16_to_f32(bytes: &[u8]) -> Vec { + let mut samples = Vec::with_capacity(bytes.len() / 2); + + for chunk in bytes.chunks_exact(2) { + let pcm16_sample = i16::from_le_bytes([chunk[0], chunk[1]]); + let f32_sample = pcm16_sample as f32 / 32767.0; + samples.push(f32_sample); + } + + samples +} + +fn convert_f32_to_pcm16(samples: &[f32]) -> Vec { + let mut pcm16_bytes = Vec::with_capacity(samples.len() * 2); + + for &sample in samples { + // Clamp to [-1.0, 1.0] and convert to i16 + let clamped = sample.max(-1.0).min(1.0); + let pcm16_sample = (clamped * 32767.0) as i16; + pcm16_bytes.extend_from_slice(&pcm16_sample.to_le_bytes()); + } + + pcm16_bytes +} + +/// Replaces a placeholder in a file and writes the result to an output file. +/// +/// # Arguments +/// +/// * `input_path` - Path to the input file with placeholder text. +/// * `placeholder` - The placeholder text to search for (e.g., "{{PLACEHOLDER}}"). +/// * `replacement` - The text to replace the placeholder with. +/// * `output_path` - Path to write the modified content. +fn replace_placeholder_in_file( + input_path: &str, + replacement: &HashMap, + output_path: &str, +) -> io::Result<()> { + // Read the file content into a string + let mut content = fs::read_to_string(input_path)?; + + // Replace the placeholder + for (placeholder, replacement) in replacement { + // Ensure the placeholder is wrapped in curly braces + // Replace the placeholder with the replacement text + content = content.replace(placeholder, replacement); + } + + // Write the modified content to the output file + let mut file = fs::File::create(output_path)?; + file.write_all(content.as_bytes())?; + + Ok(()) +} + +async fn handle_client(fut: upgrade::UpgradeFut) -> Result<(), WebSocketError> { + let mut ws = fastwebsockets::FragmentCollector::new(fut.await?); + + let frame = ws.read_frame().await?; + if frame.opcode != OpCode::Text { + return Err(WebSocketError::InvalidConnectionHeader); + } + let data: OpenAIRealtimeMessage = serde_json::from_slice(&frame.payload).unwrap(); + let OpenAIRealtimeMessage::SessionUpdate { session } = data else { + return Err(WebSocketError::InvalidConnectionHeader); + }; + + let input_audio_transcription = session + .input_audio_transcription + .map_or("moyoyo-whisper".to_string(), |t| t.model); + let llm = session.model.clone(); + let id = random::(); + let node_id = format!("server-{id}"); + let dataflow = format!("{input_audio_transcription}-{}.yml", id); + let template = format!("{input_audio_transcription}-template-metal.yml"); + let mut replacements = HashMap::new(); + replacements.insert("NODE_ID".to_string(), node_id.clone()); + replacements.insert("LLM_ID".to_string(), llm); + println!("Filling template: {}", template); + replace_placeholder_in_file(&template, &replacements, &dataflow).unwrap(); + // Copy configuration file but replace the node ID with "server-id" + // Read the configuration file and replace the node ID with "server-id" + requests::start_dataflow(dataflow, Some(node_id.to_string()), true).unwrap(); + let (mut node, mut events) = + DoraNode::init_from_node_id(NodeId::from(node_id.clone())).unwrap(); + let serialized_data = OpenAIRealtimeResponse::SessionCreated { + session: serde_json::Value::Null, + }; + + let payload = + Payload::Bytes(Bytes::from(serde_json::to_string(&serialized_data).unwrap()).into()); + let frame = Frame::text(payload); + ws.write_frame(frame).await?; + loop { + let event_fut = events.recv_async().map(Either::Left); + let frame_fut = ws.read_frame().map(Either::Right); + let event_stream = (event_fut, frame_fut).race(); + let mut finished = false; + let frame = match event_stream.await { + future::Either::Left(Some(ev)) => { + let frame = match ev { + dora_node_api::Event::Input { + id, + metadata: _, + data, + } => { + if data.data_type() == &DataType::Utf8 { + let data = data.as_string::(); + let str = data.value(0); + let serialized_data = + OpenAIRealtimeResponse::ResponseAudioTranscriptDelta { + response_id: "123".to_string(), + item_id: "123".to_string(), + output_index: 123, + content_index: 123, + delta: str.to_string(), + }; + + let frame = Frame::text(Payload::Bytes( + Bytes::from(serde_json::to_string(&serialized_data).unwrap()) + .into(), + )); + frame + } else if id.contains("audio") { + let data: Vec = into_vec(&data).unwrap(); + let data = convert_f32_to_pcm16(&data); + let serialized_data = OpenAIRealtimeResponse::ResponseAudioDelta { + response_id: "123".to_string(), + item_id: "123".to_string(), + output_index: 123, + content_index: 123, + delta: general_purpose::STANDARD.encode(data), + }; + finished = true; + + let frame = Frame::text(Payload::Bytes( + Bytes::from(serde_json::to_string(&serialized_data).unwrap()) + .into(), + )); + frame + } else if id.contains("") { + let serialized_data = + OpenAIRealtimeResponse::InputAudioBufferSpeechStarted { + audio_start_ms: 123, + item_id: "123".to_string(), + }; + finished = true; + + let frame = Frame::text(Payload::Bytes( + Bytes::from(serde_json::to_string(&serialized_data).unwrap()) + .into(), + )); + frame + } else { + unimplemented!() + } + } + dora_node_api::Event::Error(_) => { + // println!("Error in input: {}", s); + continue; + } + _ => break, + }; + Some(frame) + } + future::Either::Left(None) => break, + future::Either::Right(Ok(frame)) => { + match frame.opcode { + OpCode::Close => break, + OpCode::Text | OpCode::Binary => { + let data: OpenAIRealtimeMessage = + serde_json::from_slice(&frame.payload).unwrap(); + + match data { + OpenAIRealtimeMessage::InputAudioBufferAppend { audio } => { + // println!("Received audio data: {}", audio); + let f32_data = audio; + // Decode base64 encoded audio data + let f32_data = f32_data.trim(); + if f32_data.is_empty() { + continue; + } + + if let Ok(f32_data) = general_purpose::STANDARD.decode(f32_data) { + let f32_data = convert_pcm16_to_f32(&f32_data); + // Downsample to 16 kHz from 24 kHz + let f32_data = f32_data + .into_iter() + .enumerate() + .filter(|(i, _)| i % 3 != 0) + .map(|(_, v)| v) + .collect::>(); + + let mut parameter = MetadataParameters::default(); + parameter.insert( + "sample_rate".to_string(), + dora_node_api::Parameter::Integer(16000), + ); + node.send_output( + DataId::from("audio".to_string()), + parameter, + f32_data.into_arrow(), + ) + .unwrap(); + } + } + OpenAIRealtimeMessage::InputAudioBufferCommit => break, + OpenAIRealtimeMessage::ResponseCreate { response } => { + if let Some(text) = response.instructions { + node.send_output( + DataId::from("text".to_string()), + Default::default(), + text.into_arrow(), + ) + .unwrap(); + } + } + _ => {} + } + } + _ => break, + } + None + } + future::Either::Right(Err(_)) => break, + }; + if let Some(frame) = frame { + ws.write_frame(frame).await?; + } + if finished { + let serialized_data = OpenAIRealtimeResponse::ResponseDone { + response: serde_json::Value::Null, + }; + + let payload = Payload::Bytes( + Bytes::from(serde_json::to_string(&serialized_data).unwrap()).into(), + ); + println!("Sending response done: {:?}", serialized_data); + let frame = Frame::text(payload); + ws.write_frame(frame).await?; + }; + } + + Ok(()) +} +async fn server_upgrade( + mut req: Request, +) -> Result>, WebSocketError> { + let (response, fut) = upgrade::upgrade(&mut req)?; + + tokio::task::spawn(async move { + if let Err(e) = tokio::task::unconstrained(handle_client(fut)).await { + eprintln!("Error in websocket connection: {}", e); + } + }); + + Ok(response) +} + +fn main() -> Result<(), WebSocketError> { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + + rt.block_on(async move { + let listener = TcpListener::bind("127.0.0.1:8123").await?; + println!("Server started, listening on {}", "127.0.0.1:8123"); + loop { + let (stream, _) = listener.accept().await?; + println!("Client connected"); + tokio::spawn(async move { + let io = hyper_util::rt::TokioIo::new(stream); + let conn_fut = http1::Builder::new() + .serve_connection(io, service_fn(server_upgrade)) + .with_upgrades(); + if let Err(e) = conn_fut.await { + println!("An error occurred: {:?}", e); + } + }); + } + }) +} diff --git a/node-hub/dora-qwen/dora_qwen/main.py b/node-hub/dora-qwen/dora_qwen/main.py index 957abf429..3d9618216 100644 --- a/node-hub/dora-qwen/dora_qwen/main.py +++ b/node-hub/dora-qwen/dora_qwen/main.py @@ -1,7 +1,6 @@ """TODO: Add docstring.""" import os -import sys import pyarrow as pa from dora import Node @@ -12,14 +11,24 @@ "You're a very succinct AI assistant with short answers.", ) +MODEL_NAME_OR_PATH = os.getenv("MODEL_NAME_OR_PATH", "Qwen/Qwen2.5-0.5B-Instruct-GGUF") +MODEL_FILE_PATTERN = os.getenv("MODEL_FILE_PATTERN", "*fp16.gguf") +MAX_TOKENS = int(os.getenv("MAX_TOKENS", "512")) +N_GPU_LAYERS = int(os.getenv("N_GPU_LAYERS", "0")) +N_THREADS = int(os.getenv("N_THREADS", "4")) +CONTEXT_SIZE = int(os.getenv("CONTEXT_SIZE", "4096")) + def get_model_gguf(): """TODO: Add docstring.""" from llama_cpp import Llama return Llama.from_pretrained( - repo_id="Qwen/Qwen2.5-0.5B-Instruct-GGUF", - filename="*fp16.gguf", + repo_id=MODEL_NAME_OR_PATH, + filename=MODEL_FILE_PATTERN, + n_gpu_layers=N_GPU_LAYERS, + n_ctx=CONTEXT_SIZE, + n_threads=N_THREADS, verbose=False, ) @@ -71,12 +80,7 @@ def main(): """TODO: Add docstring.""" history = [] # If OS is not Darwin, use Huggingface model - if sys.platform == "darwin": - model = get_model_gguf() - elif sys.platform == "linux": - model, tokenizer = get_model_huggingface() - else: - model, tokenizer = get_model_darwin() + model = get_model_gguf() node = Node() @@ -89,24 +93,14 @@ def main(): if len(ACTIVATION_WORDS) == 0 or any( word in ACTIVATION_WORDS for word in words ): - # On linux, Windows - if sys.platform == "darwin": - response = model.create_chat_completion( - messages=[{"role": "user", "content": text}], # Prompt - max_tokens=24, - )["choices"][0]["message"]["content"] - elif sys.platform == "linux": - response, history = generate_hf(model, tokenizer, text, history) - else: - from mlx_lm import generate - - response = generate( - model, - tokenizer, - prompt=text, - verbose=False, - max_tokens=50, - ) + history += [{"role": "user", "content": text}] + + response = model.create_chat_completion( + messages=history, # Prompt + max_tokens=24, + )["choices"][0]["message"]["content"] + + history += [{"role": "assistant", "content": response}] node.send_output( output_id="text", diff --git a/node-hub/dora-vad/dora_vad/main.py b/node-hub/dora-vad/dora_vad/main.py index 9a674ddb9..11f5f7b8b 100644 --- a/node-hub/dora-vad/dora_vad/main.py +++ b/node-hub/dora-vad/dora_vad/main.py @@ -36,21 +36,32 @@ def main(): threshold=THRESHOLD, min_speech_duration_ms=MIN_SPEECH_DURATION_MS, min_silence_duration_ms=MIN_SILENCE_DURATION_MS, + sampling_rate=sr, ) - + if len(speech_timestamps) == 0: + # If there is no speech, return the audio + continue + arg_max = np.argmax([ts["end"] - ts["start"] for ts in speech_timestamps]) # Check ig there is timestamp if ( len(speech_timestamps) > 0 - and len(audio) > MIN_AUDIO_SAMPLING_DURATION_MS * sr / 1000 + and len( + audio[speech_timestamps[0]["start"] : speech_timestamps[-1]["end"]] + ) + > MIN_AUDIO_SAMPLING_DURATION_MS * sr / 1000 + and ( + (len(audio) - speech_timestamps[arg_max]["end"]) + > MIN_SILENCE_DURATION_MS / 1000 * sr * 5 + ) ): # Check if the audio is not cut at the end. And only return if there is a long time spent if speech_timestamps[-1]["end"] == len(audio): node.send_output( "timestamp_start", pa.array([speech_timestamps[-1]["start"]]), + metadata={"sample_rate": sr}, ) - continue - audio = audio[0 : speech_timestamps[-1]["end"]] + audio = audio[: speech_timestamps[-1]["end"]] node.send_output("audio", pa.array(audio), metadata={"sample_rate": sr}) last_audios = [audio[speech_timestamps[-1]["end"] :]]