Add HIP (ROCm) stt-stream build and update whisper-rs
This commit is contained in:
parent
6c7a6fec5f
commit
e16ba27ad6
5 changed files with 204 additions and 104 deletions
|
|
@ -28,31 +28,70 @@
|
||||||
pkgs = nixpkgs.legacyPackages.${system};
|
pkgs = nixpkgs.legacyPackages.${system};
|
||||||
craneLib = crane.mkLib pkgs;
|
craneLib = crane.mkLib pkgs;
|
||||||
|
|
||||||
# Rust STT streaming CLI
|
# Common build inputs for stt-stream
|
||||||
|
commonNativeBuildInputs = with pkgs; [
|
||||||
|
pkg-config
|
||||||
|
cmake
|
||||||
|
git # required by whisper-rs-sys build
|
||||||
|
];
|
||||||
|
|
||||||
|
commonBuildInputs = with pkgs; [
|
||||||
|
alsa-lib
|
||||||
|
openssl
|
||||||
|
];
|
||||||
|
|
||||||
|
# CPU-only build (default)
|
||||||
stt-stream = craneLib.buildPackage {
|
stt-stream = craneLib.buildPackage {
|
||||||
pname = "stt-stream";
|
pname = "stt-stream";
|
||||||
version = "0.1.0";
|
version = "0.1.0";
|
||||||
src = craneLib.cleanCargoSource ./stt-stream;
|
src = craneLib.cleanCargoSource ./stt-stream;
|
||||||
|
|
||||||
nativeBuildInputs = with pkgs; [
|
nativeBuildInputs = commonNativeBuildInputs ++ (with pkgs; [
|
||||||
pkg-config
|
|
||||||
cmake # for whisper-rs
|
|
||||||
clang
|
clang
|
||||||
llvmPackages.libclang
|
llvmPackages.libclang
|
||||||
];
|
]);
|
||||||
|
|
||||||
buildInputs = with pkgs; [
|
buildInputs = commonBuildInputs ++ (with pkgs; [
|
||||||
alsa-lib
|
|
||||||
openssl
|
|
||||||
# whisper.cpp dependencies
|
|
||||||
openblas
|
openblas
|
||||||
];
|
]);
|
||||||
|
|
||||||
# For bindgen to find libclang
|
# For bindgen to find libclang
|
||||||
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
|
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
|
||||||
|
};
|
||||||
|
|
||||||
# Enable CUDA if available (user can override)
|
# GPU build with ROCm/HIP support (AMD GPUs)
|
||||||
WHISPER_CUBLAS = "OFF";
|
stt-stream-hip = craneLib.buildPackage {
|
||||||
|
pname = "stt-stream-hip";
|
||||||
|
version = "0.1.0";
|
||||||
|
src = craneLib.cleanCargoSource ./stt-stream;
|
||||||
|
|
||||||
|
nativeBuildInputs = commonNativeBuildInputs ++ (with pkgs; [
|
||||||
|
# ROCm toolchain - clr contains the properly wrapped hipcc
|
||||||
|
rocmPackages.clr
|
||||||
|
# rocminfo provides rocm_agent_enumerator which hipcc needs
|
||||||
|
rocmPackages.rocminfo
|
||||||
|
]);
|
||||||
|
|
||||||
|
buildInputs = commonBuildInputs ++ (with pkgs; [
|
||||||
|
# ROCm/HIP libraries needed at link time
|
||||||
|
rocmPackages.clr # HIP runtime
|
||||||
|
rocmPackages.hipblas
|
||||||
|
rocmPackages.rocblas
|
||||||
|
rocmPackages.rocm-runtime
|
||||||
|
rocmPackages.rocm-device-libs
|
||||||
|
rocmPackages.rocm-comgr
|
||||||
|
]);
|
||||||
|
|
||||||
|
# Enable hipblas feature
|
||||||
|
cargoExtraArgs = "--features hipblas";
|
||||||
|
|
||||||
|
# The clr package's hipcc is already wrapped with all the right paths,
|
||||||
|
# but we need LIBCLANG_PATH for bindgen
|
||||||
|
LIBCLANG_PATH = "${pkgs.rocmPackages.llvm.clang}/lib";
|
||||||
|
|
||||||
|
# Target common AMD GPU architectures (user can override via AMDGPU_TARGETS)
|
||||||
|
# gfx1030 = RX 6000 series, gfx1100 = RX 7000 series, gfx906/gfx908 = MI50/MI100
|
||||||
|
AMDGPU_TARGETS = "gfx1030;gfx1100";
|
||||||
};
|
};
|
||||||
|
|
||||||
# Fcitx5 C++ shim addon
|
# Fcitx5 C++ shim addon
|
||||||
|
|
@ -82,10 +121,36 @@
|
||||||
mkdir -p $out/lib/fcitx5
|
mkdir -p $out/lib/fcitx5
|
||||||
'';
|
'';
|
||||||
};
|
};
|
||||||
|
# Fcitx5 addon variant using HIP-accelerated stt-stream
|
||||||
|
fcitx5-stt-hip = pkgs.stdenv.mkDerivation {
|
||||||
|
pname = "fcitx5-stt-hip";
|
||||||
|
version = "0.1.0";
|
||||||
|
src = ./fcitx5-stt;
|
||||||
|
|
||||||
|
nativeBuildInputs = with pkgs; [
|
||||||
|
cmake
|
||||||
|
extra-cmake-modules
|
||||||
|
pkg-config
|
||||||
|
];
|
||||||
|
|
||||||
|
buildInputs = with pkgs; [
|
||||||
|
fcitx5
|
||||||
|
];
|
||||||
|
|
||||||
|
cmakeFlags = [
|
||||||
|
"-DSTT_STREAM_PATH=${stt-stream-hip}/bin/stt-stream"
|
||||||
|
];
|
||||||
|
|
||||||
|
postInstall = ''
|
||||||
|
mkdir -p $out/share/fcitx5/addon
|
||||||
|
mkdir -p $out/share/fcitx5/inputmethod
|
||||||
|
mkdir -p $out/lib/fcitx5
|
||||||
|
'';
|
||||||
|
};
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
packages = {
|
packages = {
|
||||||
inherit stt-stream fcitx5-stt;
|
inherit stt-stream stt-stream-hip fcitx5-stt fcitx5-stt-hip;
|
||||||
default = fcitx5-stt;
|
default = fcitx5-stt;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -95,6 +160,10 @@
|
||||||
type = "app";
|
type = "app";
|
||||||
program = "${stt-stream}/bin/stt-stream";
|
program = "${stt-stream}/bin/stt-stream";
|
||||||
};
|
};
|
||||||
|
stt-stream-hip = {
|
||||||
|
type = "app";
|
||||||
|
program = "${stt-stream-hip}/bin/stt-stream";
|
||||||
|
};
|
||||||
default = {
|
default = {
|
||||||
type = "app";
|
type = "app";
|
||||||
program = "${stt-stream}/bin/stt-stream";
|
program = "${stt-stream}/bin/stt-stream";
|
||||||
|
|
@ -110,6 +179,18 @@
|
||||||
fcitx5
|
fcitx5
|
||||||
];
|
];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
# Dev shell with ROCm/HIP for GPU development
|
||||||
|
devShells.hip = pkgs.mkShell {
|
||||||
|
inputsFrom = [ stt-stream-hip ];
|
||||||
|
packages = with pkgs; [
|
||||||
|
rust-analyzer
|
||||||
|
rustfmt
|
||||||
|
clippy
|
||||||
|
fcitx5
|
||||||
|
rocmPackages.rocminfo # For debugging GPU detection
|
||||||
|
];
|
||||||
|
};
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
// {
|
// {
|
||||||
|
|
@ -124,6 +205,15 @@
|
||||||
let
|
let
|
||||||
cfg = config.ringofstorms.sttIme;
|
cfg = config.ringofstorms.sttIme;
|
||||||
sttPkgs = self.packages.${pkgs.stdenv.hostPlatform.system};
|
sttPkgs = self.packages.${pkgs.stdenv.hostPlatform.system};
|
||||||
|
|
||||||
|
# Select the appropriate package variant based on GPU backend
|
||||||
|
sttStreamPkg =
|
||||||
|
if cfg.gpuBackend == "hip" then sttPkgs.stt-stream-hip
|
||||||
|
else sttPkgs.stt-stream;
|
||||||
|
|
||||||
|
fcitx5SttPkg =
|
||||||
|
if cfg.gpuBackend == "hip" then sttPkgs.fcitx5-stt-hip
|
||||||
|
else sttPkgs.fcitx5-stt;
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
options.ringofstorms.sttIme = {
|
options.ringofstorms.sttIme = {
|
||||||
|
|
@ -135,16 +225,26 @@
|
||||||
description = "Whisper model to use (tiny, base, small, medium, large)";
|
description = "Whisper model to use (tiny, base, small, medium, large)";
|
||||||
};
|
};
|
||||||
|
|
||||||
|
gpuBackend = lib.mkOption {
|
||||||
|
type = lib.types.enum [ "cpu" "hip" ];
|
||||||
|
default = "cpu";
|
||||||
|
description = ''
|
||||||
|
GPU backend to use for acceleration:
|
||||||
|
- cpu: CPU-only (default, works everywhere)
|
||||||
|
- hip: AMD ROCm/HIP (requires AMD GPU with ROCm support)
|
||||||
|
'';
|
||||||
|
};
|
||||||
|
|
||||||
useGpu = lib.mkOption {
|
useGpu = lib.mkOption {
|
||||||
type = lib.types.bool;
|
type = lib.types.bool;
|
||||||
default = false;
|
default = false;
|
||||||
description = "Whether to use GPU acceleration (CUDA)";
|
description = "Whether to request GPU acceleration at runtime (--gpu flag)";
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
config = lib.mkIf cfg.enable {
|
config = lib.mkIf cfg.enable {
|
||||||
# Ensure fcitx5 addon is available
|
# Ensure fcitx5 addon is available
|
||||||
i18n.inputMethod.fcitx5.addons = [ sttPkgs.fcitx5-stt ];
|
i18n.inputMethod.fcitx5.addons = [ fcitx5SttPkg ];
|
||||||
|
|
||||||
# Add STT to the Fcitx5 input method group
|
# Add STT to the Fcitx5 input method group
|
||||||
# This assumes de_plasma sets up Groups/0 with keyboard-us (0) and mozc (1)
|
# This assumes de_plasma sets up Groups/0 with keyboard-us (0) and mozc (1)
|
||||||
|
|
@ -153,12 +253,12 @@
|
||||||
};
|
};
|
||||||
|
|
||||||
# Make stt-stream available system-wide
|
# Make stt-stream available system-wide
|
||||||
environment.systemPackages = [ sttPkgs.stt-stream ];
|
environment.systemPackages = [ sttStreamPkg ];
|
||||||
|
|
||||||
# Set default model via environment
|
# Set default model via environment
|
||||||
environment.sessionVariables = {
|
environment.sessionVariables = {
|
||||||
STT_STREAM_MODEL = cfg.model;
|
STT_STREAM_MODEL = cfg.model;
|
||||||
STT_STREAM_USE_GPU = if cfg.useGpu then "1" else "0";
|
STT_STREAM_GPU = if cfg.useGpu then "1" else "0";
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
|
||||||
90
flakes/stt_ime/stt-stream/Cargo.lock
generated
90
flakes/stt_ime/stt-stream/Cargo.lock
generated
|
|
@ -109,25 +109,22 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bindgen"
|
name = "bindgen"
|
||||||
version = "0.69.5"
|
version = "0.71.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
|
checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 2.10.0",
|
"bitflags 2.10.0",
|
||||||
"cexpr",
|
"cexpr",
|
||||||
"clang-sys",
|
"clang-sys",
|
||||||
"itertools 0.12.1",
|
"itertools",
|
||||||
"lazy_static",
|
|
||||||
"lazycell",
|
|
||||||
"log",
|
"log",
|
||||||
"prettyplease",
|
"prettyplease",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"regex",
|
"regex",
|
||||||
"rustc-hash 1.1.0",
|
"rustc-hash",
|
||||||
"shlex",
|
"shlex",
|
||||||
"syn",
|
"syn",
|
||||||
"which",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -139,11 +136,11 @@ dependencies = [
|
||||||
"bitflags 2.10.0",
|
"bitflags 2.10.0",
|
||||||
"cexpr",
|
"cexpr",
|
||||||
"clang-sys",
|
"clang-sys",
|
||||||
"itertools 0.13.0",
|
"itertools",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"regex",
|
"regex",
|
||||||
"rustc-hash 2.1.1",
|
"rustc-hash",
|
||||||
"shlex",
|
"shlex",
|
||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
@ -538,15 +535,6 @@ dependencies = [
|
||||||
"ureq",
|
"ureq",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "home"
|
|
||||||
version = "0.5.12"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d"
|
|
||||||
dependencies = [
|
|
||||||
"windows-sys 0.61.2",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "icu_collections"
|
name = "icu_collections"
|
||||||
version = "2.1.1"
|
version = "2.1.1"
|
||||||
|
|
@ -678,15 +666,6 @@ version = "1.70.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695"
|
checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "itertools"
|
|
||||||
version = "0.12.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
|
|
||||||
dependencies = [
|
|
||||||
"either",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "itertools"
|
name = "itertools"
|
||||||
version = "0.13.0"
|
version = "0.13.0"
|
||||||
|
|
@ -750,12 +729,6 @@ version = "1.5.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "lazycell"
|
|
||||||
version = "1.3.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libc"
|
name = "libc"
|
||||||
version = "0.2.180"
|
version = "0.2.180"
|
||||||
|
|
@ -782,12 +755,6 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "linux-raw-sys"
|
|
||||||
version = "0.4.15"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "linux-raw-sys"
|
name = "linux-raw-sys"
|
||||||
version = "0.11.0"
|
version = "0.11.0"
|
||||||
|
|
@ -1313,12 +1280,6 @@ dependencies = [
|
||||||
"realfft",
|
"realfft",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "rustc-hash"
|
|
||||||
version = "1.1.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustc-hash"
|
name = "rustc-hash"
|
||||||
version = "2.1.1"
|
version = "2.1.1"
|
||||||
|
|
@ -1339,19 +1300,6 @@ dependencies = [
|
||||||
"transpose",
|
"transpose",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "rustix"
|
|
||||||
version = "0.38.44"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
|
|
||||||
dependencies = [
|
|
||||||
"bitflags 2.10.0",
|
|
||||||
"errno",
|
|
||||||
"libc",
|
|
||||||
"linux-raw-sys 0.4.15",
|
|
||||||
"windows-sys 0.59.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustix"
|
name = "rustix"
|
||||||
version = "1.1.3"
|
version = "1.1.3"
|
||||||
|
|
@ -1361,7 +1309,7 @@ dependencies = [
|
||||||
"bitflags 2.10.0",
|
"bitflags 2.10.0",
|
||||||
"errno",
|
"errno",
|
||||||
"libc",
|
"libc",
|
||||||
"linux-raw-sys 0.11.0",
|
"linux-raw-sys",
|
||||||
"windows-sys 0.61.2",
|
"windows-sys 0.61.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -1616,7 +1564,7 @@ dependencies = [
|
||||||
"fastrand",
|
"fastrand",
|
||||||
"getrandom 0.3.4",
|
"getrandom 0.3.4",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"rustix 1.1.3",
|
"rustix",
|
||||||
"windows-sys 0.61.2",
|
"windows-sys 0.61.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -1982,34 +1930,22 @@ dependencies = [
|
||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "which"
|
|
||||||
version = "4.4.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
|
|
||||||
dependencies = [
|
|
||||||
"either",
|
|
||||||
"home",
|
|
||||||
"once_cell",
|
|
||||||
"rustix 0.38.44",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "whisper-rs"
|
name = "whisper-rs"
|
||||||
version = "0.12.0"
|
version = "0.15.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5c597ac8a9d5c4719fee232abc871da184ea50a4fea38d2d00348fd95072b2b0"
|
checksum = "71ea5d2401f30f51d08126a2d133fee4c1955136519d7ac6cf6f5ac0a91e6bc8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"whisper-rs-sys",
|
"whisper-rs-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "whisper-rs-sys"
|
name = "whisper-rs-sys"
|
||||||
version = "0.10.0"
|
version = "0.14.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d22f00ed0995463eecc34ef89905845f6bf6fd37ea70789fed180520050da8f8"
|
checksum = "b5e2a6e06e7ac7b8f53c53a5f50bb0bc823ba69b63ecd887339f807a5598bbd2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bindgen 0.69.5",
|
"bindgen 0.71.1",
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"cmake",
|
"cmake",
|
||||||
"fs_extra",
|
"fs_extra",
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ cpal = "0.15"
|
||||||
# Resampling (48k -> 16k)
|
# Resampling (48k -> 16k)
|
||||||
rubato = "0.15"
|
rubato = "0.15"
|
||||||
# Whisper inference
|
# Whisper inference
|
||||||
whisper-rs = "0.12"
|
whisper-rs = "0.15"
|
||||||
# Voice activity detection
|
# Voice activity detection
|
||||||
# Using silero via ONNX (reserved for future use)
|
# Using silero via ONNX (reserved for future use)
|
||||||
# ort = { version = "2.0.0-rc.9", default-features = false, features = ["load-dynamic"] }
|
# ort = { version = "2.0.0-rc.9", default-features = false, features = ["load-dynamic"] }
|
||||||
|
|
@ -44,6 +44,8 @@ hf-hub = "0.3"
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
cuda = ["whisper-rs/cuda"]
|
cuda = ["whisper-rs/cuda"]
|
||||||
|
hipblas = ["whisper-rs/hipblas"]
|
||||||
|
metal = ["whisper-rs/metal"]
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
lto = true
|
lto = true
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,10 @@ struct Args {
|
||||||
/// Use GPU acceleration
|
/// Use GPU acceleration
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
gpu: bool,
|
gpu: bool,
|
||||||
|
|
||||||
|
/// Number of threads for transcription (default: auto-detect)
|
||||||
|
#[arg(long)]
|
||||||
|
threads: Option<i32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Events emitted to stdout as NDJSON
|
/// Events emitted to stdout as NDJSON
|
||||||
|
|
@ -288,7 +292,42 @@ async fn main() -> Result<()> {
|
||||||
let model_path = get_model_path(&args).context("Failed to get model path")?;
|
let model_path = get_model_path(&args).context("Failed to get model path")?;
|
||||||
info!("Loading Whisper model from: {}", model_path);
|
info!("Loading Whisper model from: {}", model_path);
|
||||||
|
|
||||||
let ctx_params = WhisperContextParameters::default();
|
// Configure GPU and context parameters
|
||||||
|
let mut ctx_params = WhisperContextParameters::default();
|
||||||
|
|
||||||
|
// Check for GPU env var override
|
||||||
|
let gpu_enabled = args.gpu || std::env::var("STT_STREAM_GPU").map(|v| v == "1" || v.to_lowercase() == "true").unwrap_or(false);
|
||||||
|
|
||||||
|
ctx_params.use_gpu(gpu_enabled);
|
||||||
|
if gpu_enabled {
|
||||||
|
ctx_params.flash_attn(true); // Enable flash attention for GPU acceleration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine thread counts
|
||||||
|
let available_threads = std::thread::available_parallelism()
|
||||||
|
.map(|p| p.get() as i32)
|
||||||
|
.unwrap_or(4);
|
||||||
|
let final_threads = args.threads.unwrap_or(available_threads.min(8));
|
||||||
|
let partial_threads = (final_threads / 2).max(1);
|
||||||
|
|
||||||
|
// Log backend configuration
|
||||||
|
let gpu_feature_compiled = cfg!(feature = "hipblas") || cfg!(feature = "cuda") || cfg!(feature = "metal");
|
||||||
|
info!("Backend configuration:");
|
||||||
|
info!(" GPU requested: {}", gpu_enabled);
|
||||||
|
info!(" GPU feature compiled: {} (hipblas={}, cuda={}, metal={})",
|
||||||
|
gpu_feature_compiled,
|
||||||
|
cfg!(feature = "hipblas"),
|
||||||
|
cfg!(feature = "cuda"),
|
||||||
|
cfg!(feature = "metal")
|
||||||
|
);
|
||||||
|
info!(" Flash attention: {}", gpu_enabled);
|
||||||
|
info!(" Model: {:?}", args.model);
|
||||||
|
info!(" Threads (final/partial): {}/{}", final_threads, partial_threads);
|
||||||
|
|
||||||
|
if gpu_enabled && !gpu_feature_compiled {
|
||||||
|
warn!("GPU requested but no GPU feature compiled! Build with --features hipblas or --features cuda");
|
||||||
|
}
|
||||||
|
|
||||||
let whisper_ctx = WhisperContext::new_with_params(&model_path, ctx_params)
|
let whisper_ctx = WhisperContext::new_with_params(&model_path, ctx_params)
|
||||||
.context("Failed to load Whisper model")?;
|
.context("Failed to load Whisper model")?;
|
||||||
|
|
||||||
|
|
@ -520,10 +559,11 @@ async fn main() -> Result<()> {
|
||||||
let buffer_copy = state.buffer.clone();
|
let buffer_copy = state.buffer.clone();
|
||||||
let ctx = whisper_ctx.clone();
|
let ctx = whisper_ctx.clone();
|
||||||
let lang = language.clone();
|
let lang = language.clone();
|
||||||
|
let threads = partial_threads;
|
||||||
|
|
||||||
// Transcribe in background
|
// Transcribe in background
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
if let Ok(text) = transcribe(&ctx, &buffer_copy, &lang, false) {
|
if let Ok(text) = transcribe(&ctx, &buffer_copy, &lang, false, threads) {
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
emit_event(&SttEvent::Partial { text });
|
emit_event(&SttEvent::Partial { text });
|
||||||
}
|
}
|
||||||
|
|
@ -545,7 +585,7 @@ async fn main() -> Result<()> {
|
||||||
let lang = language.clone();
|
let lang = language.clone();
|
||||||
|
|
||||||
// Final transcription
|
// Final transcription
|
||||||
match transcribe(&ctx, &buffer_copy, &lang, true) {
|
match transcribe(&ctx, &buffer_copy, &lang, true, final_threads) {
|
||||||
Ok(text) => {
|
Ok(text) => {
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
emit_event(&SttEvent::Final { text });
|
emit_event(&SttEvent::Final { text });
|
||||||
|
|
@ -588,18 +628,28 @@ fn transcribe(
|
||||||
samples: &[f32],
|
samples: &[f32],
|
||||||
language: &str,
|
language: &str,
|
||||||
is_final: bool,
|
is_final: bool,
|
||||||
|
threads: i32,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
|
let start_time = std::time::Instant::now();
|
||||||
|
|
||||||
let ctx = ctx.lock().map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
|
let ctx = ctx.lock().map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
|
||||||
let mut state = ctx.create_state()?;
|
let mut state = ctx.create_state()?;
|
||||||
|
|
||||||
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
|
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
|
||||||
|
|
||||||
|
// Configure threads
|
||||||
|
params.set_n_threads(threads);
|
||||||
|
|
||||||
// Configure for speed vs accuracy
|
// Configure for speed vs accuracy
|
||||||
if is_final {
|
if is_final {
|
||||||
params.set_n_threads(4);
|
// Final transcription: balanced speed and accuracy
|
||||||
|
params.set_single_segment(false);
|
||||||
} else {
|
} else {
|
||||||
params.set_n_threads(2);
|
// Partial transcription: optimize for speed
|
||||||
params.set_no_context(true);
|
params.set_no_context(true);
|
||||||
|
params.set_single_segment(true); // Faster for streaming
|
||||||
|
params.set_no_timestamps(true); // We don't use timestamps for partials
|
||||||
|
params.set_temperature_inc(0.0); // Disable fallback retries for speed
|
||||||
}
|
}
|
||||||
|
|
||||||
params.set_language(Some(language));
|
params.set_language(Some(language));
|
||||||
|
|
@ -608,18 +658,29 @@ fn transcribe(
|
||||||
params.set_print_realtime(false);
|
params.set_print_realtime(false);
|
||||||
params.set_print_timestamps(false);
|
params.set_print_timestamps(false);
|
||||||
params.set_suppress_blank(true);
|
params.set_suppress_blank(true);
|
||||||
params.set_suppress_non_speech_tokens(true);
|
params.set_suppress_nst(true);
|
||||||
|
|
||||||
// Run inference
|
// Run inference
|
||||||
state.full(params, samples)?;
|
state.full(params, samples)?;
|
||||||
|
|
||||||
|
let inference_time = start_time.elapsed();
|
||||||
|
let audio_duration_secs = samples.len() as f32 / 16000.0;
|
||||||
|
tracing::debug!(
|
||||||
|
"Transcription took {:?} for {:.1}s audio (RTF: {:.2}x)",
|
||||||
|
inference_time,
|
||||||
|
audio_duration_secs,
|
||||||
|
inference_time.as_secs_f32() / audio_duration_secs
|
||||||
|
);
|
||||||
|
|
||||||
// Collect segments
|
// Collect segments
|
||||||
let num_segments = state.full_n_segments()?;
|
let num_segments = state.full_n_segments();
|
||||||
let mut text = String::new();
|
let mut text = String::new();
|
||||||
|
|
||||||
for i in 0..num_segments {
|
for i in 0..num_segments {
|
||||||
if let Ok(segment) = state.full_get_segment_text(i) {
|
if let Some(segment) = state.get_segment(i) {
|
||||||
text.push_str(&segment);
|
if let Ok(segment_text) = segment.to_str_lossy() {
|
||||||
|
text.push_str(&segment_text);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -75,6 +75,7 @@
|
||||||
({
|
({
|
||||||
ringofstorms.sttIme = {
|
ringofstorms.sttIme = {
|
||||||
enable = true;
|
enable = true;
|
||||||
|
gpuBackend = "hip"; # Use AMD ROCm/HIP acceleration
|
||||||
useGpu = true;
|
useGpu = true;
|
||||||
};
|
};
|
||||||
})
|
})
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue