diff --git a/.gitignore b/.gitignore index b824b69..4fcd622 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ artifacts/ **/__pycache__/ +data diff --git a/src/llm/processor.py b/src/llm/processor.py index 2100286..7fc7741 100644 --- a/src/llm/processor.py +++ b/src/llm/processor.py @@ -8,7 +8,11 @@ from openai import OpenAI from pydantic import ValidationError from .models import ExtractionResult, FilterResult -from .prompts import EXTRACTION_SYSTEM_PROMPT, NOISE_FILTER_SYSTEM_PROMPT +from .prompts import ( + EXTRACTION_SYSTEM_PROMPT, + NOISE_FILTER_SYSTEM_PROMPT, + QUERY_ANSWER_SYSTEM_PROMPT, +) logger = logging.getLogger(__name__) @@ -102,6 +106,16 @@ class LLMProcessor: logger.error(f"LLM Error: {e}") return "" + def generate_answer(self, query: str, context: str) -> str: + """ + Generates a natural language answer to a DM query. + """ + return self._call_llm( + QUERY_ANSWER_SYSTEM_PROMPT, + query, + context=context, + ) + def filter_transcript( self, text: str, context: Optional[str] = None ) -> FilterResult: diff --git a/src/llm/prompts.py b/src/llm/prompts.py index 012eba5..6fa08b1 100644 --- a/src/llm/prompts.py +++ b/src/llm/prompts.py @@ -1,4 +1,12 @@ -# System prompts for the LLM pipeline +QUERY_ANSWER_SYSTEM_PROMPT = """ +You are a helpful D&D Game Master's assistant. Your goal is to provide accurate, concise, and helpful answers to the DM's questions based on the provided context (conversation history and RAG snippets). + +Guidelines: +- Use the provided context as your primary source of truth. +- If the answer is not in the context, state that you don't have enough information, but feel free to provide general D&D 5e rules as a fallback. +- Keep responses natural and professional. +- Be concise. +""" NOISE_FILTER_SYSTEM_PROMPT = """ You are a D&D Game Master's assistant. Given a transcript, remove all out-of-character (OOC) chatter, logistical discussions (e.g., 'Where is my d20?'), and non-relevant noise. diff --git a/src/pipeline/orchestrator.py b/src/pipeline/orchestrator.py index 800792e..91069c7 100644 --- a/src/pipeline/orchestrator.py +++ b/src/pipeline/orchestrator.py @@ -6,8 +6,16 @@ from typing import List, Optional import numpy as np -from src.llm.models import ContextUpdate, ExtractionResult +from src.llm.models import ( + CharacterStateUpdate, + ContextUpdate, + ExtractionResult, + LoreUpdate, +) from src.llm.processor import LLMProcessor +from src.llm.prompts import EXTRACTION_SYSTEM_PROMPT, NOISE_FILTER_SYSTEM_PROMPT +from src.persistence.characters import update_character_state +from src.persistence.lore import update_lore from src.rag.manager import RAGManager from src.stt.listener import AudioListener from src.stt.transcriber import Transcriber @@ -41,9 +49,10 @@ class PipelineOrchestrator: self.rag_manager = RAGManager() # Queues - self.transcript_queue = asyncio.Queue() - self.proposal_queue = asyncio.Queue() - self.context_queue = asyncio.Queue() + self.stt_to_clean_queue = asyncio.Queue() + self.ui_to_llm_queue = asyncio.Queue() + self.clean_to_llm_queue = asyncio.Queue() + self.llm_to_ui_queue = asyncio.Queue() self.is_running = False @@ -58,6 +67,20 @@ class PipelineOrchestrator: self.buffer_max_samples = self.buffer_max_seconds * self.sample_rate self.last_processed_end_time = 0.0 + def _get_combined_context(self) -> str: + """ + Returns the trimmed conversation history as a context string. + """ + full_history_text = " ".join(self.history) + words = full_history_text.split() + if len(words) > self.history_max_words: + kept_words = words[-self.history_max_words :] + context_text = " ".join(kept_words) + else: + context_text = full_history_text + + return f"Conversation History:\n{context_text}\n\n" + async def stt_worker(self): """ Worker that handles STT: Audio -> Text. @@ -93,7 +116,7 @@ class PipelineOrchestrator: if new_segments: for speaker, text, start, end in new_segments: logger.info(f"Transcribed: [{speaker}] {text}") - await self.transcript_queue.put((speaker, text)) + await self.stt_to_clean_queue.put((speaker, text)) self.last_processed_end_time = max( self.last_processed_end_time, end ) @@ -104,104 +127,104 @@ class PipelineOrchestrator: # Small sleep to prevent tight loop if get_chunk is fast await asyncio.sleep(0.1) - async def llm_worker(self): + async def clean_worker(self): """ - Worker that handles LLM: Text -> Proposal. + Worker that handles Text Cleaning: Raw STT -> Filtered Text. """ - logger.info("LLM Worker started.") + logger.info("Clean Worker started.") while self.is_running: try: - # Get raw text from transcript queue (now a tuple of (speaker, text)) - speaker, raw_text = await self.transcript_queue.get() + # Get raw transcript from STT + speaker, raw_text = await self.stt_to_clean_queue.get() + logger.info(f"Clean Worker: Filtering text from {speaker}: {raw_text}") - logger.info(f"LLM Worker: Processing text from {speaker}: {raw_text}") + # RAG Retrieval for context + context = await asyncio.to_thread(self.rag_manager.retrieve, raw_text) - # 1. Prepare Context (Conversation History) - # Store as "Speaker X: [text]" - entry = f"{speaker}: {raw_text}" - self.history.append(entry) - - full_history_text = " ".join(self.history) - words = full_history_text.split() - if len(words) > self.history_max_words: - # Keep the last N words - kept_words = words[-self.history_max_words :] - context_text = " ".join(kept_words) - else: - context_text = full_history_text - - # 2. Prepare Context (Wiki / Database of Knowledge) - # wiki_context = self._get_wiki_context() - - # Combine both - combined_context = f"Conversation History:\n{context_text}\n\n" - - # --- New RAG Flow --- - # a. Filter transcript first to get cleaned text + # Filtering using the processor filter_result = await asyncio.to_thread( - self.processor.filter_transcript, raw_text, context=combined_context + self.processor.filter_transcript, + raw_text, + context=context, ) - # b. Use filtered text to retrieve relevant snippets from RAG - rag_snippets = [] + # Push filtered text to LLM queue if filter_result.filtered_text: - try: - snippets = await asyncio.to_thread( - self.rag_manager.retrieve, - filter_result.filtered_text, - summarize=True, - ) - rag_snippets = snippets - except Exception as e: - logger.error(f"RAG Retrieval Error in llm_worker: {e}") - - # c. Combine RAG snippets with existing combined_context - logger.info(f"LLM Processor (Extract): rag_snippets: {rag_snippets}") - rag_context_text = "\n".join([s.snippet for s in rag_snippets]) - augmented_context = combined_context - if rag_context_text: - augmented_context += ( - f"\n\nRelevant RAG Context:\n{rag_context_text}" + await self.clean_to_llm_queue.put( + (speaker, filter_result.filtered_text) ) + logger.info(f"Clean Worker: Pushed filtered text to LLM queue.") + else: + logger.info("Clean Worker: No filtered text to push.") - # d. Extract structured data using the augmented context + except Exception as e: + logger.error(f"Clean Worker error: {e}") + + # Small sleep to prevent tight loop + await asyncio.sleep(0.1) + + async def llm_worker(self): + """ + Worker that handles LLM: Filtered Text/UI Input -> Structured Data & UI Updates. + """ + logger.info("LLM Worker started.") + + # Internal queue to serialize processing from multiple sources + internal_queue = asyncio.Queue() + + async def feed_clean(): + while self.is_running: + try: + item = await self.clean_to_llm_queue.get() + await internal_queue.put(item) + except Exception as e: + logger.error(f"LLM Feeder (Clean) error: {e}") + + async def feed_ui(): + while self.is_running: + try: + text = await self.ui_to_llm_queue.get() + await internal_queue.put(("UI", text)) + except Exception as e: + logger.error(f"LLM Feeder (UI) error: {e}") + + # Start feeder tasks + feeders = [ + asyncio.create_task(feed_clean()), + asyncio.create_task(feed_ui()), + ] + + while self.is_running: + try: + speaker, text = await internal_queue.get() + logger.info(f"LLM Worker: Processing text from {speaker}: {text}") + + # RAG Retrieval for context + context = await asyncio.to_thread(self.rag_manager.retrieve, text) + + # Structured extraction using the processor extraction_result = await asyncio.to_thread( self.processor.extract_structured_data, - filter_result.filtered_text if filter_result.filtered_text else "", - context=augmented_context, + text, + context=context, ) - if ( - extraction_result.lore_updates - or extraction_result.character_updates - or extraction_result.significant_events - ): + # Persistence: Lore Updates + for lore_update in extraction_result.lore_updates: + await asyncio.to_thread(update_lore, lore_update) + logger.info(f"LLM Worker: Lore updated: {lore_update.topic}") + + # Persistence: Character State Updates + for char_update in extraction_result.character_updates: + await asyncio.to_thread(update_character_state, char_update) logger.info( - f"LLM Worker: Proposal generated. Putting into proposal queue. (Lore: {len(extraction_result.lore_updates)}, Char: {len(extraction_result.character_updates)})" - ) - await self.proposal_queue.put(extraction_result) - - # Trigger RAG query based on extracted entities (for TUI updates) - await self._trigger_rag_queries(extraction_result) - else: - logger.info("LLM Worker: No relevant game data extracted.") - - # e. If the filter found contextual info, push it to the context queue - if filter_result.contextual_info: - logger.info( - f"LLM Worker: Contextual info found: {filter_result.contextual_info}" - ) - await self.context_queue.put( - ContextUpdate( - query="Filter", - snippet=filter_result.contextual_info, - source="Transcript", - ) + f"LLM Worker: Character {char_update.character_name} state updated." ) - # f. Push the distilled RAG snippets from extraction to the context queue - for snippet in extraction_result.context_updates: - await self.context_queue.put(snippet) + # UI Notification: Context Updates + for context_update in extraction_result.context_updates: + await self.llm_to_ui_queue.put(context_update) + logger.info(f"LLM Worker: Pushed context update to UI.") except Exception as e: logger.error(f"LLM Worker error: {e}") @@ -209,44 +232,9 @@ class PipelineOrchestrator: # Small sleep await asyncio.sleep(0.1) - async def _trigger_rag_queries(self, result: ExtractionResult): - """ - Triggers RAG queries based on the extracted entities and results. - """ - queries = set() - - # Collect entities from lore updates - for update in result.lore_updates: - if update.entity_name: - queries.add(update.entity_name) - - # Collect entities from character updates - for update in result.character_updates: - if update.character_name: - queries.add(update.character_name) - - # Collect events as potential queries - for event in result.significant_events: - queries.add(event) - - if not queries: - logger.info("RAG: No query terms identified from extraction result.") - return - - for query in queries: - logger.info(f"RAG: Triggering query for: {query}") - try: - # Run retrieval in a thread to avoid blocking the event loop - updates = await asyncio.to_thread( - self.rag_manager.retrieve, query, summarize=True - ) - for update in updates: - await self.context_queue.put(update) - logger.info( - f"RAG: Retrieved snippet for {query} from {update.source}" - ) - except Exception as e: - logger.error(f"RAG: Error retrieving context for {query}: {e}") + # Clean up feeders + for f in feeders: + f.cancel() def _get_wiki_context(self) -> str: """ @@ -274,15 +262,15 @@ class PipelineOrchestrator: async def tui_worker(self): """ - Worker that handles TUI: Proposal -> Persistence. + Worker that handles TUI: UI interactions. """ logger.info("TUI Worker started.") try: - # Launch TUI exactly once. - # Pass the proposal queue and context queue to the app. + # Launch TUI. + # Use the new queues for the TUI. app = ConfirmationApp( - proposal_queue=self.proposal_queue, - context_queue=self.context_queue, + ui_to_llm_queue=self.ui_to_llm_queue, + llm_to_ui_queue=self.llm_to_ui_queue, ) await app.run_async() self.stop() @@ -308,12 +296,8 @@ class PipelineOrchestrator: # Start workers as background tasks tasks = [ asyncio.create_task(self.stt_worker()), + asyncio.create_task(self.clean_worker()), asyncio.create_task(self.llm_worker()), - asyncio.create_task( - self.context_pipeline.run( - self.transcript_queue, self.context_queue, stop_event - ) - ), asyncio.create_task(self.tui_worker()), ] diff --git a/src/rag/manager.py b/src/rag/manager.py index 8b930a2..e5d8bb1 100644 --- a/src/rag/manager.py +++ b/src/rag/manager.py @@ -150,7 +150,8 @@ class RAGManager: self, query: str, top_k: int = 5, summarize: bool = False ) -> List[ContextUpdate]: """ - Retrieves the top-K most relevant snippets for a given query. + Retrieves the top-K most relevant snippets for a given query, + filtering for those with a similarity score > 0.7. """ if not self.index: print("Index not initialized. Please ingest documents first.") @@ -160,6 +161,9 @@ class RAGManager: retriever = self.index.as_retriever(similarity_top_k=top_k) nodes = retriever.retrieve(query) + # Filter nodes by similarity score (threshold > 0.7) + nodes = [node for node in nodes if node.score >= 0.7] + if summarize: return self.summarize_results(query, nodes) diff --git a/src/ui/tui.py b/src/ui/tui.py index 1df3014..b1cb826 100644 --- a/src/ui/tui.py +++ b/src/ui/tui.py @@ -79,7 +79,7 @@ class ConfirmationApp(App): #modal-actions { height: auto; margin-top: 1; - align: right; + align: right middle; } #edit-input { @@ -108,17 +108,13 @@ class ConfirmationApp(App): def __init__( self, result: Optional[ExtractionResult] = None, - proposal_queue: Optional[asyncio.Queue] = None, - context_queue: Optional[asyncio.Queue] = None, - query_queue: Optional[asyncio.Queue] = None, - response_queue: Optional[asyncio.Queue] = None, + ui_to_llm_queue: Optional[asyncio.Queue] = None, + llm_to_ui_queue: Optional[asyncio.Queue] = None, ): super().__init__() self.result = result - self.proposal_queue = proposal_queue - self.context_queue = context_queue - self.query_queue = query_queue - self.response_queue = response_queue + self.ui_to_llm_queue = ui_to_llm_queue + self.llm_to_ui_queue = llm_to_ui_queue self.pending_updates: List[Union[LoreUpdate, CharacterStateUpdate]] = [] if result: @@ -145,12 +141,11 @@ class ConfirmationApp(App): for i, update in enumerate(self.pending_updates): self.add_update_to_table(update, i) - if self.proposal_queue: - self.run_worker(self.poll_proposal_queue, thread=False) - if self.context_queue: - self.run_worker(self.poll_context_queue, thread=False) - if self.response_queue: - self.run_worker(self.poll_response_queue, thread=False) + if self.ui_to_llm_queue: + # We don't need a poller for this, just the action_send + pass + if self.llm_to_ui_queue: + self.run_worker(self.poll_llm_updates, thread=False) self.query_one("#llm-input", Input).focus() @@ -170,15 +165,18 @@ class ConfirmationApp(App): change_text += f", Removed: {', '.join(update.status_effects_removed)}" table.add_row("Char", update.character_name, change_text, key=str(index)) - async def poll_proposal_queue(self) -> None: + async def poll_llm_updates(self) -> None: while True: try: - result = await self.proposal_queue.get() - self.handle_proposal_result(result) - if hasattr(self.proposal_queue, "task_done"): - self.proposal_queue.task_done() + update = await self.llm_to_ui_queue.get() + display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}" + context_list = self.query_one("#context-pane", ListView) + # Insert at the top to show most recent first + context_list.insert(0, ListItem(Static(display_text))) + if hasattr(self.llm_to_ui_queue, "task_done"): + self.llm_to_ui_queue.task_done() except Exception as e: - self.log(f"Error polling proposal queue: {e}") + self.log(f"Error polling LLM updates: {e}") def handle_proposal_result(self, result: ExtractionResult) -> None: table = self.query_one("#pending-facts-table", DataTable) @@ -188,32 +186,22 @@ class ConfirmationApp(App): self.add_update_to_table(update, index) async def poll_context_queue(self) -> None: - while True: - try: - update = await self.context_queue.get() - display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}" - context_list = self.query_one("#context-pane", ListView) - context_list.append(ListItem(Static(display_text))) - if hasattr(self.context_queue, "task_done"): - self.context_queue.task_done() - except Exception as e: - self.log(f"Error polling context queue: {e}") + # Obsolete + pass async def poll_response_queue(self) -> None: - while True: - try: - answer = await self.response_queue.get() - self.notify(answer) - if hasattr(self.response_queue, "task_done"): - self.response_queue.task_done() - except Exception as e: - self.log(f"Error polling response queue: {e}") + # Obsolete + pass + + def on_input_submitted(self, event: Input.Submitted) -> None: + if event.input.id == "llm-input": + self.action_send() def action_send(self) -> None: input_widget = self.query_one("#llm-input", Input) text = input_widget.value - if text and self.query_queue: - self.query_queue.put_nowait(text) + if text and self.ui_to_llm_queue: + self.ui_to_llm_queue.put_nowait(text) input_widget.value = "" def action_accept(self) -> None: