Migrate to WhisperX for speaker diarization

Implement a sliding window audio buffer and update the transcriber to
use WhisperX for transcription, alignment, and speaker identification.
Update the pipeline to handle and store speaker-attributed transcripts.

Additionally, update the LLM processor's reasoning parameter to
"enable_thinking".
This commit is contained in:
2026-05-26 21:48:30 -07:00
parent d0fcdfab01
commit f4c98fb2b9
7 changed files with 135 additions and 38 deletions
+1
View File
@@ -0,0 +1 @@
3.12
+6
View File
@@ -0,0 +1,6 @@
def main():
print("Hello from dnd-helpers!")
if __name__ == "__main__":
main()
+7
View File
@@ -0,0 +1,7 @@
[project]
name = "dnd-helpers"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = []
+1 -1
View File
@@ -1,5 +1,5 @@
# Core dependencies for D&D Helpers # Core dependencies for D&D Helpers
faster-whisper whisperx
sounddevice sounddevice
pydantic pydantic
textual textual
+2 -2
View File
@@ -83,7 +83,7 @@ class LLMProcessor:
model=self.model, model=self.model,
messages=messages, messages=messages,
response_format=response_format, response_format=response_format,
extra_body={"include_reasoning": False}, extra_body={"enable_thinking": False},
) )
return response.choices[0].message.content return response.choices[0].message.content
except Exception as e: except Exception as e:
@@ -125,7 +125,7 @@ class LLMProcessor:
model=self.model, model=self.model,
messages=messages, messages=messages,
response_format={"type": "json_object"}, response_format={"type": "json_object"},
extra_body={"include_reasoning": False}, extra_body={"enable_thinking": False},
) )
logger.info("LLM Processor (Extract): Response received from backend.") logger.info("LLM Processor (Extract): Response received from backend.")
+54 -13
View File
@@ -4,6 +4,8 @@ import os
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
import numpy as np
from src.llm.models import ExtractionResult from src.llm.models import ExtractionResult
from src.llm.processor import LLMProcessor from src.llm.processor import LLMProcessor
from src.stt.listener import AudioListener from src.stt.listener import AudioListener
@@ -18,6 +20,12 @@ logging.basicConfig(
logging.FileHandler("pipeline.log"), logging.FileHandler("pipeline.log"),
], ],
) )
# Suppress verbose logging from STT libraries to keep the TUI clean
logging.getLogger("whisper").setLevel(logging.WARNING)
logging.getLogger("faster_whisper").setLevel(logging.WARNING)
logging.getLogger("pyannote").setLevel(logging.WARNING)
logging.getLogger("whisperx").setLevel(logging.WARNING)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -40,6 +48,13 @@ class PipelineOrchestrator:
self.history = [] # List of strings (transcripts) self.history = [] # List of strings (transcripts)
self.history_max_words = 1000 self.history_max_words = 1000
# STT Sliding Window Buffer
self.audio_buffer = [] # List of audio chunks
self.buffer_max_seconds = 30
self.sample_rate = 16000
self.buffer_max_samples = self.buffer_max_seconds * self.sample_rate
self.last_processed_end_time = 0.0
async def stt_worker(self): async def stt_worker(self):
""" """
Worker that handles STT: Audio -> Text. Worker that handles STT: Audio -> Text.
@@ -50,12 +65,35 @@ class PipelineOrchestrator:
# Get audio chunk from listener # Get audio chunk from listener
audio_chunk = await self.listener.get_chunk() audio_chunk = await self.listener.get_chunk()
# Transcribe # Maintain sliding window buffer
text = self.transcriber.transcribe(audio_chunk) self.audio_buffer.append(audio_chunk)
current_buffer_samples = sum(len(c) for c in self.audio_buffer)
if text: if current_buffer_samples > self.buffer_max_samples:
logger.info(f"Transcribed: {text}") # Remove oldest chunks until we are within the buffer limit
await self.transcript_queue.put(text) while (
sum(len(c) for c in self.audio_buffer) > self.buffer_max_samples
):
self.audio_buffer.pop(0)
# Concatenate buffer for transcription
full_audio = np.concatenate(self.audio_buffer)
# Transcribe (WhisperX now returns a list of (speaker, text, start, end))
results = self.transcriber.transcribe(full_audio)
# Filter for only new segments that start after the last processed segment
new_segments = [
res for res in results if res[2] >= self.last_processed_end_time
]
if new_segments:
for speaker, text, start, end in new_segments:
logger.info(f"Transcribed: [{speaker}] {text}")
await self.transcript_queue.put((speaker, text))
self.last_processed_end_time = max(
self.last_processed_end_time, end
)
except Exception as e: except Exception as e:
logger.error(f"STT Worker error: {e}") logger.error(f"STT Worker error: {e}")
@@ -70,14 +108,16 @@ class PipelineOrchestrator:
logger.info("LLM Worker started.") logger.info("LLM Worker started.")
while self.is_running: while self.is_running:
try: try:
# Get raw text from transcript queue # Get raw text from transcript queue (now a tuple of (speaker, text))
raw_text = await self.transcript_queue.get() speaker, raw_text = await self.transcript_queue.get()
logger.info(f"LLM Worker: Processing text: {raw_text}") logger.info(f"LLM Worker: Processing text from {speaker}: {raw_text}")
# 1. Prepare Context (Conversation History) # 1. Prepare Context (Conversation History)
# Maintain history and truncate to max words # Store as "Speaker X: [text]"
self.history.append(raw_text) entry = f"{speaker}: {raw_text}"
self.history.append(entry)
full_history_text = " ".join(self.history) full_history_text = " ".join(self.history)
words = full_history_text.split() words = full_history_text.split()
if len(words) > self.history_max_words: if len(words) > self.history_max_words:
@@ -119,7 +159,7 @@ class PipelineOrchestrator:
def _get_wiki_context(self) -> str: def _get_wiki_context(self) -> str:
""" """
Reads all files in the lore directory and returns them as a single context string. Reads all files in the lore directory and returns them as a 저희 context string.
""" """
from src.persistence.lore import DATA_LORE_DIR from src.persistence.lore import DATA_LORE_DIR
@@ -151,11 +191,12 @@ class PipelineOrchestrator:
# Pass the proposal queue to the app. # Pass the proposal queue to the app.
app = ConfirmationApp(proposal_queue=self.proposal_queue) app = ConfirmationApp(proposal_queue=self.proposal_queue)
await app.run_async() await app.run_async()
# Once the TUI exits, stop the entire pipeline
self.stop() self.stop()
except Exception as e: except Exception as e:
logger.error(f"TUI Worker error: {e}") logger.error(f"TUI Worker error: {e}")
self.stop() self.stop()
except asyncio.CancelledError:
pass
async def run(self): async def run(self):
""" """
@@ -188,6 +229,6 @@ class PipelineOrchestrator:
def stop(self): def stop(self):
""" """
Stops the pipeline. Stops.
""" """
self.is_running = False self.is_running = False
+64 -22
View File
@@ -1,6 +1,9 @@
import logging import logging
import os
from faster_whisper import WhisperModel import numpy as np
import whisperx
from whisperx.diarize import DiarizationPipeline
# Do not call basicConfig here, as it's called in the orchestrator # Do not call basicConfig here, as it's called in the orchestrator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -8,62 +11,101 @@ logger = logging.getLogger(__name__)
class Transcriber: class Transcriber:
""" """
Converts audio chunks (numpy arrays) into text using faster-whisper. Converts audio chunks (numpy arrays) into text and identifies speakers using WhisperX.
""" """
def __init__(self, model_size="base", device="cpu", compute_type="int8"): def __init__(
self, model_size="base", device="cpu", compute_type="int8", language="en"
):
""" """
Initializes the faster-whisper model. Initializes the WhisperX model and diarization pipeline.
Args: Args:
model_size (str): The size of the model to use (e.g., "tiny", "base", "small"). model_size (str): The size of the model to use (e.g., "tiny", "base", "small").
device (str): The device to run the model on ("cpu" or "cuda"). device (str): The device to run the model on ("cpu" or "cuda").
compute_type (str): The compute type to use (e.g., "int8", "float16"). compute_type (str): The compute type to use (e.g., "int8", "float16").
language (str): The language code for alignment (e.g., "en").
""" """
self.device = device
self.compute_type = compute_type
self.language = language
logger.info( logger.info(
f"Loading faster-whisper model: {model_size} on {device} ({compute_type})..." f"Loading WhisperX model: {model_size} on {device} ({compute_type})..."
) )
try: try:
self.model = WhisperModel( # Load transcription model
self.model = whisperx.load_model(
model_size, device=device, compute_type=compute_type model_size, device=device, compute_type=compute_type
) )
logger.info("Model loaded successfully.")
# Load alignment model (required for accurate speaker assignment)
# model_dir=None allows automatic model selection based on the language
self.align_model, self.align_metadata = whisperx.load_align_model(
device=device, model_dir=None, language_code=self.language
)
self.diarize_model = DiarizationPipeline()
logger.info("WhisperX and Diarization models loaded successfully.")
except Exception as e: except Exception as e:
logger.error(f"Failed to load faster-whisper model: {e}") logger.error(f"Failed to load WhisperX models: {e}")
raise raise
def transcribe(self, audio_chunk): def transcribe(self, audio_chunk):
""" """
Transcribes a single audio chunk. Transcribes an audio chunk and performs speaker diarization.
Args: Args:
audio_chunk (np.ndarray): The audio data as a numpy array. audio_chunk (np.ndarray): The audio data as a numpy array.
Returns: Returns:
str: The transcribed text. list: A list of tuples (speaker_id, text).
""" """
if audio_chunk is None: if audio_chunk is None:
return "" return []
try: try:
# faster-whisper expects audio in float32 and 1D array # WhisperX expects audio in float32 and 1D array
audio_data = audio_chunk.astype("float32").flatten() audio = audio_chunk.astype("float32").flatten()
# Transcribe the audio # 1. Perform transcription
segments, info = self.model.transcribe(audio_data, beam_size=5) # batch_size is set to 16 for efficiency; can be adjusted based on VRAM
result = self.model.transcribe(audio, batch_size=16)
# Combine segments into a single string # 2. Perform alignment
text = " ".join([segment.text.strip() for segment in segments]) # Alignment is necessary for the assign_words_to_speakers step
result_a = whisperx.align(
result["segments"],
self.align_model,
self.align_metadata,
audio,
self.device,
)
return text.strip() # 3. Perform diarization
diarize_segments = self.diarize_model(audio)
# 4. Align transcription segments with speakers
result_final = whisperx.assign_word_speakers(diarize_segments, result_a)
# Extract (speaker_id, text, start, end) tuples from the final result
output = []
for segment in result_final.get("segments", []):
speaker = segment.get("speaker", "Unknown")
text = segment.get("text", "").strip()
start = segment.get("start", 0.0)
end = segment.get("end", 0.0)
if text:
output.append((speaker, text, start, end))
return output
except Exception as e: except Exception as e:
logger.error(f"Transcription error: {e}") logger.error(f"Transcription/Diarization error: {e}")
return "" return []
def close(self): def close(self):
""" """
Explicitly release model resources if necessary. Explicitly release model resources if necessary.
""" """
# faster-whisper's WhisperModel doesn't have a standard close(),
# but we'll provide this for consistency.
pass pass