Implement RAG summarization and context pipeline
- Add ContextPipeline for async RAG lookups - Implement RAG result summarization via LLMProcessor - Add CLI flag for PDF ingestion - Strip markdown code blocks from LLM responses - Update TUI context display to use ListItems
This commit is contained in:
+89
-2
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import chromadb
|
||||
import pdfplumber
|
||||
@@ -8,6 +8,7 @@ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
from src.llm.models import ContextUpdate
|
||||
from src.llm.processor import LLMProcessor
|
||||
|
||||
|
||||
class RAGManager:
|
||||
@@ -64,7 +65,90 @@ class RAGManager:
|
||||
)
|
||||
print(f"Successfully ingested {pdf_path} into the vector store.")
|
||||
|
||||
def retrieve(self, query: str, top_k: int = 3) -> List[ContextUpdate]:
|
||||
def ingest_file(self, file_path: str):
|
||||
"""
|
||||
Loads a single markdown file into the index.
|
||||
"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
# Use the filename as the source
|
||||
source = os.path.basename(file_path)
|
||||
doc = Document(text=text, metadata={"source": source})
|
||||
|
||||
# If index doesn't exist, initialize it
|
||||
if not self.index:
|
||||
self.index = VectorStoreIndex.from_documents(
|
||||
[doc], storage_context=self.storage_context
|
||||
)
|
||||
else:
|
||||
# Insert into existing index
|
||||
self.index.insert(doc)
|
||||
|
||||
print(f"Successfully ingested {file_path} into the vector store.")
|
||||
|
||||
def summarize_results(self, query: str, nodes: List[Any]) -> List[ContextUpdate]:
|
||||
"""
|
||||
Uses an LLM to transform raw snippets into concise "insights", filtering out irrelevant content.
|
||||
"""
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
processor = LLMProcessor()
|
||||
|
||||
# Construct the context from retrieved nodes
|
||||
context_text = "\n\n".join(
|
||||
[
|
||||
f"Source: {node.metadata.get('source', 'Unknown')}\nContent: {node.text}"
|
||||
for node in nodes
|
||||
]
|
||||
)
|
||||
|
||||
system_prompt = (
|
||||
"You are a precise research assistant. Your task is to analyze provided text snippets "
|
||||
"and extract only the information that is directly relevant to the user's query. "
|
||||
"1. If a snippet is irrelevant to the query, discard it completely. "
|
||||
"2. For relevant information, synthesize it into a concise, single-sentence 'insight'. "
|
||||
"3. Do not simply repeat the raw text; summarize it for clarity and brevity. "
|
||||
"4. If no snippets are relevant to the query, return an empty list. "
|
||||
"5. Be factual and do not hallucinate. Use only the provided snippets."
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
f"Query: {query}\n\n"
|
||||
f"Snippets:\n{context_text}\n\n"
|
||||
"Return a JSON object with a key 'insights' containing a list of objects, each with 'snippet' and 'source'."
|
||||
)
|
||||
|
||||
result = processor._call_llm(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
import json
|
||||
|
||||
try:
|
||||
data = json.loads(result)
|
||||
# Expecting a format like {"insights": [{"snippet": "...", "source": "..."}, ...]}
|
||||
insights = data.get("insights", []) if isinstance(data, dict) else data
|
||||
|
||||
if not insights:
|
||||
print(f"Summarization: No relevant insights found for query: {query}")
|
||||
|
||||
return [
|
||||
ContextUpdate(
|
||||
query=query, snippet=item["snippet"], source=item["source"]
|
||||
)
|
||||
for item in insights
|
||||
]
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||
print(f"Summarization parsing error: {e}")
|
||||
return []
|
||||
|
||||
def retrieve(
|
||||
self, query: str, top_k: int = 5, summarize: bool = False
|
||||
) -> List[ContextUpdate]:
|
||||
"""
|
||||
Retrieves the top-K most relevant snippets for a given query.
|
||||
"""
|
||||
@@ -76,6 +160,9 @@ class RAGManager:
|
||||
retriever = self.index.as_retriever(similarity_top_k=top_k)
|
||||
nodes = retriever.retrieve(query)
|
||||
|
||||
if summarize:
|
||||
return self.summarize_results(query, nodes)
|
||||
|
||||
results = []
|
||||
for node in nodes:
|
||||
# Extract metadata
|
||||
|
||||
Reference in New Issue
Block a user