Add HIP (ROCm) stt-stream build and update whisper-rs

This commit is contained in:
RingOfStorms (Joshua Bell) 2026-01-14 12:14:24 -06:00
parent 6c7a6fec5f
commit e16ba27ad6
5 changed files with 204 additions and 104 deletions

View file

@ -109,25 +109,22 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "bindgen"
version = "0.69.5"
version = "0.71.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3"
dependencies = [
"bitflags 2.10.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"itertools",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"rustc-hash",
"shlex",
"syn",
"which",
]
[[package]]
@ -139,11 +136,11 @@ dependencies = [
"bitflags 2.10.0",
"cexpr",
"clang-sys",
"itertools 0.13.0",
"itertools",
"proc-macro2",
"quote",
"regex",
"rustc-hash 2.1.1",
"rustc-hash",
"shlex",
"syn",
]
@ -538,15 +535,6 @@ dependencies = [
"ureq",
]
[[package]]
name = "home"
version = "0.5.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d"
dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "icu_collections"
version = "2.1.1"
@ -678,15 +666,6 @@ version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695"
[[package]]
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.13.0"
@ -750,12 +729,6 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "libc"
version = "0.2.180"
@ -782,12 +755,6 @@ dependencies = [
"libc",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
[[package]]
name = "linux-raw-sys"
version = "0.11.0"
@ -1313,12 +1280,6 @@ dependencies = [
"realfft",
]
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc-hash"
version = "2.1.1"
@ -1339,19 +1300,6 @@ dependencies = [
"transpose",
]
[[package]]
name = "rustix"
version = "0.38.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
dependencies = [
"bitflags 2.10.0",
"errno",
"libc",
"linux-raw-sys 0.4.15",
"windows-sys 0.59.0",
]
[[package]]
name = "rustix"
version = "1.1.3"
@ -1361,7 +1309,7 @@ dependencies = [
"bitflags 2.10.0",
"errno",
"libc",
"linux-raw-sys 0.11.0",
"linux-raw-sys",
"windows-sys 0.61.2",
]
@ -1616,7 +1564,7 @@ dependencies = [
"fastrand",
"getrandom 0.3.4",
"once_cell",
"rustix 1.1.3",
"rustix",
"windows-sys 0.61.2",
]
@ -1982,34 +1930,22 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix 0.38.44",
]
[[package]]
name = "whisper-rs"
version = "0.12.0"
version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c597ac8a9d5c4719fee232abc871da184ea50a4fea38d2d00348fd95072b2b0"
checksum = "71ea5d2401f30f51d08126a2d133fee4c1955136519d7ac6cf6f5ac0a91e6bc8"
dependencies = [
"whisper-rs-sys",
]
[[package]]
name = "whisper-rs-sys"
version = "0.10.0"
version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d22f00ed0995463eecc34ef89905845f6bf6fd37ea70789fed180520050da8f8"
checksum = "b5e2a6e06e7ac7b8f53c53a5f50bb0bc823ba69b63ecd887339f807a5598bbd2"
dependencies = [
"bindgen 0.69.5",
"bindgen 0.71.1",
"cfg-if",
"cmake",
"fs_extra",

View file

@ -11,7 +11,7 @@ cpal = "0.15"
# Resampling (48k -> 16k)
rubato = "0.15"
# Whisper inference
whisper-rs = "0.12"
whisper-rs = "0.15"
# Voice activity detection
# Using silero via ONNX (reserved for future use)
# ort = { version = "2.0.0-rc.9", default-features = false, features = ["load-dynamic"] }
@ -44,6 +44,8 @@ hf-hub = "0.3"
[features]
default = []
cuda = ["whisper-rs/cuda"]
hipblas = ["whisper-rs/hipblas"]
metal = ["whisper-rs/metal"]
[profile.release]
lto = true

View file

@ -124,6 +124,10 @@ struct Args {
/// Use GPU acceleration
#[arg(long)]
gpu: bool,
/// Number of threads for transcription (default: auto-detect)
#[arg(long)]
threads: Option<i32>,
}
/// Events emitted to stdout as NDJSON
@ -288,7 +292,42 @@ async fn main() -> Result<()> {
let model_path = get_model_path(&args).context("Failed to get model path")?;
info!("Loading Whisper model from: {}", model_path);
let ctx_params = WhisperContextParameters::default();
// Configure GPU and context parameters
let mut ctx_params = WhisperContextParameters::default();
// Check for GPU env var override
let gpu_enabled = args.gpu || std::env::var("STT_STREAM_GPU").map(|v| v == "1" || v.to_lowercase() == "true").unwrap_or(false);
ctx_params.use_gpu(gpu_enabled);
if gpu_enabled {
ctx_params.flash_attn(true); // Enable flash attention for GPU acceleration
}
// Determine thread counts
let available_threads = std::thread::available_parallelism()
.map(|p| p.get() as i32)
.unwrap_or(4);
let final_threads = args.threads.unwrap_or(available_threads.min(8));
let partial_threads = (final_threads / 2).max(1);
// Log backend configuration
let gpu_feature_compiled = cfg!(feature = "hipblas") || cfg!(feature = "cuda") || cfg!(feature = "metal");
info!("Backend configuration:");
info!(" GPU requested: {}", gpu_enabled);
info!(" GPU feature compiled: {} (hipblas={}, cuda={}, metal={})",
gpu_feature_compiled,
cfg!(feature = "hipblas"),
cfg!(feature = "cuda"),
cfg!(feature = "metal")
);
info!(" Flash attention: {}", gpu_enabled);
info!(" Model: {:?}", args.model);
info!(" Threads (final/partial): {}/{}", final_threads, partial_threads);
if gpu_enabled && !gpu_feature_compiled {
warn!("GPU requested but no GPU feature compiled! Build with --features hipblas or --features cuda");
}
let whisper_ctx = WhisperContext::new_with_params(&model_path, ctx_params)
.context("Failed to load Whisper model")?;
@ -520,10 +559,11 @@ async fn main() -> Result<()> {
let buffer_copy = state.buffer.clone();
let ctx = whisper_ctx.clone();
let lang = language.clone();
let threads = partial_threads;
// Transcribe in background
tokio::task::spawn_blocking(move || {
if let Ok(text) = transcribe(&ctx, &buffer_copy, &lang, false) {
if let Ok(text) = transcribe(&ctx, &buffer_copy, &lang, false, threads) {
if !text.is_empty() {
emit_event(&SttEvent::Partial { text });
}
@ -545,7 +585,7 @@ async fn main() -> Result<()> {
let lang = language.clone();
// Final transcription
match transcribe(&ctx, &buffer_copy, &lang, true) {
match transcribe(&ctx, &buffer_copy, &lang, true, final_threads) {
Ok(text) => {
if !text.is_empty() {
emit_event(&SttEvent::Final { text });
@ -588,18 +628,28 @@ fn transcribe(
samples: &[f32],
language: &str,
is_final: bool,
threads: i32,
) -> Result<String> {
let start_time = std::time::Instant::now();
let ctx = ctx.lock().map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
let mut state = ctx.create_state()?;
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
// Configure threads
params.set_n_threads(threads);
// Configure for speed vs accuracy
if is_final {
params.set_n_threads(4);
// Final transcription: balanced speed and accuracy
params.set_single_segment(false);
} else {
params.set_n_threads(2);
// Partial transcription: optimize for speed
params.set_no_context(true);
params.set_single_segment(true); // Faster for streaming
params.set_no_timestamps(true); // We don't use timestamps for partials
params.set_temperature_inc(0.0); // Disable fallback retries for speed
}
params.set_language(Some(language));
@ -608,18 +658,29 @@ fn transcribe(
params.set_print_realtime(false);
params.set_print_timestamps(false);
params.set_suppress_blank(true);
params.set_suppress_non_speech_tokens(true);
params.set_suppress_nst(true);
// Run inference
state.full(params, samples)?;
let inference_time = start_time.elapsed();
let audio_duration_secs = samples.len() as f32 / 16000.0;
tracing::debug!(
"Transcription took {:?} for {:.1}s audio (RTF: {:.2}x)",
inference_time,
audio_duration_secs,
inference_time.as_secs_f32() / audio_duration_secs
);
// Collect segments
let num_segments = state.full_n_segments()?;
let num_segments = state.full_n_segments();
let mut text = String::new();
for i in 0..num_segments {
if let Ok(segment) = state.full_get_segment_text(i) {
text.push_str(&segment);
if let Some(segment) = state.get_segment(i) {
if let Ok(segment_text) = segment.to_str_lossy() {
text.push_str(&segment_text);
}
}
}