87 lines
3.0 KiB
Python
87 lines
3.0 KiB
Python
|
|
import os
|
||
|
|
from typing import List, Optional
|
||
|
|
|
||
|
|
import chromadb
|
||
|
|
import pdfplumber
|
||
|
|
from llama_index.core import Document, Settings, StorageContext, VectorStoreIndex
|
||
|
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||
|
|
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||
|
|
|
||
|
|
from src.llm.models import ContextUpdate
|
||
|
|
|
||
|
|
|
||
|
|
class RAGManager:
|
||
|
|
def __init__(self, persist_dir: str = "data/rag_index"):
|
||
|
|
self.persist_dir = persist_dir
|
||
|
|
self.db = chromadb.PersistentClient(path=self.persist_dir)
|
||
|
|
self.collection_name = "phb_collection"
|
||
|
|
|
||
|
|
# Initialize Chroma Vector Store
|
||
|
|
self.vector_store = ChromaVectorStore(
|
||
|
|
chroma_collection=self.db.get_or_create_collection(self.collection_name)
|
||
|
|
)
|
||
|
|
|
||
|
|
# Initialize Storage Context
|
||
|
|
self.storage_context = StorageContext.from_defaults(
|
||
|
|
vector_store=self.vector_store
|
||
|
|
)
|
||
|
|
|
||
|
|
# Use a local HuggingFace embedding model to avoid API key issues during verification
|
||
|
|
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
|
||
|
|
|
||
|
|
# Load index if it exists, otherwise initialize
|
||
|
|
try:
|
||
|
|
self.index = VectorStoreIndex.from_vector_store(
|
||
|
|
self.vector_store, storage_context=self.storage_context
|
||
|
|
)
|
||
|
|
except Exception:
|
||
|
|
self.index = None
|
||
|
|
|
||
|
|
def ingest_pdf(self, pdf_path: str):
|
||
|
|
"""
|
||
|
|
Parses a PDF, chunks it, and stores embeddings in ChromaDB.
|
||
|
|
"""
|
||
|
|
documents = []
|
||
|
|
with pdfplumber.open(pdf_path) as pdf:
|
||
|
|
for i, page in enumerate(pdf.pages):
|
||
|
|
text = page.extract_text()
|
||
|
|
if text:
|
||
|
|
# Create a document for each page
|
||
|
|
# In a real scenario, we might use a recursive character splitter
|
||
|
|
# but for PHB, page-level chunking is a good start.
|
||
|
|
doc = Document(
|
||
|
|
text=text, metadata={"source": f"PHB p. {i + 1}", "page": i + 1}
|
||
|
|
)
|
||
|
|
documents.append(doc)
|
||
|
|
|
||
|
|
if not documents:
|
||
|
|
print(f"No text extracted from {pdf_path}")
|
||
|
|
return
|
||
|
|
|
||
|
|
# Create index from documents
|
||
|
|
self.index = VectorStoreIndex.from_documents(
|
||
|
|
documents, storage_context=self.storage_context
|
||
|
|
)
|
||
|
|
print(f"Successfully ingested {pdf_path} into the vector store.")
|
||
|
|
|
||
|
|
def retrieve(self, query: str, top_k: int = 3) -> List[ContextUpdate]:
|
||
|
|
"""
|
||
|
|
Retrieves the top-K most relevant snippets for a given query.
|
||
|
|
"""
|
||
|
|
if not self.index:
|
||
|
|
print("Index not initialized. Please ingest documents first.")
|
||
|
|
return []
|
||
|
|
|
||
|
|
# Create a retriever
|
||
|
|
retriever = self.index.as_retriever(similarity_top_k=top_k)
|
||
|
|
nodes = retriever.retrieve(query)
|
||
|
|
|
||
|
|
results = []
|
||
|
|
for node in nodes:
|
||
|
|
# Extract metadata
|
||
|
|
source = node.metadata.get("source", "Unknown Source")
|
||
|
|
|
||
|
|
results.append(ContextUpdate(query=query, snippet=node.text, source=source))
|
||
|
|
|
||
|
|
return results
|