1import hashlib
   2import json
   3import pandas as pd
   4from itertools import product
   5from uuid import uuid4
   6
   7from omegaml.documents import make_QueryCache
   8from omegaml.mdataframe import MDataFrame, MSeries
   9from omegaml.store import qops
  10from omegaml.store.filtered import FilteredCollection
  11from omegaml.util import make_tuple, extend_instance
  12
  13
  14class ApplyMixin(object):
  15    """
  16    Implements the apply() mixin supporting arbitrary functions to build aggregation pipelines
  17
  18    Note that .apply() does not execute immediately. Instead it builds an aggregation pipeline
  19    that is executed on MDataFrame.value. Note that .apply() calls cannot be cascaded yet, i.e.
  20    a later .apply() will override a previous.apply().
  21
  22    See ApplyContext for usage examples.
  23    """
  24
  25    def __init__(self, *args, **kwargs):
  26        super(ApplyMixin, self).__init__(*args, **kwargs)
  27        self._init_mixin(*args, **kwargs)
  28
  29    def _init_mixin(self, *args, **kwargs):
  30        self.apply_fn = kwargs.get('apply_fn', None)
  31        # set to True if the pipeline is a facet operation
  32        self.is_from_facet = kwargs.get('is_from_facet', False)
  33        # index columns
  34        self.index_columns = kwargs.get('index_columns', [])
  35        # db alias
  36        self._db_alias = kwargs.get('db_alias', self._ensure_db_connection())
  37        # cache used on persist()
  38        self.cache = kwargs.get('cache', ApplyCache(self._db_alias))
  39
  40    def _ensure_db_connection(self):
  41        from mongoengine.connection import _dbs, _connections
  42
  43        seek_db = self.collection.database
  44        for alias, db in _dbs.items():
  45            if db is seek_db:
  46                self._db_alias = alias
  47                break
  48        else:
  49            # fake connection register
  50            alias = self._db_alias = 'omega-{}'.format(uuid4().hex)
  51            _connections[alias] = seek_db.client
  52            _dbs[alias] = seek_db
  53        return self._db_alias
  54
  55    def nocache(self):
  56        self.cache = None
  57        return self
  58
  59    def reset_cache(self, full=False):
  60        """
  61        Reset the apply cache
  62
  63        :param full: if True will reset all caches for the collection, if False will only remove
  64            the cache for the specific .apply operations
  65        :return:
  66        """
  67        QueryCache = make_QueryCache(db_alias=self._db_alias)
  68        if full:
  69            QueryCache.objects.filter(value__collection=self.collection.name).delete()
  70        else:
  71            pipeline = self._build_pipeline()
  72            key = self._make_cache_key(self.collection, pipeline)
  73            QueryCache.objects.filter(key=key).delete()
  74        return self
  75
  76    def _make_cache_key(self, collection, pipeline):
  77        # remove random output value
  78        if '$out' in pipeline[-1] and pipeline[-1]['$out'].startswith('cache'):
  79            pipeline = list(pipeline)[:-1]
  80        spipeline = json.dumps(pipeline, sort_keys=True)
  81        data = '{}_{}'.format(collection.name, spipeline).encode('utf-8')
  82        # SEC: CWE-916
  83        # - status: wontfix
  84        # - reason: hashcode is used purely for name resolution, not a security function
  85        key = hashlib.md5(data).hexdigest()
  86        return key
  87
  88    def _getcopy_kwargs(self, **kwargs):
  89        kwargs = super(ApplyMixin, self)._getcopy_kwargs(**kwargs)
  90        kwargs.update(is_from_facet=self.is_from_facet,
  91                      index_columns=self.index_columns,
  92                      cache=self.cache,
  93                      apply_fn=self.apply_fn)
  94        return kwargs
  95
  96    def noapply(self):
  97        self.apply_fn = None
  98        return self
  99
 100    def apply(self, fn, inplace=False, preparefn=None):
 101        if inplace:
 102            obj = self
 103        else:
 104            kwargs = self._getcopy_kwargs()
 105            kwargs.update(preparefn=preparefn)
 106            if isinstance(self, MSeries):
 107                obj = MSeries(self.collection, **kwargs)
 108            else:
 109                obj = MDataFrame(self.collection, **kwargs)
 110        obj.apply_fn = fn
 111        return obj
 112
 113    def persist(self):
 114        """
 115        Execute and store results in cache
 116
 117        Any pipeline of the same operations, in the same order, on
 118        the same collection will return the same result.
 119        """
 120        # generate a cache key
 121        pipeline = self._build_pipeline()
 122        key = self._make_cache_key(self.collection, pipeline)
 123        outname = 'cache_{}'.format(uuid4().hex)
 124        value = {
 125            'collection': self.collection.name,
 126            'result': outname,
 127        }
 128        # do usual processing, store result
 129        # -- note we pass pipeline to avoid processing iterators twice
 130        pipeline.append({
 131            '$out': outname,
 132        })
 133        cursor = self._get_cursor(pipeline=pipeline, use_cache=False)
 134        # consume cursor to store output (via $out)
 135        for v in cursor:
 136            pass
 137        # set cache
 138        self.cache.set(key, value)
 139        return key
 140
 141    def set_index(self, columns):
 142        self.index_columns = make_tuple(columns)
 143        return self
 144
 145    def inspect(self, explain=False, *args, **kwargs):
 146        if self.apply_fn:
 147            details = {
 148                'pipeline': self._build_pipeline()
 149            }
 150            if explain:
 151                details.update(self.__dict__)
 152            return details
 153        return super(ApplyMixin, self).inspect(*args, explain=explain, **kwargs)
 154
 155    def _execute(self):
 156        ctx = ApplyContext(self, columns=self.columns)
 157        try:
 158            result = self.apply_fn(ctx)
 159        except Exception as e:
 160            msg = [repr(stage) for stage in ctx.stages] + [repr(e)]
 161            raise RuntimeError(msg)
 162        if result is None or isinstance(result, ApplyContext):
 163            result = result or ctx
 164            self.index_columns = self.index_columns or result.index_columns
 165            return result
 166        elif isinstance(result, list):
 167            return result
 168        elif isinstance(result, dict):
 169            # expect a mapping of col=ApplyContext each with its own list of stages
 170            # -- build a combined context by adding each expression
 171            #    this ensures any multi-stage projections are carried forward
 172            facets = {}
 173            for col, expr in result.items():
 174                if isinstance(expr, ApplyContext):
 175                    facets[col] = list(expr)
 176                    project = {
 177                        '$project': {
 178                            col: '$' + expr.columns[0]
 179                        },
 180                    }
 181                    facets[col].append(project)
 182                else:
 183                    facets[col] = expr
 184            facet = {
 185                '$facet': facets
 186            }
 187            self.is_from_facet = True
 188            return [facet]
 189        raise ValueError('Cannot build pipeline from apply result of type {}'.format(type(result)))
 190
 191    def _build_pipeline(self):
 192        pipeline = []
 193        stages = self._execute()
 194        pipeline.extend(stages)
 195        self._amend_pipeline(pipeline)
 196        return pipeline
 197
 198    def _amend_pipeline(self, pipeline):
 199        """ amend pipeline with default ops on coll.aggregate() calls """
 200        if self.sort_order:
 201            sort = qops.SORT(**dict(qops.make_sortkey(self.sort_order)))
 202            pipeline.append(sort)
 203        return pipeline
 204
 205    def _get_cached_cursor(self, pipeline=None, use_cache=True):
 206        pipeline = pipeline or self._build_pipeline()
 207        if use_cache and self.cache:
 208            key = self._make_cache_key(self.collection, pipeline)
 209            entry = self.cache.get(key)
 210            if entry is not None:
 211                # read result
 212                outname = entry.value['result']
 213                return self.collection.database[outname].find()
 214
 215    def _get_cursor(self, pipeline=None, use_cache=True):
 216        # for apply functions, call the apply function, expecting a pipeline in return
 217        if self.apply_fn:
 218            pipeline = pipeline or self._build_pipeline()
 219            cursor = self._get_cached_cursor(pipeline=pipeline, use_cache=use_cache)
 220            if cursor is None:
 221                filter_criteria = self._get_filter_criteria()
 222                cursor = FilteredCollection(self.collection).aggregate(pipeline, filter=filter_criteria, allowDiskUse=True)
 223        else:
 224            cursor = super(ApplyMixin, self)._get_cursor()
 225        return cursor
 226
 227    def _get_dataframe_from_cursor(self, cursor):
 228        df = super(ApplyMixin, self)._get_dataframe_from_cursor(cursor)
 229        if self.is_from_facet:
 230            # if this was from a facet pipeline (i.e. multi-column mapping), combine
 231            # $facet returns one document for each stage.
 232            frames = []
 233            for col in df.columns:
 234                coldf = pd.DataFrame(df[col].iloc[0]).set_index('_id')
 235                frames.append(coldf)
 236            df = pd.concat(frames, axis=1).reset_index()
 237            df = self._restore_dataframe_proper(df)
 238        # TODO write a unit test for this condition
 239        if self.index_columns and all(col in df.columns for col in self.index_columns):
 240            df.set_index(list(self.index_columns), inplace=True)
 241        return df
 242
 243
[docs]
 244class ApplyContext(object):
 245    """
 246    Enable apply functions
 247
 248    .apply(fn) will call fn(ctx) where ctx is an ApplyContext.
 249    The context supports methods to apply functions in a Pandas-style apply manner. ApplyContext is extensible
 250    by adding an extension class to defaults.OMEGA_MDF_APPLY_MIXINS.
 251
 252    Note that unlike a Pandas DataFrame, ApplyContext does not itself contain any data.
 253    Rather it is part of an expression tree, i.e. the aggregation pipeline. Thus any
 254    expressions applied are translated into operations on the expression tree. The expression
 255    tree is evaluated on MDataFrame.value, at which point the ApplyContext nor the function
 256    that created it are active.
 257
 258    Examples::
 259
 260        mdf.apply(lambda v: v * 5 ) => multiply every column in dataframe
 261        mdf.apply(lambda v: v['foo'].dt.week) => get week of date for column foo
 262        mdf.apply(lambda v: dict(a=v['foo'].dt.week,
 263                                 b=v['bar'] * 5) => run multiple pipelines and get results
 264
 265        The callable passed to apply can be any function. It can either return None,
 266        the context passed in or a list of pipeline stages.
 267
 268        # apply any of the below functions
 269        mdf.apply(customfn)
 270
 271        # same as lambda v: v.dt.week
 272        def customfn(ctx):
 273            return ctx.dt.week
 274
 275        # simple pipeline
 276        def customfn(ctx):
 277            ctx.project(x={'$multiply: ['$x', 5]})
 278            ctx.project(y={'$divide: ['$x', 2]})
 279
 280        # complex pipeline
 281        def customfn(ctx):
 282            return [
 283                { '$match': ... },
 284                { '$project': ... },
 285            ]
 286    """
 287
 288    def __init__(self, caller, columns=None, index=None):
 289        self.caller = caller
 290        self.columns = columns
 291        self.index_columns = index or []
 292        self.computed = []
 293        self.stages = []
 294        self.expressions = []
 295        self._apply_mixins()
 296
 297    def _apply_mixins(self):
 298        """
 299        apply mixins in defaults.OMEGA_MDF_APPLY_MIXINS
 300        """
 301        from omegaml import settings
 302        defaults = settings()
 303        for mixin, applyto in defaults.OMEGA_MDF_APPLY_MIXINS:
 304            if any(v in self.caller._applyto for v in applyto.split(',')):
 305                extend_instance(self, mixin)
 306
 307    def __iter__(self):
 308        # return pipeline stages
 309        for stage in self.stages:
 310            if isinstance(stage, ApplyContext):
 311                for sub_stage in stage:
 312                    yield sub_stage
 313            else:
 314                yield stage
 315
 316    def __getitem__(self, sel):
 317        """
 318        return a stage subset on a column
 319        """
 320        subctx = ApplyContext(self.caller, columns=make_tuple(sel), index=self.index_columns)
 321        self.add(subctx)
 322        return subctx
 323
 324    def __setitem__(self, sel, val):
 325        """
 326        add a projection to a sub context
 327
 328        ctx['col'] = value-expression
 329        """
 330        mapping = {
 331            col: v
 332            for (col, v) in zip(make_tuple(sel), make_tuple(val))}
 333        self.project(mapping)
 334
 335    def __repr__(self):
 336        return 'ApplyContext(stages={}, expressions={})'.format(self.stages, self.expressions)
 337
 338    def add(self, stage):
 339        """
 340        Add a processing stage to the pipeline
 341
 342        see https://docs.mongodb.com/manual/meta/aggregation-quick-reference/
 343        """
 344        self.stages.append(stage)
 345        return self
 346
 347    def project_keeper_columns(self):
 348        # keep index, computed
 349        index = {
 350            col: '$' + col
 351            for col in self.index_columns}
 352        computed = {
 353            col: '$' + col
 354            for col in self.computed}
 355        keep = {}
 356        keep.update(index)
 357        keep.update(computed)
 358        project = self.project(keep, keep=True)
 359        return project
 360
 361    def _getLastStageKind(self, kind):
 362        # see if there is already an open projection stage
 363        for stage in self.stages[::-1]:
 364            if kind in stage:
 365                return stage
 366
 367    def _getProjection(self, append=False):
 368        stage = self._getLastStageKind('$project')
 369        if stage is None or append:
 370            stage = {
 371                '$project': {
 372                    '_id': 1,
 373                }
 374            }
 375            self.stages.append(stage)
 376        return stage
 377
 378    def _getGroupBy(self, by=None, append=False):
 379        stage = self._getLastStageKind('$group')
 380        if stage and stage['$group']['_id'] != by and by != '$$last':
 381            # if a different groupby criteria, add a new one
 382            stage = None
 383        if stage is None and by == '$$last':
 384            by = None
 385        if stage is None or append:
 386            stage = {
 387                '$group': {
 388                    '_id': by,
 389                }
 390            }
 391            self.stages.append(stage)
 392        return stage
 393
 394    def groupby(self, by, expr=None, append=None, **kwargs):
 395        """
 396        add a groupby accumulation using $group
 397
 398        :param by: the groupby columns, if provided as a list will be transformed
 399        :param expr:
 400        :param append:
 401        :param kwargs:
 402        :return:
 403
 404        """
 405        by = make_tuple(by)
 406        self.index_columns = self.index_columns + list(by)
 407        # define groupby
 408        by = {col: '$' + col for col in by}
 409        stage = self._getGroupBy(by)
 410        groupby = stage['$group']
 411        # add acccumulators
 412        expr = expr or {
 413            col: colExpr
 414            for col, colExpr in kwargs.items()}
 415        groupby.update(expr)
 416        # add a projection to extract groupby values
 417        extractId = {
 418            col: '$_id.' + col
 419            for col in by}
 420        # add a projection to keep accumulator columns
 421        keepCols = {
 422            col: 1
 423            for col in expr}
 424        keepCols.update(extractId)
 425        self.project(keepCols, append=True)
 426        # sort by groupby keys
 427        self.add({
 428            '$sort': {
 429                col: 1
 430                for col in by}
 431        })
 432        return self
 433
 434    def project(self, expr=None, append=False, keep=False, **kwargs):
 435        """
 436        add a projection using $project
 437
 438        :param expr: the column-operator mapping
 439        :param append: if True add a $project stage, otherwise add to existing
 440        :param kwargs: if expr is None, the column-operator mapping as kwargs
 441        :return: ApplyContext
 442
 443        """
 444        # get last $project stage in pipeline
 445        stage = self._getProjection(append=append)
 446        expr = expr or kwargs
 447        self.expressions.append(expr)
 448        for k, v in expr.items():
 449            # only append to stage if no other column projection was there
 450            project = stage.get('$project')
 451            if k not in project:
 452                project.update({
 453                    k: v
 454                })
 455            elif not keep:
 456                # if a column is already projected, add a new projection stage
 457                stage = self._getProjection(append=True)
 458                project = stage.get('$project')
 459                project.update({
 460                    k: v
 461                })
 462        return self 
 463
 464
[docs]
 465class ApplyArithmetics(object):
 466    """
 467    Math operators for ApplyContext
 468
 469    * :code:`__mul__` (*)
 470    * :code:`__add__` (+)
 471    * :code:`__sub__` (-)
 472    * :code:`__div__` (/)
 473    * :code:`__floordiv__` (//)
 474    * :code:`__mod__` (%)
 475    * :code:`__pow__` (pow)
 476    * :code:`__ceil__` (ceil)
 477    * :code:`__floor__` (floor)
 478    * :code:`__trunc__` (trunc)
 479    * :code:`__abs__` (abs)
 480    * :code:`sqrt` (math.sqrt)
 481
 482    """
 483
 484    def __arithmop__(op, wrap_op=None):
 485        """
 486        return a pipeline $project stage math operator as
 487           { col:
 488              { '$operator': [ values, ...] }
 489              ...
 490           }
 491
 492        If wrap_op is specified, will wrap the $operator clause as
 493           { col:
 494              { '$wrap_op': { '$operator': [ values, ...] } }0
 495              ...
 496           }
 497        """
 498
 499        def inner(self, other):
 500            terms = []
 501            for term in make_tuple(other):
 502                if isinstance(term, str):
 503                    term = '$' + term
 504                terms.append(term)
 505            def wrap(expr):
 506                if wrap_op is not None:
 507                    expr = {
 508                        wrap_op: expr
 509                    }
 510                return expr
 511            mapping = {
 512                col: wrap({
 513                    op: ['$' + col] + terms,
 514                }) for col in self.columns}
 515            keepCols = {
 516                col: '$' + col
 517                for col in self.index_columns}
 518            mapping.update(keepCols)
 519            self.project(mapping)
 520            return self
 521
 522        return inner
 523
 524    #: multiply
 525    __mul__ = __arithmop__('$multiply')
 526    #: add
 527    __add__ = __arithmop__('$add')
 528    #: subtract
 529    __sub__ = __arithmop__('$subtract')
 530    #: divide
 531    __div__ = __arithmop__('$divide')
 532    __truediv__ = __arithmop__('$divide')
 533    #: divide integer
 534    __floordiv__ = __arithmop__('$divide', wrap_op='$floor')
 535    #: modulo (%)
 536    __mod__ = __arithmop__('$mod')
 537    #: pow
 538    __pow_ = __arithmop__('$pow')
 539    #: ceil
 540    __ceil__ = __arithmop__('$ceil')
 541    #: floor
 542    __floor__ = __arithmop__('$floor')
 543    #: truncate
 544    __trunc__ = __arithmop__('$trunc')
 545    #: absolute
 546    __abs__ = __arithmop__('$abs')
 547    #: square root
 548    sqrt = __arithmop__('sqrt') 
 549
 550
[docs]
 551class ApplyDateTime(object):
 552    """
 553    Datetime operators for ApplyContext
 554    """
 555
 556    @property
 557    def dt(self):
 558        return self
 559
 560    def __dtop__(op):
 561        """
 562        return a datetime $project operator as
 563           { col:
 564              { '$operator': '$col} }
 565              ...
 566           }
 567        """
 568
 569        def inner(self, columns=None):
 570            columns = make_tuple(columns or self.columns)
 571            mapping = {
 572                col: {
 573                    op: '$' + col,
 574                }
 575                for col in columns}
 576            self.project(mapping)
 577            return self
 578
 579        inner.__doc__ = op.replace('$', '')
 580        return inner
 581
 582    # mongodb mappings
 583    _year = __dtop__('$year')
 584    _month = __dtop__('$month')
 585    _week = __dtop__('$week')
 586    _dayOfWeek = __dtop__('$dayOfWeek')
 587    _dayOfMonth = __dtop__('$dayOfMonth')
 588    _dayOfYear = __dtop__('$dayOfYear')
 589    _hour = __dtop__('$hour')
 590    _minute = __dtop__('$minute')
 591    _second = __dtop__('$second')
 592    _millisecond = __dtop__('$millisecond')
 593    _isoDayOfWeek = __dtop__('$isoDayOfWeek')
 594    _isoWeek = __dtop__('$isoWeek')
 595    _isoWeekYear = __dtop__('$isoWeekYear')
 596
 597    # .dt accessor convenience similar to pandas.dt
 598    # see https://pandas.pydata.org/pandas-docs/stable/api.html#datetimelike-properties
 599    year = property(_year)
 600    month = property(_month)
 601    day = property(_dayOfMonth)
 602    hour = property(_hour)
 603    minute = property(_minute)
 604    second = property(_second)
 605    millisecond = property(_millisecond)
 606    week = property(_isoWeek)
 607    dayofyear = property(_dayOfYear)
 608    dayofweek = property(_dayOfWeek) 
 609
 610
[docs]
 611class ApplyString(object):
 612    """
 613    String operators
 614    """
 615
 616    @property
 617    def str(self):
 618        return self
 619
 620    def __strexpr__(op, unwind=False, base=None, max_terms=None):
 621        """
 622        return a pipeline $project string operator as
 623           { col:
 624              { '$operator': [ values, ...] }
 625              ...
 626           }
 627        """
 628
 629        def inner(self, other, *args):
 630            # get all values passed and build terms from them
 631            values = list(make_tuple(other) + args)
 632            terms = []
 633            for term in values:
 634                if isinstance(term, str):
 635                    # if the term is a column name, add as a column name
 636                    if term in self.columns:
 637                        term = '$' + term
 638                    # allow to specify values explicitely by $$<value> => <value>
 639                    term = term.replace('$$', '')
 640                terms.append(term)
 641            # limit number of terms if requested
 642            if max_terms:
 643                terms = terms[:max_terms]
 644            # add projection of output columns to operator
 645            mapping = {
 646                col: {
 647                    op: terms if base is None else ['$' + col] + terms,
 648                } for col in self.columns}
 649            self.project(mapping)
 650            # unwind all columns if requested
 651            if unwind:
 652                exprs = [{'$unwind': {
 653                    'path': '$' + col
 654                }} for col in self.columns]
 655                self.stages.extend(exprs)
 656            return self
 657
 658        inner.__doc__ = op.replace('$', '')
 659        return inner
 660
 661    def __strunary__(op, unwind=False):
 662        """
 663        return a datetime $project operator as
 664           { col:
 665              { '$operator': '$col} }
 666              ...
 667           }
 668        """
 669
 670        def inner(self, columns=None):
 671            columns = make_tuple(columns or self.columns)
 672            mapping = {
 673                col: {
 674                    op: '$' + col,
 675                }
 676                for col in columns}
 677            self.project(mapping)
 678            if unwind:
 679                self.stages.append({
 680                    '$unwind': {
 681                        ''
 682                    }
 683                })
 684            return self
 685
 686            inner.__doc__ = op.replace('$', '')
 687
 688        return inner
 689
 690    def isequal(self, other):
 691        self.strcasecmp(other)
 692        # strcasecmp returns 0 for equality, 1 and -1 for greater/less than
 693        # https://docs.mongodb.com/manual/reference/operator/aggregation/strcasecmp/
 694        mapping = {
 695            col: {
 696                '$cond': {
 697                    'if': {'$eq': ['$' + col, 0]},
 698                    'then': True,
 699                    'else': False,
 700                }
 701            }
 702            for col in self.columns}
 703        self.project(mapping)
 704
 705    concat = __strexpr__('$concat', base=True)
 706    split = __strexpr__('$split', unwind=True, base=True, max_terms=2)
 707    usplit = __strexpr__('$split', unwind=False, base=True, max_terms=2)
 708    upper = __strunary__('$toUpper')
 709    lower = __strunary__('$toLower')
 710    substr = __strexpr__('$substr', base=True)
 711    strcasecmp = __strexpr__('$strcasecmp', base=True)
 712    len = __strunary__('$strLenBytes')
 713    index = __strexpr__('$indexOfBytes', base=True) 
 714
 715
[docs]
 716class ApplyAccumulators(object):
 717    def agg(self, map=None, **kwargs):
 718        stage = self._getGroupBy(by='$$last')
 719        specs = map or kwargs
 720        for col, colExpr in specs.items():
 721            if isinstance(colExpr, dict):
 722                # specify an arbitrary expression
 723                groupby = stage['$group']
 724                groupby[col] = colExpr
 725            elif isinstance(colExpr, str):
 726                # specify some known operator
 727                if hasattr(self, colExpr):
 728                    method = getattr(self, colExpr)
 729                    method(col)
 730                else:
 731                    raise SyntaxError('{} is not known'.format(colExpr))
 732            elif isinstance(colExpr, (tuple, list)):
 733                # specify a list of some known operators
 734                for statExpr in colExpr:
 735                    if hasattr(self, statExpr):
 736                        method = getattr(self, statExpr)
 737                        method(col)
 738                    else:
 739                        raise SyntaxError('{} is not known'.format(statExpr))
 740            elif callable(colExpr):
 741                # specify a callable that returns an expression
 742                groupby = stage['$group']
 743                groupby[col] = colExpr(col)
 744            else:
 745                SyntaxError('{} on column {} is unknown or invalid'.format(colExpr, col))
 746        return self
 747
 748    def __statop__(op, opname=None):
 749        opname = opname or op.replace('$', '')
 750
 751        def inner(self, columns=None):
 752            columns = make_tuple(columns or self.columns)
 753            stage = self._getGroupBy(by='$$last')
 754            groupby = stage['$group']
 755            groupby.update({
 756                               '{}_{}'.format(col, opname): {
 757                                   op: '$' + col
 758                               } for col in columns
 759                               })
 760            self.computed.extend(groupby.keys())
 761            self.project_keeper_columns()
 762            return self
 763
 764        return inner
 765
 766    sum = __statop__('$sum')
 767    avg = __statop__('$avg')
 768    mean = __statop__('$avg')
 769    min = __statop__('$min')
 770    max = __statop__('$max')
 771    std = __statop__('$stdDevSamp', 'std') 
 772
 773
 774class ApplyCache(object):
 775    """
 776    A Cache that works on collections and pipelines
 777    """
 778    def __init__(self, db_alias):
 779        self._db_alias = db_alias
 780
 781    def set(self, key, value):
 782        # https://stackoverflow.com/a/22003440/890242
 783        QueryCache = make_QueryCache(self._db_alias)
 784        QueryCache.objects(key=key).update_one(set__key="{}".format(key),
 785                                               set__value=value, upsert=True)
 786
 787    def get(self, key):
 788        QueryCache = make_QueryCache(self._db_alias)
 789        try:
 790            result = QueryCache.objects.get(key=key)
 791        except:
 792            result = None
 793        return result
 794
 795
 796class ApplyStatistics(object):
 797    def quantile(self, q=.5):
 798        def preparefn(val):
 799            return val.pivot(columns='var', index='percentile', values='value')
 800        return self.apply(self._percentile(q), preparefn=preparefn)
 801
 802    def cov(self):
 803        def preparefn(val):
 804            val = val.pivot(columns='y', index='x', values='cov')
 805            val.index.name = None
 806            val.columns.name = None
 807            return val
 808        return self.apply(self._covariance, preparefn=preparefn)
 809
 810    def corr(self):
 811        def preparefn(val):
 812            val = val.pivot(columns='y', index='x', values='rho')
 813            val.index.name = None
 814            val.columns.name = None
 815            return val
 816        return self.apply(self._pearson, preparefn=preparefn)
 817
 818    def _covariance(self, ctx):
 819        # this works
 820        # source http://ci.columbia.edu/ci/premba_test/c0331/s7/s7_5.html
 821        facets = {}
 822        means = {}
 823        unwinds = []
 824        count = len(ctx.caller.noapply()) - 1
 825        for x, y in product(ctx.columns, ctx.columns):
 826            xcol = '$' + x
 827            ycol = '$' + y
 828            # only calculate the same column's mean once
 829            if xcol not in means:
 830                means[xcol] = ctx.caller[x].noapply().mean().values[0, 0]
 831            if ycol not in means:
 832                means[ycol] = ctx.caller[y].noapply().mean().values[0, 0]
 833            sumands = {
 834                xcol: {
 835                    '$subtract': [xcol, means[xcol]]
 836                },
 837                ycol: {
 838                    '$subtract': [ycol, means[ycol]]
 839                }
 840            }
 841            multiply = {
 842                '$multiply': [sumands[xcol], sumands[ycol]]
 843            }
 844            agg = {
 845                '$group': {
 846                    '_id': None,
 847                    'value': {
 848                        '$sum': multiply
 849                    }
 850                }
 851            }
 852            project = {
 853                '$project': {
 854                    'cov': {
 855                        '$divide': ['$value', count],
 856                    },
 857                    'x': x,
 858                    'y': y,
 859                }
 860            }
 861            pipeline = [agg, project]
 862            outcol = '{}_{}'.format(x, y)
 863            facets[outcol] = pipeline
 864            unwinds.append({'$unwind': '$' + outcol})
 865        facet = {
 866            '$facet': facets,
 867        }
 868        expand = [{
 869            '$project': {
 870                'value': {
 871                    '$objectToArray': '$$CURRENT',
 872                }
 873            }
 874        }, {
 875            '$unwind': '$value'
 876        }, {
 877            '$replaceRoot': {
 878                'newRoot': '$value.v'
 879            }
 880        }]
 881        return [facet, *unwinds, *expand]
 882
 883    def _pearson(self, ctx):
 884        # this works
 885        # source http://ilearnasigoalong.blogspot.ch/2017/10/calculating-correlation-inside-mongodb.html
 886        facets = {}
 887        unwinds = []
 888        for x, y in product(ctx.columns, ctx.columns):
 889            xcol = '$' + x
 890            ycol = '$' + y
 891            sumcolumns = {'$group': {'_id': None,
 892                                     'count': {'$sum': 1},
 893                                     'sumx': {'$sum': xcol},
 894                                     'sumy': {'$sum': ycol},
 895                                     'sumxsquared': {'$sum': {'$multiply': [xcol, xcol]}},
 896                                     'sumysquared': {'$sum': {'$multiply': [ycol, ycol]}},
 897                                     'sumxy': {'$sum': {'$multiply': [xcol, ycol]}}
 898                                     }}
 899
 900            multiply_sumx_sumy = {'$multiply': ["$sumx", "$sumy"]}
 901            multiply_sumxy_count = {'$multiply': ["$sumxy", "$count"]}
 902            partone = {'$subtract': [multiply_sumxy_count, multiply_sumx_sumy]}
 903
 904            multiply_sumxsquared_count = {'$multiply': ["$sumxsquared", "$count"]}
 905            sumx_squared = {'$multiply': ["$sumx", "$sumx"]}
 906            subparttwo = {'$subtract': [multiply_sumxsquared_count, sumx_squared]}
 907
 908            multiply_sumysquared_count = {'$multiply': ["$sumysquared", "$count"]}
 909            sumy_squared = {'$multiply': ["$sumy", "$sumy"]}
 910            subpartthree = {'$subtract': [multiply_sumysquared_count, sumy_squared]}
 911
 912            parttwo = {'$sqrt': {'$multiply': [subparttwo, subpartthree]}}
 913
 914            rho = {'$project': {
 915                'rho': {
 916                    '$divide': [partone, parttwo]
 917                },
 918                'x': x,
 919                'y': y
 920            }}
 921            pipeline = [sumcolumns, rho]
 922            outcol = '{}_{}'.format(x, y)
 923            facets[outcol] = pipeline
 924            unwinds.append({'$unwind': '$' + outcol})
 925        facet = {
 926            '$facet': facets,
 927        }
 928        expand = [{
 929            '$project': {
 930                'value': {
 931                    '$objectToArray': '$$CURRENT',
 932                }
 933            }
 934        }, {
 935            '$unwind': '$value'
 936        }, {
 937            '$replaceRoot': {
 938                'newRoot': '$value.v'
 939            }
 940        }]
 941        return [facet, *unwinds, *expand]
 942
 943    def _percentile(self, pctls=None):
 944        """
 945        calculate percentiles for all columns
 946        """
 947        pctls = pctls or [.25, .5, .75]
 948        if not isinstance(pctls, (list, tuple)):
 949                pctls = [pctls]
 950
 951        def calc(col, p, outcol):
 952            # sort values
 953            sort = {
 954                '$sort': {
 955                    col: 1,
 956                }
 957            }
 958            # group/push to get an array of all values
 959            group = {
 960                '$group': {
 961                    '_id': col,
 962                    'values': {
 963                        '$push': "$" + col
 964                    },
 965                }
 966            }
 967            # find value at requested percentile
 968            perc = {
 969                '$arrayElemAt': [
 970                    '$values', {
 971                        '$floor': {
 972                        '$multiply': [{
 973                            '$size': '$values'
 974                        }, p]
 975                    }}
 976                ]
 977            }
 978            # map percentile value to output column
 979            project = {
 980                '$project': {
 981                    'var': col,
 982                    'percentile': 'p{}'.format(p),
 983                    'value': perc,
 984                }
 985            }
 986            return [sort, group, project]
 987
 988        def inner(ctx):
 989            # for each column and requested percentile, build a pipeline
 990            # all pipelines will be combined into a $facet stage to
 991            # calculate every column/percentile tuple in parallel
 992            facets = {}
 993            unwind = []
 994            # for each column build a pipeline to calculate the percentiles
 995            for col in ctx.columns:
 996                for p in pctls:
 997                    # e.g. outcol for perc .25 of column abc => abcp25
 998                    outcol = '{}_p{}'.format(col, p).replace('0.', '')
 999                    facets[outcol] = calc(col, p, outcol)
1000                    unwind.append({'$unwind': '$'+ outcol})
1001            # process per-column pipelines in parallel, resulting in one
1002            # document for each variable + percentile combination
1003            facet = {
1004                '$facet': facets
1005            }
1006            # expand single document into one document per variable + percentile combo
1007            # the resulting set of documents contains var/percentile/value
1008            expand = [{
1009                '$project': {
1010                    'value': {
1011                        '$objectToArray': '$$CURRENT',
1012                    }
1013                }
1014            }, {
1015                '$unwind': '$value'
1016            }, {
1017                '$replaceRoot': {
1018                    'newRoot': '$value.v'
1019                }
1020            }]
1021            pipeline = [facet, *unwind, *expand]
1022            return pipeline
1023
1024        return inner
1025
1026