Stable state
This commit is contained in:
+116
-132
@@ -6,8 +6,16 @@ from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.llm.models import ContextUpdate, ExtractionResult
|
||||
from src.llm.models import (
|
||||
CharacterStateUpdate,
|
||||
ContextUpdate,
|
||||
ExtractionResult,
|
||||
LoreUpdate,
|
||||
)
|
||||
from src.llm.processor import LLMProcessor
|
||||
from src.llm.prompts import EXTRACTION_SYSTEM_PROMPT, NOISE_FILTER_SYSTEM_PROMPT
|
||||
from src.persistence.characters import update_character_state
|
||||
from src.persistence.lore import update_lore
|
||||
from src.rag.manager import RAGManager
|
||||
from src.stt.listener import AudioListener
|
||||
from src.stt.transcriber import Transcriber
|
||||
@@ -41,9 +49,10 @@ class PipelineOrchestrator:
|
||||
self.rag_manager = RAGManager()
|
||||
|
||||
# Queues
|
||||
self.transcript_queue = asyncio.Queue()
|
||||
self.proposal_queue = asyncio.Queue()
|
||||
self.context_queue = asyncio.Queue()
|
||||
self.stt_to_clean_queue = asyncio.Queue()
|
||||
self.ui_to_llm_queue = asyncio.Queue()
|
||||
self.clean_to_llm_queue = asyncio.Queue()
|
||||
self.llm_to_ui_queue = asyncio.Queue()
|
||||
|
||||
self.is_running = False
|
||||
|
||||
@@ -58,6 +67,20 @@ class PipelineOrchestrator:
|
||||
self.buffer_max_samples = self.buffer_max_seconds * self.sample_rate
|
||||
self.last_processed_end_time = 0.0
|
||||
|
||||
def _get_combined_context(self) -> str:
|
||||
"""
|
||||
Returns the trimmed conversation history as a context string.
|
||||
"""
|
||||
full_history_text = " ".join(self.history)
|
||||
words = full_history_text.split()
|
||||
if len(words) > self.history_max_words:
|
||||
kept_words = words[-self.history_max_words :]
|
||||
context_text = " ".join(kept_words)
|
||||
else:
|
||||
context_text = full_history_text
|
||||
|
||||
return f"Conversation History:\n{context_text}\n\n"
|
||||
|
||||
async def stt_worker(self):
|
||||
"""
|
||||
Worker that handles STT: Audio -> Text.
|
||||
@@ -93,7 +116,7 @@ class PipelineOrchestrator:
|
||||
if new_segments:
|
||||
for speaker, text, start, end in new_segments:
|
||||
logger.info(f"Transcribed: [{speaker}] {text}")
|
||||
await self.transcript_queue.put((speaker, text))
|
||||
await self.stt_to_clean_queue.put((speaker, text))
|
||||
self.last_processed_end_time = max(
|
||||
self.last_processed_end_time, end
|
||||
)
|
||||
@@ -104,104 +127,104 @@ class PipelineOrchestrator:
|
||||
# Small sleep to prevent tight loop if get_chunk is fast
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def llm_worker(self):
|
||||
async def clean_worker(self):
|
||||
"""
|
||||
Worker that handles LLM: Text -> Proposal.
|
||||
Worker that handles Text Cleaning: Raw STT -> Filtered Text.
|
||||
"""
|
||||
logger.info("LLM Worker started.")
|
||||
logger.info("Clean Worker started.")
|
||||
while self.is_running:
|
||||
try:
|
||||
# Get raw text from transcript queue (now a tuple of (speaker, text))
|
||||
speaker, raw_text = await self.transcript_queue.get()
|
||||
# Get raw transcript from STT
|
||||
speaker, raw_text = await self.stt_to_clean_queue.get()
|
||||
logger.info(f"Clean Worker: Filtering text from {speaker}: {raw_text}")
|
||||
|
||||
logger.info(f"LLM Worker: Processing text from {speaker}: {raw_text}")
|
||||
# RAG Retrieval for context
|
||||
context = await asyncio.to_thread(self.rag_manager.retrieve, raw_text)
|
||||
|
||||
# 1. Prepare Context (Conversation History)
|
||||
# Store as "Speaker X: [text]"
|
||||
entry = f"{speaker}: {raw_text}"
|
||||
self.history.append(entry)
|
||||
|
||||
full_history_text = " ".join(self.history)
|
||||
words = full_history_text.split()
|
||||
if len(words) > self.history_max_words:
|
||||
# Keep the last N words
|
||||
kept_words = words[-self.history_max_words :]
|
||||
context_text = " ".join(kept_words)
|
||||
else:
|
||||
context_text = full_history_text
|
||||
|
||||
# 2. Prepare Context (Wiki / Database of Knowledge)
|
||||
# wiki_context = self._get_wiki_context()
|
||||
|
||||
# Combine both
|
||||
combined_context = f"Conversation History:\n{context_text}\n\n"
|
||||
|
||||
# --- New RAG Flow ---
|
||||
# a. Filter transcript first to get cleaned text
|
||||
# Filtering using the processor
|
||||
filter_result = await asyncio.to_thread(
|
||||
self.processor.filter_transcript, raw_text, context=combined_context
|
||||
self.processor.filter_transcript,
|
||||
raw_text,
|
||||
context=context,
|
||||
)
|
||||
|
||||
# b. Use filtered text to retrieve relevant snippets from RAG
|
||||
rag_snippets = []
|
||||
# Push filtered text to LLM queue
|
||||
if filter_result.filtered_text:
|
||||
try:
|
||||
snippets = await asyncio.to_thread(
|
||||
self.rag_manager.retrieve,
|
||||
filter_result.filtered_text,
|
||||
summarize=True,
|
||||
)
|
||||
rag_snippets = snippets
|
||||
except Exception as e:
|
||||
logger.error(f"RAG Retrieval Error in llm_worker: {e}")
|
||||
|
||||
# c. Combine RAG snippets with existing combined_context
|
||||
logger.info(f"LLM Processor (Extract): rag_snippets: {rag_snippets}")
|
||||
rag_context_text = "\n".join([s.snippet for s in rag_snippets])
|
||||
augmented_context = combined_context
|
||||
if rag_context_text:
|
||||
augmented_context += (
|
||||
f"\n\nRelevant RAG Context:\n{rag_context_text}"
|
||||
await self.clean_to_llm_queue.put(
|
||||
(speaker, filter_result.filtered_text)
|
||||
)
|
||||
logger.info(f"Clean Worker: Pushed filtered text to LLM queue.")
|
||||
else:
|
||||
logger.info("Clean Worker: No filtered text to push.")
|
||||
|
||||
# d. Extract structured data using the augmented context
|
||||
except Exception as e:
|
||||
logger.error(f"Clean Worker error: {e}")
|
||||
|
||||
# Small sleep to prevent tight loop
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def llm_worker(self):
|
||||
"""
|
||||
Worker that handles LLM: Filtered Text/UI Input -> Structured Data & UI Updates.
|
||||
"""
|
||||
logger.info("LLM Worker started.")
|
||||
|
||||
# Internal queue to serialize processing from multiple sources
|
||||
internal_queue = asyncio.Queue()
|
||||
|
||||
async def feed_clean():
|
||||
while self.is_running:
|
||||
try:
|
||||
item = await self.clean_to_llm_queue.get()
|
||||
await internal_queue.put(item)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM Feeder (Clean) error: {e}")
|
||||
|
||||
async def feed_ui():
|
||||
while self.is_running:
|
||||
try:
|
||||
text = await self.ui_to_llm_queue.get()
|
||||
await internal_queue.put(("UI", text))
|
||||
except Exception as e:
|
||||
logger.error(f"LLM Feeder (UI) error: {e}")
|
||||
|
||||
# Start feeder tasks
|
||||
feeders = [
|
||||
asyncio.create_task(feed_clean()),
|
||||
asyncio.create_task(feed_ui()),
|
||||
]
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
speaker, text = await internal_queue.get()
|
||||
logger.info(f"LLM Worker: Processing text from {speaker}: {text}")
|
||||
|
||||
# RAG Retrieval for context
|
||||
context = await asyncio.to_thread(self.rag_manager.retrieve, text)
|
||||
|
||||
# Structured extraction using the processor
|
||||
extraction_result = await asyncio.to_thread(
|
||||
self.processor.extract_structured_data,
|
||||
filter_result.filtered_text if filter_result.filtered_text else "",
|
||||
context=augmented_context,
|
||||
text,
|
||||
context=context,
|
||||
)
|
||||
|
||||
if (
|
||||
extraction_result.lore_updates
|
||||
or extraction_result.character_updates
|
||||
or extraction_result.significant_events
|
||||
):
|
||||
# Persistence: Lore Updates
|
||||
for lore_update in extraction_result.lore_updates:
|
||||
await asyncio.to_thread(update_lore, lore_update)
|
||||
logger.info(f"LLM Worker: Lore updated: {lore_update.topic}")
|
||||
|
||||
# Persistence: Character State Updates
|
||||
for char_update in extraction_result.character_updates:
|
||||
await asyncio.to_thread(update_character_state, char_update)
|
||||
logger.info(
|
||||
f"LLM Worker: Proposal generated. Putting into proposal queue. (Lore: {len(extraction_result.lore_updates)}, Char: {len(extraction_result.character_updates)})"
|
||||
)
|
||||
await self.proposal_queue.put(extraction_result)
|
||||
|
||||
# Trigger RAG query based on extracted entities (for TUI updates)
|
||||
await self._trigger_rag_queries(extraction_result)
|
||||
else:
|
||||
logger.info("LLM Worker: No relevant game data extracted.")
|
||||
|
||||
# e. If the filter found contextual info, push it to the context queue
|
||||
if filter_result.contextual_info:
|
||||
logger.info(
|
||||
f"LLM Worker: Contextual info found: {filter_result.contextual_info}"
|
||||
)
|
||||
await self.context_queue.put(
|
||||
ContextUpdate(
|
||||
query="Filter",
|
||||
snippet=filter_result.contextual_info,
|
||||
source="Transcript",
|
||||
)
|
||||
f"LLM Worker: Character {char_update.character_name} state updated."
|
||||
)
|
||||
|
||||
# f. Push the distilled RAG snippets from extraction to the context queue
|
||||
for snippet in extraction_result.context_updates:
|
||||
await self.context_queue.put(snippet)
|
||||
# UI Notification: Context Updates
|
||||
for context_update in extraction_result.context_updates:
|
||||
await self.llm_to_ui_queue.put(context_update)
|
||||
logger.info(f"LLM Worker: Pushed context update to UI.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM Worker error: {e}")
|
||||
@@ -209,44 +232,9 @@ class PipelineOrchestrator:
|
||||
# Small sleep
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _trigger_rag_queries(self, result: ExtractionResult):
|
||||
"""
|
||||
Triggers RAG queries based on the extracted entities and results.
|
||||
"""
|
||||
queries = set()
|
||||
|
||||
# Collect entities from lore updates
|
||||
for update in result.lore_updates:
|
||||
if update.entity_name:
|
||||
queries.add(update.entity_name)
|
||||
|
||||
# Collect entities from character updates
|
||||
for update in result.character_updates:
|
||||
if update.character_name:
|
||||
queries.add(update.character_name)
|
||||
|
||||
# Collect events as potential queries
|
||||
for event in result.significant_events:
|
||||
queries.add(event)
|
||||
|
||||
if not queries:
|
||||
logger.info("RAG: No query terms identified from extraction result.")
|
||||
return
|
||||
|
||||
for query in queries:
|
||||
logger.info(f"RAG: Triggering query for: {query}")
|
||||
try:
|
||||
# Run retrieval in a thread to avoid blocking the event loop
|
||||
updates = await asyncio.to_thread(
|
||||
self.rag_manager.retrieve, query, summarize=True
|
||||
)
|
||||
for update in updates:
|
||||
await self.context_queue.put(update)
|
||||
logger.info(
|
||||
f"RAG: Retrieved snippet for {query} from {update.source}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"RAG: Error retrieving context for {query}: {e}")
|
||||
# Clean up feeders
|
||||
for f in feeders:
|
||||
f.cancel()
|
||||
|
||||
def _get_wiki_context(self) -> str:
|
||||
"""
|
||||
@@ -274,15 +262,15 @@ class PipelineOrchestrator:
|
||||
|
||||
async def tui_worker(self):
|
||||
"""
|
||||
Worker that handles TUI: Proposal -> Persistence.
|
||||
Worker that handles TUI: UI interactions.
|
||||
"""
|
||||
logger.info("TUI Worker started.")
|
||||
try:
|
||||
# Launch TUI exactly once.
|
||||
# Pass the proposal queue and context queue to the app.
|
||||
# Launch TUI.
|
||||
# Use the new queues for the TUI.
|
||||
app = ConfirmationApp(
|
||||
proposal_queue=self.proposal_queue,
|
||||
context_queue=self.context_queue,
|
||||
ui_to_llm_queue=self.ui_to_llm_queue,
|
||||
llm_to_ui_queue=self.llm_to_ui_queue,
|
||||
)
|
||||
await app.run_async()
|
||||
self.stop()
|
||||
@@ -308,12 +296,8 @@ class PipelineOrchestrator:
|
||||
# Start workers as background tasks
|
||||
tasks = [
|
||||
asyncio.create_task(self.stt_worker()),
|
||||
asyncio.create_task(self.clean_worker()),
|
||||
asyncio.create_task(self.llm_worker()),
|
||||
asyncio.create_task(
|
||||
self.context_pipeline.run(
|
||||
self.transcript_queue, self.context_queue, stop_event
|
||||
)
|
||||
),
|
||||
asyncio.create_task(self.tui_worker()),
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user