1from collections import Counter
2
3import json
4import logging
5import re
6from pgvector.sqlalchemy import Vector
7from sqlalchemy import Column, Integer, String, text, ForeignKey, select, Index, LargeBinary, and_
8from sqlalchemy.dialects.postgresql import JSONB
9from sqlalchemy.engine import RowMapping
10from sqlalchemy.orm import Session, relationship, declarative_base
11
12from omegaml.backends.genai.dbmigrate import DatabaseMigrator
13from omegaml.backends.genai.index import VectorStoreBackend
14from omegaml.util import tryOr
15
16logger = logging.getLogger(__name__)
17
18
[docs]
19class PGVectorBackend(VectorStoreBackend):
20 """
21 docker run pgvector/pgvector:pg16
22 """
23 KIND = 'pgvector.conx'
24 PROMOTE = 'metadata'
25 _attributes_keys = ('tags', 'source', 'type')
26
27 @classmethod
28 def supports(cls, obj, name, insert=False, data_store=None, model_store=None, meta=None, *args, **kwargs):
29 valid_types = (meta is not None and isinstance(obj, (str, list, tuple, dict)))
30 return valid_types or _is_valid_url(obj)
31
32 def insert_chunks(self, chunks, name, embeddings, attributes, data=None, **kwargs):
33 collection, vector_size, model = self._get_collection(name)
34 Document, Chunk = self._create_collection(name, collection, vector_size, **kwargs)
35 Session = self._get_connection(name, session=True)
36 with Session as session:
37 attributes = attributes or {}
38 source = attributes.get('source', '')
39 attributes.setdefault('source', source)
40 attributes.setdefault('tags', [])
41 doc = Document(source=source, attributes=attributes, data=data)
42 session.add(doc)
43 for text, embedding in zip(chunks, embeddings):
44 chunk = Chunk(document=doc, text=text, embedding=embedding)
45 session.add(chunk)
46 session.commit()
47
48 def list(self, name):
49 """
50 List all documents inside a collections
51 """
52 collection, vector_size, model = self._get_collection(name)
53 Document, Chunk = self._create_collection(name, collection, vector_size)
54 Session = self._get_connection(name, session=True)
55 data = []
56 with Session as session:
57 query = (select(Document.id,
58 Document.source,
59 Document.attributes)
60 .order_by(Document.source))
61 result = session.execute(query)
62 data = list(result.mappings().all())
63 return data
64
65 def find_similar(self, name, obj, top=5, filter=None, distance=None, max_distance=None, **kwargs):
66 """ Find similar documents in a collection based on the provided object.
67
68 Args:
69 name (str): The name of the collection to search in.
70 obj (list or str): The object to find similar documents for. If a string, it will be embedded.
71 top (int): The number of top similar documents to return.
72 filter (dict): Optional filter criteria to apply to the search. Filters for the 'tags', 'source', and 'type'
73 keys will be applied to the 'attributes' field of the documents. This can be configured by setting
74 the '_attributes_keys' class variable to the desired keys or by expicitely passsing the filter as the
75 'attributes' key in the filter dictionary.
76 distance (str): The distance metric to use ('l2' or 'cos'). Defaults to 'l2'.
77 max_distance (float): Optional maximum distance to filter results.
78 **kwargs: Additional keyword arguments, if filter is not provided, kwargs will be used as
79 filter.
80
81 Returns:
82 list: A list of dictionaries containing the similar documents and their attributes.
83 """
84 collection, vector_size, model = self._get_collection(name)
85 Document, Chunk = self._create_collection(name, collection, vector_size)
86 Session = self._get_connection(name, session=True)
87 METRIC_MAP = {
88 'l2': lambda target: Chunk.embedding.l2_distance(target),
89 'cos': lambda target: Chunk.embedding.cosine_distance(target),
90 }
91 metric = distance or 'l2'
92 filter = filter or kwargs
93 with Session as session:
94 distance_fn = METRIC_MAP[metric]
95 query = (select(Document.id,
96 Document.source,
97 Document.attributes,
98 Chunk.text,
99 distance_fn(obj).label('distance'))
100 .join(Chunk.document))
101 attributes_filter = filter.get('attributes', None)
102 attributes_filter = attributes_filter or {k: v for k, v in (filter or {}).items() if
103 k in self._attributes_keys}
104 # add attributes filter
105 filters = []
106 for key, value in attributes_filter.items():
107 if isinstance(value, list):
108 # Use a lateral join for the IN clause
109 filters.append(Document.attributes.contains({key: value}))
110 else:
111 filters.append(Document.attributes[key].astext == str(value))
112 if filters:
113 query = query.where(and_(*filters))
114 # Add max_distance filter if provided
115 if max_distance is not None:
116 query = query.where(distance_fn(obj) <= max_distance)
117 query = query.order_by('distance')
118 logger.debug(f'PGVector query: {query.compile()} {query.compile().params}')
119 if isinstance(top, int) and top > 0:
120 query = query.limit(top)
121 results = session.execute(query)
122 data = list(results.mappings().all())
123 return data
124
125 def embeddings(self, name):
126 """
127 List all embeddings inside a collection
128 """
129 collection, vector_size, model = self._get_collection(name)
130 Document, Chunk = self._create_collection(name, collection, vector_size)
131 Session = self._get_connection(name, session=True)
132 data = []
133 with Session as session:
134 query = (select(Chunk.id,
135 Chunk.text,
136 Chunk.embedding,
137 Document.source,
138 Document.attributes)
139 .join(Chunk.document)
140 .order_by(Document.source))
141 result = session.execute(query)
142 data = list(result.mappings().all())
143 return data
144
145 def attributes(self, name, key=None):
146 """
147 List all attributes inside a collection
148
149 Args:
150 name (str): The name of the collection to search in.
151 key (str, optional): If provided, filter the attributes by this key.
152
153 Returns:
154 dict: A dictionary where keys are attribute names and values are dictionaries of attribute counts
155 """
156 collection, vector_size, model = self._get_collection(name)
157 Document, Chunk = self._create_collection(name, collection, vector_size)
158 Session = self._get_connection(name, session=True)
159 data = []
160 with Session as session:
161 # this failed using sqlalchemy orm, so using raw SQL
162 filter = 'key = :key_filter' if key else 'TRUE'
163 query = text(f"""
164 SELECT
165 key,
166 value,
167 COUNT(*) AS count
168 FROM
169 {Document.__tablename__} AS docs,
170 jsonb_each(docs.attributes) AS kv(key, value_array)
171 LEFT JOIN LATERAL (
172 SELECT
173 value_array::text AS value
174 WHERE
175 jsonb_typeof(value_array) = 'string' OR
176 jsonb_typeof(value_array) = 'number' OR
177 jsonb_typeof(value_array) = 'boolean'
178 UNION ALL
179 SELECT
180 jsonb_array_elements_text(value_array) AS value
181 WHERE
182 jsonb_typeof(value_array) = 'array'
183 ) AS values ON TRUE
184 WHERE
185 {filter}
186 GROUP BY
187 key, value
188 ORDER BY
189 key, value;
190 """)
191 result = session.execute(query, {'key_filter': key} if key else {})
192 data = list(row for row in result.all())
193 results = {}
194 for key, value, count in data:
195 counter = results.setdefault(key, Counter())
196 # remove string quotes safely
197 # -- rationale: without this value_array::text in above SQL adds quotes to strings, e.g. '' => '""'
198 # json.loads() removes the quotes safely
199 value = tryOr(lambda: json.loads(value), value)
200 counter[value] += count
201 results = {key: dict(counts) for key, counts in results.items()}
202 return results
203
204 def delete(self, name, obj=None, filter=None, drop=False, **kwargs):
205 collection, vector_size, model = self._get_collection(name)
206 Session = self._get_connection(name, session=True)
207 Document, Chunk = self._create_collection(name, collection, vector_size, **kwargs)
208 if drop:
209 with Session as session:
210 Chunk.__table__.drop(session.get_bind(), checkfirst=False)
211 Document.__table__.drop(session.get_bind(), checkfirst=False)
212 return
213 with Session as session:
214 # get documents
215 if isinstance(obj, (dict, RowMapping)):
216 docs_query = select(Document.id).where(Document.id == obj['id'])
217 elif isinstance(obj, int):
218 docs_query = select(Document.id).where(Document.id == obj)
219 elif isinstance(obj, str):
220 docs_query = select(Document.id).where(Document.source == obj)
221 else:
222 docs_query = select(Document.id)
223 doc_ids = session.execute(docs_query).scalars().all()
224 if doc_ids:
225 session.query(Chunk).filter(Chunk.document_id.in_(doc_ids)).delete(synchronize_session='fetch')
226 session.query(Document).filter(Document.id.in_(doc_ids)).delete(synchronize_session='fetch')
227 session.commit()
228
229 def _create_collection(self, name, collection, vector_size, **kwargs):
230 Base = declarative_base()
231 collection = collection or self._default_collection(name)
232 docs_table = f'{collection}_docs'
233 chunks_table = f'{docs_table}_chunks'
234
235 class Document(Base):
236 __tablename__ = docs_table
237 id = Column(Integer, primary_key=True)
238 source = Column(String)
239 attributes = Column(JSONB, nullable=True)
240 data = Column('data', String().with_variant(LargeBinary, 'postgresql'))
241
242 class Chunk(Base):
243 __tablename__ = chunks_table
244
245 id = Column(Integer, primary_key=True)
246 text = Column(String)
247 embedding = Column(Vector(vector_size))
248 document_id = Column(Integer, ForeignKey(f'{docs_table}.id'))
249 document = relationship('Document')
250 # attributes = Column(String)
251
252 with self._get_connection(name, session=True) as session:
253 session.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
254 session.commit()
255 Base.metadata.create_all(session.get_bind())
256 session.commit()
257 migrator = DatabaseMigrator(session.connection())
258 migrator.run_migrations([Document, Chunk])
259 with self._get_connection(name, session=True) as session:
260 try:
261 index = Index(
262 'l2_index',
263 Chunk.embedding,
264 postgresql_using='hnsw',
265 postgresql_with={'m': 16, 'ef_construction': 64},
266 postgresql_ops={'embedding': 'vector_l2_ops'}
267 )
268 index.create(session.get_bind())
269 except Exception as e:
270 pass
271 session.commit()
272 return Document, Chunk
273
274 def _get_connection(self, name, session=False):
275 from sqlalchemy import create_engine
276 meta = self.data_store.metadata(name)
277 connection_str = meta.kind_meta['connection']
278 connection_str = connection_str.replace('pgvector', 'postgresql').replace('sqla+', '')
279 import sqlalchemy
280 # https://docs.sqlalchemy.org/en/14/changelog/migration_20.html
281 kwargs = {} if sqlalchemy.__version__.startswith('2.') else dict(future=True)
282 engine = create_engine(connection_str, **kwargs)
283 if session:
284 connection = Session(bind=engine)
285 else:
286 connection = engine.connect()
287 return connection
288
289
290def _is_valid_url(url):
291 return isinstance(url, str) and re.match(r'(sqla\+)?pgvector://', str(url))