Migrate to WhisperX for speaker diarization
Implement a sliding window audio buffer and update the transcriber to use WhisperX for transcription, alignment, and speaker identification. Update the pipeline to handle and store speaker-attributed transcripts. Additionally, update the LLM processor's reasoning parameter to "enable_thinking".
This commit is contained in:
@@ -0,0 +1 @@
|
|||||||
|
3.12
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
def main():
|
||||||
|
print("Hello from dnd-helpers!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
[project]
|
||||||
|
name = "dnd-helpers"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = []
|
||||||
+1
-1
@@ -1,5 +1,5 @@
|
|||||||
# Core dependencies for D&D Helpers
|
# Core dependencies for D&D Helpers
|
||||||
faster-whisper
|
whisperx
|
||||||
sounddevice
|
sounddevice
|
||||||
pydantic
|
pydantic
|
||||||
textual
|
textual
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class LLMProcessor:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
extra_body={"include_reasoning": False},
|
extra_body={"enable_thinking": False},
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -125,7 +125,7 @@ class LLMProcessor:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_format={"type": "json_object"},
|
response_format={"type": "json_object"},
|
||||||
extra_body={"include_reasoning": False},
|
extra_body={"enable_thinking": False},
|
||||||
)
|
)
|
||||||
logger.info("LLM Processor (Extract): Response received from backend.")
|
logger.info("LLM Processor (Extract): Response received from backend.")
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from src.llm.models import ExtractionResult
|
from src.llm.models import ExtractionResult
|
||||||
from src.llm.processor import LLMProcessor
|
from src.llm.processor import LLMProcessor
|
||||||
from src.stt.listener import AudioListener
|
from src.stt.listener import AudioListener
|
||||||
@@ -18,6 +20,12 @@ logging.basicConfig(
|
|||||||
logging.FileHandler("pipeline.log"),
|
logging.FileHandler("pipeline.log"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
# Suppress verbose logging from STT libraries to keep the TUI clean
|
||||||
|
logging.getLogger("whisper").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("faster_whisper").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("pyannote").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("whisperx").setLevel(logging.WARNING)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -40,6 +48,13 @@ class PipelineOrchestrator:
|
|||||||
self.history = [] # List of strings (transcripts)
|
self.history = [] # List of strings (transcripts)
|
||||||
self.history_max_words = 1000
|
self.history_max_words = 1000
|
||||||
|
|
||||||
|
# STT Sliding Window Buffer
|
||||||
|
self.audio_buffer = [] # List of audio chunks
|
||||||
|
self.buffer_max_seconds = 30
|
||||||
|
self.sample_rate = 16000
|
||||||
|
self.buffer_max_samples = self.buffer_max_seconds * self.sample_rate
|
||||||
|
self.last_processed_end_time = 0.0
|
||||||
|
|
||||||
async def stt_worker(self):
|
async def stt_worker(self):
|
||||||
"""
|
"""
|
||||||
Worker that handles STT: Audio -> Text.
|
Worker that handles STT: Audio -> Text.
|
||||||
@@ -50,12 +65,35 @@ class PipelineOrchestrator:
|
|||||||
# Get audio chunk from listener
|
# Get audio chunk from listener
|
||||||
audio_chunk = await self.listener.get_chunk()
|
audio_chunk = await self.listener.get_chunk()
|
||||||
|
|
||||||
# Transcribe
|
# Maintain sliding window buffer
|
||||||
text = self.transcriber.transcribe(audio_chunk)
|
self.audio_buffer.append(audio_chunk)
|
||||||
|
current_buffer_samples = sum(len(c) for c in self.audio_buffer)
|
||||||
|
|
||||||
if text:
|
if current_buffer_samples > self.buffer_max_samples:
|
||||||
logger.info(f"Transcribed: {text}")
|
# Remove oldest chunks until we are within the buffer limit
|
||||||
await self.transcript_queue.put(text)
|
while (
|
||||||
|
sum(len(c) for c in self.audio_buffer) > self.buffer_max_samples
|
||||||
|
):
|
||||||
|
self.audio_buffer.pop(0)
|
||||||
|
|
||||||
|
# Concatenate buffer for transcription
|
||||||
|
full_audio = np.concatenate(self.audio_buffer)
|
||||||
|
|
||||||
|
# Transcribe (WhisperX now returns a list of (speaker, text, start, end))
|
||||||
|
results = self.transcriber.transcribe(full_audio)
|
||||||
|
|
||||||
|
# Filter for only new segments that start after the last processed segment
|
||||||
|
new_segments = [
|
||||||
|
res for res in results if res[2] >= self.last_processed_end_time
|
||||||
|
]
|
||||||
|
|
||||||
|
if new_segments:
|
||||||
|
for speaker, text, start, end in new_segments:
|
||||||
|
logger.info(f"Transcribed: [{speaker}] {text}")
|
||||||
|
await self.transcript_queue.put((speaker, text))
|
||||||
|
self.last_processed_end_time = max(
|
||||||
|
self.last_processed_end_time, end
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"STT Worker error: {e}")
|
logger.error(f"STT Worker error: {e}")
|
||||||
@@ -70,14 +108,16 @@ class PipelineOrchestrator:
|
|||||||
logger.info("LLM Worker started.")
|
logger.info("LLM Worker started.")
|
||||||
while self.is_running:
|
while self.is_running:
|
||||||
try:
|
try:
|
||||||
# Get raw text from transcript queue
|
# Get raw text from transcript queue (now a tuple of (speaker, text))
|
||||||
raw_text = await self.transcript_queue.get()
|
speaker, raw_text = await self.transcript_queue.get()
|
||||||
|
|
||||||
logger.info(f"LLM Worker: Processing text: {raw_text}")
|
logger.info(f"LLM Worker: Processing text from {speaker}: {raw_text}")
|
||||||
|
|
||||||
# 1. Prepare Context (Conversation History)
|
# 1. Prepare Context (Conversation History)
|
||||||
# Maintain history and truncate to max words
|
# Store as "Speaker X: [text]"
|
||||||
self.history.append(raw_text)
|
entry = f"{speaker}: {raw_text}"
|
||||||
|
self.history.append(entry)
|
||||||
|
|
||||||
full_history_text = " ".join(self.history)
|
full_history_text = " ".join(self.history)
|
||||||
words = full_history_text.split()
|
words = full_history_text.split()
|
||||||
if len(words) > self.history_max_words:
|
if len(words) > self.history_max_words:
|
||||||
@@ -119,7 +159,7 @@ class PipelineOrchestrator:
|
|||||||
|
|
||||||
def _get_wiki_context(self) -> str:
|
def _get_wiki_context(self) -> str:
|
||||||
"""
|
"""
|
||||||
Reads all files in the lore directory and returns them as a single context string.
|
Reads all files in the lore directory and returns them as a 저희 context string.
|
||||||
"""
|
"""
|
||||||
from src.persistence.lore import DATA_LORE_DIR
|
from src.persistence.lore import DATA_LORE_DIR
|
||||||
|
|
||||||
@@ -151,11 +191,12 @@ class PipelineOrchestrator:
|
|||||||
# Pass the proposal queue to the app.
|
# Pass the proposal queue to the app.
|
||||||
app = ConfirmationApp(proposal_queue=self.proposal_queue)
|
app = ConfirmationApp(proposal_queue=self.proposal_queue)
|
||||||
await app.run_async()
|
await app.run_async()
|
||||||
# Once the TUI exits, stop the entire pipeline
|
|
||||||
self.stop()
|
self.stop()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"TUI Worker error: {e}")
|
logger.error(f"TUI Worker error: {e}")
|
||||||
self.stop()
|
self.stop()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""
|
"""
|
||||||
@@ -188,6 +229,6 @@ class PipelineOrchestrator:
|
|||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""
|
"""
|
||||||
Stops the pipeline.
|
Stops.
|
||||||
"""
|
"""
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
|||||||
+64
-22
@@ -1,6 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
from faster_whisper import WhisperModel
|
import numpy as np
|
||||||
|
import whisperx
|
||||||
|
from whisperx.diarize import DiarizationPipeline
|
||||||
|
|
||||||
# Do not call basicConfig here, as it's called in the orchestrator
|
# Do not call basicConfig here, as it's called in the orchestrator
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -8,62 +11,101 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class Transcriber:
|
class Transcriber:
|
||||||
"""
|
"""
|
||||||
Converts audio chunks (numpy arrays) into text using faster-whisper.
|
Converts audio chunks (numpy arrays) into text and identifies speakers using WhisperX.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_size="base", device="cpu", compute_type="int8"):
|
def __init__(
|
||||||
|
self, model_size="base", device="cpu", compute_type="int8", language="en"
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the faster-whisper model.
|
Initializes the WhisperX model and diarization pipeline.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_size (str): The size of the model to use (e.g., "tiny", "base", "small").
|
model_size (str): The size of the model to use (e.g., "tiny", "base", "small").
|
||||||
device (str): The device to run the model on ("cpu" or "cuda").
|
device (str): The device to run the model on ("cpu" or "cuda").
|
||||||
compute_type (str): The compute type to use (e.g., "int8", "float16").
|
compute_type (str): The compute type to use (e.g., "int8", "float16").
|
||||||
|
language (str): The language code for alignment (e.g., "en").
|
||||||
"""
|
"""
|
||||||
|
self.device = device
|
||||||
|
self.compute_type = compute_type
|
||||||
|
self.language = language
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loading faster-whisper model: {model_size} on {device} ({compute_type})..."
|
f"Loading WhisperX model: {model_size} on {device} ({compute_type})..."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self.model = WhisperModel(
|
# Load transcription model
|
||||||
|
self.model = whisperx.load_model(
|
||||||
model_size, device=device, compute_type=compute_type
|
model_size, device=device, compute_type=compute_type
|
||||||
)
|
)
|
||||||
logger.info("Model loaded successfully.")
|
|
||||||
|
# Load alignment model (required for accurate speaker assignment)
|
||||||
|
# model_dir=None allows automatic model selection based on the language
|
||||||
|
self.align_model, self.align_metadata = whisperx.load_align_model(
|
||||||
|
device=device, model_dir=None, language_code=self.language
|
||||||
|
)
|
||||||
|
|
||||||
|
self.diarize_model = DiarizationPipeline()
|
||||||
|
|
||||||
|
logger.info("WhisperX and Diarization models loaded successfully.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load faster-whisper model: {e}")
|
logger.error(f"Failed to load WhisperX models: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def transcribe(self, audio_chunk):
|
def transcribe(self, audio_chunk):
|
||||||
"""
|
"""
|
||||||
Transcribes a single audio chunk.
|
Transcribes an audio chunk and performs speaker diarization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio_chunk (np.ndarray): The audio data as a numpy array.
|
audio_chunk (np.ndarray): The audio data as a numpy array.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The transcribed text.
|
list: A list of tuples (speaker_id, text).
|
||||||
"""
|
"""
|
||||||
if audio_chunk is None:
|
if audio_chunk is None:
|
||||||
return ""
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# faster-whisper expects audio in float32 and 1D array
|
# WhisperX expects audio in float32 and 1D array
|
||||||
audio_data = audio_chunk.astype("float32").flatten()
|
audio = audio_chunk.astype("float32").flatten()
|
||||||
|
|
||||||
# Transcribe the audio
|
# 1. Perform transcription
|
||||||
segments, info = self.model.transcribe(audio_data, beam_size=5)
|
# batch_size is set to 16 for efficiency; can be adjusted based on VRAM
|
||||||
|
result = self.model.transcribe(audio, batch_size=16)
|
||||||
|
|
||||||
# Combine segments into a single string
|
# 2. Perform alignment
|
||||||
text = " ".join([segment.text.strip() for segment in segments])
|
# Alignment is necessary for the assign_words_to_speakers step
|
||||||
|
result_a = whisperx.align(
|
||||||
|
result["segments"],
|
||||||
|
self.align_model,
|
||||||
|
self.align_metadata,
|
||||||
|
audio,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
|
||||||
return text.strip()
|
# 3. Perform diarization
|
||||||
|
diarize_segments = self.diarize_model(audio)
|
||||||
|
|
||||||
|
# 4. Align transcription segments with speakers
|
||||||
|
result_final = whisperx.assign_word_speakers(diarize_segments, result_a)
|
||||||
|
|
||||||
|
# Extract (speaker_id, text, start, end) tuples from the final result
|
||||||
|
output = []
|
||||||
|
for segment in result_final.get("segments", []):
|
||||||
|
speaker = segment.get("speaker", "Unknown")
|
||||||
|
text = segment.get("text", "").strip()
|
||||||
|
start = segment.get("start", 0.0)
|
||||||
|
end = segment.get("end", 0.0)
|
||||||
|
if text:
|
||||||
|
output.append((speaker, text, start, end))
|
||||||
|
|
||||||
|
return output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Transcription error: {e}")
|
logger.error(f"Transcription/Diarization error: {e}")
|
||||||
return ""
|
return []
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""
|
"""
|
||||||
Explicitly release model resources if necessary.
|
Explicitly release model resources if necessary.
|
||||||
"""
|
"""
|
||||||
# faster-whisper's WhisperModel doesn't have a standard close(),
|
|
||||||
# but we'll provide this for consistency.
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user