From f4c98fb2b9e959cb06aeb864c8737935836275aa Mon Sep 17 00:00:00 2001 From: charles Date: Tue, 26 May 2026 21:48:30 -0700 Subject: [PATCH] Migrate to WhisperX for speaker diarization Implement a sliding window audio buffer and update the transcriber to use WhisperX for transcription, alignment, and speaker identification. Update the pipeline to handle and store speaker-attributed transcripts. Additionally, update the LLM processor's reasoning parameter to "enable_thinking". --- .python-version | 1 + main.py | 6 +++ pyproject.toml | 7 +++ requirements.txt | 2 +- src/llm/processor.py | 4 +- src/pipeline/orchestrator.py | 67 ++++++++++++++++++++++------ src/stt/transcriber.py | 86 +++++++++++++++++++++++++++--------- 7 files changed, 135 insertions(+), 38 deletions(-) create mode 100644 .python-version create mode 100644 main.py create mode 100644 pyproject.toml diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/main.py b/main.py new file mode 100644 index 0000000..dc3638d --- /dev/null +++ b/main.py @@ -0,0 +1,6 @@ +def main(): + print("Hello from dnd-helpers!") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..910a652 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[project] +name = "dnd-helpers" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [] diff --git a/requirements.txt b/requirements.txt index ab5a4a2..e1db5a1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # Core dependencies for D&D Helpers -faster-whisper +whisperx sounddevice pydantic textual diff --git a/src/llm/processor.py b/src/llm/processor.py index d973614..99fc04c 100644 --- a/src/llm/processor.py +++ b/src/llm/processor.py @@ -83,7 +83,7 @@ class LLMProcessor: model=self.model, messages=messages, response_format=response_format, - extra_body={"include_reasoning": False}, + extra_body={"enable_thinking": False}, ) return response.choices[0].message.content except Exception as e: @@ -125,7 +125,7 @@ class LLMProcessor: model=self.model, messages=messages, response_format={"type": "json_object"}, - extra_body={"include_reasoning": False}, + extra_body={"enable_thinking": False}, ) logger.info("LLM Processor (Extract): Response received from backend.") diff --git a/src/pipeline/orchestrator.py b/src/pipeline/orchestrator.py index b9ddd09..806b3e7 100644 --- a/src/pipeline/orchestrator.py +++ b/src/pipeline/orchestrator.py @@ -4,6 +4,8 @@ import os from pathlib import Path from typing import List, Optional +import numpy as np + from src.llm.models import ExtractionResult from src.llm.processor import LLMProcessor from src.stt.listener import AudioListener @@ -18,6 +20,12 @@ logging.basicConfig( logging.FileHandler("pipeline.log"), ], ) +# Suppress verbose logging from STT libraries to keep the TUI clean +logging.getLogger("whisper").setLevel(logging.WARNING) +logging.getLogger("faster_whisper").setLevel(logging.WARNING) +logging.getLogger("pyannote").setLevel(logging.WARNING) +logging.getLogger("whisperx").setLevel(logging.WARNING) + logger = logging.getLogger(__name__) @@ -40,6 +48,13 @@ class PipelineOrchestrator: self.history = [] # List of strings (transcripts) self.history_max_words = 1000 + # STT Sliding Window Buffer + self.audio_buffer = [] # List of audio chunks + self.buffer_max_seconds = 30 + self.sample_rate = 16000 + self.buffer_max_samples = self.buffer_max_seconds * self.sample_rate + self.last_processed_end_time = 0.0 + async def stt_worker(self): """ Worker that handles STT: Audio -> Text. @@ -50,12 +65,35 @@ class PipelineOrchestrator: # Get audio chunk from listener audio_chunk = await self.listener.get_chunk() - # Transcribe - text = self.transcriber.transcribe(audio_chunk) + # Maintain sliding window buffer + self.audio_buffer.append(audio_chunk) + current_buffer_samples = sum(len(c) for c in self.audio_buffer) - if text: - logger.info(f"Transcribed: {text}") - await self.transcript_queue.put(text) + if current_buffer_samples > self.buffer_max_samples: + # Remove oldest chunks until we are within the buffer limit + while ( + sum(len(c) for c in self.audio_buffer) > self.buffer_max_samples + ): + self.audio_buffer.pop(0) + + # Concatenate buffer for transcription + full_audio = np.concatenate(self.audio_buffer) + + # Transcribe (WhisperX now returns a list of (speaker, text, start, end)) + results = self.transcriber.transcribe(full_audio) + + # Filter for only new segments that start after the last processed segment + new_segments = [ + res for res in results if res[2] >= self.last_processed_end_time + ] + + if new_segments: + for speaker, text, start, end in new_segments: + logger.info(f"Transcribed: [{speaker}] {text}") + await self.transcript_queue.put((speaker, text)) + self.last_processed_end_time = max( + self.last_processed_end_time, end + ) except Exception as e: logger.error(f"STT Worker error: {e}") @@ -70,14 +108,16 @@ class PipelineOrchestrator: logger.info("LLM Worker started.") while self.is_running: try: - # Get raw text from transcript queue - raw_text = await self.transcript_queue.get() + # Get raw text from transcript queue (now a tuple of (speaker, text)) + speaker, raw_text = await self.transcript_queue.get() - logger.info(f"LLM Worker: Processing text: {raw_text}") + logger.info(f"LLM Worker: Processing text from {speaker}: {raw_text}") # 1. Prepare Context (Conversation History) - # Maintain history and truncate to max words - self.history.append(raw_text) + # Store as "Speaker X: [text]" + entry = f"{speaker}: {raw_text}" + self.history.append(entry) + full_history_text = " ".join(self.history) words = full_history_text.split() if len(words) > self.history_max_words: @@ -119,7 +159,7 @@ class PipelineOrchestrator: def _get_wiki_context(self) -> str: """ - Reads all files in the lore directory and returns them as a single context string. + Reads all files in the lore directory and returns them as a 저희 context string. """ from src.persistence.lore import DATA_LORE_DIR @@ -151,11 +191,12 @@ class PipelineOrchestrator: # Pass the proposal queue to the app. app = ConfirmationApp(proposal_queue=self.proposal_queue) await app.run_async() - # Once the TUI exits, stop the entire pipeline self.stop() except Exception as e: logger.error(f"TUI Worker error: {e}") self.stop() + except asyncio.CancelledError: + pass async def run(self): """ @@ -188,6 +229,6 @@ class PipelineOrchestrator: def stop(self): """ - Stops the pipeline. + Stops. """ self.is_running = False diff --git a/src/stt/transcriber.py b/src/stt/transcriber.py index 18a276b..bfabd95 100644 --- a/src/stt/transcriber.py +++ b/src/stt/transcriber.py @@ -1,6 +1,9 @@ import logging +import os -from faster_whisper import WhisperModel +import numpy as np +import whisperx +from whisperx.diarize import DiarizationPipeline # Do not call basicConfig here, as it's called in the orchestrator logger = logging.getLogger(__name__) @@ -8,62 +11,101 @@ logger = logging.getLogger(__name__) class Transcriber: """ - Converts audio chunks (numpy arrays) into text using faster-whisper. + Converts audio chunks (numpy arrays) into text and identifies speakers using WhisperX. """ - def __init__(self, model_size="base", device="cpu", compute_type="int8"): + def __init__( + self, model_size="base", device="cpu", compute_type="int8", language="en" + ): """ - Initializes the faster-whisper model. + Initializes the WhisperX model and diarization pipeline. Args: model_size (str): The size of the model to use (e.g., "tiny", "base", "small"). device (str): The device to run the model on ("cpu" or "cuda"). compute_type (str): The compute type to use (e.g., "int8", "float16"). + language (str): The language code for alignment (e.g., "en"). """ + self.device = device + self.compute_type = compute_type + self.language = language + logger.info( - f"Loading faster-whisper model: {model_size} on {device} ({compute_type})..." + f"Loading WhisperX model: {model_size} on {device} ({compute_type})..." ) try: - self.model = WhisperModel( + # Load transcription model + self.model = whisperx.load_model( model_size, device=device, compute_type=compute_type ) - logger.info("Model loaded successfully.") + + # Load alignment model (required for accurate speaker assignment) + # model_dir=None allows automatic model selection based on the language + self.align_model, self.align_metadata = whisperx.load_align_model( + device=device, model_dir=None, language_code=self.language + ) + + self.diarize_model = DiarizationPipeline() + + logger.info("WhisperX and Diarization models loaded successfully.") except Exception as e: - logger.error(f"Failed to load faster-whisper model: {e}") + logger.error(f"Failed to load WhisperX models: {e}") raise def transcribe(self, audio_chunk): """ - Transcribes a single audio chunk. + Transcribes an audio chunk and performs speaker diarization. Args: audio_chunk (np.ndarray): The audio data as a numpy array. Returns: - str: The transcribed text. + list: A list of tuples (speaker_id, text). """ if audio_chunk is None: - return "" + return [] try: - # faster-whisper expects audio in float32 and 1D array - audio_data = audio_chunk.astype("float32").flatten() + # WhisperX expects audio in float32 and 1D array + audio = audio_chunk.astype("float32").flatten() - # Transcribe the audio - segments, info = self.model.transcribe(audio_data, beam_size=5) + # 1. Perform transcription + # batch_size is set to 16 for efficiency; can be adjusted based on VRAM + result = self.model.transcribe(audio, batch_size=16) - # Combine segments into a single string - text = " ".join([segment.text.strip() for segment in segments]) + # 2. Perform alignment + # Alignment is necessary for the assign_words_to_speakers step + result_a = whisperx.align( + result["segments"], + self.align_model, + self.align_metadata, + audio, + self.device, + ) - return text.strip() + # 3. Perform diarization + diarize_segments = self.diarize_model(audio) + + # 4. Align transcription segments with speakers + result_final = whisperx.assign_word_speakers(diarize_segments, result_a) + + # Extract (speaker_id, text, start, end) tuples from the final result + output = [] + for segment in result_final.get("segments", []): + speaker = segment.get("speaker", "Unknown") + text = segment.get("text", "").strip() + start = segment.get("start", 0.0) + end = segment.get("end", 0.0) + if text: + output.append((speaker, text, start, end)) + + return output except Exception as e: - logger.error(f"Transcription error: {e}") - return "" + logger.error(f"Transcription/Diarization error: {e}") + return [] def close(self): """ Explicitly release model resources if necessary. """ - # faster-whisper's WhisperModel doesn't have a standard close(), - # but we'll provide this for consistency. pass