1import hashlib
2import json
3from itertools import product
4from uuid import uuid4
5
6import pandas as pd
7
8from omegaml.documents import make_QueryCache
9from omegaml.mdataframe import MDataFrame, MSeries
10from omegaml.store import qops
11from omegaml.store.filtered import FilteredCollection
12from omegaml.util import make_tuple, extend_instance
13
14
15class ApplyMixin(object):
16 """
17 Implements the apply() mixin supporting arbitrary functions to build aggregation pipelines
18
19 Note that .apply() does not execute immediately. Instead it builds an aggregation pipeline
20 that is executed on MDataFrame.value. Note that .apply() calls cannot be cascaded yet, i.e.
21 a later .apply() will override a previous.apply().
22
23 See ApplyContext for usage examples.
24 """
25
26 def __init__(self, *args, **kwargs):
27 super(ApplyMixin, self).__init__(*args, **kwargs)
28 self._init_mixin(*args, **kwargs)
29
30 def _init_mixin(self, *args, **kwargs):
31 self.apply_fn = kwargs.get('apply_fn', None)
32 # set to True if the pipeline is a facet operation
33 self.is_from_facet = kwargs.get('is_from_facet', False)
34 # index columns
35 self.index_columns = kwargs.get('index_columns', [])
36 # db alias
37 self._db_alias = kwargs.get('db_alias', self._ensure_db_connection())
38 # cache used on persist()
39 self.cache = kwargs.get('cache', ApplyCache(self._db_alias))
40
41 def _ensure_db_connection(self):
42 # must import _dbs, _connections locally to ensure mongoshim has been applied
43 from mongoengine.connection import _dbs, _connections
44
45 seek_db = self.collection.database
46 for alias, db in _dbs.items():
47 if db is seek_db:
48 self._db_alias = alias
49 break
50 else:
51 # fake connection register
52 alias = self._db_alias = 'omega-{}'.format(uuid4().hex)
53 _connections[alias] = seek_db.client
54 _dbs[alias] = seek_db
55 return self._db_alias
56
57 def nocache(self):
58 self.cache = None
59 return self
60
61 def reset_cache(self, full=False):
62 """
63 Reset the apply cache
64
65 :param full: if True will reset all caches for the collection, if False will only remove
66 the cache for the specific .apply operations
67 :return:
68 """
69 QueryCache = make_QueryCache(db_alias=self._db_alias)
70 if full:
71 QueryCache.objects.filter(value__collection=self.collection.name).delete()
72 else:
73 pipeline = self._build_pipeline()
74 key = self._make_cache_key(self.collection, pipeline)
75 QueryCache.objects.filter(key=key).delete()
76 return self
77
78 def _make_cache_key(self, collection, pipeline):
79 # remove random output value
80 if '$out' in pipeline[-1] and pipeline[-1]['$out'].startswith('cache'):
81 pipeline = list(pipeline)[:-1]
82 spipeline = json.dumps(pipeline, sort_keys=True)
83 data = '{}_{}'.format(collection.name, spipeline).encode('utf-8')
84 # SEC: CWE-916
85 # - status: wontfix
86 # - reason: hashcode is used purely for name resolution, not a security function
87 key = hashlib.md5(data).hexdigest()
88 return key
89
90 def _getcopy_kwargs(self, **kwargs):
91 kwargs = super(ApplyMixin, self)._getcopy_kwargs(**kwargs)
92 kwargs.update(is_from_facet=self.is_from_facet,
93 index_columns=self.index_columns,
94 cache=self.cache,
95 apply_fn=self.apply_fn)
96 return kwargs
97
98 def noapply(self):
99 self.apply_fn = None
100 return self
101
102 def apply(self, fn, inplace=False, preparefn=None):
103 if inplace:
104 obj = self
105 else:
106 kwargs = self._getcopy_kwargs()
107 kwargs.update(preparefn=preparefn)
108 if isinstance(self, MSeries):
109 obj = MSeries(self.collection, **kwargs)
110 else:
111 obj = MDataFrame(self.collection, **kwargs)
112 obj.apply_fn = fn
113 return obj
114
115 def persist(self):
116 """
117 Execute and store results in cache
118
119 Any pipeline of the same operations, in the same order, on
120 the same collection will return the same result.
121 """
122 # generate a cache key
123 pipeline = self._build_pipeline()
124 key = self._make_cache_key(self.collection, pipeline)
125 outname = 'cache_{}'.format(uuid4().hex)
126 value = {
127 'collection': self.collection.name,
128 'result': outname,
129 }
130 # do usual processing, store result
131 # -- note we pass pipeline to avoid processing iterators twice
132 pipeline.append({
133 '$out': outname,
134 })
135 cursor = self._get_cursor(pipeline=pipeline, use_cache=False)
136 # consume cursor to store output (via $out)
137 for v in cursor:
138 pass
139 # set cache
140 self.cache.set(key, value)
141 return key
142
143 def set_index(self, columns):
144 self.index_columns = make_tuple(columns)
145 return self
146
147 def inspect(self, explain=False, *args, **kwargs):
148 if self.apply_fn:
149 details = {
150 'pipeline': self._build_pipeline()
151 }
152 if explain:
153 details.update(self.__dict__)
154 return details
155 return super(ApplyMixin, self).inspect(*args, explain=explain, **kwargs)
156
157 def _execute(self):
158 ctx = ApplyContext(self, columns=self.columns)
159 try:
160 result = self.apply_fn(ctx)
161 except Exception as e:
162 msg = [repr(stage) for stage in ctx.stages] + [repr(e)]
163 raise RuntimeError(msg)
164 if result is None or isinstance(result, ApplyContext):
165 result = result or ctx
166 self.index_columns = self.index_columns or result.index_columns
167 return result
168 elif isinstance(result, list):
169 return result
170 elif isinstance(result, dict):
171 # expect a mapping of col=ApplyContext each with its own list of stages
172 # -- build a combined context by adding each expression
173 # this ensures any multi-stage projections are carried forward
174 facets = {}
175 for col, expr in result.items():
176 if isinstance(expr, ApplyContext):
177 facets[col] = list(expr)
178 project = {
179 '$project': {
180 col: '$' + expr.columns[0]
181 },
182 }
183 facets[col].append(project)
184 else:
185 facets[col] = expr
186 facet = {
187 '$facet': facets
188 }
189 self.is_from_facet = True
190 return [facet]
191 raise ValueError('Cannot build pipeline from apply result of type {}'.format(type(result)))
192
193 def _build_pipeline(self):
194 pipeline = []
195 stages = self._execute()
196 pipeline.extend(stages)
197 self._amend_pipeline(pipeline)
198 return pipeline
199
200 def _amend_pipeline(self, pipeline):
201 """ amend pipeline with default ops on coll.aggregate() calls """
202 if self.sort_order:
203 sort = qops.SORT(**dict(qops.make_sortkey(self.sort_order)))
204 pipeline.append(sort)
205 return pipeline
206
207 def _get_cached_cursor(self, pipeline=None, use_cache=True):
208 pipeline = pipeline or self._build_pipeline()
209 if use_cache and self.cache:
210 key = self._make_cache_key(self.collection, pipeline)
211 entry = self.cache.get(key)
212 if entry is not None:
213 # read result
214 outname = entry.value['result']
215 return self.collection.database[outname].find()
216
217 def _get_cursor(self, pipeline=None, use_cache=True):
218 # for apply functions, call the apply function, expecting a pipeline in return
219 if self.apply_fn:
220 pipeline = pipeline or self._build_pipeline()
221 cursor = self._get_cached_cursor(pipeline=pipeline, use_cache=use_cache)
222 if cursor is None:
223 filter_criteria = self._get_filter_criteria()
224 cursor = FilteredCollection(self.collection).aggregate(pipeline, filter=filter_criteria,
225 allowDiskUse=True)
226 else:
227 cursor = super(ApplyMixin, self)._get_cursor()
228 return cursor
229
230 def _get_dataframe_from_cursor(self, cursor):
231 df = super(ApplyMixin, self)._get_dataframe_from_cursor(cursor)
232 if self.is_from_facet:
233 # if this was from a facet pipeline (i.e. multi-column mapping), combine
234 # $facet returns one document for each stage.
235 frames = []
236 for col in df.columns:
237 coldf = pd.DataFrame(df[col].iloc[0]).set_index('_id')
238 frames.append(coldf)
239 df = pd.concat(frames, axis=1).reset_index()
240 df = self._restore_dataframe_proper(df)
241 # TODO write a unit test for this condition
242 if self.index_columns and all(col in df.columns for col in self.index_columns):
243 df.set_index(list(self.index_columns), inplace=True)
244 return df
245
246
[docs]
247class ApplyContext(object):
248 """
249 Enable apply functions
250
251 .apply(fn) will call fn(ctx) where ctx is an ApplyContext.
252 The context supports methods to apply functions in a Pandas-style apply manner. ApplyContext is extensible
253 by adding an extension class to defaults.OMEGA_MDF_APPLY_MIXINS.
254
255 Note that unlike a Pandas DataFrame, ApplyContext does not itself contain any data.
256 Rather it is part of an expression tree, i.e. the aggregation pipeline. Thus any
257 expressions applied are translated into operations on the expression tree. The expression
258 tree is evaluated on MDataFrame.value, at which point the ApplyContext nor the function
259 that created it are active.
260
261 Examples::
262
263 mdf.apply(lambda v: v * 5 ) => multiply every column in dataframe
264 mdf.apply(lambda v: v['foo'].dt.week) => get week of date for column foo
265 mdf.apply(lambda v: dict(a=v['foo'].dt.week,
266 b=v['bar'] * 5) => run multiple pipelines and get results
267
268 The callable passed to apply can be any function. It can either return None,
269 the context passed in or a list of pipeline stages.
270
271 # apply any of the below functions
272 mdf.apply(customfn)
273
274 # same as lambda v: v.dt.week
275 def customfn(ctx):
276 return ctx.dt.week
277
278 # simple pipeline
279 def customfn(ctx):
280 ctx.project(x={'$multiply: ['$x', 5]})
281 ctx.project(y={'$divide: ['$x', 2]})
282
283 # complex pipeline
284 def customfn(ctx):
285 return [
286 { '$match': ... },
287 { '$project': ... },
288 ]
289 """
290
291 def __init__(self, caller, columns=None, index=None):
292 self.caller = caller
293 self.columns = columns
294 self.index_columns = index or []
295 self.computed = []
296 self.stages = []
297 self.expressions = []
298 self._apply_mixins()
299
300 def _apply_mixins(self):
301 """
302 apply mixins in defaults.OMEGA_MDF_APPLY_MIXINS
303 """
304 from omegaml import settings
305 defaults = settings()
306 for mixin, applyto in defaults.OMEGA_MDF_APPLY_MIXINS:
307 if any(v in self.caller._applyto for v in applyto.split(',')):
308 extend_instance(self, mixin)
309
310 def __iter__(self):
311 # return pipeline stages
312 for stage in self.stages:
313 if isinstance(stage, ApplyContext):
314 for sub_stage in stage:
315 yield sub_stage
316 else:
317 yield stage
318
319 def __getitem__(self, sel):
320 """
321 return a stage subset on a column
322 """
323 subctx = ApplyContext(self.caller, columns=make_tuple(sel), index=self.index_columns)
324 self.add(subctx)
325 return subctx
326
327 def __setitem__(self, sel, val):
328 """
329 add a projection to a sub context
330
331 ctx['col'] = value-expression
332 """
333 mapping = {
334 col: v
335 for (col, v) in zip(make_tuple(sel), make_tuple(val))}
336 self.project(mapping)
337
338 def __repr__(self):
339 return 'ApplyContext(stages={}, expressions={})'.format(self.stages, self.expressions)
340
341 def add(self, stage):
342 """
343 Add a processing stage to the pipeline
344
345 see https://docs.mongodb.com/manual/meta/aggregation-quick-reference/
346 """
347 self.stages.append(stage)
348 return self
349
350 def project_keeper_columns(self):
351 # keep index, computed
352 index = {
353 col: '$' + col
354 for col in self.index_columns}
355 computed = {
356 col: '$' + col
357 for col in self.computed}
358 keep = {}
359 keep.update(index)
360 keep.update(computed)
361 project = self.project(keep, keep=True)
362 return project
363
364 def _getLastStageKind(self, kind):
365 # see if there is already an open projection stage
366 for stage in self.stages[::-1]:
367 if kind in stage:
368 return stage
369
370 def _getProjection(self, append=False):
371 stage = self._getLastStageKind('$project')
372 if stage is None or append:
373 stage = {
374 '$project': {
375 '_id': 1,
376 }
377 }
378 self.stages.append(stage)
379 return stage
380
381 def _getGroupBy(self, by=None, append=False):
382 stage = self._getLastStageKind('$group')
383 if stage and stage['$group']['_id'] != by and by != '$$last':
384 # if a different groupby criteria, add a new one
385 stage = None
386 if stage is None and by == '$$last':
387 by = None
388 if stage is None or append:
389 stage = {
390 '$group': {
391 '_id': by,
392 }
393 }
394 self.stages.append(stage)
395 return stage
396
397 def groupby(self, by, expr=None, append=None, **kwargs):
398 """
399 add a groupby accumulation using $group
400
401 :param by: the groupby columns, if provided as a list will be transformed
402 :param expr:
403 :param append:
404 :param kwargs:
405 :return:
406
407 """
408 by = make_tuple(by)
409 self.index_columns = self.index_columns + list(by)
410 # define groupby
411 by = {col: '$' + col for col in by}
412 stage = self._getGroupBy(by)
413 groupby = stage['$group']
414 # add acccumulators
415 expr = expr or {
416 col: colExpr
417 for col, colExpr in kwargs.items()}
418 groupby.update(expr)
419 # add a projection to extract groupby values
420 extractId = {
421 col: '$_id.' + col
422 for col in by}
423 # add a projection to keep accumulator columns
424 keepCols = {
425 col: 1
426 for col in expr}
427 keepCols.update(extractId)
428 self.project(keepCols, append=True)
429 # sort by groupby keys
430 self.add({
431 '$sort': {
432 col: 1
433 for col in by}
434 })
435 return self
436
437 def project(self, expr=None, append=False, keep=False, **kwargs):
438 """
439 add a projection using $project
440
441 :param expr: the column-operator mapping
442 :param append: if True add a $project stage, otherwise add to existing
443 :param kwargs: if expr is None, the column-operator mapping as kwargs
444 :return: ApplyContext
445
446 """
447 # get last $project stage in pipeline
448 stage = self._getProjection(append=append)
449 expr = expr or kwargs
450 self.expressions.append(expr)
451 for k, v in expr.items():
452 # only append to stage if no other column projection was there
453 project = stage.get('$project')
454 if k not in project:
455 project.update({
456 k: v
457 })
458 elif not keep:
459 # if a column is already projected, add a new projection stage
460 stage = self._getProjection(append=True)
461 project = stage.get('$project')
462 project.update({
463 k: v
464 })
465 return self
466
467
[docs]
468class ApplyArithmetics(object):
469 """
470 Math operators for ApplyContext
471
472 * :code:`__mul__` (*)
473 * :code:`__add__` (+)
474 * :code:`__sub__` (-)
475 * :code:`__div__` (/)
476 * :code:`__floordiv__` (//)
477 * :code:`__mod__` (%)
478 * :code:`__pow__` (pow)
479 * :code:`__ceil__` (ceil)
480 * :code:`__floor__` (floor)
481 * :code:`__trunc__` (trunc)
482 * :code:`__abs__` (abs)
483 * :code:`sqrt` (math.sqrt)
484
485 """
486
487 def __arithmop__(op, wrap_op=None):
488 """
489 return a pipeline $project stage math operator as
490 { col:
491 { '$operator': [ values, ...] }
492 ...
493 }
494
495 If wrap_op is specified, will wrap the $operator clause as
496 { col:
497 { '$wrap_op': { '$operator': [ values, ...] } }0
498 ...
499 }
500 """
501
502 def inner(self, other):
503 terms = []
504 for term in make_tuple(other):
505 if isinstance(term, str):
506 term = '$' + term
507 terms.append(term)
508
509 def wrap(expr):
510 if wrap_op is not None:
511 expr = {
512 wrap_op: expr
513 }
514 return expr
515
516 mapping = {
517 col: wrap({
518 op: ['$' + col] + terms,
519 }) for col in self.columns}
520 keepCols = {
521 col: '$' + col
522 for col in self.index_columns}
523 mapping.update(keepCols)
524 self.project(mapping)
525 return self
526
527 return inner
528
529 #: multiply
530 __mul__ = __arithmop__('$multiply')
531 #: add
532 __add__ = __arithmop__('$add')
533 #: subtract
534 __sub__ = __arithmop__('$subtract')
535 #: divide
536 __div__ = __arithmop__('$divide')
537 __truediv__ = __arithmop__('$divide')
538 #: divide integer
539 __floordiv__ = __arithmop__('$divide', wrap_op='$floor')
540 #: modulo (%)
541 __mod__ = __arithmop__('$mod')
542 #: pow
543 __pow_ = __arithmop__('$pow')
544 #: ceil
545 __ceil__ = __arithmop__('$ceil')
546 #: floor
547 __floor__ = __arithmop__('$floor')
548 #: truncate
549 __trunc__ = __arithmop__('$trunc')
550 #: absolute
551 __abs__ = __arithmop__('$abs')
552 #: square root
553 sqrt = __arithmop__('sqrt')
554
555
[docs]
556class ApplyDateTime(object):
557 """
558 Datetime operators for ApplyContext
559 """
560
561 @property
562 def dt(self):
563 return self
564
565 def __dtop__(op):
566 """
567 return a datetime $project operator as
568 { col:
569 { '$operator': '$col} }
570 ...
571 }
572 """
573
574 def inner(self, columns=None):
575 columns = make_tuple(columns or self.columns)
576 mapping = {
577 col: {
578 op: '$' + col,
579 }
580 for col in columns}
581 self.project(mapping)
582 return self
583
584 inner.__doc__ = op.replace('$', '')
585 return inner
586
587 # mongodb mappings
588 _year = __dtop__('$year')
589 _month = __dtop__('$month')
590 _week = __dtop__('$week')
591 _dayOfWeek = __dtop__('$dayOfWeek')
592 _dayOfMonth = __dtop__('$dayOfMonth')
593 _dayOfYear = __dtop__('$dayOfYear')
594 _hour = __dtop__('$hour')
595 _minute = __dtop__('$minute')
596 _second = __dtop__('$second')
597 _millisecond = __dtop__('$millisecond')
598 _isoDayOfWeek = __dtop__('$isoDayOfWeek')
599 _isoWeek = __dtop__('$isoWeek')
600 _isoWeekYear = __dtop__('$isoWeekYear')
601
602 # .dt accessor convenience similar to pandas.dt
603 # see https://pandas.pydata.org/pandas-docs/stable/api.html#datetimelike-properties
604 year = property(_year)
605 month = property(_month)
606 day = property(_dayOfMonth)
607 hour = property(_hour)
608 minute = property(_minute)
609 second = property(_second)
610 millisecond = property(_millisecond)
611 week = property(_isoWeek)
612 dayofyear = property(_dayOfYear)
613 dayofweek = property(_dayOfWeek)
614
615
[docs]
616class ApplyString(object):
617 """
618 String operators
619 """
620
621 @property
622 def str(self):
623 return self
624
625 def __strexpr__(op, unwind=False, base=None, max_terms=None):
626 """
627 return a pipeline $project string operator as
628 { col:
629 { '$operator': [ values, ...] }
630 ...
631 }
632 """
633
634 def inner(self, other, *args):
635 # get all values passed and build terms from them
636 values = list(make_tuple(other) + args)
637 terms = []
638 for term in values:
639 if isinstance(term, str):
640 # if the term is a column name, add as a column name
641 if term in self.columns:
642 term = '$' + term
643 # allow to specify values explicitely by $$<value> => <value>
644 term = term.replace('$$', '')
645 terms.append(term)
646 # limit number of terms if requested
647 if max_terms:
648 terms = terms[:max_terms]
649 # add projection of output columns to operator
650 mapping = {
651 col: {
652 op: terms if base is None else ['$' + col] + terms,
653 } for col in self.columns}
654 self.project(mapping)
655 # unwind all columns if requested
656 if unwind:
657 exprs = [{'$unwind': {
658 'path': '$' + col
659 }} for col in self.columns]
660 self.stages.extend(exprs)
661 return self
662
663 inner.__doc__ = op.replace('$', '')
664 return inner
665
666 def __strunary__(op, unwind=False):
667 """
668 return a datetime $project operator as
669 { col:
670 { '$operator': '$col} }
671 ...
672 }
673 """
674
675 def inner(self, columns=None):
676 columns = make_tuple(columns or self.columns)
677 mapping = {
678 col: {
679 op: '$' + col,
680 }
681 for col in columns}
682 self.project(mapping)
683 if unwind:
684 self.stages.append({
685 '$unwind': {
686 ''
687 }
688 })
689 return self
690
691 inner.__doc__ = op.replace('$', '')
692
693 return inner
694
695 def isequal(self, other):
696 self.strcasecmp(other)
697 # strcasecmp returns 0 for equality, 1 and -1 for greater/less than
698 # https://docs.mongodb.com/manual/reference/operator/aggregation/strcasecmp/
699 mapping = {
700 col: {
701 '$cond': {
702 'if': {'$eq': ['$' + col, 0]},
703 'then': True,
704 'else': False,
705 }
706 }
707 for col in self.columns}
708 self.project(mapping)
709
710 concat = __strexpr__('$concat', base=True)
711 split = __strexpr__('$split', unwind=True, base=True, max_terms=2)
712 usplit = __strexpr__('$split', unwind=False, base=True, max_terms=2)
713 upper = __strunary__('$toUpper')
714 lower = __strunary__('$toLower')
715 substr = __strexpr__('$substr', base=True)
716 strcasecmp = __strexpr__('$strcasecmp', base=True)
717 len = __strunary__('$strLenBytes')
718 index = __strexpr__('$indexOfBytes', base=True)
719
720
[docs]
721class ApplyAccumulators(object):
722 def agg(self, map=None, **kwargs):
723 stage = self._getGroupBy(by='$$last')
724 specs = map or kwargs
725 for col, colExpr in specs.items():
726 if isinstance(colExpr, dict):
727 # specify an arbitrary expression
728 groupby = stage['$group']
729 groupby[col] = colExpr
730 elif isinstance(colExpr, str):
731 # specify some known operator
732 if hasattr(self, colExpr):
733 method = getattr(self, colExpr)
734 method(col)
735 else:
736 raise SyntaxError('{} is not known'.format(colExpr))
737 elif isinstance(colExpr, (tuple, list)):
738 # specify a list of some known operators
739 for statExpr in colExpr:
740 if hasattr(self, statExpr):
741 method = getattr(self, statExpr)
742 method(col)
743 else:
744 raise SyntaxError('{} is not known'.format(statExpr))
745 elif callable(colExpr):
746 # specify a callable that returns an expression
747 groupby = stage['$group']
748 groupby[col] = colExpr(col)
749 else:
750 SyntaxError('{} on column {} is unknown or invalid'.format(colExpr, col))
751 return self
752
753 def __statop__(op, opname=None):
754 opname = opname or op.replace('$', '')
755
756 def inner(self, columns=None):
757 columns = make_tuple(columns or self.columns)
758 stage = self._getGroupBy(by='$$last')
759 groupby = stage['$group']
760 groupby.update({
761 '{}_{}'.format(col, opname): {
762 op: '$' + col
763 } for col in columns
764 })
765 self.computed.extend(groupby.keys())
766 self.project_keeper_columns()
767 return self
768
769 return inner
770
771 sum = __statop__('$sum')
772 avg = __statop__('$avg')
773 mean = __statop__('$avg')
774 min = __statop__('$min')
775 max = __statop__('$max')
776 std = __statop__('$stdDevSamp', 'std')
777
778
779class ApplyCache(object):
780 """
781 A Cache that works on collections and pipelines
782 """
783
784 def __init__(self, db_alias):
785 self._db_alias = db_alias
786
787 def set(self, key, value):
788 # https://stackoverflow.com/a/22003440/890242
789 QueryCache = make_QueryCache(self._db_alias)
790 QueryCache.objects(key=key).update_one(set__key="{}".format(key),
791 set__value=value, upsert=True)
792
793 def get(self, key):
794 QueryCache = make_QueryCache(self._db_alias)
795 try:
796 result = QueryCache.objects.get(key=key)
797 except:
798 result = None
799 return result
800
801
802class ApplyStatistics(object):
803 def quantile(self, q=.5):
804 def preparefn(val):
805 return val.pivot(columns='var', index='percentile', values='value')
806
807 return self.apply(self._percentile(q), preparefn=preparefn)
808
809 def cov(self):
810 def preparefn(val):
811 val = val.pivot(columns='y', index='x', values='cov')
812 val.index.name = None
813 val.columns.name = None
814 return val
815
816 return self.apply(self._covariance, preparefn=preparefn)
817
818 def corr(self):
819 def preparefn(val):
820 val = val.pivot(columns='y', index='x', values='rho')
821 val.index.name = None
822 val.columns.name = None
823 return val
824
825 return self.apply(self._pearson, preparefn=preparefn)
826
827 def _covariance(self, ctx):
828 # this works
829 # source http://ci.columbia.edu/ci/premba_test/c0331/s7/s7_5.html
830 facets = {}
831 means = {}
832 unwinds = []
833 count = len(ctx.caller.noapply()) - 1
834 for x, y in product(ctx.columns, ctx.columns):
835 xcol = '$' + x
836 ycol = '$' + y
837 # only calculate the same column's mean once
838 if xcol not in means:
839 means[xcol] = ctx.caller[x].noapply().mean().values[0, 0]
840 if ycol not in means:
841 means[ycol] = ctx.caller[y].noapply().mean().values[0, 0]
842 sumands = {
843 xcol: {
844 '$subtract': [xcol, means[xcol]]
845 },
846 ycol: {
847 '$subtract': [ycol, means[ycol]]
848 }
849 }
850 multiply = {
851 '$multiply': [sumands[xcol], sumands[ycol]]
852 }
853 agg = {
854 '$group': {
855 '_id': None,
856 'value': {
857 '$sum': multiply
858 }
859 }
860 }
861 project = {
862 '$project': {
863 'cov': {
864 '$divide': ['$value', count],
865 },
866 'x': x,
867 'y': y,
868 }
869 }
870 pipeline = [agg, project]
871 outcol = '{}_{}'.format(x, y)
872 facets[outcol] = pipeline
873 unwinds.append({'$unwind': '$' + outcol})
874 facet = {
875 '$facet': facets,
876 }
877 expand = [{
878 '$project': {
879 'value': {
880 '$objectToArray': '$$CURRENT',
881 }
882 }
883 }, {
884 '$unwind': '$value'
885 }, {
886 '$replaceRoot': {
887 'newRoot': '$value.v'
888 }
889 }]
890 return [facet, *unwinds, *expand]
891
892 def _pearson(self, ctx):
893 # this works
894 # source http://ilearnasigoalong.blogspot.ch/2017/10/calculating-correlation-inside-mongodb.html
895 facets = {}
896 unwinds = []
897 for x, y in product(ctx.columns, ctx.columns):
898 xcol = '$' + x
899 ycol = '$' + y
900 sumcolumns = {'$group': {'_id': None,
901 'count': {'$sum': 1},
902 'sumx': {'$sum': xcol},
903 'sumy': {'$sum': ycol},
904 'sumxsquared': {'$sum': {'$multiply': [xcol, xcol]}},
905 'sumysquared': {'$sum': {'$multiply': [ycol, ycol]}},
906 'sumxy': {'$sum': {'$multiply': [xcol, ycol]}}
907 }}
908
909 multiply_sumx_sumy = {'$multiply': ["$sumx", "$sumy"]}
910 multiply_sumxy_count = {'$multiply': ["$sumxy", "$count"]}
911 partone = {'$subtract': [multiply_sumxy_count, multiply_sumx_sumy]}
912
913 multiply_sumxsquared_count = {'$multiply': ["$sumxsquared", "$count"]}
914 sumx_squared = {'$multiply': ["$sumx", "$sumx"]}
915 subparttwo = {'$subtract': [multiply_sumxsquared_count, sumx_squared]}
916
917 multiply_sumysquared_count = {'$multiply': ["$sumysquared", "$count"]}
918 sumy_squared = {'$multiply': ["$sumy", "$sumy"]}
919 subpartthree = {'$subtract': [multiply_sumysquared_count, sumy_squared]}
920
921 parttwo = {'$sqrt': {'$multiply': [subparttwo, subpartthree]}}
922
923 rho = {'$project': {
924 'rho': {
925 '$divide': [partone, parttwo]
926 },
927 'x': x,
928 'y': y
929 }}
930 pipeline = [sumcolumns, rho]
931 outcol = '{}_{}'.format(x, y)
932 facets[outcol] = pipeline
933 unwinds.append({'$unwind': '$' + outcol})
934 facet = {
935 '$facet': facets,
936 }
937 expand = [{
938 '$project': {
939 'value': {
940 '$objectToArray': '$$CURRENT',
941 }
942 }
943 }, {
944 '$unwind': '$value'
945 }, {
946 '$replaceRoot': {
947 'newRoot': '$value.v'
948 }
949 }]
950 return [facet, *unwinds, *expand]
951
952 def _percentile(self, pctls=None):
953 """
954 calculate percentiles for all columns
955 """
956 pctls = pctls or [.25, .5, .75]
957 if not isinstance(pctls, (list, tuple)):
958 pctls = [pctls]
959
960 def calc(col, p, outcol):
961 # sort values
962 sort = {
963 '$sort': {
964 col: 1,
965 }
966 }
967 # group/push to get an array of all values
968 group = {
969 '$group': {
970 '_id': col,
971 'values': {
972 '$push': "$" + col
973 },
974 }
975 }
976 # find value at requested percentile
977 perc = {
978 '$arrayElemAt': [
979 '$values', {
980 '$floor': {
981 '$multiply': [{
982 '$size': '$values'
983 }, p]
984 }}
985 ]
986 }
987 # map percentile value to output column
988 project = {
989 '$project': {
990 'var': col,
991 'percentile': 'p{}'.format(p),
992 'value': perc,
993 }
994 }
995 return [sort, group, project]
996
997 def inner(ctx):
998 # for each column and requested percentile, build a pipeline
999 # all pipelines will be combined into a $facet stage to
1000 # calculate every column/percentile tuple in parallel
1001 facets = {}
1002 unwind = []
1003 # for each column build a pipeline to calculate the percentiles
1004 for col in ctx.columns:
1005 for p in pctls:
1006 # e.g. outcol for perc .25 of column abc => abcp25
1007 outcol = '{}_p{}'.format(col, p).replace('0.', '')
1008 facets[outcol] = calc(col, p, outcol)
1009 unwind.append({'$unwind': '$' + outcol})
1010 # process per-column pipelines in parallel, resulting in one
1011 # document for each variable + percentile combination
1012 facet = {
1013 '$facet': facets
1014 }
1015 # expand single document into one document per variable + percentile combo
1016 # the resulting set of documents contains var/percentile/value
1017 expand = [{
1018 '$project': {
1019 'value': {
1020 '$objectToArray': '$$CURRENT',
1021 }
1022 }
1023 }, {
1024 '$unwind': '$value'
1025 }, {
1026 '$replaceRoot': {
1027 'newRoot': '$value.v'
1028 }
1029 }]
1030 pipeline = [facet, *unwind, *expand]
1031 return pipeline
1032
1033 return inner