Source code for omegaml.backends.genai.pgvector

  1from collections import Counter
  2
  3import json
  4import logging
  5import re
  6from pgvector.sqlalchemy import Vector
  7from sqlalchemy import Column, Integer, String, text, ForeignKey, select, Index, LargeBinary, and_
  8from sqlalchemy.dialects.postgresql import JSONB
  9from sqlalchemy.engine import RowMapping
 10from sqlalchemy.orm import Session, relationship, declarative_base
 11
 12from omegaml.backends.genai.dbmigrate import DatabaseMigrator
 13from omegaml.backends.genai.index import VectorStoreBackend
 14from omegaml.util import tryOr
 15
 16logger = logging.getLogger(__name__)
 17
 18
[docs] 19class PGVectorBackend(VectorStoreBackend): 20 """ 21 docker run pgvector/pgvector:pg16 22 """ 23 KIND = 'pgvector.conx' 24 PROMOTE = 'metadata' 25 _attributes_keys = ('tags', 'source', 'type') 26 27 @classmethod 28 def supports(cls, obj, name, insert=False, data_store=None, model_store=None, meta=None, *args, **kwargs): 29 valid_types = (meta is not None and isinstance(obj, (str, list, tuple, dict))) 30 return valid_types or _is_valid_url(obj) 31 32 def insert_chunks(self, chunks, name, embeddings, attributes, data=None, **kwargs): 33 collection, vector_size, model = self._get_collection(name) 34 Document, Chunk = self._create_collection(name, collection, vector_size, **kwargs) 35 Session = self._get_connection(name, session=True) 36 with Session as session: 37 attributes = attributes or {} 38 source = attributes.get('source', '') 39 attributes.setdefault('source', source) 40 attributes.setdefault('tags', []) 41 doc = Document(source=source, attributes=attributes, data=data) 42 session.add(doc) 43 for text, embedding in zip(chunks, embeddings): 44 chunk = Chunk(document=doc, text=text, embedding=embedding) 45 session.add(chunk) 46 session.commit() 47 48 def list(self, name): 49 """ 50 List all documents inside a collections 51 """ 52 collection, vector_size, model = self._get_collection(name) 53 Document, Chunk = self._create_collection(name, collection, vector_size) 54 Session = self._get_connection(name, session=True) 55 data = [] 56 with Session as session: 57 query = (select(Document.id, 58 Document.source, 59 Document.attributes) 60 .order_by(Document.source)) 61 result = session.execute(query) 62 data = list(result.mappings().all()) 63 return data 64 65 def find_similar(self, name, obj, top=5, filter=None, distance=None, max_distance=None, **kwargs): 66 """ Find similar documents in a collection based on the provided object. 67 68 Args: 69 name (str): The name of the collection to search in. 70 obj (list or str): The object to find similar documents for. If a string, it will be embedded. 71 top (int): The number of top similar documents to return. 72 filter (dict): Optional filter criteria to apply to the search. Filters for the 'tags', 'source', and 'type' 73 keys will be applied to the 'attributes' field of the documents. This can be configured by setting 74 the '_attributes_keys' class variable to the desired keys or by expicitely passsing the filter as the 75 'attributes' key in the filter dictionary. 76 distance (str): The distance metric to use ('l2' or 'cos'). Defaults to 'l2'. 77 max_distance (float): Optional maximum distance to filter results. 78 **kwargs: Additional keyword arguments, if filter is not provided, kwargs will be used as 79 filter. 80 81 Returns: 82 list: A list of dictionaries containing the similar documents and their attributes. 83 """ 84 collection, vector_size, model = self._get_collection(name) 85 Document, Chunk = self._create_collection(name, collection, vector_size) 86 Session = self._get_connection(name, session=True) 87 METRIC_MAP = { 88 'l2': lambda target: Chunk.embedding.l2_distance(target), 89 'cos': lambda target: Chunk.embedding.cosine_distance(target), 90 } 91 metric = distance or 'l2' 92 filter = filter or kwargs 93 with Session as session: 94 distance_fn = METRIC_MAP[metric] 95 query = (select(Document.id, 96 Document.source, 97 Document.attributes, 98 Chunk.text, 99 distance_fn(obj).label('distance')) 100 .join(Chunk.document)) 101 attributes_filter = filter.get('attributes', None) 102 attributes_filter = attributes_filter or {k: v for k, v in (filter or {}).items() if 103 k in self._attributes_keys} 104 # add attributes filter 105 filters = [] 106 for key, value in attributes_filter.items(): 107 if isinstance(value, list): 108 # Use a lateral join for the IN clause 109 filters.append(Document.attributes.contains({key: value})) 110 else: 111 filters.append(Document.attributes[key].astext == str(value)) 112 if filters: 113 query = query.where(and_(*filters)) 114 # Add max_distance filter if provided 115 if max_distance is not None: 116 query = query.where(distance_fn(obj) <= max_distance) 117 query = query.order_by('distance') 118 logger.debug(f'PGVector query: {query.compile()} {query.compile().params}') 119 if isinstance(top, int) and top > 0: 120 query = query.limit(top) 121 results = session.execute(query) 122 data = list(results.mappings().all()) 123 return data 124 125 def embeddings(self, name): 126 """ 127 List all embeddings inside a collection 128 """ 129 collection, vector_size, model = self._get_collection(name) 130 Document, Chunk = self._create_collection(name, collection, vector_size) 131 Session = self._get_connection(name, session=True) 132 data = [] 133 with Session as session: 134 query = (select(Chunk.id, 135 Chunk.text, 136 Chunk.embedding, 137 Document.source, 138 Document.attributes) 139 .join(Chunk.document) 140 .order_by(Document.source)) 141 result = session.execute(query) 142 data = list(result.mappings().all()) 143 return data 144 145 def attributes(self, name, key=None): 146 """ 147 List all attributes inside a collection 148 149 Args: 150 name (str): The name of the collection to search in. 151 key (str, optional): If provided, filter the attributes by this key. 152 153 Returns: 154 dict: A dictionary where keys are attribute names and values are dictionaries of attribute counts 155 """ 156 collection, vector_size, model = self._get_collection(name) 157 Document, Chunk = self._create_collection(name, collection, vector_size) 158 Session = self._get_connection(name, session=True) 159 data = [] 160 with Session as session: 161 # this failed using sqlalchemy orm, so using raw SQL 162 filter = 'key = :key_filter' if key else 'TRUE' 163 query = text(f""" 164 SELECT 165 key, 166 value, 167 COUNT(*) AS count 168 FROM 169 {Document.__tablename__} AS docs, 170 jsonb_each(docs.attributes) AS kv(key, value_array) 171 LEFT JOIN LATERAL ( 172 SELECT 173 value_array::text AS value 174 WHERE 175 jsonb_typeof(value_array) = 'string' OR 176 jsonb_typeof(value_array) = 'number' OR 177 jsonb_typeof(value_array) = 'boolean' 178 UNION ALL 179 SELECT 180 jsonb_array_elements_text(value_array) AS value 181 WHERE 182 jsonb_typeof(value_array) = 'array' 183 ) AS values ON TRUE 184 WHERE 185 {filter} 186 GROUP BY 187 key, value 188 ORDER BY 189 key, value; 190 """) 191 result = session.execute(query, {'key_filter': key} if key else {}) 192 data = list(row for row in result.all()) 193 results = {} 194 for key, value, count in data: 195 counter = results.setdefault(key, Counter()) 196 # remove string quotes safely 197 # -- rationale: without this value_array::text in above SQL adds quotes to strings, e.g. '' => '""' 198 # json.loads() removes the quotes safely 199 value = tryOr(lambda: json.loads(value), value) 200 counter[value] += count 201 results = {key: dict(counts) for key, counts in results.items()} 202 return results 203 204 def delete(self, name, obj=None, filter=None, drop=False, **kwargs): 205 collection, vector_size, model = self._get_collection(name) 206 Session = self._get_connection(name, session=True) 207 Document, Chunk = self._create_collection(name, collection, vector_size, **kwargs) 208 if drop: 209 with Session as session: 210 Chunk.__table__.drop(session.get_bind(), checkfirst=False) 211 Document.__table__.drop(session.get_bind(), checkfirst=False) 212 return 213 with Session as session: 214 # get documents 215 if isinstance(obj, (dict, RowMapping)): 216 docs_query = select(Document.id).where(Document.id == obj['id']) 217 elif isinstance(obj, int): 218 docs_query = select(Document.id).where(Document.id == obj) 219 elif isinstance(obj, str): 220 docs_query = select(Document.id).where(Document.source == obj) 221 else: 222 docs_query = select(Document.id) 223 doc_ids = session.execute(docs_query).scalars().all() 224 if doc_ids: 225 session.query(Chunk).filter(Chunk.document_id.in_(doc_ids)).delete(synchronize_session='fetch') 226 session.query(Document).filter(Document.id.in_(doc_ids)).delete(synchronize_session='fetch') 227 session.commit() 228 229 def _create_collection(self, name, collection, vector_size, **kwargs): 230 Base = declarative_base() 231 collection = collection or self._default_collection(name) 232 docs_table = f'{collection}_docs' 233 chunks_table = f'{docs_table}_chunks' 234 235 class Document(Base): 236 __tablename__ = docs_table 237 id = Column(Integer, primary_key=True) 238 source = Column(String) 239 attributes = Column(JSONB, nullable=True) 240 data = Column('data', String().with_variant(LargeBinary, 'postgresql')) 241 242 class Chunk(Base): 243 __tablename__ = chunks_table 244 245 id = Column(Integer, primary_key=True) 246 text = Column(String) 247 embedding = Column(Vector(vector_size)) 248 document_id = Column(Integer, ForeignKey(f'{docs_table}.id')) 249 document = relationship('Document') 250 # attributes = Column(String) 251 252 with self._get_connection(name, session=True) as session: 253 session.execute(text('CREATE EXTENSION IF NOT EXISTS vector')) 254 session.commit() 255 Base.metadata.create_all(session.get_bind()) 256 session.commit() 257 migrator = DatabaseMigrator(session.connection()) 258 migrator.run_migrations([Document, Chunk]) 259 with self._get_connection(name, session=True) as session: 260 try: 261 index = Index( 262 'l2_index', 263 Chunk.embedding, 264 postgresql_using='hnsw', 265 postgresql_with={'m': 16, 'ef_construction': 64}, 266 postgresql_ops={'embedding': 'vector_l2_ops'} 267 ) 268 index.create(session.get_bind()) 269 except Exception as e: 270 pass 271 session.commit() 272 return Document, Chunk 273 274 def _get_connection(self, name, session=False): 275 from sqlalchemy import create_engine 276 meta = self.data_store.metadata(name) 277 connection_str = meta.kind_meta['connection'] 278 connection_str = connection_str.replace('pgvector', 'postgresql').replace('sqla+', '') 279 import sqlalchemy 280 # https://docs.sqlalchemy.org/en/14/changelog/migration_20.html 281 kwargs = {} if sqlalchemy.__version__.startswith('2.') else dict(future=True) 282 engine = create_engine(connection_str, **kwargs) 283 if session: 284 connection = Session(bind=engine) 285 else: 286 connection = engine.connect() 287 return connection
288 289 290def _is_valid_url(url): 291 return isinstance(url, str) and re.match(r'(sqla\+)?pgvector://', str(url))