feat: implement RAG capabilities and Context Pane integration

- Add RAG capabilities using LlamaIndex and ChromaDB
- Implement RAGManager for PHB indexing and retrieval
- Integrate RAG pipeline into orchestrator to trigger queries based on extracted entities
- Update TUI to include a 3-column layout with a real-time Context Pane
- Define ContextUpdate data models in src/llm/models.py
- Update requirements.txt with new dependencies
This commit is contained in:
2026-05-26 22:07:12 -07:00
parent f4c98fb2b9
commit 954f2f50d8
6 changed files with 281 additions and 5 deletions
+10
View File
@@ -44,6 +44,16 @@ class CharacterStateUpdate(BaseModel):
)
class ContextUpdate(BaseModel):
query: str = Field(..., description="The search query used to retrieve the context")
snippet: str = Field(
..., description="The relevant text snippet retrieved from the source"
)
source: str = Field(
..., description="The source of the snippet (e.g., 'PHB p. 12')"
)
class ExtractionResult(BaseModel):
lore_updates: List[LoreUpdate] = Field(
default_factory=list, description="List of discovered lore facts", alias="lore"
+49 -3
View File
@@ -6,8 +6,9 @@ from typing import List, Optional
import numpy as np
from src.llm.models import ExtractionResult
from src.llm.models import ContextUpdate, ExtractionResult
from src.llm.processor import LLMProcessor
from src.rag.manager import RAGManager
from src.stt.listener import AudioListener
from src.stt.transcriber import Transcriber
from src.ui.tui import ConfirmationApp
@@ -37,10 +38,12 @@ class PipelineOrchestrator:
self.listener = AudioListener(loop=self.loop)
self.transcriber = Transcriber(model_size="small")
self.processor = LLMProcessor()
self.rag_manager = RAGManager()
# Queues
self.transcript_queue = asyncio.Queue()
self.proposal_queue = asyncio.Queue()
self.context_queue = asyncio.Queue()
self.is_running = False
@@ -148,6 +151,9 @@ class PipelineOrchestrator:
f"LLM Worker: Proposal generated. Putting into proposal queue. (Lore: {len(result.lore_updates)}, Char: {len(result.character_updates)})"
)
await self.proposal_queue.put(result)
# Trigger RAG query based on extracted entities
await self._trigger_rag_queries(result)
else:
logger.info("LLM Worker: No relevant game data extracted.")
@@ -157,6 +163,43 @@ class PipelineOrchestrator:
# Small sleep
await asyncio.sleep(0.1)
async def _trigger_rag_queries(self, result: ExtractionResult):
"""
Triggers RAG queries based on the extracted entities and results.
"""
queries = set()
# Collect entities from lore updates
for update in result.lore_updates:
if update.entity_name:
queries.add(update.entity_name)
# Collect entities from character updates
for update in result.character_updates:
if update.character_name:
queries.add(update.character_name)
# Collect events as potential queries
for event in result.significant_events:
queries.add(event)
if not queries:
logger.info("RAG: No query terms identified from extraction result.")
return
for query in queries:
logger.info(f"RAG: Triggering query for: {query}")
try:
# Run retrieval in a thread to avoid blocking the event loop
updates = await asyncio.to_thread(self.rag_manager.retrieve, query)
for update in updates:
await self.context_queue.put(update)
logger.info(
f"RAG: Retrieved snippet for {query} from {update.source}"
)
except Exception as e:
logger.error(f"RAG: Error retrieving context for {query}: {e}")
def _get_wiki_context(self) -> str:
"""
Reads all files in the lore directory and returns them as a 저희 context string.
@@ -188,8 +231,11 @@ class PipelineOrchestrator:
logger.info("TUI Worker started.")
try:
# Launch TUI exactly once.
# Pass the proposal queue to the app.
app = ConfirmationApp(proposal_queue=self.proposal_queue)
# Pass the proposal queue and context queue to the app.
app = ConfirmationApp(
proposal_queue=self.proposal_queue,
context_queue=self.context_queue,
)
await app.run_async()
self.stop()
except Exception as e:
+86
View File
@@ -0,0 +1,86 @@
import os
from typing import List, Optional
import chromadb
import pdfplumber
from llama_index.core import Document, Settings, StorageContext, VectorStoreIndex
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore
from src.llm.models import ContextUpdate
class RAGManager:
def __init__(self, persist_dir: str = "data/rag_index"):
self.persist_dir = persist_dir
self.db = chromadb.PersistentClient(path=self.persist_dir)
self.collection_name = "phb_collection"
# Initialize Chroma Vector Store
self.vector_store = ChromaVectorStore(
chroma_collection=self.db.get_or_create_collection(self.collection_name)
)
# Initialize Storage Context
self.storage_context = StorageContext.from_defaults(
vector_store=self.vector_store
)
# Use a local HuggingFace embedding model to avoid API key issues during verification
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
# Load index if it exists, otherwise initialize
try:
self.index = VectorStoreIndex.from_vector_store(
self.vector_store, storage_context=self.storage_context
)
except Exception:
self.index = None
def ingest_pdf(self, pdf_path: str):
"""
Parses a PDF, chunks it, and stores embeddings in ChromaDB.
"""
documents = []
with pdfplumber.open(pdf_path) as pdf:
for i, page in enumerate(pdf.pages):
text = page.extract_text()
if text:
# Create a document for each page
# In a real scenario, we might use a recursive character splitter
# but for PHB, page-level chunking is a good start.
doc = Document(
text=text, metadata={"source": f"PHB p. {i + 1}", "page": i + 1}
)
documents.append(doc)
if not documents:
print(f"No text extracted from {pdf_path}")
return
# Create index from documents
self.index = VectorStoreIndex.from_documents(
documents, storage_context=self.storage_context
)
print(f"Successfully ingested {pdf_path} into the vector store.")
def retrieve(self, query: str, top_k: int = 3) -> List[ContextUpdate]:
"""
Retrieves the top-K most relevant snippets for a given query.
"""
if not self.index:
print("Index not initialized. Please ingest documents first.")
return []
# Create a retriever
retriever = self.index.as_retriever(similarity_top_k=top_k)
nodes = retriever.retrieve(query)
results = []
for node in nodes:
# Extract metadata
source = node.metadata.get("source", "Unknown Source")
results.append(ContextUpdate(query=query, snippet=node.text, source=source))
return results
+35 -2
View File
@@ -17,13 +17,19 @@ class ConfirmationApp(App):
}
#left-pane {
width: 40%;
width: 30%;
border: solid;
padding: 1;
}
#middle-pane {
width: 30%;
border: solid;
padding: 1;
}
#right-pane {
width: 60%;
width: 40%;
border: solid;
padding: 1;
layout: vertical;
@@ -61,10 +67,12 @@ class ConfirmationApp(App):
self,
result: Optional[ExtractionResult] = None,
proposal_queue: Optional[asyncio.Queue] = None,
context_queue: Optional[asyncio.Queue] = None,
):
super().__init__()
self.result = result
self.proposal_queue = proposal_queue
self.context_queue = context_queue
self.pending_updates: List[Union[LoreUpdate, CharacterStateUpdate]] = []
if result:
@@ -81,6 +89,10 @@ class ConfirmationApp(App):
DataTable(id="update-table"),
id="left-pane",
),
Vertical(
Static("No context available", id="context-pane"),
id="middle-pane",
),
Vertical(
Vertical(
Label("Details:", id="details-label"),
@@ -132,6 +144,9 @@ class ConfirmationApp(App):
if self.proposal_queue:
self.run_worker(self.poll_proposal_queue, thread=False)
if self.context_queue:
self.run_worker(self.poll_context_queue, thread=False)
async def poll_proposal_queue(self) -> None:
"""
Background worker that polls the proposal queue for new extraction results.
@@ -147,6 +162,24 @@ class ConfirmationApp(App):
# Log error but keep the worker running
self.log(f"Error polling proposal queue: {e}")
async def poll_context_queue(self) -> None:
"""
Background worker that polls the context queue for new RAG updates.
"""
while True:
try:
update = await self.context_queue.get()
context_pane = self.query_one("#context-pane", Static)
# Format the update for display
display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}"
context_pane.update(display_text)
if hasattr(self.context_queue, "task_done"):
self.context_queue.task_done()
except Exception as e:
self.log(f"Error polling context queue: {e}")
def add_result(self, result: ExtractionResult) -> None:
"""
Adds results from the LLM processor to the TUI table.