Migrate to WhisperX for speaker diarization
Implement a sliding window audio buffer and update the transcriber to use WhisperX for transcription, alignment, and speaker identification. Update the pipeline to handle and store speaker-attributed transcripts. Additionally, update the LLM processor's reasoning parameter to "enable_thinking".
This commit is contained in:
@@ -4,6 +4,8 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.llm.models import ExtractionResult
|
||||
from src.llm.processor import LLMProcessor
|
||||
from src.stt.listener import AudioListener
|
||||
@@ -18,6 +20,12 @@ logging.basicConfig(
|
||||
logging.FileHandler("pipeline.log"),
|
||||
],
|
||||
)
|
||||
# Suppress verbose logging from STT libraries to keep the TUI clean
|
||||
logging.getLogger("whisper").setLevel(logging.WARNING)
|
||||
logging.getLogger("faster_whisper").setLevel(logging.WARNING)
|
||||
logging.getLogger("pyannote").setLevel(logging.WARNING)
|
||||
logging.getLogger("whisperx").setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -40,6 +48,13 @@ class PipelineOrchestrator:
|
||||
self.history = [] # List of strings (transcripts)
|
||||
self.history_max_words = 1000
|
||||
|
||||
# STT Sliding Window Buffer
|
||||
self.audio_buffer = [] # List of audio chunks
|
||||
self.buffer_max_seconds = 30
|
||||
self.sample_rate = 16000
|
||||
self.buffer_max_samples = self.buffer_max_seconds * self.sample_rate
|
||||
self.last_processed_end_time = 0.0
|
||||
|
||||
async def stt_worker(self):
|
||||
"""
|
||||
Worker that handles STT: Audio -> Text.
|
||||
@@ -50,12 +65,35 @@ class PipelineOrchestrator:
|
||||
# Get audio chunk from listener
|
||||
audio_chunk = await self.listener.get_chunk()
|
||||
|
||||
# Transcribe
|
||||
text = self.transcriber.transcribe(audio_chunk)
|
||||
# Maintain sliding window buffer
|
||||
self.audio_buffer.append(audio_chunk)
|
||||
current_buffer_samples = sum(len(c) for c in self.audio_buffer)
|
||||
|
||||
if text:
|
||||
logger.info(f"Transcribed: {text}")
|
||||
await self.transcript_queue.put(text)
|
||||
if current_buffer_samples > self.buffer_max_samples:
|
||||
# Remove oldest chunks until we are within the buffer limit
|
||||
while (
|
||||
sum(len(c) for c in self.audio_buffer) > self.buffer_max_samples
|
||||
):
|
||||
self.audio_buffer.pop(0)
|
||||
|
||||
# Concatenate buffer for transcription
|
||||
full_audio = np.concatenate(self.audio_buffer)
|
||||
|
||||
# Transcribe (WhisperX now returns a list of (speaker, text, start, end))
|
||||
results = self.transcriber.transcribe(full_audio)
|
||||
|
||||
# Filter for only new segments that start after the last processed segment
|
||||
new_segments = [
|
||||
res for res in results if res[2] >= self.last_processed_end_time
|
||||
]
|
||||
|
||||
if new_segments:
|
||||
for speaker, text, start, end in new_segments:
|
||||
logger.info(f"Transcribed: [{speaker}] {text}")
|
||||
await self.transcript_queue.put((speaker, text))
|
||||
self.last_processed_end_time = max(
|
||||
self.last_processed_end_time, end
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"STT Worker error: {e}")
|
||||
@@ -70,14 +108,16 @@ class PipelineOrchestrator:
|
||||
logger.info("LLM Worker started.")
|
||||
while self.is_running:
|
||||
try:
|
||||
# Get raw text from transcript queue
|
||||
raw_text = await self.transcript_queue.get()
|
||||
# Get raw text from transcript queue (now a tuple of (speaker, text))
|
||||
speaker, raw_text = await self.transcript_queue.get()
|
||||
|
||||
logger.info(f"LLM Worker: Processing text: {raw_text}")
|
||||
logger.info(f"LLM Worker: Processing text from {speaker}: {raw_text}")
|
||||
|
||||
# 1. Prepare Context (Conversation History)
|
||||
# Maintain history and truncate to max words
|
||||
self.history.append(raw_text)
|
||||
# Store as "Speaker X: [text]"
|
||||
entry = f"{speaker}: {raw_text}"
|
||||
self.history.append(entry)
|
||||
|
||||
full_history_text = " ".join(self.history)
|
||||
words = full_history_text.split()
|
||||
if len(words) > self.history_max_words:
|
||||
@@ -119,7 +159,7 @@ class PipelineOrchestrator:
|
||||
|
||||
def _get_wiki_context(self) -> str:
|
||||
"""
|
||||
Reads all files in the lore directory and returns them as a single context string.
|
||||
Reads all files in the lore directory and returns them as a 저희 context string.
|
||||
"""
|
||||
from src.persistence.lore import DATA_LORE_DIR
|
||||
|
||||
@@ -151,11 +191,12 @@ class PipelineOrchestrator:
|
||||
# Pass the proposal queue to the app.
|
||||
app = ConfirmationApp(proposal_queue=self.proposal_queue)
|
||||
await app.run_async()
|
||||
# Once the TUI exits, stop the entire pipeline
|
||||
self.stop()
|
||||
except Exception as e:
|
||||
logger.error(f"TUI Worker error: {e}")
|
||||
self.stop()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
@@ -188,6 +229,6 @@ class PipelineOrchestrator:
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Stops the pipeline.
|
||||
Stops.
|
||||
"""
|
||||
self.is_running = False
|
||||
|
||||
Reference in New Issue
Block a user