343 lines
12 KiB
Python
343 lines
12 KiB
Python
"""RAG helper — index and search documents using the app's own database.
|
|
|
|
Embeddings are generated via the Druppie platform (module-llm), vector
|
|
storage lives in this app's own Postgres with pgvector. No shared
|
|
database, no cross-app data leakage.
|
|
|
|
Usage:
|
|
from app.rag import RAG
|
|
from app.database import SessionLocal
|
|
from druppie_sdk import DruppieClient
|
|
|
|
druppie = DruppieClient()
|
|
db = SessionLocal()
|
|
|
|
rag = RAG(db, druppie)
|
|
rag.create_index("knowledge-base")
|
|
rag.index_documents("knowledge-base", [
|
|
{"content": "Full text...", "source_name": "Policy 2024", "source_page": 42},
|
|
])
|
|
results = rag.search("knowledge-base", "wat is het beleid?")
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
import uuid
|
|
|
|
import numpy as np
|
|
from pgvector.sqlalchemy import Vector
|
|
from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text, func, text
|
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.database import Base
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
EMBEDDING_BATCH_SIZE = 100
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SQLAlchemy models (created in the app's own database)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class VectorIndex(Base):
|
|
__tablename__ = "vector_indices"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
name = Column(String, nullable=False, unique=True)
|
|
description = Column(String, nullable=False, default="")
|
|
embedding_model = Column(String, nullable=False, default="default")
|
|
dimensions = Column(Integer, nullable=False, default=0)
|
|
chunk_size = Column(Integer, nullable=False, default=2048)
|
|
chunk_overlap = Column(Integer, nullable=False, default=256)
|
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
|
|
|
|
|
class VectorDocument(Base):
|
|
__tablename__ = "vector_documents"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
index_id = Column(UUID(as_uuid=True), ForeignKey("vector_indices.id", ondelete="CASCADE"), nullable=False)
|
|
source_name = Column(String, nullable=False)
|
|
source_type = Column(String, nullable=False, default="text")
|
|
chunk_count = Column(Integer, nullable=False, default=0)
|
|
metadata_ = Column("metadata", JSONB, nullable=False, default=dict)
|
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
|
|
|
|
Index("idx_vector_documents_index_id", VectorDocument.index_id)
|
|
|
|
|
|
class VectorChunk(Base):
|
|
__tablename__ = "vector_chunks"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
index_id = Column(UUID(as_uuid=True), ForeignKey("vector_indices.id", ondelete="CASCADE"), nullable=False)
|
|
document_id = Column(UUID(as_uuid=True), ForeignKey("vector_documents.id", ondelete="CASCADE"), nullable=False)
|
|
chunk_index = Column(Integer, nullable=False)
|
|
content = Column(Text, nullable=False)
|
|
embedding = Column(Vector())
|
|
source_name = Column(String, nullable=False, default="")
|
|
source_page = Column(Integer)
|
|
source_section = Column(String)
|
|
metadata_ = Column("metadata", JSONB, nullable=False, default=dict)
|
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
|
|
|
|
Index("idx_vector_chunks_index_id", VectorChunk.index_id)
|
|
Index("idx_vector_chunks_document_id", VectorChunk.document_id)
|
|
|
|
_SAFE_KEY_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RAG class
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class RAG:
|
|
"""RAG helper that stores vectors in the app's own database."""
|
|
|
|
def __init__(self, db: Session, druppie):
|
|
self._db = db
|
|
self._druppie = druppie
|
|
|
|
def create_index(
|
|
self,
|
|
name: str,
|
|
description: str = "",
|
|
chunk_size: int = 2048,
|
|
chunk_overlap: int = 256,
|
|
) -> dict:
|
|
idx = self._db.query(VectorIndex).filter(VectorIndex.name == name).first()
|
|
if idx:
|
|
idx.description = description
|
|
idx.chunk_size = chunk_size
|
|
idx.chunk_overlap = chunk_overlap
|
|
else:
|
|
idx = VectorIndex(
|
|
name=name,
|
|
description=description,
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
)
|
|
self._db.add(idx)
|
|
self._db.commit()
|
|
return {"index_id": str(idx.id), "name": name}
|
|
|
|
def index_documents(
|
|
self,
|
|
index_name: str,
|
|
documents: list[dict],
|
|
embedding_model: str = "default",
|
|
) -> dict:
|
|
idx = self._db.query(VectorIndex).filter(VectorIndex.name == index_name).first()
|
|
if not idx:
|
|
raise ValueError(f"Index '{index_name}' not found — call create_index first")
|
|
|
|
chunk_size = idx.chunk_size - idx.chunk_overlap
|
|
chunk_overlap = idx.chunk_overlap
|
|
|
|
all_chunks = []
|
|
all_texts = []
|
|
|
|
for doc in documents:
|
|
content = doc["content"]
|
|
source_name = doc.get("source_name", "unknown")
|
|
texts = _chunk_text(content, chunk_size, chunk_overlap)
|
|
|
|
vdoc = VectorDocument(
|
|
index_id=idx.id,
|
|
source_name=source_name,
|
|
source_type=doc.get("source_type", "text"),
|
|
chunk_count=len(texts),
|
|
metadata_=doc.get("metadata", {}),
|
|
)
|
|
self._db.add(vdoc)
|
|
self._db.flush()
|
|
|
|
for i, chunk_text in enumerate(texts):
|
|
chunk = VectorChunk(
|
|
index_id=idx.id,
|
|
document_id=vdoc.id,
|
|
chunk_index=i,
|
|
content=chunk_text,
|
|
source_name=source_name,
|
|
source_page=doc.get("source_page"),
|
|
source_section=doc.get("source_section"),
|
|
metadata_=doc.get("metadata", {}),
|
|
)
|
|
self._db.add(chunk)
|
|
all_chunks.append(chunk)
|
|
all_texts.append(chunk_text)
|
|
|
|
self._db.flush()
|
|
|
|
if all_texts:
|
|
embeddings = self._get_embeddings(all_texts, embedding_model)
|
|
dimensions = len(embeddings[0]) if embeddings else 0
|
|
|
|
if idx.dimensions == 0 and dimensions > 0:
|
|
idx.dimensions = dimensions
|
|
|
|
for chunk, emb in zip(all_chunks, embeddings):
|
|
chunk.embedding = np.array(emb)
|
|
|
|
self._db.commit()
|
|
|
|
return {
|
|
"index_name": index_name,
|
|
"documents_indexed": len(documents),
|
|
"chunks_created": len(all_chunks),
|
|
}
|
|
|
|
def search(
|
|
self,
|
|
index_name: str,
|
|
query: str,
|
|
top_k: int = 5,
|
|
similarity_threshold: float = 0.0,
|
|
filter_metadata: dict | None = None,
|
|
) -> list[dict]:
|
|
idx = self._db.query(VectorIndex).filter(VectorIndex.name == index_name).first()
|
|
if not idx:
|
|
raise ValueError(f"Index '{index_name}' not found")
|
|
|
|
if filter_metadata:
|
|
for key in filter_metadata:
|
|
if not _SAFE_KEY_RE.match(key):
|
|
raise ValueError(f"Invalid metadata filter key: '{key}'")
|
|
|
|
query_embedding = self._get_embeddings([query], "default")[0]
|
|
|
|
sql = """
|
|
SELECT id, content, source_name, source_page, source_section,
|
|
metadata, chunk_index,
|
|
1 - (embedding <=> :qvec::vector) AS score
|
|
FROM vector_chunks
|
|
WHERE index_id = :idx AND embedding IS NOT NULL
|
|
"""
|
|
params = {"idx": idx.id, "qvec": str(query_embedding)}
|
|
|
|
if similarity_threshold > 0:
|
|
sql += " AND 1 - (embedding <=> :qvec::vector) >= :threshold"
|
|
params["threshold"] = similarity_threshold
|
|
|
|
if filter_metadata:
|
|
for i, (key, value) in enumerate(filter_metadata.items()):
|
|
pname = f"fv{i}"
|
|
sql += f" AND metadata->>'{key}' = :{pname}"
|
|
params[pname] = str(value)
|
|
|
|
sql += " ORDER BY embedding <=> :qvec::vector LIMIT :topk"
|
|
params["topk"] = top_k
|
|
|
|
rows = self._db.execute(text(sql), params).fetchall()
|
|
|
|
return [
|
|
{
|
|
"chunk_id": str(row.id),
|
|
"content": row.content,
|
|
"score": float(row.score),
|
|
"source_name": row.source_name,
|
|
"source_page": row.source_page,
|
|
"source_section": row.source_section,
|
|
"metadata": row.metadata if isinstance(row.metadata, dict) else json.loads(row.metadata or "{}"),
|
|
"chunk_index": row.chunk_index,
|
|
}
|
|
for row in rows
|
|
]
|
|
|
|
def delete_index(self, name: str) -> dict:
|
|
idx = self._db.query(VectorIndex).filter(VectorIndex.name == name).first()
|
|
if not idx:
|
|
raise ValueError(f"Index '{name}' not found")
|
|
self._db.delete(idx)
|
|
self._db.commit()
|
|
return {"deleted": True, "name": name}
|
|
|
|
def list_indices(self) -> list[dict]:
|
|
indices = self._db.query(VectorIndex).order_by(VectorIndex.created_at.desc()).all()
|
|
return [
|
|
{
|
|
"name": idx.name,
|
|
"description": idx.description,
|
|
"dimensions": idx.dimensions,
|
|
"chunk_size": idx.chunk_size,
|
|
"chunk_overlap": idx.chunk_overlap,
|
|
}
|
|
for idx in indices
|
|
]
|
|
|
|
def _get_embeddings(self, texts: list[str], model: str) -> list[list[float]]:
|
|
all_embeddings = []
|
|
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
|
batch = texts[i:i + EMBEDDING_BATCH_SIZE]
|
|
kwargs = {"texts": batch}
|
|
if model and model != "default":
|
|
kwargs["model"] = model
|
|
result = self._druppie.call("llm", "embed", kwargs)
|
|
all_embeddings.extend(result.get("embeddings", []))
|
|
return all_embeddings
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Text chunking (recursive split at sentence boundaries with overlap)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _chunk_text(text_content: str, chunk_size: int, chunk_overlap: int) -> list[str]:
|
|
if len(text_content) <= chunk_size:
|
|
return [text_content]
|
|
|
|
separators = ["\n\n", "\n", ". ", "! ", "? ", "; ", ", ", " "]
|
|
return _recursive_split(text_content, chunk_size, chunk_overlap, separators)
|
|
|
|
|
|
def _recursive_split(
|
|
text_content: str,
|
|
chunk_size: int,
|
|
chunk_overlap: int,
|
|
separators: list[str],
|
|
) -> list[str]:
|
|
if len(text_content) <= chunk_size:
|
|
return [text_content.strip()] if text_content.strip() else []
|
|
|
|
separator = separators[0] if separators else ""
|
|
remaining = separators[1:] if len(separators) > 1 else []
|
|
|
|
if not separator:
|
|
step = max(1, chunk_size - chunk_overlap)
|
|
parts = [text_content[i:i + chunk_size] for i in range(0, len(text_content), step)]
|
|
return [p.strip() for p in parts if p.strip()]
|
|
|
|
parts = text_content.split(separator)
|
|
chunks = []
|
|
current = ""
|
|
|
|
for part in parts:
|
|
candidate = current + separator + part if current else part
|
|
if len(candidate) <= chunk_size:
|
|
current = candidate
|
|
else:
|
|
if current.strip():
|
|
chunks.append(current.strip())
|
|
if len(part) > chunk_size and remaining:
|
|
chunks.extend(_recursive_split(part, chunk_size, chunk_overlap, remaining))
|
|
current = ""
|
|
else:
|
|
current = part
|
|
|
|
if current.strip():
|
|
chunks.append(current.strip())
|
|
|
|
if chunk_overlap > 0 and len(chunks) > 1:
|
|
result = [chunks[0]]
|
|
for i in range(1, len(chunks)):
|
|
prev_tail = chunks[i - 1][-chunk_overlap:] if len(chunks[i - 1]) > chunk_overlap else chunks[i - 1]
|
|
result.append(prev_tail + " " + chunks[i])
|
|
chunks = result
|
|
|
|
return chunks
|