diff --git a/requirements.txt b/requirements.txt index e1db5a1..3554082 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,6 @@ textual typer openai python-dotenv +llama-index +chromadb +pdfplumber diff --git a/src/llm/models.py b/src/llm/models.py index 72256f8..caae1d5 100644 --- a/src/llm/models.py +++ b/src/llm/models.py @@ -44,6 +44,16 @@ class CharacterStateUpdate(BaseModel): ) +class ContextUpdate(BaseModel): + query: str = Field(..., description="The search query used to retrieve the context") + snippet: str = Field( + ..., description="The relevant text snippet retrieved from the source" + ) + source: str = Field( + ..., description="The source of the snippet (e.g., 'PHB p. 12')" + ) + + class ExtractionResult(BaseModel): lore_updates: List[LoreUpdate] = Field( default_factory=list, description="List of discovered lore facts", alias="lore" diff --git a/src/pipeline/orchestrator.py b/src/pipeline/orchestrator.py index 806b3e7..7deefa4 100644 --- a/src/pipeline/orchestrator.py +++ b/src/pipeline/orchestrator.py @@ -6,8 +6,9 @@ from typing import List, Optional import numpy as np -from src.llm.models import ExtractionResult +from src.llm.models import ContextUpdate, ExtractionResult from src.llm.processor import LLMProcessor +from src.rag.manager import RAGManager from src.stt.listener import AudioListener from src.stt.transcriber import Transcriber from src.ui.tui import ConfirmationApp @@ -37,10 +38,12 @@ class PipelineOrchestrator: self.listener = AudioListener(loop=self.loop) self.transcriber = Transcriber(model_size="small") self.processor = LLMProcessor() + self.rag_manager = RAGManager() # Queues self.transcript_queue = asyncio.Queue() self.proposal_queue = asyncio.Queue() + self.context_queue = asyncio.Queue() self.is_running = False @@ -148,6 +151,9 @@ class PipelineOrchestrator: f"LLM Worker: Proposal generated. Putting into proposal queue. (Lore: {len(result.lore_updates)}, Char: {len(result.character_updates)})" ) await self.proposal_queue.put(result) + + # Trigger RAG query based on extracted entities + await self._trigger_rag_queries(result) else: logger.info("LLM Worker: No relevant game data extracted.") @@ -157,6 +163,43 @@ class PipelineOrchestrator: # Small sleep await asyncio.sleep(0.1) + async def _trigger_rag_queries(self, result: ExtractionResult): + """ + Triggers RAG queries based on the extracted entities and results. + """ + queries = set() + + # Collect entities from lore updates + for update in result.lore_updates: + if update.entity_name: + queries.add(update.entity_name) + + # Collect entities from character updates + for update in result.character_updates: + if update.character_name: + queries.add(update.character_name) + + # Collect events as potential queries + for event in result.significant_events: + queries.add(event) + + if not queries: + logger.info("RAG: No query terms identified from extraction result.") + return + + for query in queries: + logger.info(f"RAG: Triggering query for: {query}") + try: + # Run retrieval in a thread to avoid blocking the event loop + updates = await asyncio.to_thread(self.rag_manager.retrieve, query) + for update in updates: + await self.context_queue.put(update) + logger.info( + f"RAG: Retrieved snippet for {query} from {update.source}" + ) + except Exception as e: + logger.error(f"RAG: Error retrieving context for {query}: {e}") + def _get_wiki_context(self) -> str: """ Reads all files in the lore directory and returns them as a 저희 context string. @@ -188,8 +231,11 @@ class PipelineOrchestrator: logger.info("TUI Worker started.") try: # Launch TUI exactly once. - # Pass the proposal queue to the app. - app = ConfirmationApp(proposal_queue=self.proposal_queue) + # Pass the proposal queue and context queue to the app. + app = ConfirmationApp( + proposal_queue=self.proposal_queue, + context_queue=self.context_queue, + ) await app.run_async() self.stop() except Exception as e: diff --git a/src/rag/manager.py b/src/rag/manager.py new file mode 100644 index 0000000..17bd295 --- /dev/null +++ b/src/rag/manager.py @@ -0,0 +1,86 @@ +import os +from typing import List, Optional + +import chromadb +import pdfplumber +from llama_index.core import Document, Settings, StorageContext, VectorStoreIndex +from llama_index.embeddings.huggingface import HuggingFaceEmbedding +from llama_index.vector_stores.chroma import ChromaVectorStore + +from src.llm.models import ContextUpdate + + +class RAGManager: + def __init__(self, persist_dir: str = "data/rag_index"): + self.persist_dir = persist_dir + self.db = chromadb.PersistentClient(path=self.persist_dir) + self.collection_name = "phb_collection" + + # Initialize Chroma Vector Store + self.vector_store = ChromaVectorStore( + chroma_collection=self.db.get_or_create_collection(self.collection_name) + ) + + # Initialize Storage Context + self.storage_context = StorageContext.from_defaults( + vector_store=self.vector_store + ) + + # Use a local HuggingFace embedding model to avoid API key issues during verification + Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5") + + # Load index if it exists, otherwise initialize + try: + self.index = VectorStoreIndex.from_vector_store( + self.vector_store, storage_context=self.storage_context + ) + except Exception: + self.index = None + + def ingest_pdf(self, pdf_path: str): + """ + Parses a PDF, chunks it, and stores embeddings in ChromaDB. + """ + documents = [] + with pdfplumber.open(pdf_path) as pdf: + for i, page in enumerate(pdf.pages): + text = page.extract_text() + if text: + # Create a document for each page + # In a real scenario, we might use a recursive character splitter + # but for PHB, page-level chunking is a good start. + doc = Document( + text=text, metadata={"source": f"PHB p. {i + 1}", "page": i + 1} + ) + documents.append(doc) + + if not documents: + print(f"No text extracted from {pdf_path}") + return + + # Create index from documents + self.index = VectorStoreIndex.from_documents( + documents, storage_context=self.storage_context + ) + print(f"Successfully ingested {pdf_path} into the vector store.") + + def retrieve(self, query: str, top_k: int = 3) -> List[ContextUpdate]: + """ + Retrieves the top-K most relevant snippets for a given query. + """ + if not self.index: + print("Index not initialized. Please ingest documents first.") + return [] + + # Create a retriever + retriever = self.index.as_retriever(similarity_top_k=top_k) + nodes = retriever.retrieve(query) + + results = [] + for node in nodes: + # Extract metadata + source = node.metadata.get("source", "Unknown Source") + + results.append(ContextUpdate(query=query, snippet=node.text, source=source)) + + return results diff --git a/src/ui/tui.py b/src/ui/tui.py index 9cb5454..dcee0fb 100644 --- a/src/ui/tui.py +++ b/src/ui/tui.py @@ -17,13 +17,19 @@ class ConfirmationApp(App): } #left-pane { - width: 40%; + width: 30%; + border: solid; + padding: 1; + } + + #middle-pane { + width: 30%; border: solid; padding: 1; } #right-pane { - width: 60%; + width: 40%; border: solid; padding: 1; layout: vertical; @@ -61,10 +67,12 @@ class ConfirmationApp(App): self, result: Optional[ExtractionResult] = None, proposal_queue: Optional[asyncio.Queue] = None, + context_queue: Optional[asyncio.Queue] = None, ): super().__init__() self.result = result self.proposal_queue = proposal_queue + self.context_queue = context_queue self.pending_updates: List[Union[LoreUpdate, CharacterStateUpdate]] = [] if result: @@ -81,6 +89,10 @@ class ConfirmationApp(App): DataTable(id="update-table"), id="left-pane", ), + Vertical( + Static("No context available", id="context-pane"), + id="middle-pane", + ), Vertical( Vertical( Label("Details:", id="details-label"), @@ -132,6 +144,9 @@ class ConfirmationApp(App): if self.proposal_queue: self.run_worker(self.poll_proposal_queue, thread=False) + if self.context_queue: + self.run_worker(self.poll_context_queue, thread=False) + async def poll_proposal_queue(self) -> None: """ Background worker that polls the proposal queue for new extraction results. @@ -147,6 +162,24 @@ class ConfirmationApp(App): # Log error but keep the worker running self.log(f"Error polling proposal queue: {e}") + async def poll_context_queue(self) -> None: + """ + Background worker that polls the context queue for new RAG updates. + """ + while True: + try: + update = await self.context_queue.get() + context_pane = self.query_one("#context-pane", Static) + + # Format the update for display + display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}" + context_pane.update(display_text) + + if hasattr(self.context_queue, "task_done"): + self.context_queue.task_done() + except Exception as e: + self.log(f"Error polling context queue: {e}") + def add_result(self, result: ExtractionResult) -> None: """ Adds results from the LLM processor to the TUI table. diff --git a/tests/test_rag.py b/tests/test_rag.py new file mode 100644 index 0000000..5fbbf0e --- /dev/null +++ b/tests/test_rag.py @@ -0,0 +1,98 @@ +import os + +from reportlab.pdfgen import canvas + +from src.rag.manager import RAGManager + + +def create_dummy_phb(pdf_path: str): + """ + Creates a dummy PDF file to simulate the Player's Handbook for verification. + """ + print(f"Creating dummy PHB at {pdf_path}...") + c = canvas.Canvas(pdf_path) + + # Page 1: Fireball + c.drawString(100, 750, "Fireball") + c.drawString( + 100, + 730, + "A bright streak flashes from your pointing finger to a point you choose within range.", + ) + c.drawString( + 100, + 710, + "Each creature in a 20-foot-radius sphere centered on that point must make a Dexterity saving throw.", + ) + c.showPage() + + # Page 2: Grappling + c.drawString(100, 750, "Grappling") + c.drawString( + 100, + 730, + "Grappling is a special option available to any attack that hits with a melee attack.", + ) + c.drawString( + 100, + 710, + "A creature is grappled if it is the target of the grapple attack and the attack hits.", + ) + c.showPage() + + # Page 3: General Rules + c.drawString(100, 750, "General Rules") + c.drawString( + 100, + 730, + "Combat is resolved in rounds, and each round represents 6 seconds of in-game time.", + ) + c.showPage() + + c.save() + print(f"Dummy PHB created successfully at {pdf_path}") + + +def test_rag(): + pdf_path = "data/phb_dummy.pdf" + create_dummy_phb(pdf_path) + + # Initialize RAG Manager + rag = RAGManager(persist_dir="data/rag_index_test") + + # Task 2.2: Ingest PDF + print("\nTesting Ingestion...") + rag.ingest_pdf(pdf_path) + + # Task 2.3: Retrieve Logic + print("\nTesting Retrieval...") + + # Test 1: Fireball + query1 = "What is Fireball?" + results1 = rag.retrieve(query1) + print(f"Query: {query1}") + for res in results1: + print(f"Source: {res.source} | Snippet: {res.snippet[:100]}...") + + assert len(results1) > 0, "Should have retrieved at least one result for 'Fireball'" + assert "Fireball" in results1[0].snippet, "The top result should contain 'Fireball'" + + # Test 2: Grappling + query2 = "How does grappling work?" + results2 = rag.retrieve(query2) + print(f"Query: {query2}") + for res in results2: + print(f"Source: {res.source} | Snippet: {res.snippet[:100]}...") + + assert len(results2) > 0, ( + "Should have retrieved at least one result for 'Grappling'" + ) + assert "Grappling" in results2[0].snippet, ( + "The top result should contain 'Grappling'" + ) + + print("\n✅ RAG Verification Successful!") + + +if __name__ == "__main__": + test_rag()