Implement RAG summarization and context pipeline
- Add ContextPipeline for async RAG lookups - Implement RAG result summarization via LLMProcessor - Add CLI flag for PDF ingestion - Strip markdown code blocks from LLM responses - Update TUI context display to use ListItems
This commit is contained in:
@@ -1,4 +1,24 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
from src.rag.manager import RAGManager
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="D&D Helpers CLI")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ingest-pdf",
|
||||||
|
type=str,
|
||||||
|
help="Path to a PDF file to ingest into the RAG system",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.ingest_pdf:
|
||||||
|
print(f"Ingesting PDF: {args.ingest_pdf}...")
|
||||||
|
rag_manager = RAGManager()
|
||||||
|
rag_manager.ingest_pdf(args.ingest_pdf)
|
||||||
|
print("PDF ingestion complete.")
|
||||||
|
|
||||||
print("Hello from dnd-helpers!")
|
print("Hello from dnd-helpers!")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+11
-1
@@ -87,7 +87,17 @@ class LLMProcessor:
|
|||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
extra_body={"enable_thinking": False},
|
extra_body={"enable_thinking": False},
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
|
# Strip markdown code blocks if present
|
||||||
|
if content.startswith("```"):
|
||||||
|
import re
|
||||||
|
|
||||||
|
content = re.sub(
|
||||||
|
r"^```(?:json)?\n?|```$", "", content, flags=re.MULTILINE
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
return content
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM Error: {e}")
|
logger.error(f"LLM Error: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
+6
-4
@@ -4,8 +4,8 @@ NOISE_FILTER_SYSTEM_PROMPT = """
|
|||||||
You are a D&D Game Master's assistant. Given a transcript, remove all out-of-character (OOC) chatter, logistical discussions (e.g., 'Where is my d20?'), and non-relevant noise.
|
You are a D&D Game Master's assistant. Given a transcript, remove all out-of-character (OOC) chatter, logistical discussions (e.g., 'Where is my d20?'), and non-relevant noise.
|
||||||
|
|
||||||
You must output your response as a JSON object with the following keys:
|
You must output your response as a JSON object with the following keys:
|
||||||
- "contextual_info": Information that is interesting or relevant to the story/session but doesn't fit into lore, character state, or significant events (e.g., flavor text, atmospheric descriptions, player questions, or player commentary that adds context).
|
- "contextual_info": Information that is interesting or relevant to the story/session but doesn't fit into lore, character state, or significant events (e.g., flavor text, atmospheric descriptions, player commentary that adds context).
|
||||||
- "filtered_text": The cleaned transcript used for structured data extraction.
|
- "filtered_text": The cleaned transcript. IMPORTANT: Keep all player questions, requests for rule clarifications, and mentions of spells, NPCs, or locations in this field, as they are used to trigger knowledge base lookups.
|
||||||
|
|
||||||
Keep the original speakers' names if they are present in the transcript.
|
Keep the original speakers' names if they are present in the transcript.
|
||||||
Do not add any commentary or summaries. Just filter the text.
|
Do not add any commentary or summaries. Just filter the text.
|
||||||
@@ -39,7 +39,8 @@ Return a JSON object with exactly these keys:
|
|||||||
- "inventory_changes": (list of objects with "item", "quantity", "action")
|
- "inventory_changes": (list of objects with "item", "quantity", "action")
|
||||||
3. "events": A list of strings. Each string should be a concise description of a significant plot development.
|
3. "events": A list of strings. Each string should be a concise description of a significant plot development.
|
||||||
4. "context": A list of objects. Each object MUST have:
|
4. "context": A list of objects. Each object MUST have:
|
||||||
- "contextual_info": (string) The contextual information for the event
|
- "query": (string) The original query or topic this context relates to
|
||||||
|
- "snippet": (string) The contextual information or rule explanation
|
||||||
- "source": (string) The source of the context (e.g., "players handbook, page 68")
|
- "source": (string) The source of the context (e.g., "players handbook, page 68")
|
||||||
|
|
||||||
Example Output:
|
Example Output:
|
||||||
@@ -65,7 +66,8 @@ Example Output:
|
|||||||
],
|
],
|
||||||
"context": [
|
"context": [
|
||||||
{
|
{
|
||||||
"contextual_info": "fireball does 1d6 damage, with an area of effect of 10 feet circle",
|
"query": "fireball",
|
||||||
|
"snippet": "fireball does 1d6 damage, with an area of effect of 10 feet circle",
|
||||||
"source": "players handbook, page 68"
|
"source": "players handbook, page 68"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,67 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from src.llm.models import ContextUpdate
|
||||||
|
from src.rag.manager import RAGManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextPipeline:
|
||||||
|
def __init__(self, rag_manager: RAGManager):
|
||||||
|
self.rag_manager = rag_manager
|
||||||
|
|
||||||
|
async def process_message(
|
||||||
|
self, speaker: str, text: str, context_queue: asyncio.Queue
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Processes a single message and pushes summarized insights to the context queue.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use RAGManager.retrieve with summarize=True to get concise insights
|
||||||
|
# Run in a thread to avoid blocking the event loop
|
||||||
|
insights = await asyncio.to_thread(
|
||||||
|
self.rag_manager.retrieve, text, summarize=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if insights:
|
||||||
|
logger.info(
|
||||||
|
f"ContextPipeline: Found {len(insights)} insights for text: {text}"
|
||||||
|
)
|
||||||
|
for insight in insights:
|
||||||
|
await context_queue.put(insight)
|
||||||
|
else:
|
||||||
|
logger.debug(f"ContextPipeline: No insights found for text: {text}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ContextPipeline error processing message: {e}")
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
transcript_queue: asyncio.Queue,
|
||||||
|
context_queue: asyncio.Queue,
|
||||||
|
stop_event: asyncio.Event,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Main loop that listens to the transcript queue and triggers RAG lookups.
|
||||||
|
"""
|
||||||
|
logger.info("Context Pipeline started.")
|
||||||
|
while not stop_event.is_set():
|
||||||
|
try:
|
||||||
|
# Get raw text from transcript queue (speaker, text)
|
||||||
|
speaker, text = await transcript_queue.get()
|
||||||
|
|
||||||
|
# For now, implement the basic flow: every message triggers a lookup.
|
||||||
|
# If performance becomes an issue, a filter can be added here.
|
||||||
|
await self.process_message(speaker, text, context_queue)
|
||||||
|
|
||||||
|
# Mark the task as done
|
||||||
|
transcript_queue.task_done()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Context Pipeline loop error: {e}")
|
||||||
|
|
||||||
|
# Small sleep to avoid tight loop
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
logger.info("Context Pipeline stopped.")
|
||||||
@@ -147,7 +147,9 @@ class PipelineOrchestrator:
|
|||||||
if filter_result.filtered_text:
|
if filter_result.filtered_text:
|
||||||
try:
|
try:
|
||||||
snippets = await asyncio.to_thread(
|
snippets = await asyncio.to_thread(
|
||||||
self.rag_manager.retrieve, filter_result.filtered_text
|
self.rag_manager.retrieve,
|
||||||
|
filter_result.filtered_text,
|
||||||
|
summarize=True,
|
||||||
)
|
)
|
||||||
rag_snippets = snippets
|
rag_snippets = snippets
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -197,8 +199,8 @@ class PipelineOrchestrator:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# f. Also push the RAG snippets used for extraction to the context queue
|
# f. Push the distilled RAG snippets from extraction to the context queue
|
||||||
for snippet in rag_snippets:
|
for snippet in extraction_result.context_updates:
|
||||||
await self.context_queue.put(snippet)
|
await self.context_queue.put(snippet)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -235,7 +237,9 @@ class PipelineOrchestrator:
|
|||||||
logger.info(f"RAG: Triggering query for: {query}")
|
logger.info(f"RAG: Triggering query for: {query}")
|
||||||
try:
|
try:
|
||||||
# Run retrieval in a thread to avoid blocking the event loop
|
# Run retrieval in a thread to avoid blocking the event loop
|
||||||
updates = await asyncio.to_thread(self.rag_manager.retrieve, query)
|
updates = await asyncio.to_thread(
|
||||||
|
self.rag_manager.retrieve, query, summarize=True
|
||||||
|
)
|
||||||
for update in updates:
|
for update in updates:
|
||||||
await self.context_queue.put(update)
|
await self.context_queue.put(update)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -295,10 +299,21 @@ class PipelineOrchestrator:
|
|||||||
self.is_running = True
|
self.is_running = True
|
||||||
self.listener.start()
|
self.listener.start()
|
||||||
|
|
||||||
|
# Initialize Context Pipeline
|
||||||
|
from src.pipeline.context_pipeline import ContextPipeline
|
||||||
|
|
||||||
|
self.context_pipeline = ContextPipeline(self.rag_manager)
|
||||||
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
# Start workers as background tasks
|
# Start workers as background tasks
|
||||||
tasks = [
|
tasks = [
|
||||||
asyncio.create_task(self.stt_worker()),
|
asyncio.create_task(self.stt_worker()),
|
||||||
asyncio.create_task(self.llm_worker()),
|
asyncio.create_task(self.llm_worker()),
|
||||||
|
asyncio.create_task(
|
||||||
|
self.context_pipeline.run(
|
||||||
|
self.transcript_queue, self.context_queue, stop_event
|
||||||
|
)
|
||||||
|
),
|
||||||
asyncio.create_task(self.tui_worker()),
|
asyncio.create_task(self.tui_worker()),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -310,6 +325,7 @@ class PipelineOrchestrator:
|
|||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
stop_event.set()
|
||||||
self.listener.stop()
|
self.listener.stop()
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|||||||
+89
-2
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
import pdfplumber
|
import pdfplumber
|
||||||
@@ -8,6 +8,7 @@ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
|||||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||||
|
|
||||||
from src.llm.models import ContextUpdate
|
from src.llm.models import ContextUpdate
|
||||||
|
from src.llm.processor import LLMProcessor
|
||||||
|
|
||||||
|
|
||||||
class RAGManager:
|
class RAGManager:
|
||||||
@@ -64,7 +65,90 @@ class RAGManager:
|
|||||||
)
|
)
|
||||||
print(f"Successfully ingested {pdf_path} into the vector store.")
|
print(f"Successfully ingested {pdf_path} into the vector store.")
|
||||||
|
|
||||||
def retrieve(self, query: str, top_k: int = 3) -> List[ContextUpdate]:
|
def ingest_file(self, file_path: str):
|
||||||
|
"""
|
||||||
|
Loads a single markdown file into the index.
|
||||||
|
"""
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
# Use the filename as the source
|
||||||
|
source = os.path.basename(file_path)
|
||||||
|
doc = Document(text=text, metadata={"source": source})
|
||||||
|
|
||||||
|
# If index doesn't exist, initialize it
|
||||||
|
if not self.index:
|
||||||
|
self.index = VectorStoreIndex.from_documents(
|
||||||
|
[doc], storage_context=self.storage_context
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Insert into existing index
|
||||||
|
self.index.insert(doc)
|
||||||
|
|
||||||
|
print(f"Successfully ingested {file_path} into the vector store.")
|
||||||
|
|
||||||
|
def summarize_results(self, query: str, nodes: List[Any]) -> List[ContextUpdate]:
|
||||||
|
"""
|
||||||
|
Uses an LLM to transform raw snippets into concise "insights", filtering out irrelevant content.
|
||||||
|
"""
|
||||||
|
if not nodes:
|
||||||
|
return []
|
||||||
|
|
||||||
|
processor = LLMProcessor()
|
||||||
|
|
||||||
|
# Construct the context from retrieved nodes
|
||||||
|
context_text = "\n\n".join(
|
||||||
|
[
|
||||||
|
f"Source: {node.metadata.get('source', 'Unknown')}\nContent: {node.text}"
|
||||||
|
for node in nodes
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
system_prompt = (
|
||||||
|
"You are a precise research assistant. Your task is to analyze provided text snippets "
|
||||||
|
"and extract only the information that is directly relevant to the user's query. "
|
||||||
|
"1. If a snippet is irrelevant to the query, discard it completely. "
|
||||||
|
"2. For relevant information, synthesize it into a concise, single-sentence 'insight'. "
|
||||||
|
"3. Do not simply repeat the raw text; summarize it for clarity and brevity. "
|
||||||
|
"4. If no snippets are relevant to the query, return an empty list. "
|
||||||
|
"5. Be factual and do not hallucinate. Use only the provided snippets."
|
||||||
|
)
|
||||||
|
|
||||||
|
user_prompt = (
|
||||||
|
f"Query: {query}\n\n"
|
||||||
|
f"Snippets:\n{context_text}\n\n"
|
||||||
|
"Return a JSON object with a key 'insights' containing a list of objects, each with 'snippet' and 'source'."
|
||||||
|
)
|
||||||
|
|
||||||
|
result = processor._call_llm(
|
||||||
|
system_prompt,
|
||||||
|
user_prompt,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
)
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(result)
|
||||||
|
# Expecting a format like {"insights": [{"snippet": "...", "source": "..."}, ...]}
|
||||||
|
insights = data.get("insights", []) if isinstance(data, dict) else data
|
||||||
|
|
||||||
|
if not insights:
|
||||||
|
print(f"Summarization: No relevant insights found for query: {query}")
|
||||||
|
|
||||||
|
return [
|
||||||
|
ContextUpdate(
|
||||||
|
query=query, snippet=item["snippet"], source=item["source"]
|
||||||
|
)
|
||||||
|
for item in insights
|
||||||
|
]
|
||||||
|
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||||
|
print(f"Summarization parsing error: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def retrieve(
|
||||||
|
self, query: str, top_k: int = 5, summarize: bool = False
|
||||||
|
) -> List[ContextUpdate]:
|
||||||
"""
|
"""
|
||||||
Retrieves the top-K most relevant snippets for a given query.
|
Retrieves the top-K most relevant snippets for a given query.
|
||||||
"""
|
"""
|
||||||
@@ -76,6 +160,9 @@ class RAGManager:
|
|||||||
retriever = self.index.as_retriever(similarity_top_k=top_k)
|
retriever = self.index.as_retriever(similarity_top_k=top_k)
|
||||||
nodes = retriever.retrieve(query)
|
nodes = retriever.retrieve(query)
|
||||||
|
|
||||||
|
if summarize:
|
||||||
|
return self.summarize_results(query, nodes)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
# Extract metadata
|
# Extract metadata
|
||||||
|
|||||||
+12
-3
@@ -3,7 +3,16 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
from textual.app import App, ComposeResult
|
from textual.app import App, ComposeResult
|
||||||
from textual.containers import Container, Horizontal, Vertical
|
from textual.containers import Container, Horizontal, Vertical
|
||||||
from textual.widgets import Button, DataTable, Footer, Input, Label, ListView, Static
|
from textual.widgets import (
|
||||||
|
Button,
|
||||||
|
DataTable,
|
||||||
|
Footer,
|
||||||
|
Input,
|
||||||
|
Label,
|
||||||
|
ListItem,
|
||||||
|
ListView,
|
||||||
|
Static,
|
||||||
|
)
|
||||||
|
|
||||||
from src.llm.models import CharacterStateUpdate, ExtractionResult, LoreUpdate
|
from src.llm.models import CharacterStateUpdate, ExtractionResult, LoreUpdate
|
||||||
from src.persistence.characters import update_character_state
|
from src.persistence.characters import update_character_state
|
||||||
@@ -181,8 +190,8 @@ class ConfirmationApp(App):
|
|||||||
# Format the update for display
|
# Format the update for display
|
||||||
display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}"
|
display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}"
|
||||||
|
|
||||||
# Add a new Static widget to the ListView
|
# Add a new ListItem widget to the top of the ListView for 'most recent'
|
||||||
context_pane.mount(Static(display_text))
|
context_pane.mount(ListItem(Label(display_text), index=0))
|
||||||
|
|
||||||
if hasattr(self.context_queue, "task_done"):
|
if hasattr(self.context_queue, "task_done"):
|
||||||
self.context_queue.task_done()
|
self.context_queue.task_done()
|
||||||
|
|||||||
Reference in New Issue
Block a user