Files
dnd-helpers/src/llm/processor.py
T

93 lines
3.2 KiB
Python
Raw Normal View History

import os
from typing import Any, Dict, Optional
from openai import OpenAI
from pydantic import ValidationError
from .models import ExtractionResult
from .prompts import EXTRACTION_SYSTEM_PROMPT, NOISE_FILTER_SYSTEM_PROMPT
class LLMProcessor:
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "gpt-4o",
):
"""
Initializes the LLMProcessor.
:param api_key: OpenAI API key. If None, it looks for OPENAI_API_KEY in environment variables.
:param base_url: OpenAI-compatible base URL (e.g., for vLLM).
:param model: The model to use for processing.
"""
self.client = OpenAI(
api_key=api_key or os.environ.get("OPENAI_API_KEY"),
base_url=base_url or os.environ.get("OPENAI_BASE_URL"),
)
self.model = model
def _call_llm(
self,
system_prompt: str,
user_prompt: str,
response_format: Optional[Any] = None,
) -> str:
"""
Generic method to call the LLM.
"""
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format=response_format,
)
return response.choices[0].message.content
except Exception as e:
print(f"LLM Error: {e}")
return ""
def filter_transcript(self, text: str) -> str:
"""
Stage 1: Raw Transcript -> Filtered Text.
"""
return self._call_llm(NOISE_FILTER_SYSTEM_PROMPT, text)
def extract_structured_data(self, filtered_text: str) -> ExtractionResult:
"""
Stage 2: Filtered Text -> Structured Data.
"""
# We use OpenAI's structured output (JSON mode/tool calling) via Pydantic's response_format.
# For models that support it, we can pass the Pydantic model directly.
# If we are using an older model or vLLM, we might need to manually parse the JSON.
# Using the newer 'beta.chat.completions.parse' for Pydantic support
try:
completion = self.client.beta.chat.completions.parse(
model=self.model,
messages=[
{"role": "system", "content": EXTRACTION_SYSTEM_PROMPT},
{"role": "user", "content": filtered_text},
],
response_format=ExtractionResult,
)
return completion.choices[0].message.parsed
except Exception as e:
print(f"Extraction Error: {e}")
# Return an empty ExtractionResult if parsing fails
return ExtractionResult()
def process_pipeline(self, raw_text: str) -> ExtractionResult:
"""
Executes the two-stage pipeline: Raw Transcript -> Filtered Text -> Structured Data.
"""
filtered_text = self.filter_transcript(raw_text)
if not filtered_text:
return ExtractionResult()
return self.extract_structured_data(filtered_text)