From e16ba27ad601f94fa0994f5e7ff8addcc8f03ba9 Mon Sep 17 00:00:00 2001 From: "RingOfStorms (Joshua Bell)" Date: Wed, 14 Jan 2026 12:14:24 -0600 Subject: [PATCH] Add HIP (ROCm) stt-stream build and update whisper-rs --- flakes/stt_ime/flake.nix | 134 ++++++++++++++++++++++---- flakes/stt_ime/stt-stream/Cargo.lock | 90 +++-------------- flakes/stt_ime/stt-stream/Cargo.toml | 4 +- flakes/stt_ime/stt-stream/src/main.rs | 79 +++++++++++++-- hosts/lio/flake.nix | 1 + 5 files changed, 204 insertions(+), 104 deletions(-) diff --git a/flakes/stt_ime/flake.nix b/flakes/stt_ime/flake.nix index 0e6ae14d..016f5b37 100644 --- a/flakes/stt_ime/flake.nix +++ b/flakes/stt_ime/flake.nix @@ -28,31 +28,70 @@ pkgs = nixpkgs.legacyPackages.${system}; craneLib = crane.mkLib pkgs; - # Rust STT streaming CLI + # Common build inputs for stt-stream + commonNativeBuildInputs = with pkgs; [ + pkg-config + cmake + git # required by whisper-rs-sys build + ]; + + commonBuildInputs = with pkgs; [ + alsa-lib + openssl + ]; + + # CPU-only build (default) stt-stream = craneLib.buildPackage { pname = "stt-stream"; version = "0.1.0"; src = craneLib.cleanCargoSource ./stt-stream; - nativeBuildInputs = with pkgs; [ - pkg-config - cmake # for whisper-rs + nativeBuildInputs = commonNativeBuildInputs ++ (with pkgs; [ clang llvmPackages.libclang - ]; + ]); - buildInputs = with pkgs; [ - alsa-lib - openssl - # whisper.cpp dependencies + buildInputs = commonBuildInputs ++ (with pkgs; [ openblas - ]; + ]); # For bindgen to find libclang LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib"; + }; - # Enable CUDA if available (user can override) - WHISPER_CUBLAS = "OFF"; + # GPU build with ROCm/HIP support (AMD GPUs) + stt-stream-hip = craneLib.buildPackage { + pname = "stt-stream-hip"; + version = "0.1.0"; + src = craneLib.cleanCargoSource ./stt-stream; + + nativeBuildInputs = commonNativeBuildInputs ++ (with pkgs; [ + # ROCm toolchain - clr contains the properly wrapped hipcc + rocmPackages.clr + # rocminfo provides rocm_agent_enumerator which hipcc needs + rocmPackages.rocminfo + ]); + + buildInputs = commonBuildInputs ++ (with pkgs; [ + # ROCm/HIP libraries needed at link time + rocmPackages.clr # HIP runtime + rocmPackages.hipblas + rocmPackages.rocblas + rocmPackages.rocm-runtime + rocmPackages.rocm-device-libs + rocmPackages.rocm-comgr + ]); + + # Enable hipblas feature + cargoExtraArgs = "--features hipblas"; + + # The clr package's hipcc is already wrapped with all the right paths, + # but we need LIBCLANG_PATH for bindgen + LIBCLANG_PATH = "${pkgs.rocmPackages.llvm.clang}/lib"; + + # Target common AMD GPU architectures (user can override via AMDGPU_TARGETS) + # gfx1030 = RX 6000 series, gfx1100 = RX 7000 series, gfx906/gfx908 = MI50/MI100 + AMDGPU_TARGETS = "gfx1030;gfx1100"; }; # Fcitx5 C++ shim addon @@ -82,10 +121,36 @@ mkdir -p $out/lib/fcitx5 ''; }; + # Fcitx5 addon variant using HIP-accelerated stt-stream + fcitx5-stt-hip = pkgs.stdenv.mkDerivation { + pname = "fcitx5-stt-hip"; + version = "0.1.0"; + src = ./fcitx5-stt; + + nativeBuildInputs = with pkgs; [ + cmake + extra-cmake-modules + pkg-config + ]; + + buildInputs = with pkgs; [ + fcitx5 + ]; + + cmakeFlags = [ + "-DSTT_STREAM_PATH=${stt-stream-hip}/bin/stt-stream" + ]; + + postInstall = '' + mkdir -p $out/share/fcitx5/addon + mkdir -p $out/share/fcitx5/inputmethod + mkdir -p $out/lib/fcitx5 + ''; + }; in { packages = { - inherit stt-stream fcitx5-stt; + inherit stt-stream stt-stream-hip fcitx5-stt fcitx5-stt-hip; default = fcitx5-stt; }; @@ -95,6 +160,10 @@ type = "app"; program = "${stt-stream}/bin/stt-stream"; }; + stt-stream-hip = { + type = "app"; + program = "${stt-stream-hip}/bin/stt-stream"; + }; default = { type = "app"; program = "${stt-stream}/bin/stt-stream"; @@ -110,6 +179,18 @@ fcitx5 ]; }; + + # Dev shell with ROCm/HIP for GPU development + devShells.hip = pkgs.mkShell { + inputsFrom = [ stt-stream-hip ]; + packages = with pkgs; [ + rust-analyzer + rustfmt + clippy + fcitx5 + rocmPackages.rocminfo # For debugging GPU detection + ]; + }; } ) // { @@ -124,6 +205,15 @@ let cfg = config.ringofstorms.sttIme; sttPkgs = self.packages.${pkgs.stdenv.hostPlatform.system}; + + # Select the appropriate package variant based on GPU backend + sttStreamPkg = + if cfg.gpuBackend == "hip" then sttPkgs.stt-stream-hip + else sttPkgs.stt-stream; + + fcitx5SttPkg = + if cfg.gpuBackend == "hip" then sttPkgs.fcitx5-stt-hip + else sttPkgs.fcitx5-stt; in { options.ringofstorms.sttIme = { @@ -135,16 +225,26 @@ description = "Whisper model to use (tiny, base, small, medium, large)"; }; + gpuBackend = lib.mkOption { + type = lib.types.enum [ "cpu" "hip" ]; + default = "cpu"; + description = '' + GPU backend to use for acceleration: + - cpu: CPU-only (default, works everywhere) + - hip: AMD ROCm/HIP (requires AMD GPU with ROCm support) + ''; + }; + useGpu = lib.mkOption { type = lib.types.bool; default = false; - description = "Whether to use GPU acceleration (CUDA)"; + description = "Whether to request GPU acceleration at runtime (--gpu flag)"; }; }; config = lib.mkIf cfg.enable { # Ensure fcitx5 addon is available - i18n.inputMethod.fcitx5.addons = [ sttPkgs.fcitx5-stt ]; + i18n.inputMethod.fcitx5.addons = [ fcitx5SttPkg ]; # Add STT to the Fcitx5 input method group # This assumes de_plasma sets up Groups/0 with keyboard-us (0) and mozc (1) @@ -153,12 +253,12 @@ }; # Make stt-stream available system-wide - environment.systemPackages = [ sttPkgs.stt-stream ]; + environment.systemPackages = [ sttStreamPkg ]; # Set default model via environment environment.sessionVariables = { STT_STREAM_MODEL = cfg.model; - STT_STREAM_USE_GPU = if cfg.useGpu then "1" else "0"; + STT_STREAM_GPU = if cfg.useGpu then "1" else "0"; }; }; }; diff --git a/flakes/stt_ime/stt-stream/Cargo.lock b/flakes/stt_ime/stt-stream/Cargo.lock index b25c4923..fa693b70 100644 --- a/flakes/stt_ime/stt-stream/Cargo.lock +++ b/flakes/stt_ime/stt-stream/Cargo.lock @@ -109,25 +109,22 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "bindgen" -version = "0.69.5" +version = "0.71.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" dependencies = [ "bitflags 2.10.0", "cexpr", "clang-sys", - "itertools 0.12.1", - "lazy_static", - "lazycell", + "itertools", "log", "prettyplease", "proc-macro2", "quote", "regex", - "rustc-hash 1.1.0", + "rustc-hash", "shlex", "syn", - "which", ] [[package]] @@ -139,11 +136,11 @@ dependencies = [ "bitflags 2.10.0", "cexpr", "clang-sys", - "itertools 0.13.0", + "itertools", "proc-macro2", "quote", "regex", - "rustc-hash 2.1.1", + "rustc-hash", "shlex", "syn", ] @@ -538,15 +535,6 @@ dependencies = [ "ureq", ] -[[package]] -name = "home" -version = "0.5.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" -dependencies = [ - "windows-sys 0.61.2", -] - [[package]] name = "icu_collections" version = "2.1.1" @@ -678,15 +666,6 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" -[[package]] -name = "itertools" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.13.0" @@ -750,12 +729,6 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" -[[package]] -name = "lazycell" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" - [[package]] name = "libc" version = "0.2.180" @@ -782,12 +755,6 @@ dependencies = [ "libc", ] -[[package]] -name = "linux-raw-sys" -version = "0.4.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" - [[package]] name = "linux-raw-sys" version = "0.11.0" @@ -1313,12 +1280,6 @@ dependencies = [ "realfft", ] -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - [[package]] name = "rustc-hash" version = "2.1.1" @@ -1339,19 +1300,6 @@ dependencies = [ "transpose", ] -[[package]] -name = "rustix" -version = "0.38.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" -dependencies = [ - "bitflags 2.10.0", - "errno", - "libc", - "linux-raw-sys 0.4.15", - "windows-sys 0.59.0", -] - [[package]] name = "rustix" version = "1.1.3" @@ -1361,7 +1309,7 @@ dependencies = [ "bitflags 2.10.0", "errno", "libc", - "linux-raw-sys 0.11.0", + "linux-raw-sys", "windows-sys 0.61.2", ] @@ -1616,7 +1564,7 @@ dependencies = [ "fastrand", "getrandom 0.3.4", "once_cell", - "rustix 1.1.3", + "rustix", "windows-sys 0.61.2", ] @@ -1982,34 +1930,22 @@ dependencies = [ "rustls-pki-types", ] -[[package]] -name = "which" -version = "4.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" -dependencies = [ - "either", - "home", - "once_cell", - "rustix 0.38.44", -] - [[package]] name = "whisper-rs" -version = "0.12.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c597ac8a9d5c4719fee232abc871da184ea50a4fea38d2d00348fd95072b2b0" +checksum = "71ea5d2401f30f51d08126a2d133fee4c1955136519d7ac6cf6f5ac0a91e6bc8" dependencies = [ "whisper-rs-sys", ] [[package]] name = "whisper-rs-sys" -version = "0.10.0" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d22f00ed0995463eecc34ef89905845f6bf6fd37ea70789fed180520050da8f8" +checksum = "b5e2a6e06e7ac7b8f53c53a5f50bb0bc823ba69b63ecd887339f807a5598bbd2" dependencies = [ - "bindgen 0.69.5", + "bindgen 0.71.1", "cfg-if", "cmake", "fs_extra", diff --git a/flakes/stt_ime/stt-stream/Cargo.toml b/flakes/stt_ime/stt-stream/Cargo.toml index cbd7c8de..e705b7d5 100644 --- a/flakes/stt_ime/stt-stream/Cargo.toml +++ b/flakes/stt_ime/stt-stream/Cargo.toml @@ -11,7 +11,7 @@ cpal = "0.15" # Resampling (48k -> 16k) rubato = "0.15" # Whisper inference -whisper-rs = "0.12" +whisper-rs = "0.15" # Voice activity detection # Using silero via ONNX (reserved for future use) # ort = { version = "2.0.0-rc.9", default-features = false, features = ["load-dynamic"] } @@ -44,6 +44,8 @@ hf-hub = "0.3" [features] default = [] cuda = ["whisper-rs/cuda"] +hipblas = ["whisper-rs/hipblas"] +metal = ["whisper-rs/metal"] [profile.release] lto = true diff --git a/flakes/stt_ime/stt-stream/src/main.rs b/flakes/stt_ime/stt-stream/src/main.rs index 261a10c4..4ffc233d 100644 --- a/flakes/stt_ime/stt-stream/src/main.rs +++ b/flakes/stt_ime/stt-stream/src/main.rs @@ -124,6 +124,10 @@ struct Args { /// Use GPU acceleration #[arg(long)] gpu: bool, + + /// Number of threads for transcription (default: auto-detect) + #[arg(long)] + threads: Option, } /// Events emitted to stdout as NDJSON @@ -288,7 +292,42 @@ async fn main() -> Result<()> { let model_path = get_model_path(&args).context("Failed to get model path")?; info!("Loading Whisper model from: {}", model_path); - let ctx_params = WhisperContextParameters::default(); + // Configure GPU and context parameters + let mut ctx_params = WhisperContextParameters::default(); + + // Check for GPU env var override + let gpu_enabled = args.gpu || std::env::var("STT_STREAM_GPU").map(|v| v == "1" || v.to_lowercase() == "true").unwrap_or(false); + + ctx_params.use_gpu(gpu_enabled); + if gpu_enabled { + ctx_params.flash_attn(true); // Enable flash attention for GPU acceleration + } + + // Determine thread counts + let available_threads = std::thread::available_parallelism() + .map(|p| p.get() as i32) + .unwrap_or(4); + let final_threads = args.threads.unwrap_or(available_threads.min(8)); + let partial_threads = (final_threads / 2).max(1); + + // Log backend configuration + let gpu_feature_compiled = cfg!(feature = "hipblas") || cfg!(feature = "cuda") || cfg!(feature = "metal"); + info!("Backend configuration:"); + info!(" GPU requested: {}", gpu_enabled); + info!(" GPU feature compiled: {} (hipblas={}, cuda={}, metal={})", + gpu_feature_compiled, + cfg!(feature = "hipblas"), + cfg!(feature = "cuda"), + cfg!(feature = "metal") + ); + info!(" Flash attention: {}", gpu_enabled); + info!(" Model: {:?}", args.model); + info!(" Threads (final/partial): {}/{}", final_threads, partial_threads); + + if gpu_enabled && !gpu_feature_compiled { + warn!("GPU requested but no GPU feature compiled! Build with --features hipblas or --features cuda"); + } + let whisper_ctx = WhisperContext::new_with_params(&model_path, ctx_params) .context("Failed to load Whisper model")?; @@ -520,10 +559,11 @@ async fn main() -> Result<()> { let buffer_copy = state.buffer.clone(); let ctx = whisper_ctx.clone(); let lang = language.clone(); + let threads = partial_threads; // Transcribe in background tokio::task::spawn_blocking(move || { - if let Ok(text) = transcribe(&ctx, &buffer_copy, &lang, false) { + if let Ok(text) = transcribe(&ctx, &buffer_copy, &lang, false, threads) { if !text.is_empty() { emit_event(&SttEvent::Partial { text }); } @@ -545,7 +585,7 @@ async fn main() -> Result<()> { let lang = language.clone(); // Final transcription - match transcribe(&ctx, &buffer_copy, &lang, true) { + match transcribe(&ctx, &buffer_copy, &lang, true, final_threads) { Ok(text) => { if !text.is_empty() { emit_event(&SttEvent::Final { text }); @@ -588,18 +628,28 @@ fn transcribe( samples: &[f32], language: &str, is_final: bool, + threads: i32, ) -> Result { + let start_time = std::time::Instant::now(); + let ctx = ctx.lock().map_err(|_| anyhow::anyhow!("Lock poisoned"))?; let mut state = ctx.create_state()?; let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }); + // Configure threads + params.set_n_threads(threads); + // Configure for speed vs accuracy if is_final { - params.set_n_threads(4); + // Final transcription: balanced speed and accuracy + params.set_single_segment(false); } else { - params.set_n_threads(2); + // Partial transcription: optimize for speed params.set_no_context(true); + params.set_single_segment(true); // Faster for streaming + params.set_no_timestamps(true); // We don't use timestamps for partials + params.set_temperature_inc(0.0); // Disable fallback retries for speed } params.set_language(Some(language)); @@ -608,18 +658,29 @@ fn transcribe( params.set_print_realtime(false); params.set_print_timestamps(false); params.set_suppress_blank(true); - params.set_suppress_non_speech_tokens(true); + params.set_suppress_nst(true); // Run inference state.full(params, samples)?; + let inference_time = start_time.elapsed(); + let audio_duration_secs = samples.len() as f32 / 16000.0; + tracing::debug!( + "Transcription took {:?} for {:.1}s audio (RTF: {:.2}x)", + inference_time, + audio_duration_secs, + inference_time.as_secs_f32() / audio_duration_secs + ); + // Collect segments - let num_segments = state.full_n_segments()?; + let num_segments = state.full_n_segments(); let mut text = String::new(); for i in 0..num_segments { - if let Ok(segment) = state.full_get_segment_text(i) { - text.push_str(&segment); + if let Some(segment) = state.get_segment(i) { + if let Ok(segment_text) = segment.to_str_lossy() { + text.push_str(&segment_text); + } } } diff --git a/hosts/lio/flake.nix b/hosts/lio/flake.nix index a8fe0af3..5de6921a 100644 --- a/hosts/lio/flake.nix +++ b/hosts/lio/flake.nix @@ -75,6 +75,7 @@ ({ ringofstorms.sttIme = { enable = true; + gpuBackend = "hip"; # Use AMD ROCm/HIP acceleration useGpu = true; }; })