Source code for omegaml.backends.genai.mongovector

  1from collections import Counter
  2
  3import re
  4from bson import ObjectId
  5
  6from omegaml.backends.genai.index import VectorStoreBackend
  7from omegaml.util import mongo_compatible
  8
  9
[docs] 10class MongoDBVectorStore(VectorStoreBackend): 11 """ 12 MongoDB vector store for storing documents and their embeddings. 13 """ 14 KIND = 'vector.conx' 15 PROMOTE = 'metadata' 16
[docs] 17 @classmethod 18 def supports(cls, obj, name, insert=False, data_store=None, model_store=None, *args, **kwargs): 19 return bool(re.match(r'^vector(\+mongodb)?://', str(obj))) # Supports vector:// or vector+mongodb://
20 21 def _documents(self, name): 22 return self.data_store.collection(f'vecdb_{name}_docs') 23 24 def _chunks(self, name): 25 return self.data_store.collection(f'vecdb_{name}_chunks') 26
[docs] 27 def list(self, name): 28 """ 29 List all documents inside a collection. 30 """ 31 docs = self._documents(name).find({}, 32 {'_id': 1, 'source': 1, 'attributes': 1}) 33 return [{'id': str(doc['_id']), 34 'source': doc.get('source', ''), 35 'attributes': doc.get('attributes', {})} 36 for doc in docs]
37 38 def insert_chunks(self, chunks, name, embeddings, attributes=None, **kwargs): 39 attributes = attributes or {} 40 source = attributes.get('source', '') 41 attributes.setdefault('source', source) 42 attributes.setdefault('tags', []) 43 44 # Insert the document metadata 45 doc_id = self._documents(name).insert_one({ 46 'source': source, 47 'attributes': attributes, 48 }).inserted_id 49 50 # Insert the chunks and their embeddings 51 for text, embedding in zip(chunks, embeddings): 52 self._chunks(name).insert_one({ 53 'document_id': doc_id, 54 'text': text, 55 'embedding': mongo_compatible(embedding), 56 }) 57 58 def find_similar(self, name, obj, top=5, filter=None, distance='l2', max_distance=None, **kwargs): 59 # Create a pipeline to calculate distances 60 obj = mongo_compatible(obj) 61 lookup = [ 62 { 63 '$lookup': { 64 'from': self._documents(name).name, 65 'localField': 'document_id', 66 'foreignField': '_id', 67 'as': 'document' 68 } 69 }, 70 { 71 '$unwind': '$document' 72 }, 73 ] 74 if filter: 75 match = [{ 76 '$match': { 77 '$or': [ 78 {f'document.attributes.{key}': 79 {'$in': values if isinstance(filter, list) else [values]} 80 for key, values in filter.items()} 81 ] 82 } 83 }] 84 else: 85 match = [] 86 project = [{ 87 '$project': { 88 'document_id': 1, 89 'text': 1, 90 'embedding': 1, 91 'source': '$document.source', 92 'attributes': '$document.attributes', 93 'distance': { 94 '$sqrt': { 95 '$sum': { 96 '$map': { 97 'input': { 98 '$range': [0, len(obj)], 99 }, 100 'as': 'i', 101 'in': { 102 '$pow': [ 103 {'$subtract': [ 104 {'$arrayElemAt': ['$embedding', '$$i']}, 105 {'$arrayElemAt': [{'$literal': obj}, '$$i']} 106 ]}, 107 2 108 ] 109 } 110 } 111 } 112 } 113 }, 114 } 115 }] 116 sort = [ 117 { 118 '$sort': {'distance': 1} 119 }, 120 { 121 '$limit': top 122 } 123 ] 124 subset = [ 125 { 126 '$match': { 127 'distance': {'$lte': float(max_distance)} 128 } 129 } 130 ] if max_distance is not None else [] 131 pipeline = lookup + match + project + sort + subset 132 # Execute the aggregation pipeline 133 results = list(self._chunks(name).aggregate(pipeline))[0:top] 134 return results 135 136 def delete(self, name, obj=None, filter=None, **kwargs): 137 # Clear all stored documents and chunks 138 filter = filter or {} 139 if isinstance(obj, dict) and 'id' in obj: 140 filter.update({'_id': ObjectId(obj.get('id'))}) 141 elif isinstance(obj, str): 142 filter.update({'source': obj}) 143 elif isinstance(obj, (int, float)): 144 filter.update({'_id': ObjectId(str(obj))}) 145 elif obj is not None: 146 raise ValueError("Object must be a dict with 'id', or a string matching source") 147 doc_ids = self._documents(name).find(filter, {'_id': 1}) 148 self._documents(name).delete_many(filter) 149 self._chunks(name).delete_many({ 150 'document_id': {'$in': [doc['_id'] for doc in doc_ids]} 151 }) 152
[docs] 153 def attributes(self, name, key=None): 154 """ 155 Get the attributes of the vector store. 156 157 Args: 158 key (str, optional): If provided, returns attributes for this key only. 159 160 Returns: 161 dict: A dictionary of attributes for the vector store, where each value a dictionary of value counts. 162 """ 163 # write in a way that is compatible with MongoDB as an agggregation pipeline 164 key_filter = key 165 pipeline = [ 166 { 167 "$project": { 168 "keyValuePairs": {"$objectToArray": "$attributes"} 169 # Convert attributes to an array of key-value pairs 170 } 171 }, 172 { 173 "$unwind": "$keyValuePairs" # Deconstruct the array to output a document for each key-value pair 174 }, 175 { 176 "$unwind": "$keyValuePairs.v" # Deconstruct the array of values for the 'tags' key 177 }, 178 { 179 "$group": { 180 "_id": { 181 "key": "$keyValuePairs.k", # Group by the key 182 "value": "$keyValuePairs.v" # Group by the value 183 }, 184 "count": {"$sum": 1} # Count occurrences 185 } 186 }, 187 { 188 "$project": { 189 "key": "$_id.key", # Restructure the output 190 "value": "$_id.value", 191 "count": 1, 192 "_id": 0 # Exclude the default _id field 193 } 194 }, 195 { 196 "$sort": { 197 "key": 1, # Sort by key 198 "value": 1 # Sort by value 199 } 200 } 201 ] 202 data = self._documents(name).aggregate(pipeline) 203 results = {} 204 for item in data: 205 key, value, count = item.get('key'), item.get('value'), item.get('count') 206 if key_filter and key != key_filter: 207 continue 208 counter = results.setdefault(key, Counter()) 209 counter[value] += count 210 results = {key: dict(counts) for key, counts in results.items()} 211 return results