feat: implement RAG capabilities and Context Pane integration
- Add RAG capabilities using LlamaIndex and ChromaDB - Implement RAGManager for PHB indexing and retrieval - Integrate RAG pipeline into orchestrator to trigger queries based on extracted entities - Update TUI to include a 3-column layout with a real-time Context Pane - Define ContextUpdate data models in src/llm/models.py - Update requirements.txt with new dependencies
This commit is contained in:
@@ -6,3 +6,6 @@ textual
|
||||
typer
|
||||
openai
|
||||
python-dotenv
|
||||
llama-index
|
||||
chromadb
|
||||
pdfplumber
|
||||
|
||||
@@ -44,6 +44,16 @@ class CharacterStateUpdate(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ContextUpdate(BaseModel):
|
||||
query: str = Field(..., description="The search query used to retrieve the context")
|
||||
snippet: str = Field(
|
||||
..., description="The relevant text snippet retrieved from the source"
|
||||
)
|
||||
source: str = Field(
|
||||
..., description="The source of the snippet (e.g., 'PHB p. 12')"
|
||||
)
|
||||
|
||||
|
||||
class ExtractionResult(BaseModel):
|
||||
lore_updates: List[LoreUpdate] = Field(
|
||||
default_factory=list, description="List of discovered lore facts", alias="lore"
|
||||
|
||||
@@ -6,8 +6,9 @@ from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.llm.models import ExtractionResult
|
||||
from src.llm.models import ContextUpdate, ExtractionResult
|
||||
from src.llm.processor import LLMProcessor
|
||||
from src.rag.manager import RAGManager
|
||||
from src.stt.listener import AudioListener
|
||||
from src.stt.transcriber import Transcriber
|
||||
from src.ui.tui import ConfirmationApp
|
||||
@@ -37,10 +38,12 @@ class PipelineOrchestrator:
|
||||
self.listener = AudioListener(loop=self.loop)
|
||||
self.transcriber = Transcriber(model_size="small")
|
||||
self.processor = LLMProcessor()
|
||||
self.rag_manager = RAGManager()
|
||||
|
||||
# Queues
|
||||
self.transcript_queue = asyncio.Queue()
|
||||
self.proposal_queue = asyncio.Queue()
|
||||
self.context_queue = asyncio.Queue()
|
||||
|
||||
self.is_running = False
|
||||
|
||||
@@ -148,6 +151,9 @@ class PipelineOrchestrator:
|
||||
f"LLM Worker: Proposal generated. Putting into proposal queue. (Lore: {len(result.lore_updates)}, Char: {len(result.character_updates)})"
|
||||
)
|
||||
await self.proposal_queue.put(result)
|
||||
|
||||
# Trigger RAG query based on extracted entities
|
||||
await self._trigger_rag_queries(result)
|
||||
else:
|
||||
logger.info("LLM Worker: No relevant game data extracted.")
|
||||
|
||||
@@ -157,6 +163,43 @@ class PipelineOrchestrator:
|
||||
# Small sleep
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _trigger_rag_queries(self, result: ExtractionResult):
|
||||
"""
|
||||
Triggers RAG queries based on the extracted entities and results.
|
||||
"""
|
||||
queries = set()
|
||||
|
||||
# Collect entities from lore updates
|
||||
for update in result.lore_updates:
|
||||
if update.entity_name:
|
||||
queries.add(update.entity_name)
|
||||
|
||||
# Collect entities from character updates
|
||||
for update in result.character_updates:
|
||||
if update.character_name:
|
||||
queries.add(update.character_name)
|
||||
|
||||
# Collect events as potential queries
|
||||
for event in result.significant_events:
|
||||
queries.add(event)
|
||||
|
||||
if not queries:
|
||||
logger.info("RAG: No query terms identified from extraction result.")
|
||||
return
|
||||
|
||||
for query in queries:
|
||||
logger.info(f"RAG: Triggering query for: {query}")
|
||||
try:
|
||||
# Run retrieval in a thread to avoid blocking the event loop
|
||||
updates = await asyncio.to_thread(self.rag_manager.retrieve, query)
|
||||
for update in updates:
|
||||
await self.context_queue.put(update)
|
||||
logger.info(
|
||||
f"RAG: Retrieved snippet for {query} from {update.source}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"RAG: Error retrieving context for {query}: {e}")
|
||||
|
||||
def _get_wiki_context(self) -> str:
|
||||
"""
|
||||
Reads all files in the lore directory and returns them as a 저희 context string.
|
||||
@@ -188,8 +231,11 @@ class PipelineOrchestrator:
|
||||
logger.info("TUI Worker started.")
|
||||
try:
|
||||
# Launch TUI exactly once.
|
||||
# Pass the proposal queue to the app.
|
||||
app = ConfirmationApp(proposal_queue=self.proposal_queue)
|
||||
# Pass the proposal queue and context queue to the app.
|
||||
app = ConfirmationApp(
|
||||
proposal_queue=self.proposal_queue,
|
||||
context_queue=self.context_queue,
|
||||
)
|
||||
await app.run_async()
|
||||
self.stop()
|
||||
except Exception as e:
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import chromadb
|
||||
import pdfplumber
|
||||
from llama_index.core import Document, Settings, StorageContext, VectorStoreIndex
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
from src.llm.models import ContextUpdate
|
||||
|
||||
|
||||
class RAGManager:
|
||||
def __init__(self, persist_dir: str = "data/rag_index"):
|
||||
self.persist_dir = persist_dir
|
||||
self.db = chromadb.PersistentClient(path=self.persist_dir)
|
||||
self.collection_name = "phb_collection"
|
||||
|
||||
# Initialize Chroma Vector Store
|
||||
self.vector_store = ChromaVectorStore(
|
||||
chroma_collection=self.db.get_or_create_collection(self.collection_name)
|
||||
)
|
||||
|
||||
# Initialize Storage Context
|
||||
self.storage_context = StorageContext.from_defaults(
|
||||
vector_store=self.vector_store
|
||||
)
|
||||
|
||||
# Use a local HuggingFace embedding model to avoid API key issues during verification
|
||||
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
|
||||
|
||||
# Load index if it exists, otherwise initialize
|
||||
try:
|
||||
self.index = VectorStoreIndex.from_vector_store(
|
||||
self.vector_store, storage_context=self.storage_context
|
||||
)
|
||||
except Exception:
|
||||
self.index = None
|
||||
|
||||
def ingest_pdf(self, pdf_path: str):
|
||||
"""
|
||||
Parses a PDF, chunks it, and stores embeddings in ChromaDB.
|
||||
"""
|
||||
documents = []
|
||||
with pdfplumber.open(pdf_path) as pdf:
|
||||
for i, page in enumerate(pdf.pages):
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
# Create a document for each page
|
||||
# In a real scenario, we might use a recursive character splitter
|
||||
# but for PHB, page-level chunking is a good start.
|
||||
doc = Document(
|
||||
text=text, metadata={"source": f"PHB p. {i + 1}", "page": i + 1}
|
||||
)
|
||||
documents.append(doc)
|
||||
|
||||
if not documents:
|
||||
print(f"No text extracted from {pdf_path}")
|
||||
return
|
||||
|
||||
# Create index from documents
|
||||
self.index = VectorStoreIndex.from_documents(
|
||||
documents, storage_context=self.storage_context
|
||||
)
|
||||
print(f"Successfully ingested {pdf_path} into the vector store.")
|
||||
|
||||
def retrieve(self, query: str, top_k: int = 3) -> List[ContextUpdate]:
|
||||
"""
|
||||
Retrieves the top-K most relevant snippets for a given query.
|
||||
"""
|
||||
if not self.index:
|
||||
print("Index not initialized. Please ingest documents first.")
|
||||
return []
|
||||
|
||||
# Create a retriever
|
||||
retriever = self.index.as_retriever(similarity_top_k=top_k)
|
||||
nodes = retriever.retrieve(query)
|
||||
|
||||
results = []
|
||||
for node in nodes:
|
||||
# Extract metadata
|
||||
source = node.metadata.get("source", "Unknown Source")
|
||||
|
||||
results.append(ContextUpdate(query=query, snippet=node.text, source=source))
|
||||
|
||||
return results
|
||||
+35
-2
@@ -17,13 +17,19 @@ class ConfirmationApp(App):
|
||||
}
|
||||
|
||||
#left-pane {
|
||||
width: 40%;
|
||||
width: 30%;
|
||||
border: solid;
|
||||
padding: 1;
|
||||
}
|
||||
|
||||
#middle-pane {
|
||||
width: 30%;
|
||||
border: solid;
|
||||
padding: 1;
|
||||
}
|
||||
|
||||
#right-pane {
|
||||
width: 60%;
|
||||
width: 40%;
|
||||
border: solid;
|
||||
padding: 1;
|
||||
layout: vertical;
|
||||
@@ -61,10 +67,12 @@ class ConfirmationApp(App):
|
||||
self,
|
||||
result: Optional[ExtractionResult] = None,
|
||||
proposal_queue: Optional[asyncio.Queue] = None,
|
||||
context_queue: Optional[asyncio.Queue] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.result = result
|
||||
self.proposal_queue = proposal_queue
|
||||
self.context_queue = context_queue
|
||||
self.pending_updates: List[Union[LoreUpdate, CharacterStateUpdate]] = []
|
||||
|
||||
if result:
|
||||
@@ -81,6 +89,10 @@ class ConfirmationApp(App):
|
||||
DataTable(id="update-table"),
|
||||
id="left-pane",
|
||||
),
|
||||
Vertical(
|
||||
Static("No context available", id="context-pane"),
|
||||
id="middle-pane",
|
||||
),
|
||||
Vertical(
|
||||
Vertical(
|
||||
Label("Details:", id="details-label"),
|
||||
@@ -132,6 +144,9 @@ class ConfirmationApp(App):
|
||||
if self.proposal_queue:
|
||||
self.run_worker(self.poll_proposal_queue, thread=False)
|
||||
|
||||
if self.context_queue:
|
||||
self.run_worker(self.poll_context_queue, thread=False)
|
||||
|
||||
async def poll_proposal_queue(self) -> None:
|
||||
"""
|
||||
Background worker that polls the proposal queue for new extraction results.
|
||||
@@ -147,6 +162,24 @@ class ConfirmationApp(App):
|
||||
# Log error but keep the worker running
|
||||
self.log(f"Error polling proposal queue: {e}")
|
||||
|
||||
async def poll_context_queue(self) -> None:
|
||||
"""
|
||||
Background worker that polls the context queue for new RAG updates.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
update = await self.context_queue.get()
|
||||
context_pane = self.query_one("#context-pane", Static)
|
||||
|
||||
# Format the update for display
|
||||
display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}"
|
||||
context_pane.update(display_text)
|
||||
|
||||
if hasattr(self.context_queue, "task_done"):
|
||||
self.context_queue.task_done()
|
||||
except Exception as e:
|
||||
self.log(f"Error polling context queue: {e}")
|
||||
|
||||
def add_result(self, result: ExtractionResult) -> None:
|
||||
"""
|
||||
Adds results from the LLM processor to the TUI table.
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
import os
|
||||
|
||||
from reportlab.pdfgen import canvas
|
||||
|
||||
from src.rag.manager import RAGManager
|
||||
|
||||
|
||||
def create_dummy_phb(pdf_path: str):
|
||||
"""
|
||||
Creates a dummy PDF file to simulate the Player's Handbook for verification.
|
||||
"""
|
||||
print(f"Creating dummy PHB at {pdf_path}...")
|
||||
c = canvas.Canvas(pdf_path)
|
||||
|
||||
# Page 1: Fireball
|
||||
c.drawString(100, 750, "Fireball")
|
||||
c.drawString(
|
||||
100,
|
||||
730,
|
||||
"A bright streak flashes from your pointing finger to a point you choose within range.",
|
||||
)
|
||||
c.drawString(
|
||||
100,
|
||||
710,
|
||||
"Each creature in a 20-foot-radius sphere centered on that point must make a Dexterity saving throw.",
|
||||
)
|
||||
c.showPage()
|
||||
|
||||
# Page 2: Grappling
|
||||
c.drawString(100, 750, "Grappling")
|
||||
c.drawString(
|
||||
100,
|
||||
730,
|
||||
"Grappling is a special option available to any attack that hits with a melee attack.",
|
||||
)
|
||||
c.drawString(
|
||||
100,
|
||||
710,
|
||||
"A creature is grappled if it is the target of the grapple attack and the attack hits.",
|
||||
)
|
||||
c.showPage()
|
||||
|
||||
# Page 3: General Rules
|
||||
c.drawString(100, 750, "General Rules")
|
||||
c.drawString(
|
||||
100,
|
||||
730,
|
||||
"Combat is resolved in rounds, and each round represents 6 seconds of in-game time.",
|
||||
)
|
||||
c.showPage()
|
||||
|
||||
c.save()
|
||||
print(f"Dummy PHB created successfully at {pdf_path}")
|
||||
|
||||
|
||||
def test_rag():
|
||||
pdf_path = "data/phb_dummy.pdf"
|
||||
create_dummy_phb(pdf_path)
|
||||
|
||||
# Initialize RAG Manager
|
||||
rag = RAGManager(persist_dir="data/rag_index_test")
|
||||
|
||||
# Task 2.2: Ingest PDF
|
||||
print("\nTesting Ingestion...")
|
||||
rag.ingest_pdf(pdf_path)
|
||||
|
||||
# Task 2.3: Retrieve Logic
|
||||
print("\nTesting Retrieval...")
|
||||
|
||||
# Test 1: Fireball
|
||||
query1 = "What is Fireball?"
|
||||
results1 = rag.retrieve(query1)
|
||||
print(f"Query: {query1}")
|
||||
for res in results1:
|
||||
print(f"Source: {res.source} | Snippet: {res.snippet[:100]}...")
|
||||
|
||||
assert len(results1) > 0, "Should have retrieved at least one result for 'Fireball'"
|
||||
assert "Fireball" in results1[0].snippet, "The top result should contain 'Fireball'"
|
||||
|
||||
# Test 2: Grappling
|
||||
query2 = "How does grappling work?"
|
||||
results2 = rag.retrieve(query2)
|
||||
print(f"Query: {query2}")
|
||||
for res in results2:
|
||||
print(f"Source: {res.source} | Snippet: {res.snippet[:100]}...")
|
||||
|
||||
assert len(results2) > 0, (
|
||||
"Should have retrieved at least one result for 'Grappling'"
|
||||
)
|
||||
assert "Grappling" in results2[0].snippet, (
|
||||
"The top result should contain 'Grappling'"
|
||||
)
|
||||
|
||||
print("\n✅ RAG Verification Successful!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rag()
|
||||
Reference in New Issue
Block a user