feat: implement RAG capabilities and Context Pane integration
- Add RAG capabilities using LlamaIndex and ChromaDB - Implement RAGManager for PHB indexing and retrieval - Integrate RAG pipeline into orchestrator to trigger queries based on extracted entities - Update TUI to include a 3-column layout with a real-time Context Pane - Define ContextUpdate data models in src/llm/models.py - Update requirements.txt with new dependencies
This commit is contained in:
@@ -6,3 +6,6 @@ textual
|
|||||||
typer
|
typer
|
||||||
openai
|
openai
|
||||||
python-dotenv
|
python-dotenv
|
||||||
|
llama-index
|
||||||
|
chromadb
|
||||||
|
pdfplumber
|
||||||
|
|||||||
@@ -44,6 +44,16 @@ class CharacterStateUpdate(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextUpdate(BaseModel):
|
||||||
|
query: str = Field(..., description="The search query used to retrieve the context")
|
||||||
|
snippet: str = Field(
|
||||||
|
..., description="The relevant text snippet retrieved from the source"
|
||||||
|
)
|
||||||
|
source: str = Field(
|
||||||
|
..., description="The source of the snippet (e.g., 'PHB p. 12')"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExtractionResult(BaseModel):
|
class ExtractionResult(BaseModel):
|
||||||
lore_updates: List[LoreUpdate] = Field(
|
lore_updates: List[LoreUpdate] = Field(
|
||||||
default_factory=list, description="List of discovered lore facts", alias="lore"
|
default_factory=list, description="List of discovered lore facts", alias="lore"
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from src.llm.models import ExtractionResult
|
from src.llm.models import ContextUpdate, ExtractionResult
|
||||||
from src.llm.processor import LLMProcessor
|
from src.llm.processor import LLMProcessor
|
||||||
|
from src.rag.manager import RAGManager
|
||||||
from src.stt.listener import AudioListener
|
from src.stt.listener import AudioListener
|
||||||
from src.stt.transcriber import Transcriber
|
from src.stt.transcriber import Transcriber
|
||||||
from src.ui.tui import ConfirmationApp
|
from src.ui.tui import ConfirmationApp
|
||||||
@@ -37,10 +38,12 @@ class PipelineOrchestrator:
|
|||||||
self.listener = AudioListener(loop=self.loop)
|
self.listener = AudioListener(loop=self.loop)
|
||||||
self.transcriber = Transcriber(model_size="small")
|
self.transcriber = Transcriber(model_size="small")
|
||||||
self.processor = LLMProcessor()
|
self.processor = LLMProcessor()
|
||||||
|
self.rag_manager = RAGManager()
|
||||||
|
|
||||||
# Queues
|
# Queues
|
||||||
self.transcript_queue = asyncio.Queue()
|
self.transcript_queue = asyncio.Queue()
|
||||||
self.proposal_queue = asyncio.Queue()
|
self.proposal_queue = asyncio.Queue()
|
||||||
|
self.context_queue = asyncio.Queue()
|
||||||
|
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
|
||||||
@@ -148,6 +151,9 @@ class PipelineOrchestrator:
|
|||||||
f"LLM Worker: Proposal generated. Putting into proposal queue. (Lore: {len(result.lore_updates)}, Char: {len(result.character_updates)})"
|
f"LLM Worker: Proposal generated. Putting into proposal queue. (Lore: {len(result.lore_updates)}, Char: {len(result.character_updates)})"
|
||||||
)
|
)
|
||||||
await self.proposal_queue.put(result)
|
await self.proposal_queue.put(result)
|
||||||
|
|
||||||
|
# Trigger RAG query based on extracted entities
|
||||||
|
await self._trigger_rag_queries(result)
|
||||||
else:
|
else:
|
||||||
logger.info("LLM Worker: No relevant game data extracted.")
|
logger.info("LLM Worker: No relevant game data extracted.")
|
||||||
|
|
||||||
@@ -157,6 +163,43 @@ class PipelineOrchestrator:
|
|||||||
# Small sleep
|
# Small sleep
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
async def _trigger_rag_queries(self, result: ExtractionResult):
|
||||||
|
"""
|
||||||
|
Triggers RAG queries based on the extracted entities and results.
|
||||||
|
"""
|
||||||
|
queries = set()
|
||||||
|
|
||||||
|
# Collect entities from lore updates
|
||||||
|
for update in result.lore_updates:
|
||||||
|
if update.entity_name:
|
||||||
|
queries.add(update.entity_name)
|
||||||
|
|
||||||
|
# Collect entities from character updates
|
||||||
|
for update in result.character_updates:
|
||||||
|
if update.character_name:
|
||||||
|
queries.add(update.character_name)
|
||||||
|
|
||||||
|
# Collect events as potential queries
|
||||||
|
for event in result.significant_events:
|
||||||
|
queries.add(event)
|
||||||
|
|
||||||
|
if not queries:
|
||||||
|
logger.info("RAG: No query terms identified from extraction result.")
|
||||||
|
return
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
logger.info(f"RAG: Triggering query for: {query}")
|
||||||
|
try:
|
||||||
|
# Run retrieval in a thread to avoid blocking the event loop
|
||||||
|
updates = await asyncio.to_thread(self.rag_manager.retrieve, query)
|
||||||
|
for update in updates:
|
||||||
|
await self.context_queue.put(update)
|
||||||
|
logger.info(
|
||||||
|
f"RAG: Retrieved snippet for {query} from {update.source}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"RAG: Error retrieving context for {query}: {e}")
|
||||||
|
|
||||||
def _get_wiki_context(self) -> str:
|
def _get_wiki_context(self) -> str:
|
||||||
"""
|
"""
|
||||||
Reads all files in the lore directory and returns them as a 저희 context string.
|
Reads all files in the lore directory and returns them as a 저희 context string.
|
||||||
@@ -188,8 +231,11 @@ class PipelineOrchestrator:
|
|||||||
logger.info("TUI Worker started.")
|
logger.info("TUI Worker started.")
|
||||||
try:
|
try:
|
||||||
# Launch TUI exactly once.
|
# Launch TUI exactly once.
|
||||||
# Pass the proposal queue to the app.
|
# Pass the proposal queue and context queue to the app.
|
||||||
app = ConfirmationApp(proposal_queue=self.proposal_queue)
|
app = ConfirmationApp(
|
||||||
|
proposal_queue=self.proposal_queue,
|
||||||
|
context_queue=self.context_queue,
|
||||||
|
)
|
||||||
await app.run_async()
|
await app.run_async()
|
||||||
self.stop()
|
self.stop()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -0,0 +1,86 @@
|
|||||||
|
import os
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
import pdfplumber
|
||||||
|
from llama_index.core import Document, Settings, StorageContext, VectorStoreIndex
|
||||||
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||||
|
|
||||||
|
from src.llm.models import ContextUpdate
|
||||||
|
|
||||||
|
|
||||||
|
class RAGManager:
|
||||||
|
def __init__(self, persist_dir: str = "data/rag_index"):
|
||||||
|
self.persist_dir = persist_dir
|
||||||
|
self.db = chromadb.PersistentClient(path=self.persist_dir)
|
||||||
|
self.collection_name = "phb_collection"
|
||||||
|
|
||||||
|
# Initialize Chroma Vector Store
|
||||||
|
self.vector_store = ChromaVectorStore(
|
||||||
|
chroma_collection=self.db.get_or_create_collection(self.collection_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize Storage Context
|
||||||
|
self.storage_context = StorageContext.from_defaults(
|
||||||
|
vector_store=self.vector_store
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use a local HuggingFace embedding model to avoid API key issues during verification
|
||||||
|
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
|
||||||
|
|
||||||
|
# Load index if it exists, otherwise initialize
|
||||||
|
try:
|
||||||
|
self.index = VectorStoreIndex.from_vector_store(
|
||||||
|
self.vector_store, storage_context=self.storage_context
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self.index = None
|
||||||
|
|
||||||
|
def ingest_pdf(self, pdf_path: str):
|
||||||
|
"""
|
||||||
|
Parses a PDF, chunks it, and stores embeddings in ChromaDB.
|
||||||
|
"""
|
||||||
|
documents = []
|
||||||
|
with pdfplumber.open(pdf_path) as pdf:
|
||||||
|
for i, page in enumerate(pdf.pages):
|
||||||
|
text = page.extract_text()
|
||||||
|
if text:
|
||||||
|
# Create a document for each page
|
||||||
|
# In a real scenario, we might use a recursive character splitter
|
||||||
|
# but for PHB, page-level chunking is a good start.
|
||||||
|
doc = Document(
|
||||||
|
text=text, metadata={"source": f"PHB p. {i + 1}", "page": i + 1}
|
||||||
|
)
|
||||||
|
documents.append(doc)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print(f"No text extracted from {pdf_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create index from documents
|
||||||
|
self.index = VectorStoreIndex.from_documents(
|
||||||
|
documents, storage_context=self.storage_context
|
||||||
|
)
|
||||||
|
print(f"Successfully ingested {pdf_path} into the vector store.")
|
||||||
|
|
||||||
|
def retrieve(self, query: str, top_k: int = 3) -> List[ContextUpdate]:
|
||||||
|
"""
|
||||||
|
Retrieves the top-K most relevant snippets for a given query.
|
||||||
|
"""
|
||||||
|
if not self.index:
|
||||||
|
print("Index not initialized. Please ingest documents first.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Create a retriever
|
||||||
|
retriever = self.index.as_retriever(similarity_top_k=top_k)
|
||||||
|
nodes = retriever.retrieve(query)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for node in nodes:
|
||||||
|
# Extract metadata
|
||||||
|
source = node.metadata.get("source", "Unknown Source")
|
||||||
|
|
||||||
|
results.append(ContextUpdate(query=query, snippet=node.text, source=source))
|
||||||
|
|
||||||
|
return results
|
||||||
+35
-2
@@ -17,13 +17,19 @@ class ConfirmationApp(App):
|
|||||||
}
|
}
|
||||||
|
|
||||||
#left-pane {
|
#left-pane {
|
||||||
width: 40%;
|
width: 30%;
|
||||||
|
border: solid;
|
||||||
|
padding: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#middle-pane {
|
||||||
|
width: 30%;
|
||||||
border: solid;
|
border: solid;
|
||||||
padding: 1;
|
padding: 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
#right-pane {
|
#right-pane {
|
||||||
width: 60%;
|
width: 40%;
|
||||||
border: solid;
|
border: solid;
|
||||||
padding: 1;
|
padding: 1;
|
||||||
layout: vertical;
|
layout: vertical;
|
||||||
@@ -61,10 +67,12 @@ class ConfirmationApp(App):
|
|||||||
self,
|
self,
|
||||||
result: Optional[ExtractionResult] = None,
|
result: Optional[ExtractionResult] = None,
|
||||||
proposal_queue: Optional[asyncio.Queue] = None,
|
proposal_queue: Optional[asyncio.Queue] = None,
|
||||||
|
context_queue: Optional[asyncio.Queue] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.result = result
|
self.result = result
|
||||||
self.proposal_queue = proposal_queue
|
self.proposal_queue = proposal_queue
|
||||||
|
self.context_queue = context_queue
|
||||||
self.pending_updates: List[Union[LoreUpdate, CharacterStateUpdate]] = []
|
self.pending_updates: List[Union[LoreUpdate, CharacterStateUpdate]] = []
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
@@ -81,6 +89,10 @@ class ConfirmationApp(App):
|
|||||||
DataTable(id="update-table"),
|
DataTable(id="update-table"),
|
||||||
id="left-pane",
|
id="left-pane",
|
||||||
),
|
),
|
||||||
|
Vertical(
|
||||||
|
Static("No context available", id="context-pane"),
|
||||||
|
id="middle-pane",
|
||||||
|
),
|
||||||
Vertical(
|
Vertical(
|
||||||
Vertical(
|
Vertical(
|
||||||
Label("Details:", id="details-label"),
|
Label("Details:", id="details-label"),
|
||||||
@@ -132,6 +144,9 @@ class ConfirmationApp(App):
|
|||||||
if self.proposal_queue:
|
if self.proposal_queue:
|
||||||
self.run_worker(self.poll_proposal_queue, thread=False)
|
self.run_worker(self.poll_proposal_queue, thread=False)
|
||||||
|
|
||||||
|
if self.context_queue:
|
||||||
|
self.run_worker(self.poll_context_queue, thread=False)
|
||||||
|
|
||||||
async def poll_proposal_queue(self) -> None:
|
async def poll_proposal_queue(self) -> None:
|
||||||
"""
|
"""
|
||||||
Background worker that polls the proposal queue for new extraction results.
|
Background worker that polls the proposal queue for new extraction results.
|
||||||
@@ -147,6 +162,24 @@ class ConfirmationApp(App):
|
|||||||
# Log error but keep the worker running
|
# Log error but keep the worker running
|
||||||
self.log(f"Error polling proposal queue: {e}")
|
self.log(f"Error polling proposal queue: {e}")
|
||||||
|
|
||||||
|
async def poll_context_queue(self) -> None:
|
||||||
|
"""
|
||||||
|
Background worker that polls the context queue for new RAG updates.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
update = await self.context_queue.get()
|
||||||
|
context_pane = self.query_one("#context-pane", Static)
|
||||||
|
|
||||||
|
# Format the update for display
|
||||||
|
display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}"
|
||||||
|
context_pane.update(display_text)
|
||||||
|
|
||||||
|
if hasattr(self.context_queue, "task_done"):
|
||||||
|
self.context_queue.task_done()
|
||||||
|
except Exception as e:
|
||||||
|
self.log(f"Error polling context queue: {e}")
|
||||||
|
|
||||||
def add_result(self, result: ExtractionResult) -> None:
|
def add_result(self, result: ExtractionResult) -> None:
|
||||||
"""
|
"""
|
||||||
Adds results from the LLM processor to the TUI table.
|
Adds results from the LLM processor to the TUI table.
|
||||||
|
|||||||
@@ -0,0 +1,98 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from reportlab.pdfgen import canvas
|
||||||
|
|
||||||
|
from src.rag.manager import RAGManager
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_phb(pdf_path: str):
|
||||||
|
"""
|
||||||
|
Creates a dummy PDF file to simulate the Player's Handbook for verification.
|
||||||
|
"""
|
||||||
|
print(f"Creating dummy PHB at {pdf_path}...")
|
||||||
|
c = canvas.Canvas(pdf_path)
|
||||||
|
|
||||||
|
# Page 1: Fireball
|
||||||
|
c.drawString(100, 750, "Fireball")
|
||||||
|
c.drawString(
|
||||||
|
100,
|
||||||
|
730,
|
||||||
|
"A bright streak flashes from your pointing finger to a point you choose within range.",
|
||||||
|
)
|
||||||
|
c.drawString(
|
||||||
|
100,
|
||||||
|
710,
|
||||||
|
"Each creature in a 20-foot-radius sphere centered on that point must make a Dexterity saving throw.",
|
||||||
|
)
|
||||||
|
c.showPage()
|
||||||
|
|
||||||
|
# Page 2: Grappling
|
||||||
|
c.drawString(100, 750, "Grappling")
|
||||||
|
c.drawString(
|
||||||
|
100,
|
||||||
|
730,
|
||||||
|
"Grappling is a special option available to any attack that hits with a melee attack.",
|
||||||
|
)
|
||||||
|
c.drawString(
|
||||||
|
100,
|
||||||
|
710,
|
||||||
|
"A creature is grappled if it is the target of the grapple attack and the attack hits.",
|
||||||
|
)
|
||||||
|
c.showPage()
|
||||||
|
|
||||||
|
# Page 3: General Rules
|
||||||
|
c.drawString(100, 750, "General Rules")
|
||||||
|
c.drawString(
|
||||||
|
100,
|
||||||
|
730,
|
||||||
|
"Combat is resolved in rounds, and each round represents 6 seconds of in-game time.",
|
||||||
|
)
|
||||||
|
c.showPage()
|
||||||
|
|
||||||
|
c.save()
|
||||||
|
print(f"Dummy PHB created successfully at {pdf_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag():
|
||||||
|
pdf_path = "data/phb_dummy.pdf"
|
||||||
|
create_dummy_phb(pdf_path)
|
||||||
|
|
||||||
|
# Initialize RAG Manager
|
||||||
|
rag = RAGManager(persist_dir="data/rag_index_test")
|
||||||
|
|
||||||
|
# Task 2.2: Ingest PDF
|
||||||
|
print("\nTesting Ingestion...")
|
||||||
|
rag.ingest_pdf(pdf_path)
|
||||||
|
|
||||||
|
# Task 2.3: Retrieve Logic
|
||||||
|
print("\nTesting Retrieval...")
|
||||||
|
|
||||||
|
# Test 1: Fireball
|
||||||
|
query1 = "What is Fireball?"
|
||||||
|
results1 = rag.retrieve(query1)
|
||||||
|
print(f"Query: {query1}")
|
||||||
|
for res in results1:
|
||||||
|
print(f"Source: {res.source} | Snippet: {res.snippet[:100]}...")
|
||||||
|
|
||||||
|
assert len(results1) > 0, "Should have retrieved at least one result for 'Fireball'"
|
||||||
|
assert "Fireball" in results1[0].snippet, "The top result should contain 'Fireball'"
|
||||||
|
|
||||||
|
# Test 2: Grappling
|
||||||
|
query2 = "How does grappling work?"
|
||||||
|
results2 = rag.retrieve(query2)
|
||||||
|
print(f"Query: {query2}")
|
||||||
|
for res in results2:
|
||||||
|
print(f"Source: {res.source} | Snippet: {res.snippet[:100]}...")
|
||||||
|
|
||||||
|
assert len(results2) > 0, (
|
||||||
|
"Should have retrieved at least one result for 'Grappling'"
|
||||||
|
)
|
||||||
|
assert "Grappling" in results2[0].snippet, (
|
||||||
|
"The top result should contain 'Grappling'"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n✅ RAG Verification Successful!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_rag()
|
||||||
Reference in New Issue
Block a user