704 lines
22 KiB
Rust
704 lines
22 KiB
Rust
//! stt-stream: Local speech-to-text streaming CLI
|
|
//!
|
|
//! Captures audio from microphone, performs VAD, transcribes with Whisper,
|
|
//! and outputs JSON events to stdout for Fcitx5 integration.
|
|
|
|
use anyhow::{Context, Result};
|
|
use clap::{Parser, ValueEnum};
|
|
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
|
use rubato::{FftFixedInOut, Resampler};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::io::{BufRead, Write};
|
|
use std::sync::atomic::{AtomicBool, Ordering};
|
|
use std::sync::{Arc, Mutex};
|
|
use tokio::sync::mpsc;
|
|
use tracing::{error, info, warn};
|
|
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
|
|
|
|
/// Operating mode for the STT engine
|
|
#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)]
|
|
pub enum Mode {
|
|
/// Record until silence, transcribe, then reset (one-shot)
|
|
Oneshot,
|
|
/// Always listen, emit text when speech detected (continuous)
|
|
Continuous,
|
|
/// Manual start/stop via stdin commands
|
|
Manual,
|
|
}
|
|
|
|
/// Whisper model size
|
|
#[derive(Debug, Clone, Copy, ValueEnum)]
|
|
pub enum ModelSize {
|
|
Tiny,
|
|
TinyEn,
|
|
Base,
|
|
BaseEn,
|
|
Small,
|
|
SmallEn,
|
|
Medium,
|
|
MediumEn,
|
|
LargeV3,
|
|
}
|
|
|
|
impl ModelSize {
|
|
fn model_name(&self) -> &'static str {
|
|
match self {
|
|
ModelSize::Tiny => "tiny",
|
|
ModelSize::TinyEn => "tiny.en",
|
|
ModelSize::Base => "base",
|
|
ModelSize::BaseEn => "base.en",
|
|
ModelSize::Small => "small",
|
|
ModelSize::SmallEn => "small.en",
|
|
ModelSize::Medium => "medium",
|
|
ModelSize::MediumEn => "medium.en",
|
|
ModelSize::LargeV3 => "large-v3",
|
|
}
|
|
}
|
|
|
|
fn parse(input: &str) -> Option<Self> {
|
|
let normalized = input
|
|
.trim()
|
|
.to_lowercase()
|
|
.replace('.', "-")
|
|
.replace('_', "-");
|
|
|
|
match normalized.as_str() {
|
|
"tiny" => Some(ModelSize::Tiny),
|
|
"tiny-en" => Some(ModelSize::TinyEn),
|
|
"base" => Some(ModelSize::Base),
|
|
"base-en" => Some(ModelSize::BaseEn),
|
|
"small" => Some(ModelSize::Small),
|
|
"small-en" => Some(ModelSize::SmallEn),
|
|
"medium" => Some(ModelSize::Medium),
|
|
"medium-en" => Some(ModelSize::MediumEn),
|
|
"large-v3" => Some(ModelSize::LargeV3),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn hf_repo(&self) -> &'static str {
|
|
"ggerganov/whisper.cpp"
|
|
}
|
|
|
|
fn hf_filename(&self) -> String {
|
|
format!("ggml-{}.bin", self.model_name())
|
|
}
|
|
}
|
|
|
|
#[derive(Parser, Debug)]
|
|
#[command(name = "stt-stream")]
|
|
#[command(about = "Local speech-to-text streaming for Fcitx5")]
|
|
struct Args {
|
|
/// Operating mode
|
|
#[arg(short, long, value_enum, default_value = "manual")]
|
|
mode: Mode,
|
|
|
|
/// Whisper model size
|
|
#[arg(short = 'M', long, value_enum, default_value = "base-en")]
|
|
model: ModelSize,
|
|
|
|
/// Path to whisper model file (overrides --model)
|
|
#[arg(long)]
|
|
model_path: Option<String>,
|
|
|
|
/// VAD threshold (0.0-1.0)
|
|
#[arg(long, default_value = "0.5")]
|
|
vad_threshold: f32,
|
|
|
|
/// Silence duration (ms) to end utterance
|
|
#[arg(long, default_value = "800")]
|
|
silence_ms: u64,
|
|
|
|
/// Emit partial transcripts while speaking
|
|
#[arg(long, default_value = "true")]
|
|
partials: bool,
|
|
|
|
/// Partial transcript interval (ms)
|
|
#[arg(long, default_value = "500")]
|
|
partial_interval_ms: u64,
|
|
|
|
/// Language code (e.g., "en", "ja", "auto")
|
|
#[arg(short, long, default_value = "en")]
|
|
language: String,
|
|
|
|
/// Use GPU acceleration
|
|
#[arg(long)]
|
|
gpu: bool,
|
|
|
|
/// Number of threads for transcription (default: auto-detect)
|
|
#[arg(long)]
|
|
threads: Option<i32>,
|
|
}
|
|
|
|
/// Events emitted to stdout as NDJSON
|
|
#[derive(Debug, Serialize)]
|
|
#[serde(tag = "type", rename_all = "snake_case")]
|
|
pub enum SttEvent {
|
|
/// STT engine is ready
|
|
Ready,
|
|
/// Recording started
|
|
RecordingStarted,
|
|
/// Recording stopped
|
|
RecordingStopped,
|
|
/// Partial (unstable) transcript
|
|
Partial { text: String },
|
|
/// Final transcript
|
|
Final { text: String },
|
|
/// Error occurred
|
|
Error { message: String },
|
|
/// Engine shutting down
|
|
Shutdown,
|
|
}
|
|
|
|
/// Commands received from stdin as NDJSON
|
|
#[derive(Debug, Deserialize)]
|
|
#[serde(tag = "cmd", rename_all = "snake_case")]
|
|
pub enum SttCommand {
|
|
/// Start recording
|
|
Start,
|
|
/// Stop recording and transcribe
|
|
Stop,
|
|
/// Cancel current recording without transcribing
|
|
Cancel,
|
|
/// Shutdown the engine
|
|
Shutdown,
|
|
/// Switch mode
|
|
SetMode { mode: String },
|
|
}
|
|
|
|
fn emit_event(event: &SttEvent) {
|
|
if let Ok(json) = serde_json::to_string(event) {
|
|
let mut stdout = std::io::stdout().lock();
|
|
let _ = writeln!(stdout, "{}", json);
|
|
let _ = stdout.flush();
|
|
}
|
|
}
|
|
|
|
/// Simple energy-based VAD (placeholder for Silero VAD)
|
|
/// Returns true if the audio chunk likely contains speech
|
|
fn simple_vad(samples: &[f32], threshold: f32) -> bool {
|
|
if samples.is_empty() {
|
|
return false;
|
|
}
|
|
let energy: f32 = samples.iter().map(|s| s * s).sum::<f32>() / samples.len() as f32;
|
|
let db = 10.0 * energy.max(1e-10).log10();
|
|
// Typical speech is around -20 to -10 dB, silence is < -40 dB
|
|
// Map threshold 0-1 to dB range -50 to -20
|
|
let threshold_db = -50.0 + (threshold * 30.0);
|
|
db > threshold_db
|
|
}
|
|
|
|
/// Download or locate the Whisper model
|
|
fn get_model_path(args: &Args) -> Result<String> {
|
|
if let Some(ref path) = args.model_path {
|
|
return Ok(path.clone());
|
|
}
|
|
|
|
// Check environment variable
|
|
if let Ok(path) = std::env::var("STT_STREAM_MODEL_PATH") {
|
|
if std::path::Path::new(&path).exists() {
|
|
return Ok(path);
|
|
}
|
|
}
|
|
|
|
// Check XDG cache
|
|
let cache_dir = dirs::cache_dir()
|
|
.unwrap_or_else(|| std::path::PathBuf::from("."))
|
|
.join("stt-stream")
|
|
.join("models");
|
|
|
|
let model_file = cache_dir.join(args.model.hf_filename());
|
|
if model_file.exists() {
|
|
return Ok(model_file.to_string_lossy().to_string());
|
|
}
|
|
|
|
// Download from Hugging Face
|
|
info!("Downloading model {} from Hugging Face...", args.model.model_name());
|
|
std::fs::create_dir_all(&cache_dir)?;
|
|
|
|
let api = hf_hub::api::sync::Api::new()?;
|
|
let repo = api.model(args.model.hf_repo().to_string());
|
|
let path = repo.get(&args.model.hf_filename())?;
|
|
|
|
Ok(path.to_string_lossy().to_string())
|
|
}
|
|
|
|
/// Audio processing state
|
|
struct AudioState {
|
|
/// Audio samples buffer (16kHz mono)
|
|
buffer: Vec<f32>,
|
|
/// Whether we're currently recording
|
|
is_recording: bool,
|
|
/// Whether speech was detected in current segment
|
|
speech_detected: bool,
|
|
/// Samples since last speech
|
|
silence_samples: usize,
|
|
/// Last partial emission time
|
|
last_partial: std::time::Instant,
|
|
/// Manual mode: stop requested, finalize next tick
|
|
pending_finalize: bool,
|
|
}
|
|
|
|
impl AudioState {
|
|
fn new() -> Self {
|
|
Self {
|
|
buffer: Vec::with_capacity(16000 * 30), // 30 seconds max
|
|
is_recording: false,
|
|
speech_detected: false,
|
|
silence_samples: 0,
|
|
last_partial: std::time::Instant::now(),
|
|
pending_finalize: false,
|
|
}
|
|
}
|
|
|
|
fn clear(&mut self) {
|
|
self.buffer.clear();
|
|
self.speech_detected = false;
|
|
self.silence_samples = 0;
|
|
self.pending_finalize = false;
|
|
}
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<()> {
|
|
// Initialize logging
|
|
tracing_subscriber::fmt()
|
|
.with_env_filter(
|
|
tracing_subscriber::EnvFilter::from_default_env()
|
|
.add_directive("stt_stream=info".parse().unwrap()),
|
|
)
|
|
.with_writer(std::io::stderr)
|
|
.init();
|
|
|
|
let mut args = Args::parse();
|
|
|
|
// Allow Nix/session configuration via env vars.
|
|
// Precedence: explicit CLI args > env vars > defaults.
|
|
//
|
|
// `ringofstorms.sttIme.model` uses dot notation (e.g. "tiny.en"),
|
|
// while clap's value enum expects kebab-case (e.g. "tiny-en").
|
|
let cli_has_model_flag = std::env::args().any(|a| a == "--model" || a == "-M");
|
|
if !cli_has_model_flag && args.model_path.is_none() {
|
|
if let Ok(model) = std::env::var("STT_STREAM_MODEL") {
|
|
if let Some(parsed) = ModelSize::parse(&model) {
|
|
args.model = parsed;
|
|
}
|
|
}
|
|
}
|
|
|
|
info!("Starting stt-stream with mode: {:?}", args.mode);
|
|
|
|
// Load Whisper model
|
|
let model_path = get_model_path(&args).context("Failed to get model path")?;
|
|
info!("Loading Whisper model from: {}", model_path);
|
|
|
|
// Configure GPU and context parameters
|
|
let mut ctx_params = WhisperContextParameters::default();
|
|
|
|
// Check for GPU env var override
|
|
let gpu_enabled = args.gpu || std::env::var("STT_STREAM_GPU").map(|v| v == "1" || v.to_lowercase() == "true").unwrap_or(false);
|
|
|
|
ctx_params.use_gpu(gpu_enabled);
|
|
if gpu_enabled {
|
|
ctx_params.flash_attn(true); // Enable flash attention for GPU acceleration
|
|
}
|
|
|
|
// Determine thread counts
|
|
let available_threads = std::thread::available_parallelism()
|
|
.map(|p| p.get() as i32)
|
|
.unwrap_or(4);
|
|
let final_threads = args.threads.unwrap_or(available_threads.min(8));
|
|
let partial_threads = (final_threads / 2).max(1);
|
|
|
|
// Log backend configuration
|
|
let gpu_feature_compiled = cfg!(feature = "hipblas") || cfg!(feature = "cuda") || cfg!(feature = "metal");
|
|
info!("Backend configuration:");
|
|
info!(" GPU requested: {}", gpu_enabled);
|
|
info!(" GPU feature compiled: {} (hipblas={}, cuda={}, metal={})",
|
|
gpu_feature_compiled,
|
|
cfg!(feature = "hipblas"),
|
|
cfg!(feature = "cuda"),
|
|
cfg!(feature = "metal")
|
|
);
|
|
info!(" Flash attention: {}", gpu_enabled);
|
|
info!(" Model: {:?}", args.model);
|
|
info!(" Threads (final/partial): {}/{}", final_threads, partial_threads);
|
|
|
|
if gpu_enabled && !gpu_feature_compiled {
|
|
warn!("GPU requested but no GPU feature compiled! Build with --features hipblas or --features cuda");
|
|
}
|
|
|
|
let whisper_ctx = WhisperContext::new_with_params(&model_path, ctx_params)
|
|
.context("Failed to load Whisper model")?;
|
|
|
|
let whisper_ctx = Arc::new(Mutex::new(whisper_ctx));
|
|
|
|
// Audio capture setup
|
|
let host = cpal::default_host();
|
|
let device = host
|
|
.default_input_device()
|
|
.context("No input device available")?;
|
|
|
|
info!("Using input device: {}", device.name().unwrap_or_default());
|
|
|
|
let config = device.default_input_config()?;
|
|
let sample_rate = config.sample_rate().0;
|
|
let channels = config.channels() as usize;
|
|
|
|
info!("Input config: {}Hz, {} channels", sample_rate, channels);
|
|
|
|
// Resampler: input rate -> 16kHz
|
|
let resampler = if sample_rate != 16000 {
|
|
Some(Arc::new(Mutex::new(
|
|
FftFixedInOut::<f32>::new(sample_rate as usize, 16000, 1024, 1)
|
|
.context("Failed to create resampler")?,
|
|
)))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
// Shared state
|
|
let audio_state = Arc::new(Mutex::new(AudioState::new()));
|
|
let running = Arc::new(AtomicBool::new(true));
|
|
let mode = Arc::new(Mutex::new(args.mode));
|
|
|
|
// Channel for audio data
|
|
let (audio_tx, mut audio_rx) = mpsc::channel::<Vec<f32>>(100);
|
|
|
|
// Audio callback
|
|
let resampler_clone = resampler.clone();
|
|
let running_clone = running.clone();
|
|
|
|
let stream = device.build_input_stream(
|
|
&config.into(),
|
|
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
|
if !running_clone.load(Ordering::Relaxed) {
|
|
return;
|
|
}
|
|
|
|
// Convert to mono if needed
|
|
let mono: Vec<f32> = if channels > 1 {
|
|
data.chunks(channels)
|
|
.map(|frame| frame.iter().sum::<f32>() / channels as f32)
|
|
.collect()
|
|
} else {
|
|
data.to_vec()
|
|
};
|
|
|
|
// Resample if needed
|
|
let resampled = if let Some(ref resampler) = resampler_clone {
|
|
if let Ok(mut r) = resampler.lock() {
|
|
// Pad input to required length
|
|
let input_frames = r.input_frames_next();
|
|
if mono.len() >= input_frames {
|
|
let input = vec![mono[..input_frames].to_vec()];
|
|
match r.process(&input, None) {
|
|
Ok(output) => output.into_iter().flatten().collect(),
|
|
Err(_) => return,
|
|
}
|
|
} else {
|
|
return;
|
|
}
|
|
} else {
|
|
return;
|
|
}
|
|
} else {
|
|
mono
|
|
};
|
|
|
|
let _ = audio_tx.blocking_send(resampled);
|
|
},
|
|
|err| {
|
|
error!("Audio stream error: {}", err);
|
|
},
|
|
None,
|
|
)?;
|
|
|
|
stream.play()?;
|
|
emit_event(&SttEvent::Ready);
|
|
|
|
// Stdin command reader
|
|
let running_stdin = running.clone();
|
|
let mode_stdin = mode.clone();
|
|
let audio_state_stdin = audio_state.clone();
|
|
|
|
let stdin_handle = std::thread::spawn(move || {
|
|
let stdin = std::io::stdin();
|
|
for line in stdin.lock().lines() {
|
|
if !running_stdin.load(Ordering::Relaxed) {
|
|
break;
|
|
}
|
|
|
|
let line = match line {
|
|
Ok(l) => l,
|
|
Err(_) => continue,
|
|
};
|
|
|
|
let cmd: SttCommand = match serde_json::from_str(&line) {
|
|
Ok(c) => c,
|
|
Err(_) => {
|
|
// Try simple text commands
|
|
match line.trim().to_lowercase().as_str() {
|
|
"start" => SttCommand::Start,
|
|
"stop" => SttCommand::Stop,
|
|
"cancel" => SttCommand::Cancel,
|
|
"shutdown" | "quit" | "exit" => SttCommand::Shutdown,
|
|
_ => continue,
|
|
}
|
|
}
|
|
};
|
|
|
|
match cmd {
|
|
SttCommand::Start => {
|
|
if let Ok(mut state) = audio_state_stdin.lock() {
|
|
state.is_recording = true;
|
|
state.clear();
|
|
emit_event(&SttEvent::RecordingStarted);
|
|
}
|
|
}
|
|
SttCommand::Stop => {
|
|
if let Ok(mut state) = audio_state_stdin.lock() {
|
|
state.is_recording = false;
|
|
state.pending_finalize = true;
|
|
emit_event(&SttEvent::RecordingStopped);
|
|
}
|
|
}
|
|
SttCommand::Cancel => {
|
|
if let Ok(mut state) = audio_state_stdin.lock() {
|
|
state.is_recording = false;
|
|
state.clear();
|
|
emit_event(&SttEvent::RecordingStopped);
|
|
}
|
|
}
|
|
SttCommand::Shutdown => {
|
|
running_stdin.store(false, Ordering::Relaxed);
|
|
break;
|
|
}
|
|
SttCommand::SetMode { mode: m } => {
|
|
if let Ok(mut current_mode) = mode_stdin.lock() {
|
|
*current_mode = match m.as_str() {
|
|
"oneshot" => Mode::Oneshot,
|
|
"continuous" => Mode::Continuous,
|
|
"manual" => Mode::Manual,
|
|
_ => continue,
|
|
};
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
// Main processing loop
|
|
let vad_threshold = args.vad_threshold;
|
|
let silence_samples_threshold = (args.silence_ms as f32 * 16.0) as usize; // 16 samples per ms at 16kHz
|
|
let partial_interval = std::time::Duration::from_millis(args.partial_interval_ms);
|
|
let emit_partials = args.partials;
|
|
let language = args.language.clone();
|
|
|
|
while running.load(Ordering::Relaxed) {
|
|
// Receive audio data
|
|
let samples = match tokio::time::timeout(
|
|
std::time::Duration::from_millis(100),
|
|
audio_rx.recv(),
|
|
)
|
|
.await
|
|
{
|
|
Ok(Some(s)) => s,
|
|
Ok(None) => break,
|
|
Err(_) => continue, // Timeout, check running flag
|
|
};
|
|
|
|
let current_mode = *mode.lock().unwrap();
|
|
let mut state = audio_state.lock().unwrap();
|
|
|
|
// Mode-specific behavior
|
|
match current_mode {
|
|
Mode::Manual => {
|
|
// In manual mode we normally ignore audio unless explicitly recording.
|
|
// Exception: after receiving a "stop" command, we need one more tick
|
|
// to finalize and emit the transcript.
|
|
if !state.is_recording && !state.pending_finalize {
|
|
continue;
|
|
}
|
|
}
|
|
Mode::Oneshot | Mode::Continuous => {
|
|
// Auto-start on speech detection
|
|
let has_speech = simple_vad(&samples, vad_threshold);
|
|
|
|
if !state.is_recording && has_speech {
|
|
state.is_recording = true;
|
|
state.clear();
|
|
emit_event(&SttEvent::RecordingStarted);
|
|
}
|
|
|
|
if !state.is_recording {
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Accumulate audio
|
|
state.buffer.extend_from_slice(&samples);
|
|
|
|
// VAD check
|
|
let has_speech = simple_vad(&samples, vad_threshold);
|
|
if has_speech {
|
|
state.speech_detected = true;
|
|
state.silence_samples = 0;
|
|
} else {
|
|
state.silence_samples += samples.len();
|
|
}
|
|
|
|
// Emit partial transcript if enabled
|
|
if emit_partials
|
|
&& state.speech_detected
|
|
&& state.last_partial.elapsed() > partial_interval
|
|
&& state.buffer.len() > 16000 // At least 1 second
|
|
{
|
|
state.last_partial = std::time::Instant::now();
|
|
let buffer_copy = state.buffer.clone();
|
|
let ctx = whisper_ctx.clone();
|
|
let lang = language.clone();
|
|
let threads = partial_threads;
|
|
|
|
// Transcribe in background
|
|
tokio::task::spawn_blocking(move || {
|
|
if let Ok(text) = transcribe(&ctx, &buffer_copy, &lang, false, threads) {
|
|
if !text.is_empty() {
|
|
emit_event(&SttEvent::Partial { text });
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
// Check for end of utterance
|
|
let should_finalize = match current_mode {
|
|
Mode::Manual => state.pending_finalize && state.speech_detected,
|
|
Mode::Oneshot | Mode::Continuous => {
|
|
state.speech_detected && state.silence_samples > silence_samples_threshold
|
|
}
|
|
};
|
|
|
|
if should_finalize && !state.buffer.is_empty() {
|
|
let buffer_copy = state.buffer.clone();
|
|
let ctx = whisper_ctx.clone();
|
|
let lang = language.clone();
|
|
|
|
// Final transcription
|
|
match transcribe(&ctx, &buffer_copy, &lang, true, final_threads) {
|
|
Ok(text) => {
|
|
if !text.is_empty() {
|
|
emit_event(&SttEvent::Final { text });
|
|
}
|
|
}
|
|
Err(e) => {
|
|
emit_event(&SttEvent::Error {
|
|
message: e.to_string(),
|
|
});
|
|
}
|
|
}
|
|
|
|
state.clear();
|
|
state.is_recording = current_mode == Mode::Continuous;
|
|
|
|
if current_mode == Mode::Oneshot {
|
|
emit_event(&SttEvent::RecordingStopped);
|
|
}
|
|
}
|
|
|
|
// Prevent buffer from growing too large
|
|
if state.buffer.len() > 16000 * 30 {
|
|
warn!("Buffer too large, truncating");
|
|
let start = state.buffer.len() - 16000 * 20;
|
|
state.buffer = state.buffer[start..].to_vec();
|
|
}
|
|
}
|
|
|
|
// Cleanup
|
|
drop(stream);
|
|
emit_event(&SttEvent::Shutdown);
|
|
let _ = stdin_handle.join();
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Transcribe audio buffer using Whisper
|
|
fn transcribe(
|
|
ctx: &Arc<Mutex<WhisperContext>>,
|
|
samples: &[f32],
|
|
language: &str,
|
|
is_final: bool,
|
|
threads: i32,
|
|
) -> Result<String> {
|
|
let start_time = std::time::Instant::now();
|
|
|
|
let ctx = ctx.lock().map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
|
|
let mut state = ctx.create_state()?;
|
|
|
|
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
|
|
|
|
// Configure threads
|
|
params.set_n_threads(threads);
|
|
|
|
// Configure for speed vs accuracy
|
|
if is_final {
|
|
// Final transcription: balanced speed and accuracy
|
|
params.set_single_segment(false);
|
|
} else {
|
|
// Partial transcription: optimize for speed
|
|
params.set_no_context(true);
|
|
params.set_single_segment(true); // Faster for streaming
|
|
params.set_no_timestamps(true); // We don't use timestamps for partials
|
|
params.set_temperature_inc(0.0); // Disable fallback retries for speed
|
|
}
|
|
|
|
params.set_language(Some(language));
|
|
params.set_print_special(false);
|
|
params.set_print_progress(false);
|
|
params.set_print_realtime(false);
|
|
params.set_print_timestamps(false);
|
|
params.set_suppress_blank(true);
|
|
params.set_suppress_nst(true);
|
|
|
|
// Run inference
|
|
state.full(params, samples)?;
|
|
|
|
let inference_time = start_time.elapsed();
|
|
let audio_duration_secs = samples.len() as f32 / 16000.0;
|
|
tracing::debug!(
|
|
"Transcription took {:?} for {:.1}s audio (RTF: {:.2}x)",
|
|
inference_time,
|
|
audio_duration_secs,
|
|
inference_time.as_secs_f32() / audio_duration_secs
|
|
);
|
|
|
|
// Collect segments
|
|
let num_segments = state.full_n_segments();
|
|
let mut text = String::new();
|
|
|
|
for i in 0..num_segments {
|
|
if let Some(segment) = state.get_segment(i) {
|
|
if let Ok(segment_text) = segment.to_str_lossy() {
|
|
text.push_str(&segment_text);
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(text.trim().to_string())
|
|
}
|
|
|
|
/// Stub for dirs crate functionality
|
|
mod dirs {
|
|
use std::path::PathBuf;
|
|
|
|
pub fn cache_dir() -> Option<PathBuf> {
|
|
std::env::var("XDG_CACHE_HOME")
|
|
.map(PathBuf::from)
|
|
.ok()
|
|
.or_else(|| {
|
|
std::env::var("HOME")
|
|
.map(|h| PathBuf::from(h).join(".cache"))
|
|
.ok()
|
|
})
|
|
}
|
|
}
|