Compare commits

..

10 Commits

Author SHA1 Message Date
charles afa8d17f10 Mostly working 2026-05-28 00:08:52 -07:00
charles 1cfba3a0ae Add LLM input logging and UI log pane
- Add log_queue to PipelineOrchestrator and log LLM inputs to UI
- Use entity_name for lore update logs instead of topic
- Pass log_queue into ConfirmationApp to display logs in UI
- Introduce a log pane and left/right pane layout in the UI
- Poll and render log messages via a new poll_log_updates worker
- Run log polling with Textual workers to avoid GC issues
- Fix ListView insertion by wrapping ListItem in a list
- Relax RAG similarity threshold from 0.7 to 0.5
2026-05-27 23:09:11 -07:00
charles 1098bdb2f9 Stable state 2026-05-27 22:30:20 -07:00
charles 58f736a5f8 refactor(ui): rewrite ConfirmationApp with three-pane layout
- Implement Pending Facts, LLM Input, and Context Pane using Textual
- Add keyboard shortcuts for Accept, Reject, and Edit actions
2026-05-27 20:05:29 -07:00
charles b25f82cefc Implement RAG summarization and context pipeline
- Add ContextPipeline for async RAG lookups
- Implement RAG result summarization via LLMProcessor
- Add CLI flag for PDF ingestion
- Strip markdown code blocks from LLM responses
- Update TUI context display to use ListItems
2026-05-27 00:17:47 -07:00
charles b83d9b5e6a Update UI and prompts 2026-05-26 23:25:53 -07:00
charles 679eca3fef fix: suppress whisperx.asr warnings 2026-05-26 22:17:50 -07:00
charles 954f2f50d8 feat: implement RAG capabilities and Context Pane integration
- Add RAG capabilities using LlamaIndex and ChromaDB
- Implement RAGManager for PHB indexing and retrieval
- Integrate RAG pipeline into orchestrator to trigger queries based on extracted entities
- Update TUI to include a 3-column layout with a real-time Context Pane
- Define ContextUpdate data models in src/llm/models.py
- Update requirements.txt with new dependencies
2026-05-26 22:07:12 -07:00
charles f4c98fb2b9 Migrate to WhisperX for speaker diarization
Implement a sliding window audio buffer and update the transcriber to
use WhisperX for transcription, alignment, and speaker identification.
Update the pipeline to handle and store speaker-attributed transcripts.

Additionally, update the LLM processor's reasoning parameter to
"enable_thinking".
2026-05-26 21:48:30 -07:00
charles d0fcdfab01 Improvements 2026-05-26 21:07:58 -07:00
30 changed files with 1099 additions and 327 deletions
+2 -1
View File
@@ -1,2 +1,3 @@
artifacts/
__pycache__
**/__pycache__/
data
+1
View File
@@ -0,0 +1 @@
3.12
+48
View File
@@ -0,0 +1,48 @@
import argparse
from src.rag.manager import RAGManager
def main():
parser = argparse.ArgumentParser(description="D&D Helpers CLI")
parser.add_argument(
"--ingest-pdf",
type=str,
help="Path to a PDF file to ingest into the RAG system",
)
parser.add_argument(
"--ingest-file",
type=str,
help="Path to a markdown file to ingest into the RAG system",
)
parser.add_argument(
"--ingest-dir",
type=str,
help="Path to a directory of markdown files to ingest into the RAG system",
)
args = parser.parse_args()
rag_manager = RAGManager()
if args.ingest_pdf:
print(f"Ingesting PDF: {args.ingest_pdf}...")
rag_manager.ingest_pdf(args.ingest_pdf)
print("PDF ingestion complete.")
if args.ingest_file:
print(f"Ingesting File: {args.ingest_file}...")
rag_manager.ingest_file(args.ingest_file)
print("File ingestion complete.")
if args.ingest_dir:
print(f"Ingesting Directory: {args.ingest_dir}...")
rag_manager.ingest_directory(args.ingest_dir)
print("Directory ingestion complete.")
if not any([args.ingest_pdf, args.ingest_file, args.ingest_dir]):
print("Hello from dnd-helpers!")
if __name__ == "__main__":
main()
+7
View File
@@ -0,0 +1,7 @@
[project]
name = "dnd-helpers"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = []
+4 -1
View File
@@ -1,8 +1,11 @@
# Core dependencies for D&D Helpers
faster-whisper
whisperx
sounddevice
pydantic
textual
typer
openai
python-dotenv
llama-index
chromadb
pdfplumber
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+25
View File
@@ -44,6 +44,26 @@ class CharacterStateUpdate(BaseModel):
)
class ContextUpdate(BaseModel):
query: str = Field(..., description="The search query used to retrieve the context")
snippet: str = Field(
..., description="The relevant text snippet retrieved from the source"
)
source: str = Field(
..., description="The source of the snippet (e.g., 'PHB p. 12')"
)
class FilterResult(BaseModel):
contextual_info: str = Field(
...,
description="Information interesting to the user but not useful for structured extraction",
)
filtered_text: str = Field(
..., description="Cleaned transcript used for structured data extraction"
)
class ExtractionResult(BaseModel):
lore_updates: List[LoreUpdate] = Field(
default_factory=list, description="List of discovered lore facts", alias="lore"
@@ -58,6 +78,11 @@ class ExtractionResult(BaseModel):
description="List of significant plot points or events",
alias="events",
)
context_updates: List[ContextUpdate] = Field(
default_factory=list,
description="List of context updates",
alias="context",
)
class Config:
populate_by_name = True
+92 -35
View File
@@ -1,11 +1,20 @@
import logging
import os
from posix import system
from this import s
from typing import Any, Dict, Optional
from openai import OpenAI
from pydantic import ValidationError
from .models import ExtractionResult
from .prompts import EXTRACTION_SYSTEM_PROMPT, NOISE_FILTER_SYSTEM_PROMPT
from .models import ExtractionResult, FilterResult
from .prompts import (
EXTRACTION_SYSTEM_PROMPT,
NOISE_FILTER_SYSTEM_PROMPT,
QUERY_ANSWER_SYSTEM_PROMPT,
)
logger = logging.getLogger(__name__)
class LLMProcessor:
@@ -47,7 +56,7 @@ class LLMProcessor:
# but we can ensure the client is initialized.
pass
except Exception as e:
print(f"Error initializing LLM client for backend {backend}: {e}")
logger.error(f"Error initializing LLM client for backend {backend}: {e}")
raise
self.model = model or os.environ.get("LLM_MODEL", "gpt-4o")
@@ -56,73 +65,121 @@ class LLMProcessor:
self,
system_prompt: str,
user_prompt: str,
context: Optional[str] = None,
response_format: Optional[Any] = None,
) -> str:
"""
Generic method to call the LLM.
"""
messages = [
{"role": "system", "content": system_prompt},
]
if context:
messages.append(
{
"role": "system",
"content": f"Context from previous conversation:\n{context}",
}
)
messages.append({"role": "user", "content": user_prompt})
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
messages=messages,
response_format=response_format,
extra_body={"include_reasoning": False},
extra_body={"enable_thinking": False},
)
return response.choices[0].message.content
content = response.choices[0].message.content
# Strip markdown code blocks if present
if content.startswith("```"):
import re
content = re.sub(
r"^```(?:json)?\n?|```$", "", content, flags=re.MULTILINE
).strip()
return content
except Exception as e:
print(f"LLM Error: {e}")
logger.error(f"LLM Error: {e}")
return ""
def filter_transcript(self, text: str) -> str:
def generate_answer(self, query: str, context: str) -> str:
"""
Generates a natural language answer to a DM query.
"""
return self._call_llm(
QUERY_ANSWER_SYSTEM_PROMPT,
query,
context=context,
)
def filter_transcript(
self, text: str, context: Optional[str] = None
) -> FilterResult:
"""
Stage 1: Raw Transcript -> Filtered Text.
"""
result = self._call_llm(NOISE_FILTER_SYSTEM_PROMPT, text)
print(f"LLM Processor (Filter): {text} -> {result}")
return result
result = self._call_llm(
NOISE_FILTER_SYSTEM_PROMPT,
text,
context=context,
response_format={"type": "json_object"},
)
logger.info(f"LLM Processor (Filter): {text} -> {result}")
def extract_structured_data(self, filtered_text: str) -> ExtractionResult:
import json
try:
data = json.loads(result)
return FilterResult(**data)
except (json.JSONDecodeError, ValidationError) as e:
logger.error(f"Filter Parsing Error: {e}")
return FilterResult(contextual_info="", filtered_text=result)
def extract_structured_data(
self, filtered_text: str, context: Optional[str] = None
) -> ExtractionResult:
"""
Stage 2: Filtered Text -> Structured Data.
"""
print(f"LLM Processor (Extract): Calling extraction for: {filtered_text}")
logger.info(f"LLM Processor (Extract): Calling extraction for: {filtered_text}")
try:
# Using standard chat.completions.create with JSON mode for better compatibility with vLLM
print("LLM Processor (Extract): Sending request to backend...")
logger.info("LLM Processor (Extract): Sending request to backend...")
system_prompt = EXTRACTION_SYSTEM_PROMPT
if context:
system_prompt += f"\n{context}"
messages = [
{"role": "system", "content": system_prompt},
]
messages.append({"role": "user", "content": filtered_text})
for message in messages:
logger.info(f"LLM Processor (Extract): Message: {message}")
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": EXTRACTION_SYSTEM_PROMPT},
{"role": "user", "content": filtered_text},
],
messages=messages,
response_format={"type": "json_object"},
extra_body={"include_reasoning": False},
extra_body={"enable_thinking": False},
)
print("LLM Processor (Extract): Response received from backend.")
logger.info("LLM Processor (Extract): Response received from backend.")
import json
content = response.choices[0].message.content
print(f"LLM Processor (Extract): Raw JSON response: {content}")
logger.info(f"LLM Processor (Extract): Raw JSON response: {content}")
data = json.loads(content)
# Map the JSON data to the Pydantic model
return ExtractionResult(**data)
except Exception as e:
print(f"Extraction Error: {e}")
logger.error(f"Extraction Error: {e}")
# Return an empty ExtractionResult if parsing fails
return ExtractionResult()
def process_pipeline(self, raw_text: str) -> ExtractionResult:
"""
Executes the two-stage pipeline: Raw Transcript -> Filtered Text -> Structured Data.
"""
filtered_text = self.filter_transcript(raw_text)
if not filtered_text:
return ExtractionResult()
return self.extract_structured_data(filtered_text)
+27 -3
View File
@@ -1,8 +1,20 @@
# System prompts for the LLM pipeline
QUERY_ANSWER_SYSTEM_PROMPT = """
You are a helpful D&D Game Master's assistant. Your goal is to provide accurate, concise, and helpful answers to the DM's questions based on the provided context (conversation history and RAG snippets).
Guidelines:
- Use the provided context as your primary source of truth.
- If the answer is not in the context, state that you don't have enough information, but feel free to provide general D&D 5e rules as a fallback.
- Keep responses natural and professional.
- Be concise.
"""
NOISE_FILTER_SYSTEM_PROMPT = """
You are a D&D Game Master's assistant. Given a transcript, remove all out-of-character (OOC) chatter, logistical discussions (e.g., 'Where is my d20?'), and non-relevant noise.
Output only the in-character dialogue and game-relevant events.
You must output your response as a JSON object with the following keys:
- "contextual_info": Information that is interesting or relevant to the story/session but doesn't fit into lore, character state, or significant events (e.g., flavor text, atmospheric descriptions, player commentary that adds context).
- "filtered_text": The cleaned transcript. IMPORTANT: Keep all player questions, requests for rule clarifications, and mentions of spells, NPCs, or locations in this field, as they are used to trigger knowledge base lookups.
Keep the original speakers' names if they are present in the transcript.
Do not add any commentary or summaries. Just filter the text.
"""
@@ -26,6 +38,7 @@ Return a JSON object with exactly these keys:
- "category": (string) 'NPC', 'Location', 'WorldBuilding', or 'Plot'
- "entity_name": (string) The name of the NPC, Location, or entity
- "content": (string) The actual lore fact or description
- "context": (string, optional) Helpful information for the DM (e.g., descriptions of characters, spell details, game mechanics) discovered via the knowledge base or the transcript.
2. "character_state": A list of objects. Each object MUST have:
- "character_name": (string) Name of the character
- "hp_change": (integer, optional) Change in HP
@@ -33,6 +46,10 @@ Return a JSON object with exactly these keys:
- "status_effects_removed": (list of strings)
- "inventory_changes": (list of objects with "item", "quantity", "action")
3. "events": A list of strings. Each string should be a concise description of a significant plot development.
4. "context": A list of objects. Each object MUST have:
- "query": (string) The original query or topic this context relates to
- "snippet": (string) The contextual information or rule explanation
- "source": (string) The source of the context (e.g., "players handbook, page 68")
Example Output:
{
@@ -40,7 +57,7 @@ Example Output:
{
"category": "NPC",
"entity_name": "Thorne",
"content": "A gruff dwarf who runs the local tavern."
"content": "A gruff dwarf who runs the local tavern.",
}
],
"character_state": [
@@ -54,6 +71,13 @@ Example Output:
],
"events": [
"The party discovered the secret entrance to the crypt."
],
"context": [
{
"query": "fireball",
"snippet": "fireball does 1d6 damage, with an area of effect of 10 feet circle",
"source": "players handbook, page 68"
}
]
}
Binary file not shown.
Binary file not shown.
+67
View File
@@ -0,0 +1,67 @@
import asyncio
import logging
from typing import Tuple
from src.llm.models import ContextUpdate
from src.rag.manager import RAGManager
logger = logging.getLogger(__name__)
class ContextPipeline:
def __init__(self, rag_manager: RAGManager):
self.rag_manager = rag_manager
async def process_message(
self, speaker: str, text: str, context_queue: asyncio.Queue
):
"""
Processes a single message and pushes summarized insights to the context queue.
"""
try:
# Use RAGManager.retrieve with summarize=True to get concise insights
# Run in a thread to avoid blocking the event loop
insights = await asyncio.to_thread(
self.rag_manager.retrieve, text, summarize=True
)
if insights:
logger.info(
f"ContextPipeline: Found {len(insights)} insights for text: {text}"
)
for insight in insights:
await context_queue.put(insight)
else:
logger.debug(f"ContextPipeline: No insights found for text: {text}")
except Exception as e:
logger.error(f"ContextPipeline error processing message: {e}")
async def run(
self,
transcript_queue: asyncio.Queue,
context_queue: asyncio.Queue,
stop_event: asyncio.Event,
):
"""
Main loop that listens to the transcript queue and triggers RAG lookups.
"""
logger.info("Context Pipeline started.")
while not stop_event.is_set():
try:
# Get raw text from transcript queue (speaker, text)
speaker, text = await transcript_queue.get()
# For now, implement the basic flow: every message triggers a lookup.
# If performance becomes an issue, a filter can be added here.
await self.process_message(speaker, text, context_queue)
# Mark the task as done
transcript_queue.task_done()
except Exception as e:
logger.error(f"Context Pipeline loop error: {e}")
# Small sleep to avoid tight loop
await asyncio.sleep(0.1)
logger.info("Context Pipeline stopped.")
+234 -34
View File
@@ -1,13 +1,40 @@
import asyncio
import logging
import os
from pathlib import Path
from typing import List, Optional
from src.llm.models import ExtractionResult
import numpy as np
from src.llm.models import (
CharacterStateUpdate,
ContextUpdate,
ExtractionResult,
LoreUpdate,
)
from src.llm.processor import LLMProcessor
from src.llm.prompts import EXTRACTION_SYSTEM_PROMPT, NOISE_FILTER_SYSTEM_PROMPT
from src.persistence.characters import update_character_state
from src.persistence.lore import update_lore
from src.rag.manager import RAGManager
from src.stt.listener import AudioListener
from src.stt.transcriber import Transcriber
from src.ui.tui import ConfirmationApp
logging.basicConfig(level=logging.INFO)
# Configure logging to write to a file instead of stdout
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[
logging.FileHandler("pipeline.log"),
],
)
# Suppress verbose logging from STT libraries to keep the TUI clean
logging.getLogger("whisper").setLevel(logging.WARNING)
logging.getLogger("faster_whisper").setLevel(logging.WARNING)
logging.getLogger("pyannote").setLevel(logging.WARNING)
logging.getLogger("whisperx").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)
@@ -17,15 +44,44 @@ class PipelineOrchestrator:
# Modules
self.listener = AudioListener(loop=self.loop)
self.transcriber = Transcriber()
self.transcriber = Transcriber(model_size="small")
self.processor = LLMProcessor()
self.rag_manager = RAGManager()
# Queues
self.transcript_queue = asyncio.Queue()
self.proposal_queue = asyncio.Queue()
self.stt_to_clean_queue = asyncio.Queue()
self.ui_to_llm_queue = asyncio.Queue()
self.clean_to_llm_queue = asyncio.Queue()
self.llm_to_ui_queue = asyncio.Queue()
self.log_queue = asyncio.Queue()
self.is_running = False
# Conversation history for context
self.history = [] # List of strings (transcripts)
self.history_max_words = 1000
# STT Sliding Window Buffer
self.audio_buffer = [] # List of audio chunks
self.buffer_max_seconds = 30
self.sample_rate = 16000
self.buffer_max_samples = self.buffer_max_seconds * self.sample_rate
self.last_processed_end_time = 0.0
def _get_combined_context(self) -> str:
"""
Returns the trimmed conversation history as a context string.
"""
full_history_text = " ".join(self.history)
words = full_history_text.split()
if len(words) > self.history_max_words:
kept_words = words[-self.history_max_words :]
context_text = " ".join(kept_words)
else:
context_text = full_history_text
return f"Conversation History:\n{context_text}\n\n"
async def stt_worker(self):
"""
Worker that handles STT: Audio -> Text.
@@ -36,12 +92,35 @@ class PipelineOrchestrator:
# Get audio chunk from listener
audio_chunk = await self.listener.get_chunk()
# Transcribe
text = self.transcriber.transcribe(audio_chunk)
# Maintain sliding window buffer
self.audio_buffer.append(audio_chunk)
current_buffer_samples = sum(len(c) for c in self.audio_buffer)
if text:
logger.info(f"Transcribed: {text}")
await self.transcript_queue.put(text)
if current_buffer_samples > self.buffer_max_samples:
# Remove oldest chunks until we are within the buffer limit
while (
sum(len(c) for c in self.audio_buffer) > self.buffer_max_samples
):
self.audio_buffer.pop(0)
# Concatenate buffer for transcription
full_audio = np.concatenate(self.audio_buffer)
# Transcribe (WhisperX now returns a list of (speaker, text, start, end))
results = self.transcriber.transcribe(full_audio)
# Filter for only new segments that start after the last processed segment
new_segments = [
res for res in results if res[2] >= self.last_processed_end_time
]
if new_segments:
for speaker, text, start, end in new_segments:
logger.info(f"Transcribed: [{speaker}] {text}")
await self.stt_to_clean_queue.put((speaker, text))
self.last_processed_end_time = max(
self.last_processed_end_time, end
)
except Exception as e:
logger.error(f"STT Worker error: {e}")
@@ -49,33 +128,110 @@ class PipelineOrchestrator:
# Small sleep to prevent tight loop if get_chunk is fast
await asyncio.sleep(0.1)
async def llm_worker(self):
async def clean_worker(self):
"""
Worker that handles LLM: Text -> Proposal.
Worker that handles Text Cleaning: Raw STT -> Filtered Text.
"""
logger.info("LLM Worker started.")
logger.info("Clean Worker started.")
while self.is_running:
try:
# Get raw text from transcript queue
raw_text = await self.transcript_queue.get()
# Get raw transcript from STT
speaker, raw_text = await self.stt_to_clean_queue.get()
logger.info(f"Clean Worker: Filtering text from {speaker}: {raw_text}")
logger.info(f"LLM Worker: Processing text: {raw_text}")
# RAG Retrieval for context
context = await asyncio.to_thread(self.rag_manager.retrieve, raw_text)
# Process via LLM (Filter -> Extract)
# Note: this is currently a synchronous call, which blocks the loop.
result = self.processor.process_pipeline(raw_text)
if (
result.lore_updates
or result.character_updates
or result.significant_events
):
logger.info(
f"LLM Worker: Proposal generated. Putting into proposal queue. (Lore: {len(result.lore_updates)}, Char: {len(result.character_updates)})"
# Filtering using the processor
filter_result = await asyncio.to_thread(
self.processor.filter_transcript,
raw_text,
context=context,
)
await self.proposal_queue.put(result)
# Push filtered text to LLM queue
if filter_result.filtered_text:
await self.clean_to_llm_queue.put(
(speaker, filter_result.filtered_text)
)
logger.info(f"Clean Worker: Pushed filtered text to LLM queue.")
else:
logger.info("LLM Worker: No relevant game data extracted.")
logger.info("Clean Worker: No filtered text to push.")
except Exception as e:
logger.error(f"Clean Worker error: {e}")
# Small sleep to prevent tight loop
await asyncio.sleep(0.1)
async def llm_worker(self):
"""
Worker that handles LLM: Filtered Text/UI Input -> Structured Data & UI Updates.
"""
logger.info("LLM Worker started.")
# Internal queue to serialize processing from multiple sources
internal_queue = asyncio.Queue()
async def feed_clean():
while self.is_running:
try:
item = await self.clean_to_llm_queue.get()
await internal_queue.put(item)
except Exception as e:
logger.error(f"LLM Feeder (Clean) error: {e}")
async def feed_ui():
while self.is_running:
try:
text = await self.ui_to_llm_queue.get()
await internal_queue.put(("UI", text))
except Exception as e:
logger.error(f"LLM Feeder (UI) error: {e}")
# Start feeder tasks
feeders = [
asyncio.create_task(feed_clean()),
asyncio.create_task(feed_ui()),
]
while self.is_running:
try:
speaker, text = await internal_queue.get()
logger.info(f"LLM Worker: Processing text from {speaker}: {text}")
# RAG Retrieval for context
context = await asyncio.to_thread(self.rag_manager.retrieve, text)
# Log the text sent to the LLM for UI affordance
await self.log_queue.put(f"[{speaker}] {text}")
# Structured extraction using the processor
extraction_result = await asyncio.to_thread(
self.processor.extract_structured_data,
text,
context=context,
)
# Persistence: Lore Updates
for lore_update in extraction_result.lore_updates:
file_path = await asyncio.to_thread(update_lore, lore_update)
await asyncio.to_thread(self.rag_manager.ingest_file, file_path)
logger.info(
f"LLM Worker: Lore updated and ingested into RAG: {lore_update.entity_name}"
)
# Persistence: Character State Updates
for char_update in extraction_result.character_updates:
await asyncio.to_thread(update_character_state, char_update)
logger.info(
f"LLM Worker: Character {char_update.character_name} state updated."
)
# UI Notification: Context Updates
for context_update in extraction_result.context_updates:
await self.llm_to_ui_queue.put(context_update)
logger.info(f"LLM Worker: Pushed context update to UI.")
except Exception as e:
logger.error(f"LLM Worker error: {e}")
@@ -83,18 +239,54 @@ class PipelineOrchestrator:
# Small sleep
await asyncio.sleep(0.1)
# Clean up feeders
for f in feeders:
f.cancel()
def _get_wiki_context(self) -> str:
"""
Reads all files in the lore directory and returns them as a 저희 context string.
"""
from src.persistence.lore import DATA_LORE_DIR
wiki_contents = []
# Recursively find all .md files in the lore directory
for path in DATA_LORE_DIR.rglob("*.md"):
try:
with open(path, "r", encoding="utf-8") as f:
content = f.read()
wiki_contents.append(
f"File: {path.relative_to(DATA_LORE_DIR)}\nContent:\n{content}"
)
except Exception as e:
logger.error(f"Error reading wiki file {path}: {e}")
return (
"\n\n".join(wiki_contents)
if wiki_contents
else "No wiki knowledge available."
)
async def tui_worker(self):
"""
Worker that handles TUI: Proposal -> Persistence.
Worker that handles TUI: UI interactions.
"""
logger.info("TUI Worker started.")
try:
# Launch TUI exactly once.
# Pass the proposal queue to the app.
app = ConfirmationApp(proposal_queue=self.proposal_queue)
# Launch TUI.
# Use the new queues for the TUI.
app = ConfirmationApp(
ui_to_llm_queue=self.ui_to_llm_queue,
llm_to_ui_queue=self.llm_to_ui_queue,
log_queue=self.log_queue,
)
await app.run_async()
self.stop()
except Exception as e:
logger.error(f"TUI Worker error: {e}")
self.stop()
except asyncio.CancelledError:
pass
async def run(self):
"""
@@ -103,9 +295,16 @@ class PipelineOrchestrator:
self.is_running = True
self.listener.start()
# Initialize Context Pipeline
from src.pipeline.context_pipeline import ContextPipeline
self.context_pipeline = ContextPipeline(self.rag_manager)
stop_event = asyncio.Event()
# Start workers as background tasks
tasks = [
asyncio.create_task(self.stt_worker()),
asyncio.create_task(self.clean_worker()),
asyncio.create_task(self.llm_worker()),
asyncio.create_task(self.tui_worker()),
]
@@ -118,6 +317,7 @@ class PipelineOrchestrator:
pass
finally:
self.is_running = False
stop_event.set()
self.listener.stop()
for task in tasks:
task.cancel()
@@ -127,6 +327,6 @@ class PipelineOrchestrator:
def stop(self):
"""
Stops the pipeline.
Stops.
"""
self.is_running = False
+193
View File
@@ -0,0 +1,193 @@
import os
from typing import Any, List, Optional
import chromadb
import pdfplumber
from llama_index.core import Document, Settings, StorageContext, VectorStoreIndex
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore
from src.llm.models import ContextUpdate
from src.llm.processor import LLMProcessor
class RAGManager:
def __init__(self, persist_dir: str = "data/rag_index"):
self.persist_dir = persist_dir
self.db = chromadb.PersistentClient(path=self.persist_dir)
self.collection_name = "phb_collection"
# Initialize Chroma Vector Store
self.vector_store = ChromaVectorStore(
chroma_collection=self.db.get_or_create_collection(self.collection_name)
)
# Initialize Storage Context
self.storage_context = StorageContext.from_defaults(
vector_store=self.vector_store
)
# Use a local HuggingFace embedding model to avoid API key issues during verification
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
# Load index if it exists, otherwise initialize
try:
self.index = VectorStoreIndex.from_vector_store(
self.vector_store, storage_context=self.storage_context
)
except Exception:
self.index = None
def ingest_pdf(self, pdf_path: str):
"""
Parses a PDF, chunks it, and stores embeddings in ChromaDB.
"""
documents = []
with pdfplumber.open(pdf_path) as pdf:
for i, page in enumerate(pdf.pages):
text = page.extract_text()
if text:
# Create a document for each page
# In a real scenario, we might use a recursive character splitter
# but for PHB, page-level chunking is a good start.
doc = Document(
text=text, metadata={"source": f"PHB p. {i + 1}", "page": i + 1}
)
documents.append(doc)
if not documents:
print(f"No text extracted from {pdf_path}")
return
# Create index from documents
self.index = VectorStoreIndex.from_documents(
documents, storage_context=self.storage_context
)
print(f"Successfully ingested {pdf_path} into the vector store.")
def ingest_file(self, file_path: str):
"""
Loads a single markdown file into the index.
"""
with open(file_path, "r", encoding="utf-8") as f:
text = f.read()
# Use the filename as the source
source = os.path.basename(file_path)
doc = Document(text=text, metadata={"source": source})
# If index doesn't exist, initialize it
if not self.index:
self.index = VectorStoreIndex.from_documents(
[doc], storage_context=self.storage_context
)
else:
# Insert into existing index
self.index.insert(doc)
print(f"Successfully ingested {file_path} into the vector store.")
def ingest_directory(self, dir_path: str):
"""
Recursively loads all markdown files in a directory into the index.
"""
files_processed = 0
for root, _, files in os.walk(dir_path):
for file in files:
if file.endswith(".md"):
file_path = os.path.join(root, file)
self.ingest_file(file_path)
files_processed += 1
print(
f"Successfully ingested {files_processed} files from {dir_path} into the vector store."
)
def summarize_results(self, query: str, nodes: List[Any]) -> List[ContextUpdate]:
"""
Uses an LLM to transform raw snippets into concise "insights", filtering out irrelevant content.
"""
if not nodes:
return []
processor = LLMProcessor()
# Construct the context from retrieved nodes
context_text = "\n\n".join(
[
f"Source: {node.metadata.get('source', 'Unknown')}\nContent: {node.text}"
for node in nodes
]
)
system_prompt = (
"You are a precise research assistant. Your task is to analyze provided text snippets "
"and extract only the information that is directly relevant to the user's query. "
"1. If a snippet is irrelevant to the query, discard it completely. "
"2. For relevant information, synthesize it into a concise, single-sentence 'insight'. "
"3. Do not simply repeat the raw text; summarize it for clarity and brevity. "
"4. If no snippets are relevant to the query, return an empty list. "
"5. Be factual and do not hallucinate. Use only the provided snippets."
)
user_prompt = (
f"Query: {query}\n\n"
f"Snippets:\n{context_text}\n\n"
"Return a JSON object with a key 'insights' containing a list of objects, each with 'snippet' and 'source'."
)
result = processor._call_llm(
system_prompt,
user_prompt,
response_format={"type": "json_object"},
)
import json
try:
data = json.loads(result)
# Expecting a format like {"insights": [{"snippet": "...", "source": "..."}, ...]}
insights = data.get("insights", []) if isinstance(data, dict) else data
if not insights:
print(f"Summarization: No relevant insights found for query: {query}")
return [
ContextUpdate(
query=query, snippet=item["snippet"], source=item["source"]
)
for item in insights
]
except (json.JSONDecodeError, KeyError, TypeError) as e:
print(f"Summarization parsing error: {e}")
return []
def retrieve(
self, query: str, top_k: int = 5, summarize: bool = False
) -> List[ContextUpdate]:
"""
Retrieves the top-K most relevant snippets for a given query,
filtering for those with a similarity score > 0.7.
"""
if not self.index:
print("Index not initialized. Please ingest documents first.")
return []
# Create a retriever
retriever = self.index.as_retriever(similarity_top_k=top_k)
nodes = retriever.retrieve(query)
# Filter nodes by similarity score (threshold > 0.7)
nodes = [node for node in nodes if node.score >= 0.5]
if summarize:
return self.summarize_results(query, nodes)
results = []
for node in nodes:
# Extract metadata
source = node.metadata.get("source", "Unknown Source")
results.append(ContextUpdate(query=query, snippet=node.text, source=source))
return results
Binary file not shown.
Binary file not shown.
Binary file not shown.
+1 -1
View File
@@ -5,7 +5,7 @@ import numpy as np
import sounddevice as sd
import torch
logging.basicConfig(level=logging.INFO)
# Do not call basicConfig here, as it's called in the orchestrator
logger = logging.getLogger(__name__)
+65 -23
View File
@@ -1,69 +1,111 @@
import logging
import os
from faster_whisper import WhisperModel
import numpy as np
import whisperx
from whisperx.diarize import DiarizationPipeline
logging.basicConfig(level=logging.INFO)
# Do not call basicConfig here, as it's called in the orchestrator
logger = logging.getLogger(__name__)
class Transcriber:
"""
Converts audio chunks (numpy arrays) into text using faster-whisper.
Converts audio chunks (numpy arrays) into text and identifies speakers using WhisperX.
"""
def __init__(self, model_size="base", device="cpu", compute_type="int8"):
def __init__(
self, model_size="base", device="cpu", compute_type="int8", language="en"
):
"""
Initializes the faster-whisper model.
Initializes the WhisperX model and diarization pipeline.
Args:
model_size (str): The size of the model to use (e.g., "tiny", "base", "small").
device (str): The device to run the model on ("cpu" or "cuda").
compute_type (str): The compute type to use (e.g., "int8", "float16").
language (str): The language code for alignment (e.g., "en").
"""
self.device = device
self.compute_type = compute_type
self.language = language
logger.info(
f"Loading faster-whisper model: {model_size} on {device} ({compute_type})..."
f"Loading WhisperX model: {model_size} on {device} ({compute_type})..."
)
try:
self.model = WhisperModel(
# Load transcription model
self.model = whisperx.load_model(
model_size, device=device, compute_type=compute_type
)
logger.info("Model loaded successfully.")
# Load alignment model (required for accurate speaker assignment)
# model_dir=None allows automatic model selection based on the language
self.align_model, self.align_metadata = whisperx.load_align_model(
device=device, model_dir=None, language_code=self.language
)
self.diarize_model = DiarizationPipeline()
logger.info("WhisperX and Diarization models loaded successfully.")
except Exception as e:
logger.error(f"Failed to load faster-whisper model: {e}")
logger.error(f"Failed to load WhisperX models: {e}")
raise
def transcribe(self, audio_chunk):
"""
Transcribes a single audio chunk.
Transcribes an audio chunk and performs speaker diarization.
Args:
audio_chunk (np.ndarray): The audio data as a numpy array.
Returns:
str: The transcribed text.
list: A list of tuples (speaker_id, text).
"""
if audio_chunk is None:
return ""
return []
try:
# faster-whisper expects audio in float32 and 1D array
audio_data = audio_chunk.astype("float32").flatten()
# WhisperX expects audio in float32 and 1D array
audio = audio_chunk.astype("float32").flatten()
# Transcribe the audio
segments, info = self.model.transcribe(audio_data, beam_size=5)
# 1. Perform transcription
# batch_size is set to 16 for efficiency; can be adjusted based on VRAM
result = self.model.transcribe(audio, batch_size=16)
# Combine segments into a single string
text = " ".join([segment.text.strip() for segment in segments])
# 2. Perform alignment
# Alignment is necessary for the assign_words_to_speakers step
result_a = whisperx.align(
result["segments"],
self.align_model,
self.align_metadata,
audio,
self.device,
)
return text.strip()
# 3. Perform diarization
diarize_segments = self.diarize_model(audio)
# 4. Align transcription segments with speakers
result_final = whisperx.assign_word_speakers(diarize_segments, result_a)
# Extract (speaker_id, text, start, end) tuples from the final result
output = []
for segment in result_final.get("segments", []):
speaker = segment.get("speaker", "Unknown")
text = segment.get("text", "").strip()
start = segment.get("start", 0.0)
end = segment.get("end", 0.0)
if text:
output.append((speaker, text, start, end))
return output
except Exception as e:
logger.error(f"Transcription error: {e}")
return ""
logger.error(f"Transcription/Diarization error: {e}")
return []
def close(self):
"""
Explicitly release model resources if necessary.
"""
# faster-whisper's WhisperModel doesn't have a standard close(),
# but we'll provide this for consistency.
pass
Binary file not shown.
Binary file not shown.
Binary file not shown.
+236 -230
View File
@@ -3,315 +3,321 @@ from typing import List, Optional, Union
from textual.app import App, ComposeResult
from textual.containers import Container, Horizontal, Vertical
from textual.widgets import Button, DataTable, Footer, Input, Label, Static
from textual.message import Message
from textual.screen import ModalScreen
from textual.widgets import (
Button,
DataTable,
Footer,
Input,
Label,
ListItem,
ListView,
Static,
)
from src.llm.models import CharacterStateUpdate, ExtractionResult, LoreUpdate
from src.persistence.characters import update_character_state
from src.persistence.lore import update_lore
class EditModal(ModalScreen):
def __init__(self, initial_text: str, on_save: callable):
super().__init__()
self.initial_text = initial_text
self.on_save = on_save
def compose(self) -> ComposeResult:
with Vertical(id="modal-container"):
yield Label("Edit Fact Content:")
yield Input(value=self.initial_text, id="edit-input")
with Horizontal(id="modal-actions"):
yield Button("Save", id="btn-save")
yield Button("Cancel", id="btn-cancel")
def on_button_pressed(self, event: Button.Pressed) -> None:
if event.button.id == "btn-save":
edit_input = self.query_one("#edit-input", Input)
self.on_save(edit_input.value)
self.dismiss()
elif event.button.id == "btn-cancel":
self.dismiss()
class ConfirmationApp(App):
CSS = """
Screen {
#main-container {
layout: vertical;
height: 100%;
}
#content-wrapper {
layout: horizontal;
height: 100%;
}
#left-pane {
width: 40%;
border: solid;
padding: 1;
width: 70%;
layout: vertical;
}
#right-pane {
width: 60%;
border: solid;
padding: 1;
width: 30%;
layout: vertical;
border: solid white;
}
#details-container {
height: auto;
margin-bottom: 1;
#pending-facts-table {
height: 40%;
border: solid white;
}
#actions-container {
#llm-input-container {
height: 10%;
border: solid white;
padding: 1;
}
#context-pane {
height: 50%;
border: solid white;
}
#log-pane {
height: 30%;
border: solid white;
background: #111;
}
#log-footer {
height: 70%;
border: solid white;
}
#modal-container {
width: 60%;
height: auto;
layout: horizontal;
border: double white;
background: #222;
padding: 2;
align: center middle;
}
#edit-container {
display: none;
#modal-actions {
height: auto;
layout: vertical;
border: solid;
padding: 1;
margin-top: 1;
align: right middle;
}
Button {
margin: 0 1;
#edit-input {
margin: 1 0;
}
#llm-input {
width: 100%;
}
ListItem Static {
border: solid grey;
margin: 1 0;
padding: 1;
}
"""
BINDINGS = [
("q", "quit", "Quit"),
("a", "accept", "Accept"),
("r", "reject", "Reject"),
("e", "edit", "Edit"),
("enter", "send", "Send"),
]
def __init__(
self,
result: Optional[ExtractionResult] = None,
proposal_queue: Optional[asyncio.Queue] = None,
ui_to_llm_queue: Optional[asyncio.Queue] = None,
llm_to_ui_queue: Optional[asyncio.Queue] = None,
log_queue: Optional[asyncio.Queue] = None,
):
super().__init__()
self.result = result
self.proposal_queue = proposal_queue
self.ui_to_llm_queue = ui_to_llm_queue
self.llm_to_ui_queue = llm_to_ui_queue
self.log_queue = log_queue
self.pending_updates: List[Union[LoreUpdate, CharacterStateUpdate]] = []
if result:
# Populate pending updates from result
self.pending_updates.extend(result.lore_updates)
self.pending_updates.extend(result.character_updates)
self.selected_index = -1
def compose(self) -> ComposeResult:
yield Container(
yield Vertical(
Horizontal(
Vertical(
DataTable(id="update-table"),
DataTable(id="pending-facts-table"),
Vertical(
Input(placeholder="Message LLM...", id="llm-input"),
id="llm-input-container",
),
ListView(id="context-pane"),
id="left-pane",
),
Vertical(
Vertical(
Label("Details:", id="details-label"),
Static("No update selected", id="details-text"),
id="details-container",
),
Vertical(
Label("Edit Value:"),
Input(id="edit-input"),
Button("Save Edit", id="save-edit"),
id="edit-container",
),
Horizontal(
Button("Accept", id="btn-accept"),
Button("Reject", id="btn-reject"),
Button("Edit", id="btn-edit"),
id="actions-container",
),
ListView(id="log-pane"),
Static("LATEST LLM INPUTS", id="log-footer"),
id="right-pane",
),
id="content-wrapper",
),
Footer(),
id="main-container",
)
yield Footer()
def on_mount(self) -> None:
table = self.query_one("#update-table", DataTable)
table = self.query_one("#pending-facts-table", DataTable)
table.cursor_type = "row"
table.add_columns("Type", "Target", "Update")
table.add_columns("Type", "Target", "Content")
for i, update in enumerate(self.pending_updates):
if isinstance(update, LoreUpdate):
table.add_row(
"Lore", update.entity_name or "General", update.content, key=str(i)
)
elif isinstance(update, CharacterStateUpdate):
change_text = f"HP: {update.hp_change or 0}"
if update.status_effects_added:
change_text += f", Added: {', '.join(update.status_effects_added)}"
if update.status_effects_removed:
change_text += (
f", Removed: {', '.join(update.status_effects_removed)}"
)
table.add_row("Char", update.character_name, change_text, key=str(i))
self.add_update_to_table(update, i)
if self.pending_updates:
self.handle_row_highlight(0)
self.query_one("#btn-accept", Button).focus()
if self.proposal_queue:
self.run_worker(self.poll_proposal_queue, thread=False)
async def poll_proposal_queue(self) -> None:
"""
Background worker that polls the proposal queue for new extraction results.
"""
while True:
try:
result = await self.proposal_queue.get()
self.add_result(result)
except Exception as e:
# Log error but keep the worker running
self.log(f"Error polling proposal queue: {e}")
finally:
# Signal that the item has been processed
if hasattr(self.proposal_queue, "task_done"):
self.proposal_queue.task_done()
def add_result(self, result: ExtractionResult) -> None:
"""
Adds results from the LLM processor to the TUI table.
"""
table = self.query_one("#update-table", DataTable)
start_index = len(self.pending_updates)
for update in result.lore_updates + result.character_updates:
self.pending_updates.append(update)
actual_index = len(self.pending_updates) - 1
if isinstance(update, LoreUpdate):
table.add_row(
"Lore",
update.entity_name or "General",
update.content,
key=str(actual_index),
)
elif isinstance(update, CharacterStateUpdate):
change_text = f"HP: {update.hp_change or 0}"
if update.status_effects_added:
change_text += f", Added: {', '.join(update.status_effects_added)}"
if update.status_effects_removed:
change_text += (
f", Removed: {', '.join(update.status_effects_removed)}"
)
table.add_row(
"Char", update.character_name, change_text, key=str(actual_index)
)
# If the table was previously empty and we added updates, focus the first one.
if start_index == 0 and self.pending_updates:
self.handle_row_highlight(0)
self.query_one("#btn-accept", Button).focus()
def on_data_table_row_highlighted(self, event: DataTable.RowHighlighted) -> None:
self.handle_row_highlight(event.cursor_row)
def handle_row_highlight(self, row: int) -> None:
self.selected_index = row
if self.selected_index < 0 or self.selected_index >= len(self.pending_updates):
return
update = self.pending_updates[self.selected_index]
details_text = self.query_one("#details-text", Static)
if isinstance(update, LoreUpdate):
details_text.update(
f"Category: {update.category}\nTarget: {update.entity_name}\nContent: {update.content}"
)
elif isinstance(update, CharacterStateUpdate):
details_text.update(
f"Character: {update.character_name}\nHP Change: {update.hp_change}\nAdded Effects: {update.status_effects_added}\nRemoved Effects: {update.status_effects_removed}"
)
# Reset to detail view
self.query_one("#edit-container", Vertical).styles.display = "none"
self.query_one("#details-container", Vertical).styles.display = "block"
def on_button_pressed(self, event: Button.Pressed) -> None:
if self.selected_index == -1:
return
update = self.pending_updates[self.selected_index]
if event.button.id == "btn-accept":
if isinstance(update, LoreUpdate):
update_lore(update)
elif isinstance(update, CharacterStateUpdate):
update_character_state(update)
self.remove_update(self.selected_index)
elif event.button.id == "btn-reject":
self.remove_update(self.selected_index)
elif event.button.id == "btn-edit":
self.show_edit_mode(update)
elif event.button.id == "save-edit":
self.save_edit(update)
def show_edit_mode(self, update: Union[LoreUpdate, CharacterStateUpdate]) -> None:
edit_input = self.query_one("#edit-input", Input)
if isinstance(update, LoreUpdate):
edit_input.value = update.content
elif isinstance(update, CharacterStateUpdate):
# For simplicity, only allow editing HP change in this TUI
edit_input.value = str(update.hp_change or 0)
self.query_one("#edit-container", Vertical).styles.display = "block"
self.query_one("#details-container", Vertical).styles.display = "none"
def save_edit(self, update: Union[LoreUpdate, CharacterStateUpdate]) -> None:
new_val = self.query_one("#edit-input", Input).value
if isinstance(update, LoreUpdate):
update.content = new_val
elif isinstance(update, CharacterStateUpdate):
try:
update.hp_change = int(new_val)
except ValueError:
# Ignore invalid integer input
if self.ui_to_llm_queue:
# We don't need a poller for this, just the action_send
pass
if self.llm_to_ui_queue:
# Use Textual workers so the task isn't garbage-collected and
# exceptions are surfaced via the worker manager.
self.run_worker(self.poll_llm_updates(), exclusive=False)
if self.log_queue:
self.run_worker(self.poll_log_updates(), exclusive=False)
# Refresh the table
table = self.query_one("#update-table", DataTable)
# Textual DataTable doesn't have a simple 'update_row', so we clear and refill
# or we can use update_cell.
self.query_one("#llm-input", Input).focus()
# Update the table row
def add_update_to_table(
self, update: Union[LoreUpdate, CharacterStateUpdate], index: int
) -> None:
table = self.query_one("#pending-facts-table", DataTable)
if isinstance(update, LoreUpdate):
table.update_cell(self.selected_index, 2, update.content)
table.add_row(
"Lore", update.entity_name or "General", update.content, key=str(index)
)
elif isinstance(update, CharacterStateUpdate):
change_text = f"HP: {update.hp_change or 0}"
if update.status_effects_added:
change_text += f", Added: {', '.join(update.status_effects_added)}"
if update.status_effects_removed:
change_text += f", Removed: {', '.join(update.status_effects_removed)}"
table.update_cell(self.selected_index, 2, change_text)
table.add_row("Char", update.character_name, change_text, key=str(index))
self.show_edit_mode(update) # just to refresh the value maybe? No,
# Actually let's go back to detail view
self.query_one("#edit-container", Vertical).styles.display = "none"
self.query_one("#details-container", Vertical).styles.display = "block"
async def poll_llm_updates(self) -> None:
while True:
try:
update = await self.llm_to_ui_queue.get()
display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}"
context_list = self.query_one("#context-pane", ListView)
# ListView.insert takes an *iterable* of ListItems; passing a
# bare ListItem raises TypeError because ListItem is not iterable.
# Insert at the top to show most recent first.
await context_list.insert(0, [ListItem(Static(display_text))])
if hasattr(self.llm_to_ui_queue, "task_done"):
self.llm_to_ui_queue.task_done()
except Exception as e:
self.log(f"Error polling LLM updates: {e}")
# Update details text
details_text = self.query_one("#details-text", Static)
async def poll_log_updates(self) -> None:
while True:
try:
log_text = await self.log_queue.get()
log_list = self.query_one("#log-pane", ListView)
# See poll_llm_updates: wrap the ListItem in a list.
# Insert at the top to show most recent first.
await log_list.insert(0, [ListItem(Static(log_text))])
if hasattr(self.log_queue, "task_done"):
self.log_queue.task_done()
except Exception as e:
self.log(f"Error polling log updates: {e}")
def handle_proposal_result(self, result: ExtractionResult) -> None:
table = self.query_one("#pending-facts-table", DataTable)
for update in result.lore_updates + result.character_updates:
index = len(self.pending_updates)
self.pending_updates.append(update)
self.add_update_to_table(update, index)
async def poll_context_queue(self) -> None:
# Obsolete
pass
async def poll_response_queue(self) -> None:
# Obsolete
pass
def on_input_submitted(self, event: Input.Submitted) -> None:
if event.input.id == "llm-input":
self.action_send()
def action_send(self) -> None:
input_widget = self.query_one("#llm-input", Input)
text = input_widget.value
if text and self.ui_to_llm_queue:
self.ui_to_llm_queue.put_nowait(text)
input_widget.value = ""
def action_accept(self) -> None:
table = self.query_one("#pending-facts-table", DataTable)
row_index = table.cursor_row
if row_index < 0 or row_index >= len(self.pending_updates):
return
update = self.pending_updates[row_index]
if isinstance(update, LoreUpdate):
details_text.update(
f"Category: {update.category}\nTarget: {update.entity_name}\nContent: {update.content}"
)
update_lore(update)
elif isinstance(update, CharacterStateUpdate):
details_text.update(
f"Character: {update.character_name}\nHP Change: {update.hp_change}\nAdded Effects: {update.status_effects_added}\nRemoved Effects: {update.status_effects_removed}"
)
update_character_state(update)
self.remove_update(row_index)
def action_reject(self) -> None:
table = self.query_one("#pending-facts-table", DataTable)
row_index = table.cursor_row
if row_index < 0 or row_index >= len(self.pending_updates):
return
self.remove_update(row_index)
def action_edit(self) -> None:
table = self.query_one("#pending-facts-table", DataTable)
row_index = table.cursor_row
if row_index < 0 or row_index >= len(self.pending_updates):
return
update = self.pending_updates[row_index]
initial_text = ""
if isinstance(update, LoreUpdate):
initial_text = update.content
elif isinstance(update, CharacterStateUpdate):
initial_text = str(update.hp_change or 0)
def save_callback(new_text: str):
if isinstance(update, LoreUpdate):
update.content = new_text
elif isinstance(update, CharacterStateUpdate):
try:
update.hp_change = int(new_text)
except ValueError:
pass
# Update the table
self.refresh_table()
self.push_screen(EditModal(initial_text, save_callback))
def remove_update(self, index: int) -> None:
# Remove from the pending list
del self.pending_updates[index]
self.refresh_table()
# Clear and refill the table
table = self.query_one("#update-table", DataTable)
def refresh_table(self) -> None:
table = self.query_one("#pending-facts-table", DataTable)
table.clear()
for i, update in enumerate(self.pending_updates):
if isinstance(update, LoreUpdate):
table.add_row(
"Lore", update.entity_name or "General", update.content, key=str(i)
)
elif isinstance(update, CharacterStateUpdate):
change_text = f"HP: {update.hp_change or 0}"
if update.status_effects_added:
change_text += f", Added: {', '.join(update.status_effects_added)}"
if update.status_effects_removed:
change_text += (
f", Removed: {', '.join(update.status_effects_removed)}"
)
table.add_row("Char", update.character_name, change_text, key=str(i))
if self.pending_updates:
self.handle_row_highlight(0)
self.query_one("#btn-accept", Button).focus()
else:
self.selected_index = -1
self.query_one("#details-text", Static).update("All updates processed.")
self.add_update_to_table(update, i)
+98
View File
@@ -0,0 +1,98 @@
import os
from reportlab.pdfgen import canvas
from src.rag.manager import RAGManager
def create_dummy_phb(pdf_path: str):
"""
Creates a dummy PDF file to simulate the Player's Handbook for verification.
"""
print(f"Creating dummy PHB at {pdf_path}...")
c = canvas.Canvas(pdf_path)
# Page 1: Fireball
c.drawString(100, 750, "Fireball")
c.drawString(
100,
730,
"A bright streak flashes from your pointing finger to a point you choose within range.",
)
c.drawString(
100,
710,
"Each creature in a 20-foot-radius sphere centered on that point must make a Dexterity saving throw.",
)
c.showPage()
# Page 2: Grappling
c.drawString(100, 750, "Grappling")
c.drawString(
100,
730,
"Grappling is a special option available to any attack that hits with a melee attack.",
)
c.drawString(
100,
710,
"A creature is grappled if it is the target of the grapple attack and the attack hits.",
)
c.showPage()
# Page 3: General Rules
c.drawString(100, 750, "General Rules")
c.drawString(
100,
730,
"Combat is resolved in rounds, and each round represents 6 seconds of in-game time.",
)
c.showPage()
c.save()
print(f"Dummy PHB created successfully at {pdf_path}")
def test_rag():
pdf_path = "data/phb_dummy.pdf"
create_dummy_phb(pdf_path)
# Initialize RAG Manager
rag = RAGManager(persist_dir="data/rag_index_test")
# Task 2.2: Ingest PDF
print("\nTesting Ingestion...")
rag.ingest_pdf(pdf_path)
# Task 2.3: Retrieve Logic
print("\nTesting Retrieval...")
# Test 1: Fireball
query1 = "What is Fireball?"
results1 = rag.retrieve(query1)
print(f"Query: {query1}")
for res in results1:
print(f"Source: {res.source} | Snippet: {res.snippet[:100]}...")
assert len(results1) > 0, "Should have retrieved at least one result for 'Fireball'"
assert "Fireball" in results1[0].snippet, "The top result should contain 'Fireball'"
# Test 2: Grappling
query2 = "How does grappling work?"
results2 = rag.retrieve(query2)
print(f"Query: {query2}")
for res in results2:
print(f"Source: {res.source} | Snippet: {res.snippet[:100]}...")
assert len(results2) > 0, (
"Should have retrieved at least one result for 'Grappling'"
)
assert "Grappling" in results2[0].snippet, (
"The top result should contain 'Grappling'"
)
print("\n✅ RAG Verification Successful!")
if __name__ == "__main__":
test_rag()