Add HIP (ROCm) stt-stream build and update whisper-rs
This commit is contained in:
parent
6c7a6fec5f
commit
e16ba27ad6
5 changed files with 204 additions and 104 deletions
|
|
@ -124,6 +124,10 @@ struct Args {
|
|||
/// Use GPU acceleration
|
||||
#[arg(long)]
|
||||
gpu: bool,
|
||||
|
||||
/// Number of threads for transcription (default: auto-detect)
|
||||
#[arg(long)]
|
||||
threads: Option<i32>,
|
||||
}
|
||||
|
||||
/// Events emitted to stdout as NDJSON
|
||||
|
|
@ -288,7 +292,42 @@ async fn main() -> Result<()> {
|
|||
let model_path = get_model_path(&args).context("Failed to get model path")?;
|
||||
info!("Loading Whisper model from: {}", model_path);
|
||||
|
||||
let ctx_params = WhisperContextParameters::default();
|
||||
// Configure GPU and context parameters
|
||||
let mut ctx_params = WhisperContextParameters::default();
|
||||
|
||||
// Check for GPU env var override
|
||||
let gpu_enabled = args.gpu || std::env::var("STT_STREAM_GPU").map(|v| v == "1" || v.to_lowercase() == "true").unwrap_or(false);
|
||||
|
||||
ctx_params.use_gpu(gpu_enabled);
|
||||
if gpu_enabled {
|
||||
ctx_params.flash_attn(true); // Enable flash attention for GPU acceleration
|
||||
}
|
||||
|
||||
// Determine thread counts
|
||||
let available_threads = std::thread::available_parallelism()
|
||||
.map(|p| p.get() as i32)
|
||||
.unwrap_or(4);
|
||||
let final_threads = args.threads.unwrap_or(available_threads.min(8));
|
||||
let partial_threads = (final_threads / 2).max(1);
|
||||
|
||||
// Log backend configuration
|
||||
let gpu_feature_compiled = cfg!(feature = "hipblas") || cfg!(feature = "cuda") || cfg!(feature = "metal");
|
||||
info!("Backend configuration:");
|
||||
info!(" GPU requested: {}", gpu_enabled);
|
||||
info!(" GPU feature compiled: {} (hipblas={}, cuda={}, metal={})",
|
||||
gpu_feature_compiled,
|
||||
cfg!(feature = "hipblas"),
|
||||
cfg!(feature = "cuda"),
|
||||
cfg!(feature = "metal")
|
||||
);
|
||||
info!(" Flash attention: {}", gpu_enabled);
|
||||
info!(" Model: {:?}", args.model);
|
||||
info!(" Threads (final/partial): {}/{}", final_threads, partial_threads);
|
||||
|
||||
if gpu_enabled && !gpu_feature_compiled {
|
||||
warn!("GPU requested but no GPU feature compiled! Build with --features hipblas or --features cuda");
|
||||
}
|
||||
|
||||
let whisper_ctx = WhisperContext::new_with_params(&model_path, ctx_params)
|
||||
.context("Failed to load Whisper model")?;
|
||||
|
||||
|
|
@ -520,10 +559,11 @@ async fn main() -> Result<()> {
|
|||
let buffer_copy = state.buffer.clone();
|
||||
let ctx = whisper_ctx.clone();
|
||||
let lang = language.clone();
|
||||
let threads = partial_threads;
|
||||
|
||||
// Transcribe in background
|
||||
tokio::task::spawn_blocking(move || {
|
||||
if let Ok(text) = transcribe(&ctx, &buffer_copy, &lang, false) {
|
||||
if let Ok(text) = transcribe(&ctx, &buffer_copy, &lang, false, threads) {
|
||||
if !text.is_empty() {
|
||||
emit_event(&SttEvent::Partial { text });
|
||||
}
|
||||
|
|
@ -545,7 +585,7 @@ async fn main() -> Result<()> {
|
|||
let lang = language.clone();
|
||||
|
||||
// Final transcription
|
||||
match transcribe(&ctx, &buffer_copy, &lang, true) {
|
||||
match transcribe(&ctx, &buffer_copy, &lang, true, final_threads) {
|
||||
Ok(text) => {
|
||||
if !text.is_empty() {
|
||||
emit_event(&SttEvent::Final { text });
|
||||
|
|
@ -588,18 +628,28 @@ fn transcribe(
|
|||
samples: &[f32],
|
||||
language: &str,
|
||||
is_final: bool,
|
||||
threads: i32,
|
||||
) -> Result<String> {
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
let ctx = ctx.lock().map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
|
||||
let mut state = ctx.create_state()?;
|
||||
|
||||
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
|
||||
|
||||
// Configure threads
|
||||
params.set_n_threads(threads);
|
||||
|
||||
// Configure for speed vs accuracy
|
||||
if is_final {
|
||||
params.set_n_threads(4);
|
||||
// Final transcription: balanced speed and accuracy
|
||||
params.set_single_segment(false);
|
||||
} else {
|
||||
params.set_n_threads(2);
|
||||
// Partial transcription: optimize for speed
|
||||
params.set_no_context(true);
|
||||
params.set_single_segment(true); // Faster for streaming
|
||||
params.set_no_timestamps(true); // We don't use timestamps for partials
|
||||
params.set_temperature_inc(0.0); // Disable fallback retries for speed
|
||||
}
|
||||
|
||||
params.set_language(Some(language));
|
||||
|
|
@ -608,18 +658,29 @@ fn transcribe(
|
|||
params.set_print_realtime(false);
|
||||
params.set_print_timestamps(false);
|
||||
params.set_suppress_blank(true);
|
||||
params.set_suppress_non_speech_tokens(true);
|
||||
params.set_suppress_nst(true);
|
||||
|
||||
// Run inference
|
||||
state.full(params, samples)?;
|
||||
|
||||
let inference_time = start_time.elapsed();
|
||||
let audio_duration_secs = samples.len() as f32 / 16000.0;
|
||||
tracing::debug!(
|
||||
"Transcription took {:?} for {:.1}s audio (RTF: {:.2}x)",
|
||||
inference_time,
|
||||
audio_duration_secs,
|
||||
inference_time.as_secs_f32() / audio_duration_secs
|
||||
);
|
||||
|
||||
// Collect segments
|
||||
let num_segments = state.full_n_segments()?;
|
||||
let num_segments = state.full_n_segments();
|
||||
let mut text = String::new();
|
||||
|
||||
for i in 0..num_segments {
|
||||
if let Ok(segment) = state.full_get_segment_text(i) {
|
||||
text.push_str(&segment);
|
||||
if let Some(segment) = state.get_segment(i) {
|
||||
if let Ok(segment_text) = segment.to_str_lossy() {
|
||||
text.push_str(&segment_text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue