Source code for omegaml.backends.genai.pgvector

  1import json
  2import re
  3from omegaml.backends.genai.index import VectorStoreBackend
  4from pgvector.sqlalchemy import Vector
  5from sqlalchemy import Column, Integer, String, text, ForeignKey, select
  6from sqlalchemy.orm import Session, relationship
  7
  8
[docs] 9class PGVectorBackend(VectorStoreBackend): 10 """ 11 docker run pgvector/pgvector:pg16 12 """ 13 KIND = 'pgvector.conx' 14 PROMOTE = 'metadata' 15 16 @classmethod 17 def supports(cls, obj, name, insert=False, data_store=None, model_store=None, *args, **kwargs): 18 return _is_valid_url(obj) # or data_store.exists(name) 19 20 def insert_chunks(self, chunks, name, embeddings, attributes, **kwargs): 21 collection, vector_size, model = self._get_collection(name) 22 Document, Chunk = self._create_collection(name, collection, vector_size, **kwargs) 23 Session = self._get_connection(name, session=True) 24 with Session as session: 25 source = (attributes or {}).pop('source', None) 26 attributes = json.dumps(attributes or {}) 27 doc = Document(source=source, attributes=attributes) 28 session.add(doc) 29 for text, embedding in zip(chunks, embeddings): 30 chunk = Chunk(document=doc, text=text, embedding=embedding) 31 session.add(chunk) 32 session.commit() 33 34 def find_similar(self, name, obj, top=5, filter=None, distance=None, **kwargs): 35 collection, vector_size, model = self._get_collection(name) 36 Document, Chunk = self._create_collection(name, collection, vector_size) 37 Session = self._get_connection(name, session=True) 38 METRIC_MAP = { 39 'l2': lambda obj: Chunk.embedding.l2_distance(obj), 40 'cos': lambda obj: Chunk.embedding.l2_distance(obj), 41 } 42 with Session as session: 43 distance = METRIC_MAP[distance or 'l2'](obj) 44 chunks = (select(Document.id, 45 Document.source, 46 Document.attributes, 47 Chunk.text, 48 # Chunk.embedding, 49 distance.label('distance')) 50 .join(Chunk.document) 51 .order_by(distance)) 52 if isinstance(top, int) and top > 0: 53 chunks = chunks.limit(top) 54 results = session.execute(chunks) 55 data = list(results.mappings().all()) 56 return data 57 58 def delete(self, name, obj=None, filter=None, **kwargs): 59 collection, vector_size, model = self._get_collection(name) 60 Document, Chunk = self._create_collection(name, collection, vector_size, **kwargs) 61 Session = self._get_connection(name, session=True) 62 with Session as session: 63 session.query(Chunk).delete() 64 session.query(Document).delete() 65 session.commit() 66 67 def _create_collection(self, name, collection, vector_size, **kwargs): 68 from sqlalchemy.orm import declarative_base 69 Base = declarative_base() 70 collection = collection or self._default_collection(name) 71 docs_table = f'{collection}_docs' 72 chunks_table = f'{docs_table}_chunks' 73 74 class Document(Base): 75 __tablename__ = docs_table 76 id = Column(Integer, primary_key=True) 77 source = Column(String) 78 attributes = Column(String) 79 80 class Chunk(Base): 81 __tablename__ = chunks_table 82 id = Column(Integer, primary_key=True) 83 text = Column(String) 84 embedding = Column(Vector(vector_size)) 85 document_id = Column(Integer, ForeignKey(f'{docs_table}.id')) 86 document = relationship('Document') 87 88 Session = self._get_connection(name, session=True) 89 with Session as session: 90 session.execute(text('CREATE EXTENSION IF NOT EXISTS vector')) 91 session.commit() 92 Base.metadata.create_all(session.get_bind()) 93 session.commit() 94 return Document, Chunk 95 96 def _get_connection(self, name, session=False): 97 from sqlalchemy import create_engine 98 meta = self.data_store.metadata(name) 99 connection_str = meta.kind_meta['connection'] 100 connection_str = connection_str.replace('pgvector', 'postgresql').replace('sqla+', '') 101 import sqlalchemy 102 # https://docs.sqlalchemy.org/en/14/changelog/migration_20.html 103 kwargs = {} if sqlalchemy.__version__.startswith('2.') else dict(future=True) 104 engine = create_engine(connection_str, **kwargs) 105 if session: 106 connection = Session(bind=engine) 107 else: 108 connection = engine.connect() 109 return connection
110 111 112def _is_valid_url(url): 113 return isinstance(url, str) and re.match(r'(sqla\+)?pgvector://', str(url))