vergunningzoeker-4eb1d867/app/rag.py

343 lines
12 KiB
Python

"""RAG helper — index and search documents using the app's own database.
Embeddings are generated via the Druppie platform (module-llm), vector
storage lives in this app's own Postgres with pgvector. No shared
database, no cross-app data leakage.
Usage:
from app.rag import RAG
from app.database import SessionLocal
from druppie_sdk import DruppieClient
druppie = DruppieClient()
db = SessionLocal()
rag = RAG(db, druppie)
rag.create_index("knowledge-base")
rag.index_documents("knowledge-base", [
{"content": "Full text...", "source_name": "Policy 2024", "source_page": 42},
])
results = rag.search("knowledge-base", "wat is het beleid?")
"""
import json
import logging
import re
import uuid
import numpy as np
from pgvector.sqlalchemy import Vector
from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text, func, text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Session
from app.database import Base
logger = logging.getLogger(__name__)
EMBEDDING_BATCH_SIZE = 100
# ---------------------------------------------------------------------------
# SQLAlchemy models (created in the app's own database)
# ---------------------------------------------------------------------------
class VectorIndex(Base):
__tablename__ = "vector_indices"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String, nullable=False, unique=True)
description = Column(String, nullable=False, default="")
embedding_model = Column(String, nullable=False, default="default")
dimensions = Column(Integer, nullable=False, default=0)
chunk_size = Column(Integer, nullable=False, default=2048)
chunk_overlap = Column(Integer, nullable=False, default=256)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
class VectorDocument(Base):
__tablename__ = "vector_documents"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
index_id = Column(UUID(as_uuid=True), ForeignKey("vector_indices.id", ondelete="CASCADE"), nullable=False)
source_name = Column(String, nullable=False)
source_type = Column(String, nullable=False, default="text")
chunk_count = Column(Integer, nullable=False, default=0)
metadata_ = Column("metadata", JSONB, nullable=False, default=dict)
created_at = Column(DateTime(timezone=True), server_default=func.now())
Index("idx_vector_documents_index_id", VectorDocument.index_id)
class VectorChunk(Base):
__tablename__ = "vector_chunks"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
index_id = Column(UUID(as_uuid=True), ForeignKey("vector_indices.id", ondelete="CASCADE"), nullable=False)
document_id = Column(UUID(as_uuid=True), ForeignKey("vector_documents.id", ondelete="CASCADE"), nullable=False)
chunk_index = Column(Integer, nullable=False)
content = Column(Text, nullable=False)
embedding = Column(Vector())
source_name = Column(String, nullable=False, default="")
source_page = Column(Integer)
source_section = Column(String)
metadata_ = Column("metadata", JSONB, nullable=False, default=dict)
created_at = Column(DateTime(timezone=True), server_default=func.now())
Index("idx_vector_chunks_index_id", VectorChunk.index_id)
Index("idx_vector_chunks_document_id", VectorChunk.document_id)
_SAFE_KEY_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
# ---------------------------------------------------------------------------
# RAG class
# ---------------------------------------------------------------------------
class RAG:
"""RAG helper that stores vectors in the app's own database."""
def __init__(self, db: Session, druppie):
self._db = db
self._druppie = druppie
def create_index(
self,
name: str,
description: str = "",
chunk_size: int = 2048,
chunk_overlap: int = 256,
) -> dict:
idx = self._db.query(VectorIndex).filter(VectorIndex.name == name).first()
if idx:
idx.description = description
idx.chunk_size = chunk_size
idx.chunk_overlap = chunk_overlap
else:
idx = VectorIndex(
name=name,
description=description,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
self._db.add(idx)
self._db.commit()
return {"index_id": str(idx.id), "name": name}
def index_documents(
self,
index_name: str,
documents: list[dict],
embedding_model: str = "default",
) -> dict:
idx = self._db.query(VectorIndex).filter(VectorIndex.name == index_name).first()
if not idx:
raise ValueError(f"Index '{index_name}' not found — call create_index first")
chunk_size = idx.chunk_size - idx.chunk_overlap
chunk_overlap = idx.chunk_overlap
all_chunks = []
all_texts = []
for doc in documents:
content = doc["content"]
source_name = doc.get("source_name", "unknown")
texts = _chunk_text(content, chunk_size, chunk_overlap)
vdoc = VectorDocument(
index_id=idx.id,
source_name=source_name,
source_type=doc.get("source_type", "text"),
chunk_count=len(texts),
metadata_=doc.get("metadata", {}),
)
self._db.add(vdoc)
self._db.flush()
for i, chunk_text in enumerate(texts):
chunk = VectorChunk(
index_id=idx.id,
document_id=vdoc.id,
chunk_index=i,
content=chunk_text,
source_name=source_name,
source_page=doc.get("source_page"),
source_section=doc.get("source_section"),
metadata_=doc.get("metadata", {}),
)
self._db.add(chunk)
all_chunks.append(chunk)
all_texts.append(chunk_text)
self._db.flush()
if all_texts:
embeddings = self._get_embeddings(all_texts, embedding_model)
dimensions = len(embeddings[0]) if embeddings else 0
if idx.dimensions == 0 and dimensions > 0:
idx.dimensions = dimensions
for chunk, emb in zip(all_chunks, embeddings):
chunk.embedding = np.array(emb)
self._db.commit()
return {
"index_name": index_name,
"documents_indexed": len(documents),
"chunks_created": len(all_chunks),
}
def search(
self,
index_name: str,
query: str,
top_k: int = 5,
similarity_threshold: float = 0.0,
filter_metadata: dict | None = None,
) -> list[dict]:
idx = self._db.query(VectorIndex).filter(VectorIndex.name == index_name).first()
if not idx:
raise ValueError(f"Index '{index_name}' not found")
if filter_metadata:
for key in filter_metadata:
if not _SAFE_KEY_RE.match(key):
raise ValueError(f"Invalid metadata filter key: '{key}'")
query_embedding = self._get_embeddings([query], "default")[0]
sql = """
SELECT id, content, source_name, source_page, source_section,
metadata, chunk_index,
1 - (embedding <=> :qvec::vector) AS score
FROM vector_chunks
WHERE index_id = :idx AND embedding IS NOT NULL
"""
params = {"idx": idx.id, "qvec": str(query_embedding)}
if similarity_threshold > 0:
sql += " AND 1 - (embedding <=> :qvec::vector) >= :threshold"
params["threshold"] = similarity_threshold
if filter_metadata:
for i, (key, value) in enumerate(filter_metadata.items()):
pname = f"fv{i}"
sql += f" AND metadata->>'{key}' = :{pname}"
params[pname] = str(value)
sql += " ORDER BY embedding <=> :qvec::vector LIMIT :topk"
params["topk"] = top_k
rows = self._db.execute(text(sql), params).fetchall()
return [
{
"chunk_id": str(row.id),
"content": row.content,
"score": float(row.score),
"source_name": row.source_name,
"source_page": row.source_page,
"source_section": row.source_section,
"metadata": row.metadata if isinstance(row.metadata, dict) else json.loads(row.metadata or "{}"),
"chunk_index": row.chunk_index,
}
for row in rows
]
def delete_index(self, name: str) -> dict:
idx = self._db.query(VectorIndex).filter(VectorIndex.name == name).first()
if not idx:
raise ValueError(f"Index '{name}' not found")
self._db.delete(idx)
self._db.commit()
return {"deleted": True, "name": name}
def list_indices(self) -> list[dict]:
indices = self._db.query(VectorIndex).order_by(VectorIndex.created_at.desc()).all()
return [
{
"name": idx.name,
"description": idx.description,
"dimensions": idx.dimensions,
"chunk_size": idx.chunk_size,
"chunk_overlap": idx.chunk_overlap,
}
for idx in indices
]
def _get_embeddings(self, texts: list[str], model: str) -> list[list[float]]:
all_embeddings = []
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
batch = texts[i:i + EMBEDDING_BATCH_SIZE]
kwargs = {"texts": batch}
if model and model != "default":
kwargs["model"] = model
result = self._druppie.call("llm", "embed", kwargs)
all_embeddings.extend(result.get("embeddings", []))
return all_embeddings
# ---------------------------------------------------------------------------
# Text chunking (recursive split at sentence boundaries with overlap)
# ---------------------------------------------------------------------------
def _chunk_text(text_content: str, chunk_size: int, chunk_overlap: int) -> list[str]:
if len(text_content) <= chunk_size:
return [text_content]
separators = ["\n\n", "\n", ". ", "! ", "? ", "; ", ", ", " "]
return _recursive_split(text_content, chunk_size, chunk_overlap, separators)
def _recursive_split(
text_content: str,
chunk_size: int,
chunk_overlap: int,
separators: list[str],
) -> list[str]:
if len(text_content) <= chunk_size:
return [text_content.strip()] if text_content.strip() else []
separator = separators[0] if separators else ""
remaining = separators[1:] if len(separators) > 1 else []
if not separator:
step = max(1, chunk_size - chunk_overlap)
parts = [text_content[i:i + chunk_size] for i in range(0, len(text_content), step)]
return [p.strip() for p in parts if p.strip()]
parts = text_content.split(separator)
chunks = []
current = ""
for part in parts:
candidate = current + separator + part if current else part
if len(candidate) <= chunk_size:
current = candidate
else:
if current.strip():
chunks.append(current.strip())
if len(part) > chunk_size and remaining:
chunks.extend(_recursive_split(part, chunk_size, chunk_overlap, remaining))
current = ""
else:
current = part
if current.strip():
chunks.append(current.strip())
if chunk_overlap > 0 and len(chunks) > 1:
result = [chunks[0]]
for i in range(1, len(chunks)):
prev_tail = chunks[i - 1][-chunk_overlap:] if len(chunks[i - 1]) > chunk_overlap else chunks[i - 1]
result.append(prev_tail + " " + chunks[i])
chunks = result
return chunks