Implement RAG summarization and context pipeline
- Add ContextPipeline for async RAG lookups - Implement RAG result summarization via LLMProcessor - Add CLI flag for PDF ingestion - Strip markdown code blocks from LLM responses - Update TUI context display to use ListItems
This commit is contained in:
@@ -1,4 +1,24 @@
|
||||
import argparse
|
||||
|
||||
from src.rag.manager import RAGManager
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="D&D Helpers CLI")
|
||||
parser.add_argument(
|
||||
"--ingest-pdf",
|
||||
type=str,
|
||||
help="Path to a PDF file to ingest into the RAG system",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.ingest_pdf:
|
||||
print(f"Ingesting PDF: {args.ingest_pdf}...")
|
||||
rag_manager = RAGManager()
|
||||
rag_manager.ingest_pdf(args.ingest_pdf)
|
||||
print("PDF ingestion complete.")
|
||||
|
||||
print("Hello from dnd-helpers!")
|
||||
|
||||
|
||||
|
||||
+11
-1
@@ -87,7 +87,17 @@ class LLMProcessor:
|
||||
response_format=response_format,
|
||||
extra_body={"enable_thinking": False},
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# Strip markdown code blocks if present
|
||||
if content.startswith("```"):
|
||||
import re
|
||||
|
||||
content = re.sub(
|
||||
r"^```(?:json)?\n?|```$", "", content, flags=re.MULTILINE
|
||||
).strip()
|
||||
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"LLM Error: {e}")
|
||||
return ""
|
||||
|
||||
+6
-4
@@ -4,8 +4,8 @@ NOISE_FILTER_SYSTEM_PROMPT = """
|
||||
You are a D&D Game Master's assistant. Given a transcript, remove all out-of-character (OOC) chatter, logistical discussions (e.g., 'Where is my d20?'), and non-relevant noise.
|
||||
|
||||
You must output your response as a JSON object with the following keys:
|
||||
- "contextual_info": Information that is interesting or relevant to the story/session but doesn't fit into lore, character state, or significant events (e.g., flavor text, atmospheric descriptions, player questions, or player commentary that adds context).
|
||||
- "filtered_text": The cleaned transcript used for structured data extraction.
|
||||
- "contextual_info": Information that is interesting or relevant to the story/session but doesn't fit into lore, character state, or significant events (e.g., flavor text, atmospheric descriptions, player commentary that adds context).
|
||||
- "filtered_text": The cleaned transcript. IMPORTANT: Keep all player questions, requests for rule clarifications, and mentions of spells, NPCs, or locations in this field, as they are used to trigger knowledge base lookups.
|
||||
|
||||
Keep the original speakers' names if they are present in the transcript.
|
||||
Do not add any commentary or summaries. Just filter the text.
|
||||
@@ -39,7 +39,8 @@ Return a JSON object with exactly these keys:
|
||||
- "inventory_changes": (list of objects with "item", "quantity", "action")
|
||||
3. "events": A list of strings. Each string should be a concise description of a significant plot development.
|
||||
4. "context": A list of objects. Each object MUST have:
|
||||
- "contextual_info": (string) The contextual information for the event
|
||||
- "query": (string) The original query or topic this context relates to
|
||||
- "snippet": (string) The contextual information or rule explanation
|
||||
- "source": (string) The source of the context (e.g., "players handbook, page 68")
|
||||
|
||||
Example Output:
|
||||
@@ -65,7 +66,8 @@ Example Output:
|
||||
],
|
||||
"context": [
|
||||
{
|
||||
"contextual_info": "fireball does 1d6 damage, with an area of effect of 10 feet circle",
|
||||
"query": "fireball",
|
||||
"snippet": "fireball does 1d6 damage, with an area of effect of 10 feet circle",
|
||||
"source": "players handbook, page 68"
|
||||
}
|
||||
]
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from src.llm.models import ContextUpdate
|
||||
from src.rag.manager import RAGManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContextPipeline:
|
||||
def __init__(self, rag_manager: RAGManager):
|
||||
self.rag_manager = rag_manager
|
||||
|
||||
async def process_message(
|
||||
self, speaker: str, text: str, context_queue: asyncio.Queue
|
||||
):
|
||||
"""
|
||||
Processes a single message and pushes summarized insights to the context queue.
|
||||
"""
|
||||
try:
|
||||
# Use RAGManager.retrieve with summarize=True to get concise insights
|
||||
# Run in a thread to avoid blocking the event loop
|
||||
insights = await asyncio.to_thread(
|
||||
self.rag_manager.retrieve, text, summarize=True
|
||||
)
|
||||
|
||||
if insights:
|
||||
logger.info(
|
||||
f"ContextPipeline: Found {len(insights)} insights for text: {text}"
|
||||
)
|
||||
for insight in insights:
|
||||
await context_queue.put(insight)
|
||||
else:
|
||||
logger.debug(f"ContextPipeline: No insights found for text: {text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ContextPipeline error processing message: {e}")
|
||||
|
||||
async def run(
|
||||
self,
|
||||
transcript_queue: asyncio.Queue,
|
||||
context_queue: asyncio.Queue,
|
||||
stop_event: asyncio.Event,
|
||||
):
|
||||
"""
|
||||
Main loop that listens to the transcript queue and triggers RAG lookups.
|
||||
"""
|
||||
logger.info("Context Pipeline started.")
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
# Get raw text from transcript queue (speaker, text)
|
||||
speaker, text = await transcript_queue.get()
|
||||
|
||||
# For now, implement the basic flow: every message triggers a lookup.
|
||||
# If performance becomes an issue, a filter can be added here.
|
||||
await self.process_message(speaker, text, context_queue)
|
||||
|
||||
# Mark the task as done
|
||||
transcript_queue.task_done()
|
||||
except Exception as e:
|
||||
logger.error(f"Context Pipeline loop error: {e}")
|
||||
|
||||
# Small sleep to avoid tight loop
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
logger.info("Context Pipeline stopped.")
|
||||
@@ -147,7 +147,9 @@ class PipelineOrchestrator:
|
||||
if filter_result.filtered_text:
|
||||
try:
|
||||
snippets = await asyncio.to_thread(
|
||||
self.rag_manager.retrieve, filter_result.filtered_text
|
||||
self.rag_manager.retrieve,
|
||||
filter_result.filtered_text,
|
||||
summarize=True,
|
||||
)
|
||||
rag_snippets = snippets
|
||||
except Exception as e:
|
||||
@@ -197,8 +199,8 @@ class PipelineOrchestrator:
|
||||
)
|
||||
)
|
||||
|
||||
# f. Also push the RAG snippets used for extraction to the context queue
|
||||
for snippet in rag_snippets:
|
||||
# f. Push the distilled RAG snippets from extraction to the context queue
|
||||
for snippet in extraction_result.context_updates:
|
||||
await self.context_queue.put(snippet)
|
||||
|
||||
except Exception as e:
|
||||
@@ -235,7 +237,9 @@ class PipelineOrchestrator:
|
||||
logger.info(f"RAG: Triggering query for: {query}")
|
||||
try:
|
||||
# Run retrieval in a thread to avoid blocking the event loop
|
||||
updates = await asyncio.to_thread(self.rag_manager.retrieve, query)
|
||||
updates = await asyncio.to_thread(
|
||||
self.rag_manager.retrieve, query, summarize=True
|
||||
)
|
||||
for update in updates:
|
||||
await self.context_queue.put(update)
|
||||
logger.info(
|
||||
@@ -295,10 +299,21 @@ class PipelineOrchestrator:
|
||||
self.is_running = True
|
||||
self.listener.start()
|
||||
|
||||
# Initialize Context Pipeline
|
||||
from src.pipeline.context_pipeline import ContextPipeline
|
||||
|
||||
self.context_pipeline = ContextPipeline(self.rag_manager)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
# Start workers as background tasks
|
||||
tasks = [
|
||||
asyncio.create_task(self.stt_worker()),
|
||||
asyncio.create_task(self.llm_worker()),
|
||||
asyncio.create_task(
|
||||
self.context_pipeline.run(
|
||||
self.transcript_queue, self.context_queue, stop_event
|
||||
)
|
||||
),
|
||||
asyncio.create_task(self.tui_worker()),
|
||||
]
|
||||
|
||||
@@ -310,6 +325,7 @@ class PipelineOrchestrator:
|
||||
pass
|
||||
finally:
|
||||
self.is_running = False
|
||||
stop_event.set()
|
||||
self.listener.stop()
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
+89
-2
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import chromadb
|
||||
import pdfplumber
|
||||
@@ -8,6 +8,7 @@ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
from src.llm.models import ContextUpdate
|
||||
from src.llm.processor import LLMProcessor
|
||||
|
||||
|
||||
class RAGManager:
|
||||
@@ -64,7 +65,90 @@ class RAGManager:
|
||||
)
|
||||
print(f"Successfully ingested {pdf_path} into the vector store.")
|
||||
|
||||
def retrieve(self, query: str, top_k: int = 3) -> List[ContextUpdate]:
|
||||
def ingest_file(self, file_path: str):
|
||||
"""
|
||||
Loads a single markdown file into the index.
|
||||
"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
# Use the filename as the source
|
||||
source = os.path.basename(file_path)
|
||||
doc = Document(text=text, metadata={"source": source})
|
||||
|
||||
# If index doesn't exist, initialize it
|
||||
if not self.index:
|
||||
self.index = VectorStoreIndex.from_documents(
|
||||
[doc], storage_context=self.storage_context
|
||||
)
|
||||
else:
|
||||
# Insert into existing index
|
||||
self.index.insert(doc)
|
||||
|
||||
print(f"Successfully ingested {file_path} into the vector store.")
|
||||
|
||||
def summarize_results(self, query: str, nodes: List[Any]) -> List[ContextUpdate]:
|
||||
"""
|
||||
Uses an LLM to transform raw snippets into concise "insights", filtering out irrelevant content.
|
||||
"""
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
processor = LLMProcessor()
|
||||
|
||||
# Construct the context from retrieved nodes
|
||||
context_text = "\n\n".join(
|
||||
[
|
||||
f"Source: {node.metadata.get('source', 'Unknown')}\nContent: {node.text}"
|
||||
for node in nodes
|
||||
]
|
||||
)
|
||||
|
||||
system_prompt = (
|
||||
"You are a precise research assistant. Your task is to analyze provided text snippets "
|
||||
"and extract only the information that is directly relevant to the user's query. "
|
||||
"1. If a snippet is irrelevant to the query, discard it completely. "
|
||||
"2. For relevant information, synthesize it into a concise, single-sentence 'insight'. "
|
||||
"3. Do not simply repeat the raw text; summarize it for clarity and brevity. "
|
||||
"4. If no snippets are relevant to the query, return an empty list. "
|
||||
"5. Be factual and do not hallucinate. Use only the provided snippets."
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
f"Query: {query}\n\n"
|
||||
f"Snippets:\n{context_text}\n\n"
|
||||
"Return a JSON object with a key 'insights' containing a list of objects, each with 'snippet' and 'source'."
|
||||
)
|
||||
|
||||
result = processor._call_llm(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
import json
|
||||
|
||||
try:
|
||||
data = json.loads(result)
|
||||
# Expecting a format like {"insights": [{"snippet": "...", "source": "..."}, ...]}
|
||||
insights = data.get("insights", []) if isinstance(data, dict) else data
|
||||
|
||||
if not insights:
|
||||
print(f"Summarization: No relevant insights found for query: {query}")
|
||||
|
||||
return [
|
||||
ContextUpdate(
|
||||
query=query, snippet=item["snippet"], source=item["source"]
|
||||
)
|
||||
for item in insights
|
||||
]
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||
print(f"Summarization parsing error: {e}")
|
||||
return []
|
||||
|
||||
def retrieve(
|
||||
self, query: str, top_k: int = 5, summarize: bool = False
|
||||
) -> List[ContextUpdate]:
|
||||
"""
|
||||
Retrieves the top-K most relevant snippets for a given query.
|
||||
"""
|
||||
@@ -76,6 +160,9 @@ class RAGManager:
|
||||
retriever = self.index.as_retriever(similarity_top_k=top_k)
|
||||
nodes = retriever.retrieve(query)
|
||||
|
||||
if summarize:
|
||||
return self.summarize_results(query, nodes)
|
||||
|
||||
results = []
|
||||
for node in nodes:
|
||||
# Extract metadata
|
||||
|
||||
+12
-3
@@ -3,7 +3,16 @@ from typing import List, Optional, Union
|
||||
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.containers import Container, Horizontal, Vertical
|
||||
from textual.widgets import Button, DataTable, Footer, Input, Label, ListView, Static
|
||||
from textual.widgets import (
|
||||
Button,
|
||||
DataTable,
|
||||
Footer,
|
||||
Input,
|
||||
Label,
|
||||
ListItem,
|
||||
ListView,
|
||||
Static,
|
||||
)
|
||||
|
||||
from src.llm.models import CharacterStateUpdate, ExtractionResult, LoreUpdate
|
||||
from src.persistence.characters import update_character_state
|
||||
@@ -181,8 +190,8 @@ class ConfirmationApp(App):
|
||||
# Format the update for display
|
||||
display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}"
|
||||
|
||||
# Add a new Static widget to the ListView
|
||||
context_pane.mount(Static(display_text))
|
||||
# Add a new ListItem widget to the top of the ListView for 'most recent'
|
||||
context_pane.mount(ListItem(Label(display_text), index=0))
|
||||
|
||||
if hasattr(self.context_queue, "task_done"):
|
||||
self.context_queue.task_done()
|
||||
|
||||
Reference in New Issue
Block a user