Files
dnd-helpers/tests/test_rag.py
T

99 lines
2.6 KiB
Python
Raw Normal View History

import os
from reportlab.pdfgen import canvas
from src.rag.manager import RAGManager
def create_dummy_phb(pdf_path: str):
"""
Creates a dummy PDF file to simulate the Player's Handbook for verification.
"""
print(f"Creating dummy PHB at {pdf_path}...")
c = canvas.Canvas(pdf_path)
# Page 1: Fireball
c.drawString(100, 750, "Fireball")
c.drawString(
100,
730,
"A bright streak flashes from your pointing finger to a point you choose within range.",
)
c.drawString(
100,
710,
"Each creature in a 20-foot-radius sphere centered on that point must make a Dexterity saving throw.",
)
c.showPage()
# Page 2: Grappling
c.drawString(100, 750, "Grappling")
c.drawString(
100,
730,
"Grappling is a special option available to any attack that hits with a melee attack.",
)
c.drawString(
100,
710,
"A creature is grappled if it is the target of the grapple attack and the attack hits.",
)
c.showPage()
# Page 3: General Rules
c.drawString(100, 750, "General Rules")
c.drawString(
100,
730,
"Combat is resolved in rounds, and each round represents 6 seconds of in-game time.",
)
c.showPage()
c.save()
print(f"Dummy PHB created successfully at {pdf_path}")
def test_rag():
pdf_path = "data/phb_dummy.pdf"
create_dummy_phb(pdf_path)
# Initialize RAG Manager
rag = RAGManager(persist_dir="data/rag_index_test")
# Task 2.2: Ingest PDF
print("\nTesting Ingestion...")
rag.ingest_pdf(pdf_path)
# Task 2.3: Retrieve Logic
print("\nTesting Retrieval...")
# Test 1: Fireball
query1 = "What is Fireball?"
results1 = rag.retrieve(query1)
print(f"Query: {query1}")
for res in results1:
print(f"Source: {res.source} | Snippet: {res.snippet[:100]}...")
assert len(results1) > 0, "Should have retrieved at least one result for 'Fireball'"
assert "Fireball" in results1[0].snippet, "The top result should contain 'Fireball'"
# Test 2: Grappling
query2 = "How does grappling work?"
results2 = rag.retrieve(query2)
print(f"Query: {query2}")
for res in results2:
print(f"Source: {res.source} | Snippet: {res.snippet[:100]}...")
assert len(results2) > 0, (
"Should have retrieved at least one result for 'Grappling'"
)
assert "Grappling" in results2[0].snippet, (
"The top result should contain 'Grappling'"
)
print("\n✅ RAG Verification Successful!")
if __name__ == "__main__":
test_rag()