Implement RAG summarization and context pipeline

- Add ContextPipeline for async RAG lookups
- Implement RAG result summarization via LLMProcessor
- Add CLI flag for PDF ingestion
- Strip markdown code blocks from LLM responses
- Update TUI context display to use ListItems
This commit is contained in:
2026-05-27 00:17:47 -07:00
parent b83d9b5e6a
commit b25f82cefc
7 changed files with 225 additions and 14 deletions
+20
View File
@@ -1,4 +1,24 @@
import argparse
from src.rag.manager import RAGManager
def main(): def main():
parser = argparse.ArgumentParser(description="D&D Helpers CLI")
parser.add_argument(
"--ingest-pdf",
type=str,
help="Path to a PDF file to ingest into the RAG system",
)
args = parser.parse_args()
if args.ingest_pdf:
print(f"Ingesting PDF: {args.ingest_pdf}...")
rag_manager = RAGManager()
rag_manager.ingest_pdf(args.ingest_pdf)
print("PDF ingestion complete.")
print("Hello from dnd-helpers!") print("Hello from dnd-helpers!")
+11 -1
View File
@@ -87,7 +87,17 @@ class LLMProcessor:
response_format=response_format, response_format=response_format,
extra_body={"enable_thinking": False}, extra_body={"enable_thinking": False},
) )
return response.choices[0].message.content content = response.choices[0].message.content
# Strip markdown code blocks if present
if content.startswith("```"):
import re
content = re.sub(
r"^```(?:json)?\n?|```$", "", content, flags=re.MULTILINE
).strip()
return content
except Exception as e: except Exception as e:
logger.error(f"LLM Error: {e}") logger.error(f"LLM Error: {e}")
return "" return ""
+6 -4
View File
@@ -4,8 +4,8 @@ NOISE_FILTER_SYSTEM_PROMPT = """
You are a D&D Game Master's assistant. Given a transcript, remove all out-of-character (OOC) chatter, logistical discussions (e.g., 'Where is my d20?'), and non-relevant noise. You are a D&D Game Master's assistant. Given a transcript, remove all out-of-character (OOC) chatter, logistical discussions (e.g., 'Where is my d20?'), and non-relevant noise.
You must output your response as a JSON object with the following keys: You must output your response as a JSON object with the following keys:
- "contextual_info": Information that is interesting or relevant to the story/session but doesn't fit into lore, character state, or significant events (e.g., flavor text, atmospheric descriptions, player questions, or player commentary that adds context). - "contextual_info": Information that is interesting or relevant to the story/session but doesn't fit into lore, character state, or significant events (e.g., flavor text, atmospheric descriptions, player commentary that adds context).
- "filtered_text": The cleaned transcript used for structured data extraction. - "filtered_text": The cleaned transcript. IMPORTANT: Keep all player questions, requests for rule clarifications, and mentions of spells, NPCs, or locations in this field, as they are used to trigger knowledge base lookups.
Keep the original speakers' names if they are present in the transcript. Keep the original speakers' names if they are present in the transcript.
Do not add any commentary or summaries. Just filter the text. Do not add any commentary or summaries. Just filter the text.
@@ -39,7 +39,8 @@ Return a JSON object with exactly these keys:
- "inventory_changes": (list of objects with "item", "quantity", "action") - "inventory_changes": (list of objects with "item", "quantity", "action")
3. "events": A list of strings. Each string should be a concise description of a significant plot development. 3. "events": A list of strings. Each string should be a concise description of a significant plot development.
4. "context": A list of objects. Each object MUST have: 4. "context": A list of objects. Each object MUST have:
- "contextual_info": (string) The contextual information for the event - "query": (string) The original query or topic this context relates to
- "snippet": (string) The contextual information or rule explanation
- "source": (string) The source of the context (e.g., "players handbook, page 68") - "source": (string) The source of the context (e.g., "players handbook, page 68")
Example Output: Example Output:
@@ -65,7 +66,8 @@ Example Output:
], ],
"context": [ "context": [
{ {
"contextual_info": "fireball does 1d6 damage, with an area of effect of 10 feet circle", "query": "fireball",
"snippet": "fireball does 1d6 damage, with an area of effect of 10 feet circle",
"source": "players handbook, page 68" "source": "players handbook, page 68"
} }
] ]
+67
View File
@@ -0,0 +1,67 @@
import asyncio
import logging
from typing import Tuple
from src.llm.models import ContextUpdate
from src.rag.manager import RAGManager
logger = logging.getLogger(__name__)
class ContextPipeline:
def __init__(self, rag_manager: RAGManager):
self.rag_manager = rag_manager
async def process_message(
self, speaker: str, text: str, context_queue: asyncio.Queue
):
"""
Processes a single message and pushes summarized insights to the context queue.
"""
try:
# Use RAGManager.retrieve with summarize=True to get concise insights
# Run in a thread to avoid blocking the event loop
insights = await asyncio.to_thread(
self.rag_manager.retrieve, text, summarize=True
)
if insights:
logger.info(
f"ContextPipeline: Found {len(insights)} insights for text: {text}"
)
for insight in insights:
await context_queue.put(insight)
else:
logger.debug(f"ContextPipeline: No insights found for text: {text}")
except Exception as e:
logger.error(f"ContextPipeline error processing message: {e}")
async def run(
self,
transcript_queue: asyncio.Queue,
context_queue: asyncio.Queue,
stop_event: asyncio.Event,
):
"""
Main loop that listens to the transcript queue and triggers RAG lookups.
"""
logger.info("Context Pipeline started.")
while not stop_event.is_set():
try:
# Get raw text from transcript queue (speaker, text)
speaker, text = await transcript_queue.get()
# For now, implement the basic flow: every message triggers a lookup.
# If performance becomes an issue, a filter can be added here.
await self.process_message(speaker, text, context_queue)
# Mark the task as done
transcript_queue.task_done()
except Exception as e:
logger.error(f"Context Pipeline loop error: {e}")
# Small sleep to avoid tight loop
await asyncio.sleep(0.1)
logger.info("Context Pipeline stopped.")
+20 -4
View File
@@ -147,7 +147,9 @@ class PipelineOrchestrator:
if filter_result.filtered_text: if filter_result.filtered_text:
try: try:
snippets = await asyncio.to_thread( snippets = await asyncio.to_thread(
self.rag_manager.retrieve, filter_result.filtered_text self.rag_manager.retrieve,
filter_result.filtered_text,
summarize=True,
) )
rag_snippets = snippets rag_snippets = snippets
except Exception as e: except Exception as e:
@@ -197,8 +199,8 @@ class PipelineOrchestrator:
) )
) )
# f. Also push the RAG snippets used for extraction to the context queue # f. Push the distilled RAG snippets from extraction to the context queue
for snippet in rag_snippets: for snippet in extraction_result.context_updates:
await self.context_queue.put(snippet) await self.context_queue.put(snippet)
except Exception as e: except Exception as e:
@@ -235,7 +237,9 @@ class PipelineOrchestrator:
logger.info(f"RAG: Triggering query for: {query}") logger.info(f"RAG: Triggering query for: {query}")
try: try:
# Run retrieval in a thread to avoid blocking the event loop # Run retrieval in a thread to avoid blocking the event loop
updates = await asyncio.to_thread(self.rag_manager.retrieve, query) updates = await asyncio.to_thread(
self.rag_manager.retrieve, query, summarize=True
)
for update in updates: for update in updates:
await self.context_queue.put(update) await self.context_queue.put(update)
logger.info( logger.info(
@@ -295,10 +299,21 @@ class PipelineOrchestrator:
self.is_running = True self.is_running = True
self.listener.start() self.listener.start()
# Initialize Context Pipeline
from src.pipeline.context_pipeline import ContextPipeline
self.context_pipeline = ContextPipeline(self.rag_manager)
stop_event = asyncio.Event()
# Start workers as background tasks # Start workers as background tasks
tasks = [ tasks = [
asyncio.create_task(self.stt_worker()), asyncio.create_task(self.stt_worker()),
asyncio.create_task(self.llm_worker()), asyncio.create_task(self.llm_worker()),
asyncio.create_task(
self.context_pipeline.run(
self.transcript_queue, self.context_queue, stop_event
)
),
asyncio.create_task(self.tui_worker()), asyncio.create_task(self.tui_worker()),
] ]
@@ -310,6 +325,7 @@ class PipelineOrchestrator:
pass pass
finally: finally:
self.is_running = False self.is_running = False
stop_event.set()
self.listener.stop() self.listener.stop()
for task in tasks: for task in tasks:
task.cancel() task.cancel()
+89 -2
View File
@@ -1,5 +1,5 @@
import os import os
from typing import List, Optional from typing import Any, List, Optional
import chromadb import chromadb
import pdfplumber import pdfplumber
@@ -8,6 +8,7 @@ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.chroma import ChromaVectorStore
from src.llm.models import ContextUpdate from src.llm.models import ContextUpdate
from src.llm.processor import LLMProcessor
class RAGManager: class RAGManager:
@@ -64,7 +65,90 @@ class RAGManager:
) )
print(f"Successfully ingested {pdf_path} into the vector store.") print(f"Successfully ingested {pdf_path} into the vector store.")
def retrieve(self, query: str, top_k: int = 3) -> List[ContextUpdate]: def ingest_file(self, file_path: str):
"""
Loads a single markdown file into the index.
"""
with open(file_path, "r", encoding="utf-8") as f:
text = f.read()
# Use the filename as the source
source = os.path.basename(file_path)
doc = Document(text=text, metadata={"source": source})
# If index doesn't exist, initialize it
if not self.index:
self.index = VectorStoreIndex.from_documents(
[doc], storage_context=self.storage_context
)
else:
# Insert into existing index
self.index.insert(doc)
print(f"Successfully ingested {file_path} into the vector store.")
def summarize_results(self, query: str, nodes: List[Any]) -> List[ContextUpdate]:
"""
Uses an LLM to transform raw snippets into concise "insights", filtering out irrelevant content.
"""
if not nodes:
return []
processor = LLMProcessor()
# Construct the context from retrieved nodes
context_text = "\n\n".join(
[
f"Source: {node.metadata.get('source', 'Unknown')}\nContent: {node.text}"
for node in nodes
]
)
system_prompt = (
"You are a precise research assistant. Your task is to analyze provided text snippets "
"and extract only the information that is directly relevant to the user's query. "
"1. If a snippet is irrelevant to the query, discard it completely. "
"2. For relevant information, synthesize it into a concise, single-sentence 'insight'. "
"3. Do not simply repeat the raw text; summarize it for clarity and brevity. "
"4. If no snippets are relevant to the query, return an empty list. "
"5. Be factual and do not hallucinate. Use only the provided snippets."
)
user_prompt = (
f"Query: {query}\n\n"
f"Snippets:\n{context_text}\n\n"
"Return a JSON object with a key 'insights' containing a list of objects, each with 'snippet' and 'source'."
)
result = processor._call_llm(
system_prompt,
user_prompt,
response_format={"type": "json_object"},
)
import json
try:
data = json.loads(result)
# Expecting a format like {"insights": [{"snippet": "...", "source": "..."}, ...]}
insights = data.get("insights", []) if isinstance(data, dict) else data
if not insights:
print(f"Summarization: No relevant insights found for query: {query}")
return [
ContextUpdate(
query=query, snippet=item["snippet"], source=item["source"]
)
for item in insights
]
except (json.JSONDecodeError, KeyError, TypeError) as e:
print(f"Summarization parsing error: {e}")
return []
def retrieve(
self, query: str, top_k: int = 5, summarize: bool = False
) -> List[ContextUpdate]:
""" """
Retrieves the top-K most relevant snippets for a given query. Retrieves the top-K most relevant snippets for a given query.
""" """
@@ -76,6 +160,9 @@ class RAGManager:
retriever = self.index.as_retriever(similarity_top_k=top_k) retriever = self.index.as_retriever(similarity_top_k=top_k)
nodes = retriever.retrieve(query) nodes = retriever.retrieve(query)
if summarize:
return self.summarize_results(query, nodes)
results = [] results = []
for node in nodes: for node in nodes:
# Extract metadata # Extract metadata
+12 -3
View File
@@ -3,7 +3,16 @@ from typing import List, Optional, Union
from textual.app import App, ComposeResult from textual.app import App, ComposeResult
from textual.containers import Container, Horizontal, Vertical from textual.containers import Container, Horizontal, Vertical
from textual.widgets import Button, DataTable, Footer, Input, Label, ListView, Static from textual.widgets import (
Button,
DataTable,
Footer,
Input,
Label,
ListItem,
ListView,
Static,
)
from src.llm.models import CharacterStateUpdate, ExtractionResult, LoreUpdate from src.llm.models import CharacterStateUpdate, ExtractionResult, LoreUpdate
from src.persistence.characters import update_character_state from src.persistence.characters import update_character_state
@@ -181,8 +190,8 @@ class ConfirmationApp(App):
# Format the update for display # Format the update for display
display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}" display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}"
# Add a new Static widget to the ListView # Add a new ListItem widget to the top of the ListView for 'most recent'
context_pane.mount(Static(display_text)) context_pane.mount(ListItem(Label(display_text), index=0))
if hasattr(self.context_queue, "task_done"): if hasattr(self.context_queue, "task_done"):
self.context_queue.task_done() self.context_queue.task_done()