1from collections import Counter
2
3import re
4from bson import ObjectId
5
6from omegaml.backends.genai.index import VectorStoreBackend
7from omegaml.util import mongo_compatible
8
9
[docs]
10class MongoDBVectorStore(VectorStoreBackend):
11 """
12 MongoDB vector store for storing documents and their embeddings.
13 """
14 KIND = 'vector.conx'
15 PROMOTE = 'metadata'
16
[docs]
17 @classmethod
18 def supports(cls, obj, name, insert=False, data_store=None, model_store=None, *args, **kwargs):
19 return bool(re.match(r'^vector(\+mongodb)?://', str(obj))) # Supports vector:// or vector+mongodb://
20
21 def _documents(self, name):
22 return self.data_store.collection(f'vecdb_{name}_docs')
23
24 def _chunks(self, name):
25 return self.data_store.collection(f'vecdb_{name}_chunks')
26
[docs]
27 def list(self, name):
28 """
29 List all documents inside a collection.
30 """
31 docs = self._documents(name).find({},
32 {'_id': 1, 'source': 1, 'attributes': 1})
33 return [{'id': str(doc['_id']),
34 'source': doc.get('source', ''),
35 'attributes': doc.get('attributes', {})}
36 for doc in docs]
37
38 def insert_chunks(self, chunks, name, embeddings, attributes=None, **kwargs):
39 attributes = attributes or {}
40 source = attributes.get('source', '')
41 attributes.setdefault('source', source)
42 attributes.setdefault('tags', [])
43
44 # Insert the document metadata
45 doc_id = self._documents(name).insert_one({
46 'source': source,
47 'attributes': attributes,
48 }).inserted_id
49
50 # Insert the chunks and their embeddings
51 for text, embedding in zip(chunks, embeddings):
52 self._chunks(name).insert_one({
53 'document_id': doc_id,
54 'text': text,
55 'embedding': mongo_compatible(embedding),
56 })
57
58 def find_similar(self, name, obj, top=5, filter=None, distance='l2', max_distance=None, **kwargs):
59 # Create a pipeline to calculate distances
60 obj = mongo_compatible(obj)
61 lookup = [
62 {
63 '$lookup': {
64 'from': self._documents(name).name,
65 'localField': 'document_id',
66 'foreignField': '_id',
67 'as': 'document'
68 }
69 },
70 {
71 '$unwind': '$document'
72 },
73 ]
74 if filter:
75 match = [{
76 '$match': {
77 '$or': [
78 {f'document.attributes.{key}':
79 {'$in': values if isinstance(filter, list) else [values]}
80 for key, values in filter.items()}
81 ]
82 }
83 }]
84 else:
85 match = []
86 project = [{
87 '$project': {
88 'document_id': 1,
89 'text': 1,
90 'embedding': 1,
91 'source': '$document.source',
92 'attributes': '$document.attributes',
93 'distance': {
94 '$sqrt': {
95 '$sum': {
96 '$map': {
97 'input': {
98 '$range': [0, len(obj)],
99 },
100 'as': 'i',
101 'in': {
102 '$pow': [
103 {'$subtract': [
104 {'$arrayElemAt': ['$embedding', '$$i']},
105 {'$arrayElemAt': [{'$literal': obj}, '$$i']}
106 ]},
107 2
108 ]
109 }
110 }
111 }
112 }
113 },
114 }
115 }]
116 sort = [
117 {
118 '$sort': {'distance': 1}
119 },
120 {
121 '$limit': top
122 }
123 ]
124 subset = [
125 {
126 '$match': {
127 'distance': {'$lte': float(max_distance)}
128 }
129 }
130 ] if max_distance is not None else []
131 pipeline = lookup + match + project + sort + subset
132 # Execute the aggregation pipeline
133 results = list(self._chunks(name).aggregate(pipeline))[0:top]
134 return results
135
136 def delete(self, name, obj=None, filter=None, **kwargs):
137 # Clear all stored documents and chunks
138 filter = filter or {}
139 if isinstance(obj, dict) and 'id' in obj:
140 filter.update({'_id': ObjectId(obj.get('id'))})
141 elif isinstance(obj, str):
142 filter.update({'source': obj})
143 elif isinstance(obj, (int, float)):
144 filter.update({'_id': ObjectId(str(obj))})
145 elif obj is not None:
146 raise ValueError("Object must be a dict with 'id', or a string matching source")
147 doc_ids = self._documents(name).find(filter, {'_id': 1})
148 self._documents(name).delete_many(filter)
149 self._chunks(name).delete_many({
150 'document_id': {'$in': [doc['_id'] for doc in doc_ids]}
151 })
152
[docs]
153 def attributes(self, name, key=None):
154 """
155 Get the attributes of the vector store.
156
157 Args:
158 key (str, optional): If provided, returns attributes for this key only.
159
160 Returns:
161 dict: A dictionary of attributes for the vector store, where each value a dictionary of value counts.
162 """
163 # write in a way that is compatible with MongoDB as an agggregation pipeline
164 key_filter = key
165 pipeline = [
166 {
167 "$project": {
168 "keyValuePairs": {"$objectToArray": "$attributes"}
169 # Convert attributes to an array of key-value pairs
170 }
171 },
172 {
173 "$unwind": "$keyValuePairs" # Deconstruct the array to output a document for each key-value pair
174 },
175 {
176 "$unwind": "$keyValuePairs.v" # Deconstruct the array of values for the 'tags' key
177 },
178 {
179 "$group": {
180 "_id": {
181 "key": "$keyValuePairs.k", # Group by the key
182 "value": "$keyValuePairs.v" # Group by the value
183 },
184 "count": {"$sum": 1} # Count occurrences
185 }
186 },
187 {
188 "$project": {
189 "key": "$_id.key", # Restructure the output
190 "value": "$_id.value",
191 "count": 1,
192 "_id": 0 # Exclude the default _id field
193 }
194 },
195 {
196 "$sort": {
197 "key": 1, # Sort by key
198 "value": 1 # Sort by value
199 }
200 }
201 ]
202 data = self._documents(name).aggregate(pipeline)
203 results = {}
204 for item in data:
205 key, value, count = item.get('key'), item.get('value'), item.get('count')
206 if key_filter and key != key_filter:
207 continue
208 counter = results.setdefault(key, Counter())
209 counter[value] += count
210 results = {key: dict(counts) for key, counts in results.items()}
211 return results