1import json
2import re
3from omegaml.backends.genai.index import VectorStoreBackend
4from pgvector.sqlalchemy import Vector
5from sqlalchemy import Column, Integer, String, text, ForeignKey, select
6from sqlalchemy.orm import Session, relationship
7
8
[docs]
9class PGVectorBackend(VectorStoreBackend):
10 """
11 docker run pgvector/pgvector:pg16
12 """
13 KIND = 'pgvector.conx'
14 PROMOTE = 'metadata'
15
16 @classmethod
17 def supports(cls, obj, name, insert=False, data_store=None, model_store=None, *args, **kwargs):
18 return _is_valid_url(obj) # or data_store.exists(name)
19
20 def insert_chunks(self, chunks, name, embeddings, attributes, **kwargs):
21 collection, vector_size, model = self._get_collection(name)
22 Document, Chunk = self._create_collection(name, collection, vector_size, **kwargs)
23 Session = self._get_connection(name, session=True)
24 with Session as session:
25 source = (attributes or {}).pop('source', None)
26 attributes = json.dumps(attributes or {})
27 doc = Document(source=source, attributes=attributes)
28 session.add(doc)
29 for text, embedding in zip(chunks, embeddings):
30 chunk = Chunk(document=doc, text=text, embedding=embedding)
31 session.add(chunk)
32 session.commit()
33
34 def find_similar(self, name, obj, top=5, filter=None, distance=None, **kwargs):
35 collection, vector_size, model = self._get_collection(name)
36 Document, Chunk = self._create_collection(name, collection, vector_size)
37 Session = self._get_connection(name, session=True)
38 METRIC_MAP = {
39 'l2': lambda obj: Chunk.embedding.l2_distance(obj),
40 'cos': lambda obj: Chunk.embedding.l2_distance(obj),
41 }
42 with Session as session:
43 distance = METRIC_MAP[distance or 'l2'](obj)
44 chunks = (select(Document.id,
45 Document.source,
46 Document.attributes,
47 Chunk.text,
48 # Chunk.embedding,
49 distance.label('distance'))
50 .join(Chunk.document)
51 .order_by(distance))
52 if isinstance(top, int) and top > 0:
53 chunks = chunks.limit(top)
54 results = session.execute(chunks)
55 data = list(results.mappings().all())
56 return data
57
58 def delete(self, name, obj=None, filter=None, **kwargs):
59 collection, vector_size, model = self._get_collection(name)
60 Document, Chunk = self._create_collection(name, collection, vector_size, **kwargs)
61 Session = self._get_connection(name, session=True)
62 with Session as session:
63 session.query(Chunk).delete()
64 session.query(Document).delete()
65 session.commit()
66
67 def _create_collection(self, name, collection, vector_size, **kwargs):
68 from sqlalchemy.orm import declarative_base
69 Base = declarative_base()
70 collection = collection or self._default_collection(name)
71 docs_table = f'{collection}_docs'
72 chunks_table = f'{docs_table}_chunks'
73
74 class Document(Base):
75 __tablename__ = docs_table
76 id = Column(Integer, primary_key=True)
77 source = Column(String)
78 attributes = Column(String)
79
80 class Chunk(Base):
81 __tablename__ = chunks_table
82 id = Column(Integer, primary_key=True)
83 text = Column(String)
84 embedding = Column(Vector(vector_size))
85 document_id = Column(Integer, ForeignKey(f'{docs_table}.id'))
86 document = relationship('Document')
87
88 Session = self._get_connection(name, session=True)
89 with Session as session:
90 session.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
91 session.commit()
92 Base.metadata.create_all(session.get_bind())
93 session.commit()
94 return Document, Chunk
95
96 def _get_connection(self, name, session=False):
97 from sqlalchemy import create_engine
98 meta = self.data_store.metadata(name)
99 connection_str = meta.kind_meta['connection']
100 connection_str = connection_str.replace('pgvector', 'postgresql').replace('sqla+', '')
101 import sqlalchemy
102 # https://docs.sqlalchemy.org/en/14/changelog/migration_20.html
103 kwargs = {} if sqlalchemy.__version__.startswith('2.') else dict(future=True)
104 engine = create_engine(connection_str, **kwargs)
105 if session:
106 connection = Session(bind=engine)
107 else:
108 connection = engine.connect()
109 return connection
110
111
112def _is_valid_url(url):
113 return isinstance(url, str) and re.match(r'(sqla\+)?pgvector://', str(url))