Compare commits
13 Commits
58bab75bb5
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 15dfbfb467 | |||
| 49127d695a | |||
| 2363cde160 | |||
| afa8d17f10 | |||
| 1cfba3a0ae | |||
| 1098bdb2f9 | |||
| 58f736a5f8 | |||
| b25f82cefc | |||
| b83d9b5e6a | |||
| 679eca3fef | |||
| 954f2f50d8 | |||
| f4c98fb2b9 | |||
| d0fcdfab01 |
@@ -1,7 +1,8 @@
|
|||||||
# D&D Helpers Configuration
|
# D&D Helpers Configuration
|
||||||
OPENAI_API_KEY=no-key-required
|
OPENAI_API_KEY=no-key-required
|
||||||
OPENAI_BASE_URL=https://vllm.tipsy.codes/v1
|
OPENAI_BASE_URL=https://vllm.tipsy.codes/v1
|
||||||
LLM_MODEL=Intel/gemma-4-31B-it-int4-AutoRound
|
LLM_MODEL=google/gemma-4-26b-a4b-it
|
||||||
|
LLM_BACKEND=vllm
|
||||||
#LLM_BACKEND=ollama
|
#LLM_BACKEND=ollama
|
||||||
#LLM_MODEL=gemma:2b
|
#LLM_MODEL=gemma:2b
|
||||||
WHISPER_MODEL=base
|
WHISPER_MODEL=base
|
||||||
|
|||||||
+2
-1
@@ -1,2 +1,3 @@
|
|||||||
artifacts/
|
artifacts/
|
||||||
__pycache__
|
**/__pycache__/
|
||||||
|
data
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
3.12
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
from src.rag.manager import RAGManager
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="D&D Helpers CLI")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ingest-pdf",
|
||||||
|
type=str,
|
||||||
|
help="Path to a PDF file to ingest into the RAG system",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ingest-file",
|
||||||
|
type=str,
|
||||||
|
help="Path to a markdown file to ingest into the RAG system",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ingest-dir",
|
||||||
|
type=str,
|
||||||
|
help="Path to a directory of markdown files to ingest into the RAG system",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
rag_manager = RAGManager()
|
||||||
|
|
||||||
|
if args.ingest_pdf:
|
||||||
|
print(f"Ingesting PDF: {args.ingest_pdf}...")
|
||||||
|
rag_manager.ingest_pdf(args.ingest_pdf)
|
||||||
|
print("PDF ingestion complete.")
|
||||||
|
|
||||||
|
if args.ingest_file:
|
||||||
|
print(f"Ingesting File: {args.ingest_file}...")
|
||||||
|
rag_manager.ingest_file(args.ingest_file)
|
||||||
|
print("File ingestion complete.")
|
||||||
|
|
||||||
|
if args.ingest_dir:
|
||||||
|
print(f"Ingesting Directory: {args.ingest_dir}...")
|
||||||
|
rag_manager.ingest_directory(args.ingest_dir)
|
||||||
|
print("Directory ingestion complete.")
|
||||||
|
|
||||||
|
if not any([args.ingest_pdf, args.ingest_file, args.ingest_dir]):
|
||||||
|
print("Hello from dnd-helpers!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
[project]
|
||||||
|
name = "dnd-helpers"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = []
|
||||||
+6
-1
@@ -1,8 +1,13 @@
|
|||||||
# Core dependencies for D&D Helpers
|
# Core dependencies for D&D Helpers
|
||||||
faster-whisper
|
whisperx
|
||||||
sounddevice
|
sounddevice
|
||||||
pydantic
|
pydantic
|
||||||
textual
|
textual
|
||||||
typer
|
typer
|
||||||
openai
|
openai
|
||||||
python-dotenv
|
python-dotenv
|
||||||
|
llama-index
|
||||||
|
chromadb
|
||||||
|
pdfplumber
|
||||||
|
llama-index-embeddings-huggingface
|
||||||
|
llama-index-vector_stores-chroma
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -44,6 +44,22 @@ class CharacterStateUpdate(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextUpdate(BaseModel):
|
||||||
|
query: str = Field(..., description="The search query used to retrieve the context")
|
||||||
|
snippet: str = Field(
|
||||||
|
..., description="The relevant text snippet retrieved from the source"
|
||||||
|
)
|
||||||
|
source: str = Field(
|
||||||
|
..., description="The source of the snippet (e.g., 'PHB p. 12')"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FilterResult(BaseModel):
|
||||||
|
filtered_text: str = Field(
|
||||||
|
..., description="Cleaned transcript used for structured data extraction"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExtractionResult(BaseModel):
|
class ExtractionResult(BaseModel):
|
||||||
lore_updates: List[LoreUpdate] = Field(
|
lore_updates: List[LoreUpdate] = Field(
|
||||||
default_factory=list, description="List of discovered lore facts", alias="lore"
|
default_factory=list, description="List of discovered lore facts", alias="lore"
|
||||||
@@ -58,6 +74,11 @@ class ExtractionResult(BaseModel):
|
|||||||
description="List of significant plot points or events",
|
description="List of significant plot points or events",
|
||||||
alias="events",
|
alias="events",
|
||||||
)
|
)
|
||||||
|
context_updates: List[ContextUpdate] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="List of context updates",
|
||||||
|
alias="context",
|
||||||
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
populate_by_name = True
|
populate_by_name = True
|
||||||
|
|||||||
+105
-45
@@ -1,11 +1,20 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from posix import system
|
||||||
|
from this import s
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from .models import ExtractionResult
|
from .models import ExtractionResult, FilterResult
|
||||||
from .prompts import EXTRACTION_SYSTEM_PROMPT, NOISE_FILTER_SYSTEM_PROMPT
|
from .prompts import (
|
||||||
|
EXTRACTION_SYSTEM_PROMPT,
|
||||||
|
NOISE_FILTER_SYSTEM_PROMPT,
|
||||||
|
QUERY_ANSWER_SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LLMProcessor:
|
class LLMProcessor:
|
||||||
@@ -36,6 +45,7 @@ class LLMProcessor:
|
|||||||
final_base_url = base_url or os.environ.get("OPENAI_BASE_URL")
|
final_base_url = base_url or os.environ.get("OPENAI_BASE_URL")
|
||||||
final_api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
final_api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
logger.info(f"Using LLM backend: {backend}")
|
||||||
try:
|
try:
|
||||||
self.client = OpenAI(
|
self.client = OpenAI(
|
||||||
api_key=final_api_key,
|
api_key=final_api_key,
|
||||||
@@ -47,82 +57,132 @@ class LLMProcessor:
|
|||||||
# but we can ensure the client is initialized.
|
# but we can ensure the client is initialized.
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error initializing LLM client for backend {backend}: {e}")
|
logger.error(f"Error initializing LLM client for backend {backend}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
self.model = model or os.environ.get("LLM_MODEL", "gpt-4o")
|
self.model = model or os.environ.get("LLM_MODEL", "gpt-4o")
|
||||||
|
|
||||||
|
def _strip_markdown_code_blocks(self, content: str) -> str:
|
||||||
|
"""
|
||||||
|
Strips markdown code blocks (e.g., ```json ... ```) from the content.
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Remove opening and closing code blocks
|
||||||
|
content = re.sub(
|
||||||
|
r"^```(?:json)?\n?|```$", "", content, flags=re.MULTILINE
|
||||||
|
).strip()
|
||||||
|
return content
|
||||||
|
|
||||||
def _call_llm(
|
def _call_llm(
|
||||||
self,
|
self,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
user_prompt: str,
|
||||||
|
context: Optional[str] = None,
|
||||||
response_format: Optional[Any] = None,
|
response_format: Optional[Any] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generic method to call the LLM.
|
Generic method to call the LLM.
|
||||||
"""
|
"""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
]
|
||||||
|
if context:
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"Context from previous conversation:\n{context}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": user_prompt})
|
||||||
|
|
||||||
|
# Debugging: Dump inputs
|
||||||
|
logger.debug("--- LLM CALL START ---")
|
||||||
|
logger.debug(f"Model: {self.model}")
|
||||||
|
logger.debug(f"Messages: {messages}")
|
||||||
|
if response_format:
|
||||||
|
logger.debug(f"Response Format: {response_format}")
|
||||||
|
logger.debug("--- LLM CALL END ---")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=[
|
messages=messages,
|
||||||
{"role": "system", "content": system_prompt},
|
|
||||||
{"role": "user", "content": user_prompt},
|
|
||||||
],
|
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
extra_body={"include_reasoning": False},
|
extra_body={"enable_thinking": False},
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
|
# Debugging: Dump outputs
|
||||||
|
logger.debug("--- LLM RESPONSE START ---")
|
||||||
|
logger.debug(f"Content: {content}")
|
||||||
|
logger.debug("--- LLM RESPONSE END ---")
|
||||||
|
|
||||||
|
return self._strip_markdown_code_blocks(content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"LLM Error: {e}")
|
logger.error(f"LLM Error: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def filter_transcript(self, text: str) -> str:
|
def generate_answer(self, query: str, context: str) -> str:
|
||||||
|
"""
|
||||||
|
Generates a natural language answer to a DM query.
|
||||||
|
"""
|
||||||
|
return self._call_llm(
|
||||||
|
QUERY_ANSWER_SYSTEM_PROMPT,
|
||||||
|
query,
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
def filter_transcript(
|
||||||
|
self, text: str, context: Optional[str] = None
|
||||||
|
) -> FilterResult:
|
||||||
"""
|
"""
|
||||||
Stage 1: Raw Transcript -> Filtered Text.
|
Stage 1: Raw Transcript -> Filtered Text.
|
||||||
"""
|
"""
|
||||||
result = self._call_llm(NOISE_FILTER_SYSTEM_PROMPT, text)
|
result = self._call_llm(
|
||||||
print(f"LLM Processor (Filter): {text} -> {result}")
|
NOISE_FILTER_SYSTEM_PROMPT,
|
||||||
return result
|
text,
|
||||||
|
context=context,
|
||||||
def extract_structured_data(self, filtered_text: str) -> ExtractionResult:
|
|
||||||
"""
|
|
||||||
Stage 2: Filtered Text -> Structured Data.
|
|
||||||
"""
|
|
||||||
print(f"LLM Processor (Extract): Calling extraction for: {filtered_text}")
|
|
||||||
try:
|
|
||||||
# Using standard chat.completions.create with JSON mode for better compatibility with vLLM
|
|
||||||
print("LLM Processor (Extract): Sending request to backend...")
|
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": EXTRACTION_SYSTEM_PROMPT},
|
|
||||||
{"role": "user", "content": filtered_text},
|
|
||||||
],
|
|
||||||
response_format={"type": "json_object"},
|
response_format={"type": "json_object"},
|
||||||
extra_body={"include_reasoning": False},
|
|
||||||
)
|
)
|
||||||
print("LLM Processor (Extract): Response received from backend.")
|
logger.info(f"LLM Processor (Filter): {text} -> {result}")
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
try:
|
||||||
print(f"LLM Processor (Extract): Raw JSON response: {content}")
|
data = json.loads(result)
|
||||||
data = json.loads(content)
|
return FilterResult(**data)
|
||||||
|
except (json.JSONDecodeError, ValidationError) as e:
|
||||||
|
logger.error(f"Filter Parsing Error: {e}")
|
||||||
|
return FilterResult(contextual_info="", filtered_text=result)
|
||||||
|
|
||||||
|
def extract_structured_data(
|
||||||
|
self, filtered_text: str, context: Optional[str] = None
|
||||||
|
) -> ExtractionResult:
|
||||||
|
"""
|
||||||
|
Stage 2: Filtered Text -> Structured Data.
|
||||||
|
"""
|
||||||
|
logger.info(f"LLM Processor (Extract): Calling extraction for: {filtered_text}")
|
||||||
|
try:
|
||||||
|
system_prompt = EXTRACTION_SYSTEM_PROMPT
|
||||||
|
if context:
|
||||||
|
system_prompt += f"\n{context}"
|
||||||
|
|
||||||
|
result = self._call_llm(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_prompt=filtered_text,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
)
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = json.loads(result)
|
||||||
|
|
||||||
# Map the JSON data to the Pydantic model
|
# Map the JSON data to the Pydantic model
|
||||||
return ExtractionResult(**data)
|
return ExtractionResult(**data)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Extraction Error: {e}")
|
logger.error(f"Extraction Error: {e}")
|
||||||
# Return an empty ExtractionResult if parsing fails
|
# Return an empty ExtractionResult if parsing fails
|
||||||
return ExtractionResult()
|
return ExtractionResult()
|
||||||
|
|
||||||
def process_pipeline(self, raw_text: str) -> ExtractionResult:
|
|
||||||
"""
|
|
||||||
Executes the two-stage pipeline: Raw Transcript -> Filtered Text -> Structured Data.
|
|
||||||
"""
|
|
||||||
filtered_text = self.filter_transcript(raw_text)
|
|
||||||
if not filtered_text:
|
|
||||||
return ExtractionResult()
|
|
||||||
|
|
||||||
return self.extract_structured_data(filtered_text)
|
|
||||||
|
|||||||
+29
-7
@@ -1,8 +1,19 @@
|
|||||||
# System prompts for the LLM pipeline
|
QUERY_ANSWER_SYSTEM_PROMPT = """
|
||||||
|
You are a helpful D&D Game Master's assistant. Your goal is to provide accurate, concise, and helpful answers to the DM's questions based on the provided context (conversation history and RAG snippets).
|
||||||
|
|
||||||
|
Guidelines:
|
||||||
|
- Use the provided context as your primary source of truth.
|
||||||
|
- If the answer is not in the context, state that you don't have enough information, but feel free to provide general D&D 5e rules as a fallback.
|
||||||
|
- Keep responses natural and professional.
|
||||||
|
- Be concise.
|
||||||
|
"""
|
||||||
|
|
||||||
NOISE_FILTER_SYSTEM_PROMPT = """
|
NOISE_FILTER_SYSTEM_PROMPT = """
|
||||||
You are a D&D Game Master's assistant. Given a transcript, remove all out-of-character (OOC) chatter, logistical discussions (e.g., 'Where is my d20?'), and non-relevant noise.
|
You are a D&D Game Master's assistant. Given a transcript, remove all out-of-character (OOC) chatter, logistical discussions (e.g., 'Where is my d20?'), and non-relevant noise.
|
||||||
Output only the in-character dialogue and game-relevant events.
|
|
||||||
|
You must output your response as a JSON object with the following keys:
|
||||||
|
- "filtered_text": The cleaned transcript. IMPORTANT: Keep all player questions, requests for rule clarifications, and mentions of spells, NPCs, or locations in this field, as they are used to trigger knowledge base lookups.
|
||||||
|
|
||||||
Keep the original speakers' names if they are present in the transcript.
|
Keep the original speakers' names if they are present in the transcript.
|
||||||
Do not add any commentary or summaries. Just filter the text.
|
Do not add any commentary or summaries. Just filter the text.
|
||||||
"""
|
"""
|
||||||
@@ -10,15 +21,14 @@ Do not add any commentary or summaries. Just filter the text.
|
|||||||
EXTRACTION_SYSTEM_PROMPT = """
|
EXTRACTION_SYSTEM_PROMPT = """
|
||||||
You are a D&D session analyzer. Your goal is to extract structured data from a filtered transcript.
|
You are a D&D session analyzer. Your goal is to extract structured data from a filtered transcript.
|
||||||
Extract any changes to character states (HP, status effects, inventory) and any new lore facts (NPCs, locations, world-building).
|
Extract any changes to character states (HP, status effects, inventory) and any new lore facts (NPCs, locations, world-building).
|
||||||
|
In addition extracting updates to character state and lore, look for the oppertunity to provide useful context,
|
||||||
DO NOT THINK.
|
such as the answer to a player's question or the resolution of a lore fact.
|
||||||
|
|
||||||
CONSTRAINTS:
|
CONSTRAINTS:
|
||||||
- OUTPUT ONLY VALID JSON.
|
- OUTPUT ONLY VALID JSON.
|
||||||
- DO NOT include any commentary, explanations, or "thought" blocks.
|
|
||||||
- DO NOT include any keys other than "lore", "character_state", and "events".
|
|
||||||
- If no relevant information is found, return empty lists for all keys.
|
- If no relevant information is found, return empty lists for all keys.
|
||||||
- If a character name is not specified (e.g., "Your character"), use "Player Character".
|
- If a character name is not specified (e.g., "Your character"), use "Player Character".
|
||||||
|
- Do not repeat lore if it is already known; only provide new or updated facts.
|
||||||
|
|
||||||
Strict Output Format:
|
Strict Output Format:
|
||||||
Return a JSON object with exactly these keys:
|
Return a JSON object with exactly these keys:
|
||||||
@@ -26,6 +36,7 @@ Return a JSON object with exactly these keys:
|
|||||||
- "category": (string) 'NPC', 'Location', 'WorldBuilding', or 'Plot'
|
- "category": (string) 'NPC', 'Location', 'WorldBuilding', or 'Plot'
|
||||||
- "entity_name": (string) The name of the NPC, Location, or entity
|
- "entity_name": (string) The name of the NPC, Location, or entity
|
||||||
- "content": (string) The actual lore fact or description
|
- "content": (string) The actual lore fact or description
|
||||||
|
- "context": (string, optional) Helpful information for the DM (e.g., descriptions of characters, spell details, game mechanics) discovered via the knowledge base or the transcript.
|
||||||
2. "character_state": A list of objects. Each object MUST have:
|
2. "character_state": A list of objects. Each object MUST have:
|
||||||
- "character_name": (string) Name of the character
|
- "character_name": (string) Name of the character
|
||||||
- "hp_change": (integer, optional) Change in HP
|
- "hp_change": (integer, optional) Change in HP
|
||||||
@@ -33,6 +44,10 @@ Return a JSON object with exactly these keys:
|
|||||||
- "status_effects_removed": (list of strings)
|
- "status_effects_removed": (list of strings)
|
||||||
- "inventory_changes": (list of objects with "item", "quantity", "action")
|
- "inventory_changes": (list of objects with "item", "quantity", "action")
|
||||||
3. "events": A list of strings. Each string should be a concise description of a significant plot development.
|
3. "events": A list of strings. Each string should be a concise description of a significant plot development.
|
||||||
|
4. "context": A list of objects. Each object MUST have:
|
||||||
|
- "query": (string) The original query or topic this context relates to
|
||||||
|
- "snippet": (string) The contextual information or rule explanation
|
||||||
|
- "source": (string) The source of the context (e.g., "players handbook, page 68")
|
||||||
|
|
||||||
Example Output:
|
Example Output:
|
||||||
{
|
{
|
||||||
@@ -40,7 +55,7 @@ Example Output:
|
|||||||
{
|
{
|
||||||
"category": "NPC",
|
"category": "NPC",
|
||||||
"entity_name": "Thorne",
|
"entity_name": "Thorne",
|
||||||
"content": "A gruff dwarf who runs the local tavern."
|
"content": "A gruff dwarf who runs the local tavern.",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"character_state": [
|
"character_state": [
|
||||||
@@ -54,6 +69,13 @@ Example Output:
|
|||||||
],
|
],
|
||||||
"events": [
|
"events": [
|
||||||
"The party discovered the secret entrance to the crypt."
|
"The party discovered the secret entrance to the crypt."
|
||||||
|
],
|
||||||
|
"context": [
|
||||||
|
{
|
||||||
|
"query": "fireball",
|
||||||
|
"snippet": "fireball does 1d6 damage, with an area of effect of 10 feet circle",
|
||||||
|
"source": "players handbook, page 68"
|
||||||
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,67 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from src.llm.models import ContextUpdate
|
||||||
|
from src.rag.manager import RAGManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextPipeline:
|
||||||
|
def __init__(self, rag_manager: RAGManager):
|
||||||
|
self.rag_manager = rag_manager
|
||||||
|
|
||||||
|
async def process_message(
|
||||||
|
self, speaker: str, text: str, context_queue: asyncio.Queue
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Processes a single message and pushes summarized insights to the context queue.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use RAGManager.retrieve with summarize=True to get concise insights
|
||||||
|
# Run in a thread to avoid blocking the event loop
|
||||||
|
insights = await asyncio.to_thread(
|
||||||
|
self.rag_manager.retrieve, text, summarize=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if insights:
|
||||||
|
logger.info(
|
||||||
|
f"ContextPipeline: Found {len(insights)} insights for text: {text}"
|
||||||
|
)
|
||||||
|
for insight in insights:
|
||||||
|
await context_queue.put(insight)
|
||||||
|
else:
|
||||||
|
logger.debug(f"ContextPipeline: No insights found for text: {text}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ContextPipeline error processing message: {e}")
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
transcript_queue: asyncio.Queue,
|
||||||
|
context_queue: asyncio.Queue,
|
||||||
|
stop_event: asyncio.Event,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Main loop that listens to the transcript queue and triggers RAG lookups.
|
||||||
|
"""
|
||||||
|
logger.info("Context Pipeline started.")
|
||||||
|
while not stop_event.is_set():
|
||||||
|
try:
|
||||||
|
# Get raw text from transcript queue (speaker, text)
|
||||||
|
speaker, text = await transcript_queue.get()
|
||||||
|
|
||||||
|
# For now, implement the basic flow: every message triggers a lookup.
|
||||||
|
# If performance becomes an issue, a filter can be added here.
|
||||||
|
await self.process_message(speaker, text, context_queue)
|
||||||
|
|
||||||
|
# Mark the task as done
|
||||||
|
transcript_queue.task_done()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Context Pipeline loop error: {e}")
|
||||||
|
|
||||||
|
# Small sleep to avoid tight loop
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
logger.info("Context Pipeline stopped.")
|
||||||
+245
-34
@@ -1,13 +1,40 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from src.llm.models import ExtractionResult
|
import numpy as np
|
||||||
|
|
||||||
|
from src.llm.models import (
|
||||||
|
CharacterStateUpdate,
|
||||||
|
ContextUpdate,
|
||||||
|
ExtractionResult,
|
||||||
|
LoreUpdate,
|
||||||
|
)
|
||||||
from src.llm.processor import LLMProcessor
|
from src.llm.processor import LLMProcessor
|
||||||
|
from src.llm.prompts import EXTRACTION_SYSTEM_PROMPT, NOISE_FILTER_SYSTEM_PROMPT
|
||||||
|
from src.persistence.characters import update_character_state
|
||||||
|
from src.persistence.lore import update_lore
|
||||||
|
from src.rag.manager import RAGManager
|
||||||
from src.stt.listener import AudioListener
|
from src.stt.listener import AudioListener
|
||||||
from src.stt.transcriber import Transcriber
|
from src.stt.transcriber import Transcriber
|
||||||
from src.ui.tui import ConfirmationApp
|
from src.ui.tui import ConfirmationApp
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
# Configure logging to write to a file instead of stdout
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("pipeline.log"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# Suppress verbose logging from STT libraries to keep the TUI clean
|
||||||
|
logging.getLogger("whisper").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("faster_whisper").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("pyannote").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("whisperx").setLevel(logging.ERROR)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -17,15 +44,45 @@ class PipelineOrchestrator:
|
|||||||
|
|
||||||
# Modules
|
# Modules
|
||||||
self.listener = AudioListener(loop=self.loop)
|
self.listener = AudioListener(loop=self.loop)
|
||||||
self.transcriber = Transcriber()
|
self.transcriber = Transcriber(model_size="base", device="cuda")
|
||||||
self.processor = LLMProcessor()
|
self.processor = LLMProcessor()
|
||||||
|
self.rag_manager = RAGManager()
|
||||||
|
|
||||||
# Queues
|
# Queues
|
||||||
self.transcript_queue = asyncio.Queue()
|
self.stt_to_clean_queue = asyncio.Queue()
|
||||||
self.proposal_queue = asyncio.Queue()
|
self.ui_to_llm_queue = asyncio.Queue()
|
||||||
|
self.clean_to_llm_queue = asyncio.Queue()
|
||||||
|
self.llm_to_ui_queue = asyncio.Queue()
|
||||||
|
self.log_queue = asyncio.Queue()
|
||||||
|
self.persistence_queue = asyncio.Queue()
|
||||||
|
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
|
||||||
|
# Conversation history for context
|
||||||
|
self.history = [] # List of strings (transcripts)
|
||||||
|
self.history_max_words = 1000
|
||||||
|
|
||||||
|
# STT Sliding Window Buffer
|
||||||
|
self.audio_buffer = [] # List of audio chunks
|
||||||
|
self.buffer_max_seconds = 30
|
||||||
|
self.sample_rate = 16000
|
||||||
|
self.buffer_max_samples = self.buffer_max_seconds * self.sample_rate
|
||||||
|
self.last_processed_end_time = 0.0
|
||||||
|
|
||||||
|
def _get_combined_context(self) -> str:
|
||||||
|
"""
|
||||||
|
Returns the trimmed conversation history as a context string.
|
||||||
|
"""
|
||||||
|
full_history_text = " ".join(self.history)
|
||||||
|
words = full_history_text.split()
|
||||||
|
if len(words) > self.history_max_words:
|
||||||
|
kept_words = words[-self.history_max_words :]
|
||||||
|
context_text = " ".join(kept_words)
|
||||||
|
else:
|
||||||
|
context_text = full_history_text
|
||||||
|
|
||||||
|
return f"Conversation History:\n{context_text}\n\n"
|
||||||
|
|
||||||
async def stt_worker(self):
|
async def stt_worker(self):
|
||||||
"""
|
"""
|
||||||
Worker that handles STT: Audio -> Text.
|
Worker that handles STT: Audio -> Text.
|
||||||
@@ -36,12 +93,37 @@ class PipelineOrchestrator:
|
|||||||
# Get audio chunk from listener
|
# Get audio chunk from listener
|
||||||
audio_chunk = await self.listener.get_chunk()
|
audio_chunk = await self.listener.get_chunk()
|
||||||
|
|
||||||
# Transcribe
|
# Maintain sliding window buffer
|
||||||
text = self.transcriber.transcribe(audio_chunk)
|
self.audio_buffer.append(audio_chunk)
|
||||||
|
current_buffer_samples = sum(len(c) for c in self.audio_buffer)
|
||||||
|
|
||||||
if text:
|
if current_buffer_samples > self.buffer_max_samples:
|
||||||
logger.info(f"Transcribed: {text}")
|
# Remove oldest chunks until we are within the buffer limit
|
||||||
await self.transcript_queue.put(text)
|
while (
|
||||||
|
sum(len(c) for c in self.audio_buffer) > self.buffer_max_samples
|
||||||
|
):
|
||||||
|
self.audio_buffer.pop(0)
|
||||||
|
|
||||||
|
# Concatenate buffer for transcription
|
||||||
|
full_audio = np.concatenate(self.audio_buffer)
|
||||||
|
|
||||||
|
# Transcribe (WhisperX now returns a list of (speaker, text, start, end))
|
||||||
|
results = await asyncio.to_thread(
|
||||||
|
self.transcriber.transcribe, full_audio
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter for only new segments that start after the last processed segment
|
||||||
|
new_segments = [
|
||||||
|
res for res in results if res[2] >= self.last_processed_end_time
|
||||||
|
]
|
||||||
|
|
||||||
|
if new_segments:
|
||||||
|
for speaker, text, start, end in new_segments:
|
||||||
|
logger.info(f"Transcribed: [{speaker}] {text}")
|
||||||
|
await self.stt_to_clean_queue.put((speaker, text))
|
||||||
|
self.last_processed_end_time = max(
|
||||||
|
self.last_processed_end_time, end
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"STT Worker error: {e}")
|
logger.error(f"STT Worker error: {e}")
|
||||||
@@ -49,33 +131,101 @@ class PipelineOrchestrator:
|
|||||||
# Small sleep to prevent tight loop if get_chunk is fast
|
# Small sleep to prevent tight loop if get_chunk is fast
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
async def llm_worker(self):
|
async def clean_worker(self):
|
||||||
"""
|
"""
|
||||||
Worker that handles LLM: Text -> Proposal.
|
Worker that handles Text Cleaning: Raw STT -> Filtered Text.
|
||||||
"""
|
"""
|
||||||
logger.info("LLM Worker started.")
|
logger.info("Clean Worker started.")
|
||||||
while self.is_running:
|
while self.is_running:
|
||||||
try:
|
try:
|
||||||
# Get raw text from transcript queue
|
# Get raw transcript from STT
|
||||||
raw_text = await self.transcript_queue.get()
|
speaker, raw_text = await self.stt_to_clean_queue.get()
|
||||||
|
logger.info(f"Clean Worker: Filtering text from {speaker}: {raw_text}")
|
||||||
|
|
||||||
logger.info(f"LLM Worker: Processing text: {raw_text}")
|
# RAG Retrieval for context
|
||||||
|
context = await asyncio.to_thread(self.rag_manager.retrieve, raw_text)
|
||||||
|
|
||||||
# Process via LLM (Filter -> Extract)
|
# Filtering using the processor
|
||||||
# Note: this is currently a synchronous call, which blocks the loop.
|
filter_result = await asyncio.to_thread(
|
||||||
result = self.processor.process_pipeline(raw_text)
|
self.processor.filter_transcript,
|
||||||
|
raw_text,
|
||||||
if (
|
context=context,
|
||||||
result.lore_updates
|
|
||||||
or result.character_updates
|
|
||||||
or result.significant_events
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
f"LLM Worker: Proposal generated. Putting into proposal queue. (Lore: {len(result.lore_updates)}, Char: {len(result.character_updates)})"
|
|
||||||
)
|
)
|
||||||
await self.proposal_queue.put(result)
|
|
||||||
|
# Push filtered text to LLM queue
|
||||||
|
if filter_result.filtered_text:
|
||||||
|
await self.clean_to_llm_queue.put(
|
||||||
|
(speaker, filter_result.filtered_text)
|
||||||
|
)
|
||||||
|
logger.info(f"Clean Worker: Pushed filtered text to LLM queue.")
|
||||||
else:
|
else:
|
||||||
logger.info("LLM Worker: No relevant game data extracted.")
|
logger.info("Clean Worker: No filtered text to push.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Clean Worker error: {e}")
|
||||||
|
|
||||||
|
# Small sleep to prevent tight loop
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
async def llm_worker(self):
|
||||||
|
"""
|
||||||
|
Worker that handles LLM: Filtered Text/UI Input -> Structured Data & UI Updates.
|
||||||
|
"""
|
||||||
|
logger.info("LLM Worker started.")
|
||||||
|
|
||||||
|
# Internal queue to serialize processing from multiple sources
|
||||||
|
internal_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def feed_clean():
|
||||||
|
while self.is_running:
|
||||||
|
try:
|
||||||
|
item = await self.clean_to_llm_queue.get()
|
||||||
|
await internal_queue.put(item)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM Feeder (Clean) error: {e}")
|
||||||
|
|
||||||
|
async def feed_ui():
|
||||||
|
while self.is_running:
|
||||||
|
try:
|
||||||
|
item = await self.ui_to_llm_queue.get()
|
||||||
|
if isinstance(item, (LoreUpdate, CharacterStateUpdate)):
|
||||||
|
await self.persistence_queue.put(item)
|
||||||
|
else:
|
||||||
|
await internal_queue.put(("UI", item))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM Feeder (UI) error: {e}")
|
||||||
|
|
||||||
|
# Start feeder tasks
|
||||||
|
feeders = [
|
||||||
|
asyncio.create_task(feed_clean()),
|
||||||
|
asyncio.create_task(feed_ui()),
|
||||||
|
]
|
||||||
|
|
||||||
|
while self.is_running:
|
||||||
|
try:
|
||||||
|
speaker, text = await internal_queue.get()
|
||||||
|
logger.info(f"LLM Worker: Processing text from {speaker}: {text}")
|
||||||
|
|
||||||
|
# RAG Retrieval for context
|
||||||
|
context = await asyncio.to_thread(self.rag_manager.retrieve, text)
|
||||||
|
|
||||||
|
# Log the text sent to the LLM for UI affordance
|
||||||
|
await self.log_queue.put(f"[{speaker}] {text}")
|
||||||
|
|
||||||
|
# Structured extraction using the processor
|
||||||
|
extraction_result = await asyncio.to_thread(
|
||||||
|
self.processor.extract_structured_data,
|
||||||
|
text,
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send the entire result to UI for confirmation
|
||||||
|
await self.llm_to_ui_queue.put(extraction_result)
|
||||||
|
|
||||||
|
# UI Notification: Context Updates
|
||||||
|
for context_update in extraction_result.context_updates:
|
||||||
|
await self.llm_to_ui_queue.put(context_update)
|
||||||
|
logger.info(f"LLM Worker: Pushed context update to UI.")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM Worker error: {e}")
|
logger.error(f"LLM Worker error: {e}")
|
||||||
@@ -83,18 +233,57 @@ class PipelineOrchestrator:
|
|||||||
# Small sleep
|
# Small sleep
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Clean up feeders
|
||||||
|
for f in feeders:
|
||||||
|
f.cancel()
|
||||||
|
|
||||||
|
async def persistence_worker(self):
|
||||||
|
"""
|
||||||
|
Worker that handles persistence: Confirmed updates -> Disk & RAG.
|
||||||
|
"""
|
||||||
|
logger.info("Persistence Worker started.")
|
||||||
|
while self.is_running:
|
||||||
|
try:
|
||||||
|
update = await self.persistence_queue.get()
|
||||||
|
if isinstance(update, LoreUpdate):
|
||||||
|
file_path = await asyncio.to_thread(update_lore, update)
|
||||||
|
await asyncio.to_thread(self.rag_manager.ingest_file, file_path)
|
||||||
|
logger.info(
|
||||||
|
f"Persistence Worker: Lore updated and ingested into RAG: {update.entity_name}"
|
||||||
|
)
|
||||||
|
elif isinstance(update, CharacterStateUpdate):
|
||||||
|
await asyncio.to_thread(update_character_state, update)
|
||||||
|
logger.info(
|
||||||
|
f"Persistence Worker: Character {update.character_name} state updated."
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(self.persistence_queue, "task_done"):
|
||||||
|
self.persistence_queue.task_done()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Persistence Worker error: {e}")
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
async def tui_worker(self):
|
async def tui_worker(self):
|
||||||
"""
|
"""
|
||||||
Worker that handles TUI: Proposal -> Persistence.
|
Worker that handles TUI: UI interactions.
|
||||||
"""
|
"""
|
||||||
logger.info("TUI Worker started.")
|
logger.info("TUI Worker started.")
|
||||||
try:
|
try:
|
||||||
# Launch TUI exactly once.
|
# Launch TUI.
|
||||||
# Pass the proposal queue to the app.
|
# Use the new queues for the TUI.
|
||||||
app = ConfirmationApp(proposal_queue=self.proposal_queue)
|
app = ConfirmationApp(
|
||||||
|
ui_to_llm_queue=self.ui_to_llm_queue,
|
||||||
|
llm_to_ui_queue=self.llm_to_ui_queue,
|
||||||
|
log_queue=self.log_queue,
|
||||||
|
)
|
||||||
await app.run_async()
|
await app.run_async()
|
||||||
|
self.stop()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"TUI Worker error: {e}")
|
logger.error(f"TUI Worker error: {e}")
|
||||||
|
self.stop()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""
|
"""
|
||||||
@@ -103,10 +292,18 @@ class PipelineOrchestrator:
|
|||||||
self.is_running = True
|
self.is_running = True
|
||||||
self.listener.start()
|
self.listener.start()
|
||||||
|
|
||||||
|
# Initialize Context Pipeline
|
||||||
|
from src.pipeline.context_pipeline import ContextPipeline
|
||||||
|
|
||||||
|
self.context_pipeline = ContextPipeline(self.rag_manager)
|
||||||
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
# Start workers as background tasks
|
# Start workers as background tasks
|
||||||
tasks = [
|
tasks = [
|
||||||
asyncio.create_task(self.stt_worker()),
|
asyncio.create_task(self.stt_worker()),
|
||||||
|
asyncio.create_task(self.clean_worker()),
|
||||||
asyncio.create_task(self.llm_worker()),
|
asyncio.create_task(self.llm_worker()),
|
||||||
|
asyncio.create_task(self.persistence_worker()),
|
||||||
asyncio.create_task(self.tui_worker()),
|
asyncio.create_task(self.tui_worker()),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -118,6 +315,7 @@ class PipelineOrchestrator:
|
|||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
stop_event.set()
|
||||||
self.listener.stop()
|
self.listener.stop()
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
@@ -127,6 +325,19 @@ class PipelineOrchestrator:
|
|||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""
|
"""
|
||||||
Stops the pipeline.
|
Stops.
|
||||||
"""
|
"""
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
orchestrator = PipelineOrchestrator(loop)
|
||||||
|
try:
|
||||||
|
await orchestrator.run()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
orchestrator.stop()
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -0,0 +1,193 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
import pdfplumber
|
||||||
|
from llama_index.core import Document, Settings, StorageContext, VectorStoreIndex
|
||||||
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||||
|
|
||||||
|
from src.llm.models import ContextUpdate
|
||||||
|
from src.llm.processor import LLMProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class RAGManager:
|
||||||
|
def __init__(self, persist_dir: str = "data/rag_index"):
|
||||||
|
self.persist_dir = persist_dir
|
||||||
|
self.db = chromadb.PersistentClient(path=self.persist_dir)
|
||||||
|
self.collection_name = "phb_collection"
|
||||||
|
|
||||||
|
# Initialize Chroma Vector Store
|
||||||
|
self.vector_store = ChromaVectorStore(
|
||||||
|
chroma_collection=self.db.get_or_create_collection(self.collection_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize Storage Context
|
||||||
|
self.storage_context = StorageContext.from_defaults(
|
||||||
|
vector_store=self.vector_store
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use a local HuggingFace embedding model to avoid API key issues during verification
|
||||||
|
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
|
||||||
|
|
||||||
|
# Load index if it exists, otherwise initialize
|
||||||
|
try:
|
||||||
|
self.index = VectorStoreIndex.from_vector_store(
|
||||||
|
self.vector_store, storage_context=self.storage_context
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self.index = None
|
||||||
|
|
||||||
|
def ingest_pdf(self, pdf_path: str):
|
||||||
|
"""
|
||||||
|
Parses a PDF, chunks it, and stores embeddings in ChromaDB.
|
||||||
|
"""
|
||||||
|
documents = []
|
||||||
|
with pdfplumber.open(pdf_path) as pdf:
|
||||||
|
for i, page in enumerate(pdf.pages):
|
||||||
|
text = page.extract_text()
|
||||||
|
if text:
|
||||||
|
# Create a document for each page
|
||||||
|
# In a real scenario, we might use a recursive character splitter
|
||||||
|
# but for PHB, page-level chunking is a good start.
|
||||||
|
doc = Document(
|
||||||
|
text=text, metadata={"source": f"PHB p. {i + 1}", "page": i + 1}
|
||||||
|
)
|
||||||
|
documents.append(doc)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print(f"No text extracted from {pdf_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create index from documents
|
||||||
|
self.index = VectorStoreIndex.from_documents(
|
||||||
|
documents, storage_context=self.storage_context
|
||||||
|
)
|
||||||
|
print(f"Successfully ingested {pdf_path} into the vector store.")
|
||||||
|
|
||||||
|
def ingest_file(self, file_path: str):
|
||||||
|
"""
|
||||||
|
Loads a single markdown file into the index.
|
||||||
|
"""
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
# Use the filename as the source
|
||||||
|
source = os.path.basename(file_path)
|
||||||
|
doc = Document(text=text, metadata={"source": source})
|
||||||
|
|
||||||
|
# If index doesn't exist, initialize it
|
||||||
|
if not self.index:
|
||||||
|
self.index = VectorStoreIndex.from_documents(
|
||||||
|
[doc], storage_context=self.storage_context
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Insert into existing index
|
||||||
|
self.index.insert(doc)
|
||||||
|
|
||||||
|
print(f"Successfully ingested {file_path} into the vector store.")
|
||||||
|
|
||||||
|
def ingest_directory(self, dir_path: str):
|
||||||
|
"""
|
||||||
|
Recursively loads all markdown files in a directory into the index.
|
||||||
|
"""
|
||||||
|
files_processed = 0
|
||||||
|
for root, _, files in os.walk(dir_path):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(".md"):
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
self.ingest_file(file_path)
|
||||||
|
files_processed += 1
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Successfully ingested {files_processed} files from {dir_path} into the vector store."
|
||||||
|
)
|
||||||
|
|
||||||
|
def summarize_results(self, query: str, nodes: List[Any]) -> List[ContextUpdate]:
|
||||||
|
"""
|
||||||
|
Uses an LLM to transform raw snippets into concise "insights", filtering out irrelevant content.
|
||||||
|
"""
|
||||||
|
if not nodes:
|
||||||
|
return []
|
||||||
|
|
||||||
|
processor = LLMProcessor()
|
||||||
|
|
||||||
|
# Construct the context from retrieved nodes
|
||||||
|
context_text = "\n\n".join(
|
||||||
|
[
|
||||||
|
f"Source: {node.metadata.get('source', 'Unknown')}\nContent: {node.text}"
|
||||||
|
for node in nodes
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
system_prompt = (
|
||||||
|
"You are a precise research assistant. Your task is to analyze provided text snippets "
|
||||||
|
"and extract only the information that is directly relevant to the user's query. "
|
||||||
|
"1. If a snippet is irrelevant to the query, discard it completely. "
|
||||||
|
"2. For relevant information, synthesize it into a concise, single-sentence 'insight'. "
|
||||||
|
"3. Do not simply repeat the raw text; summarize it for clarity and brevity. "
|
||||||
|
"4. If no snippets are relevant to the query, return an empty list. "
|
||||||
|
"5. Be factual and do not hallucinate. Use only the provided snippets."
|
||||||
|
)
|
||||||
|
|
||||||
|
user_prompt = (
|
||||||
|
f"Query: {query}\n\n"
|
||||||
|
f"Snippets:\n{context_text}\n\n"
|
||||||
|
"Return a JSON object with a key 'insights' containing a list of objects, each with 'snippet' and 'source'."
|
||||||
|
)
|
||||||
|
|
||||||
|
result = processor._call_llm(
|
||||||
|
system_prompt,
|
||||||
|
user_prompt,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
)
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(result)
|
||||||
|
# Expecting a format like {"insights": [{"snippet": "...", "source": "..."}, ...]}
|
||||||
|
insights = data.get("insights", []) if isinstance(data, dict) else data
|
||||||
|
|
||||||
|
if not insights:
|
||||||
|
print(f"Summarization: No relevant insights found for query: {query}")
|
||||||
|
|
||||||
|
return [
|
||||||
|
ContextUpdate(
|
||||||
|
query=query, snippet=item["snippet"], source=item["source"]
|
||||||
|
)
|
||||||
|
for item in insights
|
||||||
|
]
|
||||||
|
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||||
|
print(f"Summarization parsing error: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def retrieve(
|
||||||
|
self, query: str, top_k: int = 5, summarize: bool = False
|
||||||
|
) -> List[ContextUpdate]:
|
||||||
|
"""
|
||||||
|
Retrieves the top-K most relevant snippets for a given query,
|
||||||
|
filtering for those with a similarity score > 0.7.
|
||||||
|
"""
|
||||||
|
if not self.index:
|
||||||
|
print("Index not initialized. Please ingest documents first.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Create a retriever
|
||||||
|
retriever = self.index.as_retriever(similarity_top_k=top_k)
|
||||||
|
nodes = retriever.retrieve(query)
|
||||||
|
|
||||||
|
# Filter nodes by similarity score (threshold > 0.7)
|
||||||
|
nodes = [node for node in nodes if node.score >= 0.5]
|
||||||
|
|
||||||
|
if summarize:
|
||||||
|
return self.summarize_results(query, nodes)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for node in nodes:
|
||||||
|
# Extract metadata
|
||||||
|
source = node.metadata.get("source", "Unknown Source")
|
||||||
|
|
||||||
|
results.append(ContextUpdate(query=query, snippet=node.text, source=source))
|
||||||
|
|
||||||
|
return results
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
+1
-1
@@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import sounddevice as sd
|
import sounddevice as sd
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
# Do not call basicConfig here, as it's called in the orchestrator
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+65
-23
@@ -1,69 +1,111 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
from faster_whisper import WhisperModel
|
import numpy as np
|
||||||
|
import whisperx
|
||||||
|
from whisperx.diarize import DiarizationPipeline
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
# Do not call basicConfig here, as it's called in the orchestrator
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Transcriber:
|
class Transcriber:
|
||||||
"""
|
"""
|
||||||
Converts audio chunks (numpy arrays) into text using faster-whisper.
|
Converts audio chunks (numpy arrays) into text and identifies speakers using WhisperX.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_size="base", device="cpu", compute_type="int8"):
|
def __init__(
|
||||||
|
self, model_size="base", device="cpu", compute_type="int8", language="en"
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the faster-whisper model.
|
Initializes the WhisperX model and diarization pipeline.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_size (str): The size of the model to use (e.g., "tiny", "base", "small").
|
model_size (str): The size of the model to use (e.g., "tiny", "base", "small").
|
||||||
device (str): The device to run the model on ("cpu" or "cuda").
|
device (str): The device to run the model on ("cpu" or "cuda").
|
||||||
compute_type (str): The compute type to use (e.g., "int8", "float16").
|
compute_type (str): The compute type to use (e.g., "int8", "float16").
|
||||||
|
language (str): The language code for alignment (e.g., "en").
|
||||||
"""
|
"""
|
||||||
|
self.device = device
|
||||||
|
self.compute_type = compute_type
|
||||||
|
self.language = language
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loading faster-whisper model: {model_size} on {device} ({compute_type})..."
|
f"Loading WhisperX model: {model_size} on {device} ({compute_type})..."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self.model = WhisperModel(
|
# Load transcription model
|
||||||
|
self.model = whisperx.load_model(
|
||||||
model_size, device=device, compute_type=compute_type
|
model_size, device=device, compute_type=compute_type
|
||||||
)
|
)
|
||||||
logger.info("Model loaded successfully.")
|
|
||||||
|
# Load alignment model (required for accurate speaker assignment)
|
||||||
|
# model_dir=None allows automatic model selection based on the language
|
||||||
|
self.align_model, self.align_metadata = whisperx.load_align_model(
|
||||||
|
device=device, model_dir=None, language_code=self.language
|
||||||
|
)
|
||||||
|
|
||||||
|
self.diarize_model = DiarizationPipeline()
|
||||||
|
|
||||||
|
logger.info("WhisperX and Diarization models loaded successfully.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load faster-whisper model: {e}")
|
logger.error(f"Failed to load WhisperX models: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def transcribe(self, audio_chunk):
|
def transcribe(self, audio_chunk):
|
||||||
"""
|
"""
|
||||||
Transcribes a single audio chunk.
|
Transcribes an audio chunk and performs speaker diarization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio_chunk (np.ndarray): The audio data as a numpy array.
|
audio_chunk (np.ndarray): The audio data as a numpy array.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The transcribed text.
|
list: A list of tuples (speaker_id, text).
|
||||||
"""
|
"""
|
||||||
if audio_chunk is None:
|
if audio_chunk is None:
|
||||||
return ""
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# faster-whisper expects audio in float32 and 1D array
|
# WhisperX expects audio in float32 and 1D array
|
||||||
audio_data = audio_chunk.astype("float32").flatten()
|
audio = audio_chunk.astype("float32").flatten()
|
||||||
|
|
||||||
# Transcribe the audio
|
# 1. Perform transcription
|
||||||
segments, info = self.model.transcribe(audio_data, beam_size=5)
|
# batch_size is set to 16 for efficiency; can be adjusted based on VRAM
|
||||||
|
result = self.model.transcribe(audio, batch_size=16)
|
||||||
|
|
||||||
# Combine segments into a single string
|
# 2. Perform alignment
|
||||||
text = " ".join([segment.text.strip() for segment in segments])
|
# Alignment is necessary for the assign_words_to_speakers step
|
||||||
|
result_a = whisperx.align(
|
||||||
|
result["segments"],
|
||||||
|
self.align_model,
|
||||||
|
self.align_metadata,
|
||||||
|
audio,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
|
||||||
return text.strip()
|
# 3. Perform diarization
|
||||||
|
diarize_segments = self.diarize_model(audio)
|
||||||
|
|
||||||
|
# 4. Align transcription segments with speakers
|
||||||
|
result_final = whisperx.assign_word_speakers(diarize_segments, result_a)
|
||||||
|
|
||||||
|
# Extract (speaker_id, text, start, end) tuples from the final result
|
||||||
|
output = []
|
||||||
|
for segment in result_final.get("segments", []):
|
||||||
|
speaker = segment.get("speaker", "Unknown")
|
||||||
|
text = segment.get("text", "").strip()
|
||||||
|
start = segment.get("start", 0.0)
|
||||||
|
end = segment.get("end", 0.0)
|
||||||
|
if text:
|
||||||
|
output.append((speaker, text, start, end))
|
||||||
|
|
||||||
|
return output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Transcription error: {e}")
|
logger.error(f"Transcription/Diarization error: {e}")
|
||||||
return ""
|
return []
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""
|
"""
|
||||||
Explicitly release model resources if necessary.
|
Explicitly release model resources if necessary.
|
||||||
"""
|
"""
|
||||||
# faster-whisper's WhisperModel doesn't have a standard close(),
|
|
||||||
# but we'll provide this for consistency.
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
+236
-231
@@ -3,315 +3,320 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
from textual.app import App, ComposeResult
|
from textual.app import App, ComposeResult
|
||||||
from textual.containers import Container, Horizontal, Vertical
|
from textual.containers import Container, Horizontal, Vertical
|
||||||
from textual.widgets import Button, DataTable, Footer, Input, Label, Static
|
from textual.message import Message
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import (
|
||||||
|
Button,
|
||||||
|
DataTable,
|
||||||
|
Footer,
|
||||||
|
Input,
|
||||||
|
Label,
|
||||||
|
ListItem,
|
||||||
|
ListView,
|
||||||
|
Static,
|
||||||
|
)
|
||||||
|
|
||||||
from src.llm.models import CharacterStateUpdate, ExtractionResult, LoreUpdate
|
from src.llm.models import CharacterStateUpdate, ContextUpdate, ExtractionResult, LoreUpdate
|
||||||
from src.persistence.characters import update_character_state
|
from src.persistence.characters import update_character_state
|
||||||
from src.persistence.lore import update_lore
|
from src.persistence.lore import update_lore
|
||||||
|
|
||||||
|
|
||||||
|
class EditModal(ModalScreen):
|
||||||
|
def __init__(self, initial_text: str, on_save: callable):
|
||||||
|
super().__init__()
|
||||||
|
self.initial_text = initial_text
|
||||||
|
self.on_save = on_save
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
with Vertical(id="modal-container"):
|
||||||
|
yield Label("Edit Fact Content:")
|
||||||
|
yield Input(value=self.initial_text, id="edit-input")
|
||||||
|
with Horizontal(id="modal-actions"):
|
||||||
|
yield Button("Save", id="btn-save")
|
||||||
|
yield Button("Cancel", id="btn-cancel")
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
if event.button.id == "btn-save":
|
||||||
|
edit_input = self.query_one("#edit-input", Input)
|
||||||
|
self.on_save(edit_input.value)
|
||||||
|
self.dismiss()
|
||||||
|
elif event.button.id == "btn-cancel":
|
||||||
|
self.dismiss()
|
||||||
|
|
||||||
|
|
||||||
class ConfirmationApp(App):
|
class ConfirmationApp(App):
|
||||||
CSS = """
|
CSS = """
|
||||||
Screen {
|
#main-container {
|
||||||
|
layout: vertical;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
#content-wrapper {
|
||||||
layout: horizontal;
|
layout: horizontal;
|
||||||
|
height: 100%;
|
||||||
}
|
}
|
||||||
|
|
||||||
#left-pane {
|
#left-pane {
|
||||||
width: 40%;
|
width: 70%;
|
||||||
border: solid;
|
layout: vertical;
|
||||||
padding: 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#right-pane {
|
#right-pane {
|
||||||
width: 60%;
|
width: 30%;
|
||||||
border: solid;
|
|
||||||
padding: 1;
|
|
||||||
layout: vertical;
|
layout: vertical;
|
||||||
|
border: solid white;
|
||||||
}
|
}
|
||||||
|
|
||||||
#details-container {
|
#pending-facts-table {
|
||||||
height: auto;
|
height: 40%;
|
||||||
margin-bottom: 1;
|
border: solid white;
|
||||||
}
|
}
|
||||||
|
|
||||||
#actions-container {
|
#llm-input-container {
|
||||||
|
height: 10%;
|
||||||
|
border: solid white;
|
||||||
|
padding: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#context-pane {
|
||||||
|
height: 50%;
|
||||||
|
border: solid white;
|
||||||
|
}
|
||||||
|
|
||||||
|
#log-pane {
|
||||||
|
height: 30%;
|
||||||
|
border: solid white;
|
||||||
|
background: #111;
|
||||||
|
}
|
||||||
|
|
||||||
|
#log-footer {
|
||||||
|
height: 70%;
|
||||||
|
border: solid white;
|
||||||
|
}
|
||||||
|
|
||||||
|
#modal-container {
|
||||||
|
width: 60%;
|
||||||
height: auto;
|
height: auto;
|
||||||
layout: horizontal;
|
border: double white;
|
||||||
|
background: #222;
|
||||||
|
padding: 2;
|
||||||
align: center middle;
|
align: center middle;
|
||||||
}
|
}
|
||||||
|
|
||||||
#edit-container {
|
#modal-actions {
|
||||||
display: none;
|
|
||||||
height: auto;
|
height: auto;
|
||||||
layout: vertical;
|
margin-top: 1;
|
||||||
border: solid;
|
align: right middle;
|
||||||
padding: 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Button {
|
#edit-input {
|
||||||
margin: 0 1;
|
margin: 1 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#llm-input {
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
ListItem Static {
|
||||||
|
border: solid grey;
|
||||||
|
margin: 1 0;
|
||||||
|
padding: 1;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
BINDINGS = [
|
BINDINGS = [
|
||||||
("q", "quit", "Quit"),
|
("q", "quit", "Quit"),
|
||||||
|
("a", "accept", "Accept"),
|
||||||
|
("r", "reject", "Reject"),
|
||||||
|
("e", "edit", "Edit"),
|
||||||
|
("enter", "send", "Send"),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
result: Optional[ExtractionResult] = None,
|
result: Optional[ExtractionResult] = None,
|
||||||
proposal_queue: Optional[asyncio.Queue] = None,
|
ui_to_llm_queue: Optional[asyncio.Queue] = None,
|
||||||
|
llm_to_ui_queue: Optional[asyncio.Queue] = None,
|
||||||
|
log_queue: Optional[asyncio.Queue] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.result = result
|
self.result = result
|
||||||
self.proposal_queue = proposal_queue
|
self.ui_to_llm_queue = ui_to_llm_queue
|
||||||
|
self.llm_to_ui_queue = llm_to_ui_queue
|
||||||
|
self.log_queue = log_queue
|
||||||
self.pending_updates: List[Union[LoreUpdate, CharacterStateUpdate]] = []
|
self.pending_updates: List[Union[LoreUpdate, CharacterStateUpdate]] = []
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
# Populate pending updates from result
|
|
||||||
self.pending_updates.extend(result.lore_updates)
|
self.pending_updates.extend(result.lore_updates)
|
||||||
self.pending_updates.extend(result.character_updates)
|
self.pending_updates.extend(result.character_updates)
|
||||||
|
|
||||||
self.selected_index = -1
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
def compose(self) -> ComposeResult:
|
||||||
yield Container(
|
yield Vertical(
|
||||||
Horizontal(
|
Horizontal(
|
||||||
Vertical(
|
Vertical(
|
||||||
DataTable(id="update-table"),
|
DataTable(id="pending-facts-table"),
|
||||||
|
Vertical(
|
||||||
|
Input(placeholder="Message LLM...", id="llm-input"),
|
||||||
|
id="llm-input-container",
|
||||||
|
),
|
||||||
|
ListView(id="context-pane"),
|
||||||
id="left-pane",
|
id="left-pane",
|
||||||
),
|
),
|
||||||
Vertical(
|
Vertical(
|
||||||
Vertical(
|
ListView(id="log-pane"),
|
||||||
Label("Details:", id="details-label"),
|
Static("LATEST LLM INPUTS", id="log-footer"),
|
||||||
Static("No update selected", id="details-text"),
|
|
||||||
id="details-container",
|
|
||||||
),
|
|
||||||
Vertical(
|
|
||||||
Label("Edit Value:"),
|
|
||||||
Input(id="edit-input"),
|
|
||||||
Button("Save Edit", id="save-edit"),
|
|
||||||
id="edit-container",
|
|
||||||
),
|
|
||||||
Horizontal(
|
|
||||||
Button("Accept", id="btn-accept"),
|
|
||||||
Button("Reject", id="btn-reject"),
|
|
||||||
Button("Edit", id="btn-edit"),
|
|
||||||
id="actions-container",
|
|
||||||
),
|
|
||||||
id="right-pane",
|
id="right-pane",
|
||||||
),
|
),
|
||||||
|
id="content-wrapper",
|
||||||
),
|
),
|
||||||
Footer(),
|
id="main-container",
|
||||||
)
|
)
|
||||||
|
yield Footer()
|
||||||
|
|
||||||
def on_mount(self) -> None:
|
def on_mount(self) -> None:
|
||||||
table = self.query_one("#update-table", DataTable)
|
table = self.query_one("#pending-facts-table", DataTable)
|
||||||
table.cursor_type = "row"
|
table.cursor_type = "row"
|
||||||
table.add_columns("Type", "Target", "Update")
|
table.add_columns("Type", "Target", "Content")
|
||||||
|
|
||||||
for i, update in enumerate(self.pending_updates):
|
for i, update in enumerate(self.pending_updates):
|
||||||
if isinstance(update, LoreUpdate):
|
self.add_update_to_table(update, i)
|
||||||
table.add_row(
|
|
||||||
"Lore", update.entity_name or "General", update.content, key=str(i)
|
|
||||||
)
|
|
||||||
elif isinstance(update, CharacterStateUpdate):
|
|
||||||
change_text = f"HP: {update.hp_change or 0}"
|
|
||||||
if update.status_effects_added:
|
|
||||||
change_text += f", Added: {', '.join(update.status_effects_added)}"
|
|
||||||
if update.status_effects_removed:
|
|
||||||
change_text += (
|
|
||||||
f", Removed: {', '.join(update.status_effects_removed)}"
|
|
||||||
)
|
|
||||||
table.add_row("Char", update.character_name, change_text, key=str(i))
|
|
||||||
|
|
||||||
if self.pending_updates:
|
if self.ui_to_llm_queue:
|
||||||
self.handle_row_highlight(0)
|
# We don't need a poller for this, just the action_send
|
||||||
self.query_one("#btn-accept", Button).focus()
|
|
||||||
|
|
||||||
if self.proposal_queue:
|
|
||||||
self.run_worker(self.poll_proposal_queue, thread=False)
|
|
||||||
|
|
||||||
async def poll_proposal_queue(self) -> None:
|
|
||||||
"""
|
|
||||||
Background worker that polls the proposal queue for new extraction results.
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
result = await self.proposal_queue.get()
|
|
||||||
self.add_result(result)
|
|
||||||
except Exception as e:
|
|
||||||
# Log error but keep the worker running
|
|
||||||
self.log(f"Error polling proposal queue: {e}")
|
|
||||||
finally:
|
|
||||||
# Signal that the item has been processed
|
|
||||||
if hasattr(self.proposal_queue, "task_done"):
|
|
||||||
self.proposal_queue.task_done()
|
|
||||||
|
|
||||||
def add_result(self, result: ExtractionResult) -> None:
|
|
||||||
"""
|
|
||||||
Adds results from the LLM processor to the TUI table.
|
|
||||||
"""
|
|
||||||
table = self.query_one("#update-table", DataTable)
|
|
||||||
start_index = len(self.pending_updates)
|
|
||||||
|
|
||||||
for update in result.lore_updates + result.character_updates:
|
|
||||||
self.pending_updates.append(update)
|
|
||||||
actual_index = len(self.pending_updates) - 1
|
|
||||||
|
|
||||||
if isinstance(update, LoreUpdate):
|
|
||||||
table.add_row(
|
|
||||||
"Lore",
|
|
||||||
update.entity_name or "General",
|
|
||||||
update.content,
|
|
||||||
key=str(actual_index),
|
|
||||||
)
|
|
||||||
elif isinstance(update, CharacterStateUpdate):
|
|
||||||
change_text = f"HP: {update.hp_change or 0}"
|
|
||||||
if update.status_effects_added:
|
|
||||||
change_text += f", Added: {', '.join(update.status_effects_added)}"
|
|
||||||
if update.status_effects_removed:
|
|
||||||
change_text += (
|
|
||||||
f", Removed: {', '.join(update.status_effects_removed)}"
|
|
||||||
)
|
|
||||||
table.add_row(
|
|
||||||
"Char", update.character_name, change_text, key=str(actual_index)
|
|
||||||
)
|
|
||||||
|
|
||||||
# If the table was previously empty and we added updates, focus the first one.
|
|
||||||
if start_index == 0 and self.pending_updates:
|
|
||||||
self.handle_row_highlight(0)
|
|
||||||
self.query_one("#btn-accept", Button).focus()
|
|
||||||
|
|
||||||
def on_data_table_row_highlighted(self, event: DataTable.RowHighlighted) -> None:
|
|
||||||
self.handle_row_highlight(event.cursor_row)
|
|
||||||
|
|
||||||
def handle_row_highlight(self, row: int) -> None:
|
|
||||||
self.selected_index = row
|
|
||||||
if self.selected_index < 0 or self.selected_index >= len(self.pending_updates):
|
|
||||||
return
|
|
||||||
|
|
||||||
update = self.pending_updates[self.selected_index]
|
|
||||||
|
|
||||||
details_text = self.query_one("#details-text", Static)
|
|
||||||
if isinstance(update, LoreUpdate):
|
|
||||||
details_text.update(
|
|
||||||
f"Category: {update.category}\nTarget: {update.entity_name}\nContent: {update.content}"
|
|
||||||
)
|
|
||||||
elif isinstance(update, CharacterStateUpdate):
|
|
||||||
details_text.update(
|
|
||||||
f"Character: {update.character_name}\nHP Change: {update.hp_change}\nAdded Effects: {update.status_effects_added}\nRemoved Effects: {update.status_effects_removed}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reset to detail view
|
|
||||||
self.query_one("#edit-container", Vertical).styles.display = "none"
|
|
||||||
self.query_one("#details-container", Vertical).styles.display = "block"
|
|
||||||
|
|
||||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
||||||
if self.selected_index == -1:
|
|
||||||
return
|
|
||||||
|
|
||||||
update = self.pending_updates[self.selected_index]
|
|
||||||
|
|
||||||
if event.button.id == "btn-accept":
|
|
||||||
if isinstance(update, LoreUpdate):
|
|
||||||
update_lore(update)
|
|
||||||
elif isinstance(update, CharacterStateUpdate):
|
|
||||||
update_character_state(update)
|
|
||||||
|
|
||||||
self.remove_update(self.selected_index)
|
|
||||||
|
|
||||||
elif event.button.id == "btn-reject":
|
|
||||||
self.remove_update(self.selected_index)
|
|
||||||
|
|
||||||
elif event.button.id == "btn-edit":
|
|
||||||
self.show_edit_mode(update)
|
|
||||||
|
|
||||||
elif event.button.id == "save-edit":
|
|
||||||
self.save_edit(update)
|
|
||||||
|
|
||||||
def show_edit_mode(self, update: Union[LoreUpdate, CharacterStateUpdate]) -> None:
|
|
||||||
edit_input = self.query_one("#edit-input", Input)
|
|
||||||
|
|
||||||
if isinstance(update, LoreUpdate):
|
|
||||||
edit_input.value = update.content
|
|
||||||
elif isinstance(update, CharacterStateUpdate):
|
|
||||||
# For simplicity, only allow editing HP change in this TUI
|
|
||||||
edit_input.value = str(update.hp_change or 0)
|
|
||||||
|
|
||||||
self.query_one("#edit-container", Vertical).styles.display = "block"
|
|
||||||
self.query_one("#details-container", Vertical).styles.display = "none"
|
|
||||||
|
|
||||||
def save_edit(self, update: Union[LoreUpdate, CharacterStateUpdate]) -> None:
|
|
||||||
new_val = self.query_one("#edit-input", Input).value
|
|
||||||
|
|
||||||
if isinstance(update, LoreUpdate):
|
|
||||||
update.content = new_val
|
|
||||||
elif isinstance(update, CharacterStateUpdate):
|
|
||||||
try:
|
|
||||||
update.hp_change = int(new_val)
|
|
||||||
except ValueError:
|
|
||||||
# Ignore invalid integer input
|
|
||||||
pass
|
pass
|
||||||
|
if self.llm_to_ui_queue:
|
||||||
|
# Use Textual workers so the task isn't garbage-collected and
|
||||||
|
# exceptions are surfaced via the worker manager.
|
||||||
|
self.run_worker(self.poll_llm_updates(), exclusive=False)
|
||||||
|
if self.log_queue:
|
||||||
|
self.run_worker(self.poll_log_updates(), exclusive=False)
|
||||||
|
|
||||||
# Refresh the table
|
self.query_one("#llm-input", Input).focus()
|
||||||
table = self.query_one("#update-table", DataTable)
|
|
||||||
# Textual DataTable doesn't have a simple 'update_row', so we clear and refill
|
|
||||||
# or we can use update_cell.
|
|
||||||
|
|
||||||
# Update the table row
|
def add_update_to_table(
|
||||||
|
self, update: Union[LoreUpdate, CharacterStateUpdate], index: int
|
||||||
|
) -> None:
|
||||||
|
table = self.query_one("#pending-facts-table", DataTable)
|
||||||
if isinstance(update, LoreUpdate):
|
if isinstance(update, LoreUpdate):
|
||||||
table.update_cell(self.selected_index, 2, update.content)
|
table.add_row(
|
||||||
|
"Lore", update.entity_name or "General", update.content, key=str(index)
|
||||||
|
)
|
||||||
elif isinstance(update, CharacterStateUpdate):
|
elif isinstance(update, CharacterStateUpdate):
|
||||||
change_text = f"HP: {update.hp_change or 0}"
|
change_text = f"HP: {update.hp_change or 0}"
|
||||||
if update.status_effects_added:
|
if update.status_effects_added:
|
||||||
change_text += f", Added: {', '.join(update.status_effects_added)}"
|
change_text += f", Added: {', '.join(update.status_effects_added)}"
|
||||||
if update.status_effects_removed:
|
if update.status_effects_removed:
|
||||||
change_text += f", Removed: {', '.join(update.status_effects_removed)}"
|
change_text += f", Removed: {', '.join(update.status_effects_removed)}"
|
||||||
table.update_cell(self.selected_index, 2, change_text)
|
table.add_row("Char", update.character_name, change_text, key=str(index))
|
||||||
|
|
||||||
self.show_edit_mode(update) # just to refresh the value maybe? No,’
|
async def poll_llm_updates(self) -> None:
|
||||||
# Actually let's go back to detail view
|
while True:
|
||||||
self.query_one("#edit-container", Vertical).styles.display = "none"
|
try:
|
||||||
self.query_one("#details-container", Vertical).styles.display = "block"
|
update = await self.llm_to_ui_queue.get()
|
||||||
|
if isinstance(update, ExtractionResult):
|
||||||
|
self.handle_proposal_result(update)
|
||||||
|
elif isinstance(update, ContextUpdate):
|
||||||
|
display_text = f"Query: {update.query}\nSource: {update.source}\n\n{update.snippet}"
|
||||||
|
context_list = self.query_one("#context-pane", ListView)
|
||||||
|
# Insert at the top to show most recent first.
|
||||||
|
await context_list.insert(0, [ListItem(Static(display_text))])
|
||||||
|
if hasattr(self.llm_to_ui_queue, "task_done"):
|
||||||
|
self.llm_to_ui_queue.task_done()
|
||||||
|
except Exception as e:
|
||||||
|
self.log(f"Error polling LLM updates: {e}")
|
||||||
|
|
||||||
# Update details text
|
async def poll_log_updates(self) -> None:
|
||||||
details_text = self.query_one("#details-text", Static)
|
while True:
|
||||||
|
try:
|
||||||
|
log_text = await self.log_queue.get()
|
||||||
|
log_list = self.query_one("#log-pane", ListView)
|
||||||
|
# See poll_llm_updates: wrap the ListItem in a list.
|
||||||
|
# Insert at the top to show most recent first.
|
||||||
|
await log_list.insert(0, [ListItem(Static(log_text))])
|
||||||
|
if hasattr(self.log_queue, "task_done"):
|
||||||
|
self.log_queue.task_done()
|
||||||
|
except Exception as e:
|
||||||
|
self.log(f"Error polling log updates: {e}")
|
||||||
|
|
||||||
|
def handle_proposal_result(self, result: ExtractionResult) -> None:
|
||||||
|
table = self.query_one("#pending-facts-table", DataTable)
|
||||||
|
for update in result.lore_updates + result.character_updates:
|
||||||
|
index = len(self.pending_updates)
|
||||||
|
self.pending_updates.append(update)
|
||||||
|
self.add_update_to_table(update, index)
|
||||||
|
|
||||||
|
async def poll_context_queue(self) -> None:
|
||||||
|
# Obsolete
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def poll_response_queue(self) -> None:
|
||||||
|
# Obsolete
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_input_submitted(self, event: Input.Submitted) -> None:
|
||||||
|
if event.input.id == "llm-input":
|
||||||
|
self.action_send()
|
||||||
|
|
||||||
|
def action_send(self) -> None:
|
||||||
|
input_widget = self.query_one("#llm-input", Input)
|
||||||
|
text = input_widget.value
|
||||||
|
if text and self.ui_to_llm_queue:
|
||||||
|
self.ui_to_llm_queue.put_nowait(text)
|
||||||
|
input_widget.value = ""
|
||||||
|
|
||||||
|
async def action_accept(self) -> None:
|
||||||
|
table = self.query_one("#pending-facts-table", DataTable)
|
||||||
|
row_index = table.cursor_row
|
||||||
|
if row_index < 0 or row_index >= len(self.pending_updates):
|
||||||
|
return
|
||||||
|
|
||||||
|
update = self.pending_updates[row_index]
|
||||||
|
if self.ui_to_llm_queue:
|
||||||
|
self.ui_to_llm_queue.put_nowait(update)
|
||||||
|
|
||||||
|
self.remove_update(row_index)
|
||||||
|
|
||||||
|
def action_reject(self) -> None:
|
||||||
|
table = self.query_one("#pending-facts-table", DataTable)
|
||||||
|
row_index = table.cursor_row
|
||||||
|
if row_index < 0 or row_index >= len(self.pending_updates):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.remove_update(row_index)
|
||||||
|
|
||||||
|
def action_edit(self) -> None:
|
||||||
|
table = self.query_one("#pending-facts-table", DataTable)
|
||||||
|
row_index = table.cursor_row
|
||||||
|
if row_index < 0 or row_index >= len(self.pending_updates):
|
||||||
|
return
|
||||||
|
|
||||||
|
update = self.pending_updates[row_index]
|
||||||
|
initial_text = ""
|
||||||
if isinstance(update, LoreUpdate):
|
if isinstance(update, LoreUpdate):
|
||||||
details_text.update(
|
initial_text = update.content
|
||||||
f"Category: {update.category}\nTarget: {update.entity_name}\nContent: {update.content}"
|
|
||||||
)
|
|
||||||
elif isinstance(update, CharacterStateUpdate):
|
elif isinstance(update, CharacterStateUpdate):
|
||||||
details_text.update(
|
initial_text = str(update.hp_change or 0)
|
||||||
f"Character: {update.character_name}\nHP Change: {update.hp_change}\nAdded Effects: {update.status_effects_added}\nRemoved Effects: {update.status_effects_removed}"
|
|
||||||
)
|
def save_callback(new_text: str):
|
||||||
|
if isinstance(update, LoreUpdate):
|
||||||
|
update.content = new_text
|
||||||
|
elif isinstance(update, CharacterStateUpdate):
|
||||||
|
try:
|
||||||
|
update.hp_change = int(new_text)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Update the table
|
||||||
|
self.refresh_table()
|
||||||
|
|
||||||
|
self.push_screen(EditModal(initial_text, save_callback))
|
||||||
|
|
||||||
def remove_update(self, index: int) -> None:
|
def remove_update(self, index: int) -> None:
|
||||||
# Remove from the pending list
|
|
||||||
del self.pending_updates[index]
|
del self.pending_updates[index]
|
||||||
|
self.refresh_table()
|
||||||
|
|
||||||
# Clear and refill the table
|
def refresh_table(self) -> None:
|
||||||
table = self.query_one("#update-table", DataTable)
|
table = self.query_one("#pending-facts-table", DataTable)
|
||||||
table.clear()
|
table.clear()
|
||||||
|
|
||||||
for i, update in enumerate(self.pending_updates):
|
for i, update in enumerate(self.pending_updates):
|
||||||
if isinstance(update, LoreUpdate):
|
self.add_update_to_table(update, i)
|
||||||
table.add_row(
|
|
||||||
"Lore", update.entity_name or "General", update.content, key=str(i)
|
|
||||||
)
|
|
||||||
elif isinstance(update, CharacterStateUpdate):
|
|
||||||
change_text = f"HP: {update.hp_change or 0}"
|
|
||||||
if update.status_effects_added:
|
|
||||||
change_text += f", Added: {', '.join(update.status_effects_added)}"
|
|
||||||
if update.status_effects_removed:
|
|
||||||
change_text += (
|
|
||||||
f", Removed: {', '.join(update.status_effects_removed)}"
|
|
||||||
)
|
|
||||||
table.add_row("Char", update.character_name, change_text, key=str(i))
|
|
||||||
|
|
||||||
if self.pending_updates:
|
|
||||||
self.handle_row_highlight(0)
|
|
||||||
self.query_one("#btn-accept", Button).focus()
|
|
||||||
else:
|
|
||||||
self.selected_index = -1
|
|
||||||
self.query_one("#details-text", Static).update("All updates processed.")
|
|
||||||
|
|||||||
@@ -0,0 +1,98 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from reportlab.pdfgen import canvas
|
||||||
|
|
||||||
|
from src.rag.manager import RAGManager
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_phb(pdf_path: str):
|
||||||
|
"""
|
||||||
|
Creates a dummy PDF file to simulate the Player's Handbook for verification.
|
||||||
|
"""
|
||||||
|
print(f"Creating dummy PHB at {pdf_path}...")
|
||||||
|
c = canvas.Canvas(pdf_path)
|
||||||
|
|
||||||
|
# Page 1: Fireball
|
||||||
|
c.drawString(100, 750, "Fireball")
|
||||||
|
c.drawString(
|
||||||
|
100,
|
||||||
|
730,
|
||||||
|
"A bright streak flashes from your pointing finger to a point you choose within range.",
|
||||||
|
)
|
||||||
|
c.drawString(
|
||||||
|
100,
|
||||||
|
710,
|
||||||
|
"Each creature in a 20-foot-radius sphere centered on that point must make a Dexterity saving throw.",
|
||||||
|
)
|
||||||
|
c.showPage()
|
||||||
|
|
||||||
|
# Page 2: Grappling
|
||||||
|
c.drawString(100, 750, "Grappling")
|
||||||
|
c.drawString(
|
||||||
|
100,
|
||||||
|
730,
|
||||||
|
"Grappling is a special option available to any attack that hits with a melee attack.",
|
||||||
|
)
|
||||||
|
c.drawString(
|
||||||
|
100,
|
||||||
|
710,
|
||||||
|
"A creature is grappled if it is the target of the grapple attack and the attack hits.",
|
||||||
|
)
|
||||||
|
c.showPage()
|
||||||
|
|
||||||
|
# Page 3: General Rules
|
||||||
|
c.drawString(100, 750, "General Rules")
|
||||||
|
c.drawString(
|
||||||
|
100,
|
||||||
|
730,
|
||||||
|
"Combat is resolved in rounds, and each round represents 6 seconds of in-game time.",
|
||||||
|
)
|
||||||
|
c.showPage()
|
||||||
|
|
||||||
|
c.save()
|
||||||
|
print(f"Dummy PHB created successfully at {pdf_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag():
|
||||||
|
pdf_path = "data/phb_dummy.pdf"
|
||||||
|
create_dummy_phb(pdf_path)
|
||||||
|
|
||||||
|
# Initialize RAG Manager
|
||||||
|
rag = RAGManager(persist_dir="data/rag_index_test")
|
||||||
|
|
||||||
|
# Task 2.2: Ingest PDF
|
||||||
|
print("\nTesting Ingestion...")
|
||||||
|
rag.ingest_pdf(pdf_path)
|
||||||
|
|
||||||
|
# Task 2.3: Retrieve Logic
|
||||||
|
print("\nTesting Retrieval...")
|
||||||
|
|
||||||
|
# Test 1: Fireball
|
||||||
|
query1 = "What is Fireball?"
|
||||||
|
results1 = rag.retrieve(query1)
|
||||||
|
print(f"Query: {query1}")
|
||||||
|
for res in results1:
|
||||||
|
print(f"Source: {res.source} | Snippet: {res.snippet[:100]}...")
|
||||||
|
|
||||||
|
assert len(results1) > 0, "Should have retrieved at least one result for 'Fireball'"
|
||||||
|
assert "Fireball" in results1[0].snippet, "The top result should contain 'Fireball'"
|
||||||
|
|
||||||
|
# Test 2: Grappling
|
||||||
|
query2 = "How does grappling work?"
|
||||||
|
results2 = rag.retrieve(query2)
|
||||||
|
print(f"Query: {query2}")
|
||||||
|
for res in results2:
|
||||||
|
print(f"Source: {res.source} | Snippet: {res.snippet[:100]}...")
|
||||||
|
|
||||||
|
assert len(results2) > 0, (
|
||||||
|
"Should have retrieved at least one result for 'Grappling'"
|
||||||
|
)
|
||||||
|
assert "Grappling" in results2[0].snippet, (
|
||||||
|
"The top result should contain 'Grappling'"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n✅ RAG Verification Successful!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_rag()
|
||||||
Reference in New Issue
Block a user