feat: implement RAG capabilities and Context Pane integration

- Add RAG capabilities using LlamaIndex and ChromaDB
- Implement RAGManager for PHB indexing and retrieval
- Integrate RAG pipeline into orchestrator to trigger queries based on extracted entities
- Update TUI to include a 3-column layout with a real-time Context Pane
- Define ContextUpdate data models in src/llm/models.py
- Update requirements.txt with new dependencies
This commit is contained in:
2026-05-26 22:07:12 -07:00
parent f4c98fb2b9
commit 954f2f50d8
6 changed files with 281 additions and 5 deletions
+3
View File
@@ -6,3 +6,6 @@ textual
typer typer
openai openai
python-dotenv python-dotenv
llama-index
chromadb
pdfplumber
+10
View File
@@ -44,6 +44,16 @@ class CharacterStateUpdate(BaseModel):
) )
class ContextUpdate(BaseModel):
query: str = Field(..., description="The search query used to retrieve the context")
snippet: str = Field(
..., description="The relevant text snippet retrieved from the source"
)
source: str = Field(
..., description="The source of the snippet (e.g., 'PHB p. 12')"
)
class ExtractionResult(BaseModel): class ExtractionResult(BaseModel):
lore_updates: List[LoreUpdate] = Field( lore_updates: List[LoreUpdate] = Field(
default_factory=list, description="List of discovered lore facts", alias="lore" default_factory=list, description="List of discovered lore facts", alias="lore"
+49 -3
View File
@@ -6,8 +6,9 @@ from typing import List, Optional
import numpy as np import numpy as np
from src.llm.models import ExtractionResult from src.llm.models import ContextUpdate, ExtractionResult
from src.llm.processor import LLMProcessor from src.llm.processor import LLMProcessor
from src.rag.manager import RAGManager
from src.stt.listener import AudioListener from src.stt.listener import AudioListener
from src.stt.transcriber import Transcriber from src.stt.transcriber import Transcriber
from src.ui.tui import ConfirmationApp from src.ui.tui import ConfirmationApp
@@ -37,10 +38,12 @@ class PipelineOrchestrator:
self.listener = AudioListener(loop=self.loop) self.listener = AudioListener(loop=self.loop)
self.transcriber = Transcriber(model_size="small") self.transcriber = Transcriber(model_size="small")
self.processor = LLMProcessor() self.processor = LLMProcessor()
self.rag_manager = RAGManager()
# Queues # Queues
self.transcript_queue = asyncio.Queue() self.transcript_queue = asyncio.Queue()
self.proposal_queue = asyncio.Queue() self.proposal_queue = asyncio.Queue()
self.context_queue = asyncio.Queue()
self.is_running = False self.is_running = False
@@ -148,6 +151,9 @@ class PipelineOrchestrator:
f"LLM Worker: Proposal generated. Putting into proposal queue. (Lore: {len(result.lore_updates)}, Char: {len(result.character_updates)})" f"LLM Worker: Proposal generated. Putting into proposal queue. (Lore: {len(result.lore_updates)}, Char: {len(result.character_updates)})"
) )
await self.proposal_queue.put(result) await self.proposal_queue.put(result)
# Trigger RAG query based on extracted entities
await self._trigger_rag_queries(result)
else: else:
logger.info("LLM Worker: No relevant game data extracted.") logger.info("LLM Worker: No relevant game data extracted.")
@@ -157,6 +163,43 @@ class PipelineOrchestrator:
# Small sleep # Small sleep
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
async def _trigger_rag_queries(self, result: ExtractionResult):
"""
Triggers RAG queries based on the extracted entities and results.
"""
queries = set()
# Collect entities from lore updates
for update in result.lore_updates:
if update.entity_name:
queries.add(update.entity_name)
# Collect entities from character updates
for update in result.character_updates:
if update.character_name:
queries.add(update.character_name)
# Collect events as potential queries
for event in result.significant_events:
queries.add(event)
if not queries:
logger.info("RAG: No query terms identified from extraction result.")
return
for query in queries:
logger.info(f"RAG: Triggering query for: {query}")
try:
# Run retrieval in a thread to avoid blocking the event loop
updates = await asyncio.to_thread(self.rag_manager.retrieve, query)
for update in updates:
await self.context_queue.put(update)
logger.info(
f"RAG: Retrieved snippet for {query} from {update.source}"
)
except Exception as e:
logger.error(f"RAG: Error retrieving context for {query}: {e}")
def _get_wiki_context(self) -> str: def _get_wiki_context(self) -> str:
""" """
Reads all files in the lore directory and returns them as a 저희 context string. Reads all files in the lore directory and returns them as a 저희 context string.
@@ -188,8 +231,11 @@ class PipelineOrchestrator:
logger.info("TUI Worker started.") logger.info("TUI Worker started.")
try: try:
# Launch TUI exactly once. # Launch TUI exactly once.
# Pass the proposal queue to the app. # Pass the proposal queue and context queue to the app.
app = ConfirmationApp(proposal_queue=self.proposal_queue) app = ConfirmationApp(
proposal_queue=self.proposal_queue,
context_queue=self.context_queue,
)
await app.run_async() await app.run_async()
self.stop() self.stop()
except Exception as e: except Exception as e:
+86
View File
@@ -0,0 +1,86 @@
import os
from typing import List, Optional
import chromadb
import pdfplumber
from llama_index.core import Document, Settings, StorageContext, VectorStoreIndex
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore
from src.llm.models import ContextUpdate
class RAGManager:
def __init__(self, persist_dir: str = "data/rag_index"):
self.persist_dir = persist_dir
self.db = chromadb.PersistentClient(path=self.persist_dir)
self.collection_name = "phb_collection"
# Initialize Chroma Vector Store
self.vector_store = ChromaVectorStore(
chroma_collection=self.db.get_or_create_collection(self.collection_name)
)
# Initialize Storage Context
self.storage_context = StorageContext.from_defaults(
vector_store=self.vector_store
)
# Use a local HuggingFace embedding model to avoid API key issues during verification
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
# Load index if it exists, otherwise initialize
try:
self.index = VectorStoreIndex.from_vector_store(
self.vector_store, storage_context=self.storage_context
)
except Exception:
self.index = None
def ingest_pdf(self, pdf_path: str):
"""
Parses a PDF, chunks it, and stores embeddings in ChromaDB.
"""
documents = []
with pdfplumber.open(pdf_path) as pdf:
for i, page in enumerate(pdf.pages):
text = page.extract_text()
if text:
# Create a document for each page
# In a real scenario, we might use a recursive character splitter
# but for PHB, page-level chunking is a good start.
doc = Document(
text=text, metadata={"source": f"PHB p. {i + 1}", "page": i + 1}
)
documents.append(doc)
if not documents:
print(f"No text extracted from {pdf_path}")
return
# Create index from documents
self.index = VectorStoreIndex.from_documents(
documents, storage_context=self.storage_context
)
print(f"Successfully ingested {pdf_path} into the vector store.")
def retrieve(self, query: str, top_k: int = 3) -> List[ContextUpdate]:
"""
Retrieves the top-K most relevant snippets for a given query.
"""
if not self.index:
print("Index not initialized. Please ingest documents first.")
return []
# Create a retriever
retriever = self.index.as_retriever(similarity_top_k=top_k)
nodes = retriever.retrieve(query)
results = []
for node in nodes:
# Extract metadata
source = node.metadata.get("source", "Unknown Source")
results.append(ContextUpdate(query=query, snippet=node.text, source=source))
return results
+35 -2
View File
@@ -17,13 +17,19 @@ class ConfirmationApp(App):
} }
#left-pane { #left-pane {
width: 40%; width: 30%;
border: solid;
padding: 1;
}
#middle-pane {
width: 30%;
border: solid; border: solid;
padding: 1; padding: 1;
} }
#right-pane { #right-pane {
width: 60%; width: 40%;
border: solid; border: solid;
padding: 1; padding: 1;
layout: vertical; layout: vertical;
@@ -61,10 +67,12 @@ class ConfirmationApp(App):
self, self,
result: Optional[ExtractionResult] = None, result: Optional[ExtractionResult] = None,
proposal_queue: Optional[asyncio.Queue] = None, proposal_queue: Optional[asyncio.Queue] = None,
context_queue: Optional[asyncio.Queue] = None,
): ):
super().__init__() super().__init__()
self.result = result self.result = result
self.proposal_queue = proposal_queue self.proposal_queue = proposal_queue
self.context_queue = context_queue
self.pending_updates: List[Union[LoreUpdate, CharacterStateUpdate]] = [] self.pending_updates: List[Union[LoreUpdate, CharacterStateUpdate]] = []
if result: if result:
@@ -81,6 +89,10 @@ class ConfirmationApp(App):
DataTable(id="update-table"), DataTable(id="update-table"),
id="left-pane", id="left-pane",
), ),
Vertical(
Static("No context available", id="context-pane"),
id="middle-pane",
),
Vertical( Vertical(
Vertical( Vertical(
Label("Details:", id="details-label"), Label("Details:", id="details-label"),
@@ -132,6 +144,9 @@ class ConfirmationApp(App):
if self.proposal_queue: if self.proposal_queue:
self.run_worker(self.poll_proposal_queue, thread=False) self.run_worker(self.poll_proposal_queue, thread=False)
if self.context_queue:
self.run_worker(self.poll_context_queue, thread=False)
async def poll_proposal_queue(self) -> None: async def poll_proposal_queue(self) -> None:
""" """
Background worker that polls the proposal queue for new extraction results. Background worker that polls the proposal queue for new extraction results.
@@ -147,6 +162,24 @@ class ConfirmationApp(App):
# Log error but keep the worker running # Log error but keep the worker running
self.log(f"Error polling proposal queue: {e}") self.log(f"Error polling proposal queue: {e}")
async def poll_context_queue(self) -> None:
"""
Background worker that polls the context queue for new RAG updates.
"""
while True:
try:
update = await self.context_queue.get()
context_pane = self.query_one("#context-pane", Static)
# Format the update for display
display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}"
context_pane.update(display_text)
if hasattr(self.context_queue, "task_done"):
self.context_queue.task_done()
except Exception as e:
self.log(f"Error polling context queue: {e}")
def add_result(self, result: ExtractionResult) -> None: def add_result(self, result: ExtractionResult) -> None:
""" """
Adds results from the LLM processor to the TUI table. Adds results from the LLM processor to the TUI table.
+98
View File
@@ -0,0 +1,98 @@
import os
from reportlab.pdfgen import canvas
from src.rag.manager import RAGManager
def create_dummy_phb(pdf_path: str):
"""
Creates a dummy PDF file to simulate the Player's Handbook for verification.
"""
print(f"Creating dummy PHB at {pdf_path}...")
c = canvas.Canvas(pdf_path)
# Page 1: Fireball
c.drawString(100, 750, "Fireball")
c.drawString(
100,
730,
"A bright streak flashes from your pointing finger to a point you choose within range.",
)
c.drawString(
100,
710,
"Each creature in a 20-foot-radius sphere centered on that point must make a Dexterity saving throw.",
)
c.showPage()
# Page 2: Grappling
c.drawString(100, 750, "Grappling")
c.drawString(
100,
730,
"Grappling is a special option available to any attack that hits with a melee attack.",
)
c.drawString(
100,
710,
"A creature is grappled if it is the target of the grapple attack and the attack hits.",
)
c.showPage()
# Page 3: General Rules
c.drawString(100, 750, "General Rules")
c.drawString(
100,
730,
"Combat is resolved in rounds, and each round represents 6 seconds of in-game time.",
)
c.showPage()
c.save()
print(f"Dummy PHB created successfully at {pdf_path}")
def test_rag():
pdf_path = "data/phb_dummy.pdf"
create_dummy_phb(pdf_path)
# Initialize RAG Manager
rag = RAGManager(persist_dir="data/rag_index_test")
# Task 2.2: Ingest PDF
print("\nTesting Ingestion...")
rag.ingest_pdf(pdf_path)
# Task 2.3: Retrieve Logic
print("\nTesting Retrieval...")
# Test 1: Fireball
query1 = "What is Fireball?"
results1 = rag.retrieve(query1)
print(f"Query: {query1}")
for res in results1:
print(f"Source: {res.source} | Snippet: {res.snippet[:100]}...")
assert len(results1) > 0, "Should have retrieved at least one result for 'Fireball'"
assert "Fireball" in results1[0].snippet, "The top result should contain 'Fireball'"
# Test 2: Grappling
query2 = "How does grappling work?"
results2 = rag.retrieve(query2)
print(f"Query: {query2}")
for res in results2:
print(f"Source: {res.source} | Snippet: {res.snippet[:100]}...")
assert len(results2) > 0, (
"Should have retrieved at least one result for 'Grappling'"
)
assert "Grappling" in results2[0].snippet, (
"The top result should contain 'Grappling'"
)
print("\n✅ RAG Verification Successful!")
if __name__ == "__main__":
test_rag()