1from __future__ import absolute_import
   2
   3import numpy as np
   4import pandas as pd
   5from bson import Code
   6from numpy import isscalar
   7from pymongo.collection import Collection
   8from uuid import uuid4
   9
  10from omegaml.store import qops
  11from omegaml.store.filtered import FilteredCollection
  12from omegaml.store.query import Filter, MongoQ
  13from omegaml.store.queryops import MongoQueryOps
  14from omegaml.util import make_tuple, make_list, restore_index, \
  15    cursor_to_dataframe, restore_index_columns_order, PickableCollection, extend_instance, json_normalize, ensure_index
  16
  17INSPECT_CACHE = []
  18
  19
[docs]
  20class MGrouper(object):
  21    """
  22    a Grouper for MDataFrames
  23    """
  24    STATS_MAP = {
  25        'std': 'stdDevSamp',
  26        'mean': 'avg',
  27    }
  28
  29    def __init__(self, mdataframe, collection, columns, sort=True):
  30        self.mdataframe = mdataframe
  31        self.collection = collection
  32        self.columns = make_tuple(columns)
  33        self.should_sort = sort
  34
  35    def __getattr__(self, attr):
  36        if attr in self.columns:
  37            return MSeriesGroupby(self, self.collection, attr)
  38
  39        def statfunc():
  40            columns = self.columns or self._non_group_columns()
  41            return self.agg({col: attr for col in columns})
  42
  43        return statfunc
  44
  45    def __getitem__(self, item):
  46        return self.__getattr__(item)
  47
[docs]
  48    def agg(self, specs):
  49        """
  50        shortcut for .aggregate
  51        """
  52        return self.aggregate(specs) 
  53
[docs]
  54    def aggregate(self, specs, **kwargs):
  55        """
  56        aggregate by given specs
  57
  58        See the following link for a list of supported operations.
  59        https://docs.mongodb.com/manual/reference/operator/aggregation/group/
  60
  61        :param specs: a dictionary of { column : function | list[functions] }
  62           pairs.
  63        """
  64
  65        def add_stats(specs, column, stat):
  66            specs['%s_%s' % (column, stat)] = {
  67                '$%s' % MGrouper.STATS_MAP.get(stat, stat): '$%s' % column}
  68
  69        # generate $group command
  70        _specs = {}
  71        for column, stats in specs.items():
  72            stats = make_tuple(stats)
  73            for stat in stats:
  74                add_stats(_specs, column, stat)
  75        groupby = qops.GROUP(columns=self.columns,
  76                             **_specs)
  77        # execute and return a dataframe
  78        pipeline = self._amend_pipeline([groupby])
  79        data = self.collection.aggregate(pipeline, allowDiskUse=True)
  80
  81        def get_data():
  82            # we need this to build a pipeline for from_records
  83            # to process, otherwise the cursor will be exhausted already
  84            for group in data:
  85                _id = group.pop('_id')
  86                if isinstance(_id, dict):
  87                    group.update(_id)
  88                yield group
  89
  90        df = pd.DataFrame.from_records(get_data())
  91        columns = make_list(self.columns)
  92        if columns:
  93            df = df.set_index(columns, drop=True)
  94        return df 
  95
  96    def _amend_pipeline(self, pipeline):
  97        """ amend pipeline with default ops on coll.aggregate() calls """
  98        if self.should_sort:
  99            sort = qops.SORT(**dict(qops.make_sortkey('_id')))
 100            pipeline.append(sort)
 101        return pipeline
 102
 103    def _non_group_columns(self):
 104        """ get all columns in mdataframe that is not in columns """
 105        return [col for col in self.mdataframe.columns
 106                if col not in self.columns and col != '_id'
 107                and not col.startswith('_idx')
 108                and not col.startswith('_om#')]
 109
 110    def _count(self):
 111        count_columns = self._non_group_columns()
 112        if len(count_columns) == 0:
 113            count_columns.append('_'.join(self.columns) + '_count')
 114        groupby = {
 115            "$group": {
 116                "_id": {k: "$%s" % k for k in self.columns},
 117            }
 118        }
 119        for k in count_columns:
 120            groupby['$group']['%s' % k] = {"$sum": 1}
 121        pipeline = self._amend_pipeline([groupby])
 122        if self.should_sort:
 123            sort = qops.SORT(**dict(qops.make_sortkey('_id')))
 124            pipeline.append(sort)
 125        return list(self.collection.aggregate(pipeline, allowDiskUse=True))
 126
[docs]
 127    def count(self):
 128        """ return counts by group columns """
 129        counts = self._count()
 130        # remove mongo object _id
 131        for group in counts:
 132            group.update(group.pop('_id'))
 133        # transform results to dataframe, then return as pandas would
 134        resultdf = pd.DataFrame(counts).set_index(make_list(self.columns),
 135                                                  drop=True)
 136        return resultdf 
 137
 138    def __iter__(self):
 139        """ for each group returns the key and a Filter object"""
 140        # reduce count to only one column
 141        groups = getattr(self, self.columns[0])._count()
 142        for group in groups:
 143            keys = group.get('_id')
 144            data = self.mdataframe._clone(query=keys)
 145            yield keys, data 
 146
 147
[docs]
 148class MLocIndexer(object):
 149    """
 150    implements the LocIndexer for MDataFrames
 151    """
 152
 153    def __init__(self, mdataframe, positional=False):
 154        self.mdataframe = mdataframe
 155        # if positional, any loc[spec] will be applied on the rowid only
 156        self.positional = positional
 157        # indicator will be set true if loc specs are from a range type (list, tuple, np.ndarray)
 158        self._from_range = False
 159
[docs]
 160    def __getitem__(self, specs):
 161        """
 162        access by index
 163
 164        use as mdf.loc[specs] where specs is any of
 165
 166        * a list or tuple of scalar index values, e.g. .loc[(1,2,3)]
 167        * a slice of values e.g. .loc[1:5]
 168        * a list of slices, e.g. .loc[1:5, 2:3]
 169
 170        :return: the sliced part of the MDataFrame
 171        """
 172        filterq, projection = self._get_filter(specs)
 173        df = self.mdataframe
 174        if filterq:
 175            df = self.mdataframe.query(filterq)
 176            df.from_loc_indexer = True
 177            df.from_loc_range = self._from_range
 178        if projection:
 179            df = df[projection]
 180        if isinstance(self.mdataframe, MSeries):
 181            df = df._as_mseries(df.columns[0])
 182        if getattr(df, 'immediate_loc', False):
 183            df = df.value
 184        return df 
 185
 186    def __setitem__(self, specs, value):
 187        raise NotImplementedError()
 188
 189    def _get_filter(self, specs):
 190        filterq = []
 191        projection = None
 192        if self.positional:
 193            idx_cols = ['_om#rowid']
 194        else:
 195            idx_cols = self.mdataframe._get_frame_index()
 196        flt_kwargs = {}
 197        enumerable_types = (list, tuple, np.ndarray)
 198        if isinstance(specs, np.ndarray):
 199            specs = specs.tolist()
 200        if (isinstance(specs, enumerable_types)
 201                and isscalar(specs[0]) and len(idx_cols) == 1
 202                and not any(isinstance(s, slice) for s in specs)):
 203            # single column index with list of scalar values
 204            if (self.positional and isinstance(specs, tuple) and len(specs) == 2
 205                    and all(isscalar(v) for v in specs)):
 206                # iloc[int, int] is a cell access
 207                flt_kwargs[idx_cols[0]] = specs[0]
 208                projection = self._get_projection(specs[1])
 209            else:
 210                flt_kwargs['{}__in'.format(idx_cols[0])] = specs
 211                self._from_range = True
 212        elif isinstance(specs, (int, str)):
 213            flt_kwargs[idx_cols[0]] = specs
 214        else:
 215            specs = make_tuple(specs)
 216            # list/tuple of slices or scalar values, or MultiIndex
 217            for i, spec in enumerate(specs):
 218                if i < len(idx_cols):
 219                    col = idx_cols[i]
 220                    if isinstance(spec, slice):
 221                        self._from_range = True
 222                        start, stop = spec.start, spec.stop
 223                        if start is not None:
 224                            flt_kwargs['{}__gte'.format(col)] = start
 225                        if stop is not None:
 226                            if isinstance(stop, int):
 227                                stop -= int(self.positional)
 228                            flt_kwargs['{}__lte'.format(col)] = stop
 229                    elif isinstance(spec, enumerable_types) and isscalar(spec[0]):
 230                        self._from_range = True
 231                        # single column index with list of scalar values
 232                        # -- convert to list for PyMongo serialization
 233                        if isinstance(spec, np.ndarray):
 234                            spec = spec.tolist()
 235                        flt_kwargs['{}__in'.format(col)] = spec
 236                    elif isscalar(col):
 237                        flt_kwargs[col] = spec
 238                else:
 239                    # we're out of index columns, let's look at columns
 240                    cur_proj = self._get_projection(spec)
 241                    # if we get a single column, that's the projection
 242                    # if we had a single column, now need to add more, create a list
 243                    # if we have a list already, extend that
 244                    if projection is None:
 245                        projection = cur_proj
 246                    elif isinstance(projection, list):
 247                        projection.extend(cur_proj)
 248                    else:
 249                        projection = [projection, cur_proj]
 250        if flt_kwargs:
 251            filterq.append(MongoQ(**flt_kwargs))
 252        finalq = None
 253        for q in filterq:
 254            if finalq:
 255                finalq |= q
 256            else:
 257                finalq = q
 258        return finalq, projection
 259
 260    def _get_projection(self, spec):
 261        columns = self.mdataframe.columns
 262        if np.isscalar(spec):
 263            return [spec]
 264        if isinstance(spec, (tuple, list)):
 265            assert all(columns.index(col) for col in columns)
 266            return spec
 267        if isinstance(spec, slice):
 268            start, stop = spec.start, spec.stop
 269            if all(isinstance(v, int) for v in (start, stop)):
 270                start, stop, step = spec.indices(len(columns))
 271            else:
 272                start = columns.index(start) if start is not None else 0
 273                stop = columns.index(stop) + 1 if stop is not None else len(columns)
 274            return columns[slice(start, stop)]
 275        raise IndexError 
 276
 277
[docs]
 278class MPosIndexer(MLocIndexer):
 279    """
 280    implements the position-based indexer for MDataFrames
 281    """
 282
 283    def __init__(self, mdataframe):
 284        super(MPosIndexer, self).__init__(mdataframe, positional=True)
 285
 286    def _get_projection(self, spec):
 287        columns = self.mdataframe.columns
 288        if np.isscalar(spec):
 289            return columns[spec]
 290        if isinstance(spec, (tuple, list)):
 291            return [col for i, col in enumerate(spec) if i in spec]
 292        if isinstance(spec, slice):
 293            start, stop = slice.start, slice.stop
 294            if start and not isinstance(start, int):
 295                start = 0
 296            if stop and not isinstance(stop, int):
 297                # sliced ranges are inclusive
 298                stop = len(columns)
 299            return columns[slice(start, stop)]
 300        raise IndexError 
 301
 302
 303class MSeriesGroupby(MGrouper):
 304    """
 305    like a MGrouper but limited to one column
 306    """
 307
 308    def count(self):
 309        """
 310        return series count
 311
 312        :return: counts by group
 313        """
 314        # MGrouper will insert a _count column, see _count(). we remove
 315        # that column again and return a series named as the group column
 316        resultdf = super(MSeriesGroupby, self).count()
 317        count_column = [col for col in resultdf.columns
 318                        if col.endswith('_count')][0]
 319        new_column = count_column.replace('_count', '')
 320        resultdf = resultdf.rename(columns={count_column: new_column})
 321        return resultdf[new_column]
 322
 323
[docs]
 324class MDataFrame(object):
 325    """
 326    A DataFrame for mongodb
 327
 328    Performs out-of-core, lazy computOation on a mongodb cluster.
 329    Behaves like a pandas DataFrame. Actual results are returned
 330    as pandas DataFrames.
 331    """
 332
 333    STATFUNCS = ['mean', 'std', 'min', 'max', 'sum', 'var']
 334
 335    def __init__(self, collection, columns=None, query=None,
 336                 limit=None, skip=None, sort_order=None,
 337                 force_columns=None, immediate_loc=False, auto_inspect=False,
 338                 normalize=False, raw=False,
 339                 parser=None,
 340                 preparefn=None, from_loc_range=False, metadata=None, **kwargs):
 341        self.collection = PickableCollection(collection)
 342        # columns in frame
 343        self.columns = make_tuple(columns) if columns else self._get_fields(raw=raw)
 344        self.columns = [str(col) for col in self.columns]
 345        # columns to sort by, defaults to not sorted
 346        self.sort_order = sort_order
 347        # top n documents to fetch
 348        self.head_limit = limit
 349        # top n documents to skip before returning
 350        self.skip_topn = skip
 351        # filter criteria
 352        self.filter_criteria = query or {}
 353        # force columns -- on output add columns not present
 354        self.force_columns = force_columns or []
 355        # was this created from the loc indexer?
 356        self.from_loc_indexer = kwargs.get('from_loc_indexer', False)
 357        # was the loc index used a range? Else a single value
 358        self.from_loc_range = from_loc_range
 359        # setup query for filter criteries, if provided
 360        if self.filter_criteria:
 361            # make sure we have a filtered collection with the criteria given
 362            if isinstance(self.filter_criteria, dict):
 363                self.query_inplace(**self.filter_criteria)
 364            elif isinstance(self.filter_criteria, Filter):
 365                self.query_inplace(self.filter_criteria)
 366            else:
 367                raise ValueError('Invalid query specification of type {}'.format(type(self.filter_criteria)))
 368        # if immediate_loc is True, .loc and .iloc always evaluate
 369        self.immediate_loc = immediate_loc
 370        # __array__ will return this value if it is set, set it otherwise
 371        self._evaluated = None
 372        # set true to automatically capture inspects on .value. retrieve using .inspect(cached=True)
 373        self.auto_inspect = auto_inspect
 374        self._inspect_cache = INSPECT_CACHE
 375        # apply mixins
 376        self._applyto = str(self.__class__)
 377        self._apply_mixins()
 378        # parser to parse documents to dataframe
 379        self._parser = json_normalize if normalize else parser
 380        # prepare function to be applied just before returning from .value
 381        self._preparefn = preparefn
 382        # keep technical fields like _id, _idx etc
 383        self._raw = raw
 384        # metadata stored by omegaml (equiv. of metadata.kind_meta)
 385        self.metadata = metadata or dict()
 386
 387    def _apply_mixins(self, *args, **kwargs):
 388        """
 389        apply mixins in defaults.OMEGA_MDF_MIXINS
 390        """
 391        from omegaml import settings
 392        defaults = settings()
 393        for mixin, applyto in defaults.OMEGA_MDF_MIXINS:
 394            if any(v in self._applyto for v in applyto.split(',')):
 395                extend_instance(self, mixin, *args, **kwargs)
 396
 397    def __getstate__(self):
 398        # pickle support. note that the hard work is done in PickableCollection
 399        data = dict(self.__dict__)
 400        data.update(_evaluated=None)
 401        data.update(_inspect_cache=None)
 402        data.update(auto_inspect=self.auto_inspect)
 403        data.update(_preparefn=self._preparefn)
 404        data.update(_parser=self._parser)
 405        data.update(_raw=self._raw)
 406        data.update(collection=self.collection)
 407        return data
 408
 409    def __reduce__(self):
 410        state = self.__getstate__()
 411        args = self.collection,
 412        return _mdf_remake, args, state
 413
 414    def __setstate__(self, state):
 415        # pickle support. note that the hard work is done in PickableCollection
 416        for k, v in state.items():
 417            setattr(self, k, v)
 418
 419    def _getcopy_kwargs(self, without=None):
 420        """ return all parameters required on a copy of this MDataFrame """
 421        kwargs = dict(columns=self.columns,
 422                      sort_order=self.sort_order,
 423                      limit=self.head_limit,
 424                      skip=self.skip_topn,
 425                      from_loc_indexer=self.from_loc_indexer,
 426                      from_loc_range=self.from_loc_range,
 427                      immediate_loc=self.immediate_loc,
 428                      metadata=self.metadata,
 429                      query=self.filter_criteria,
 430                      auto_inspect=self.auto_inspect,
 431                      parser=self._parser,
 432                      preparefn=self._preparefn)
 433        [kwargs.pop(k) for k in make_tuple(without or [])]
 434        return kwargs
 435
 436    def __array__(self, dtype=None):
 437        # FIXME inefficient. make MDataFrame a drop-in replacement for any numpy ndarray
 438        # this evaluates every single time
 439        if self._evaluated is None:
 440            self._evaluated = array = self.value.values
 441        else:
 442            array = self._evaluated
 443        return array
 444
 445    def __getattr__(self, attr):
 446        if attr in MDataFrame.STATFUNCS:
 447            return self.statfunc(attr)
 448        if attr in self.columns:
 449            kwargs = self._getcopy_kwargs()
 450            kwargs.update(columns=attr)
 451            return MSeries(self.collection, **kwargs)
 452        raise AttributeError(attr)
 453
 454    def __getitem__(self, cols_or_slice):
 455        """
 456        select and project by column, columns, slice, masked-style filter
 457
 458        Masked-style filters work similar to pd.DataFrame/Series masks
 459        but do not actually return masks but an instance of Filter. A
 460        Filter is a delayed evaluation on the data frame.
 461
 462            # select all rows where any column is == 5
 463            mdf = MDataFrame(coll)
 464            flt = mdf == 5
 465            mdf[flt]
 466            =>
 467
 468        :param cols_or_slice: single column (str), multi-columns (list),
 469          slice to select columns or a masked-style
 470        :return: filtered MDataFrame or MSeries
 471        """
 472        if isinstance(cols_or_slice, str):
 473            # column name => MSeries
 474            return self._as_mseries(cols_or_slice)
 475        elif isinstance(cols_or_slice, int):
 476            # column number => MSeries
 477            column = self.columns[cols_or_slice]
 478            return self._as_mseries(column)
 479        elif isinstance(cols_or_slice, (tuple, list)):
 480            # list of column names => MDataFrame subset on columns
 481            kwargs = self._getcopy_kwargs()
 482            kwargs.update(columns=cols_or_slice)
 483            return MDataFrame(self.collection, **kwargs)
 484        elif isinstance(cols_or_slice, Filter):
 485            kwargs = self._getcopy_kwargs()
 486            kwargs.update(query=cols_or_slice.query)
 487            return MDataFrame(self.collection, **kwargs)
 488        elif isinstance(cols_or_slice, np.ndarray):
 489            return self.iloc[cols_or_slice]
 490        raise ValueError('unknown accessor type %s' % type(cols_or_slice))
 491
 492    def __setitem__(self, column, value):
 493        # True for any scalar type, numeric, bool, string
 494        if np.isscalar(value):
 495            result = self.collection.update_many(filter=self.filter_criteria,
 496                                                 update=qops.SET(column, value))
 497            self.columns.append(column)
 498        return self
 499
 500    def _clone(self, collection=None, **kwargs):
 501        # convenience method to clone itself with updates
 502        collection = collection if collection is not None else self.collection
 503        return self.__class__(collection, **kwargs,
 504                              **self._getcopy_kwargs(without=list(kwargs.keys())))
 505
 506    def statfunc(self, stat):
 507        aggr = MGrouper(self, self.collection, [], sort=False)
 508        return getattr(aggr, stat)
 509
[docs]
 510    def groupby(self, columns, sort=True):
 511        """
 512        Group by a given set of columns
 513
 514        :param columns: the list of columns
 515        :param sort: if True sort by group key
 516        :return: MGrouper
 517        """
 518        return MGrouper(self, self.collection, columns, sort=sort) 
 519
 520    def _get_fields(self, raw=False):
 521        result = []
 522        doc = self.collection.find_one()
 523        if doc is not None:
 524            if raw:
 525                result = list(doc.keys())
 526            else:
 527                result = [str(col) for col in doc.keys()
 528                          if col != '_id'
 529                          and not col.startswith('_idx')
 530                          and not col.startswith('_om#')]
 531        return result
 532
 533    def _get_frame_index(self):
 534        """ return the dataframe's index columns """
 535        doc = self.collection.find_one()
 536        if doc is None:
 537            result = []
 538        else:
 539            result = restore_index_columns_order(doc.keys())
 540        return result
 541
 542    def _get_frame_om_fields(self):
 543        """ return the dataframe's omega special fields columns """
 544        doc = self.collection.find_one()
 545        if doc is None:
 546            result = []
 547        else:
 548            result = [k for k in list(doc.keys()) if k.startswith('_om#')]
 549        return result
 550
 551    def _as_mseries(self, column):
 552        kwargs = self._getcopy_kwargs()
 553        kwargs.update(columns=make_tuple(column))
 554        return MSeries(self.collection, **kwargs)
 555
[docs]
 556    def inspect(self, explain=False, cached=False, cursor=None, raw=False):
 557        """
 558        inspect this dataframe's actual mongodb query
 559
 560        :param explain: if True explains access path
 561        """
 562        if not cached:
 563            if isinstance(self.collection, FilteredCollection):
 564                query = self.collection.query
 565            else:
 566                query = '*',
 567            if explain:
 568                cursor = cursor or self._get_cursor()
 569                explain = cursor.explain()
 570            data = {
 571                'projection': self.columns,
 572                'query': query,
 573                'explain': explain or 'specify explain=True'
 574            }
 575        else:
 576            data = self._inspect_cache
 577        if not (raw or explain):
 578            data = pd.DataFrame(json_normalize(data))
 579        return data 
 580
 581    def count(self):
 582        """
 583        projected number of rows when resolving
 584        """
 585        nrows = len(self)
 586        counts = pd.Series({
 587            col: nrows
 588            for col in self.columns}, index=self.columns)
 589        return counts
 590
[docs]
 591    def __len__(self):
 592        """
 593        the projected number of rows when resolving
 594        """
 595        # we reduce to just 1 column to reduce speed
 596        short = self._clone()
 597        short = short[self.columns[0]] if self.columns else short
 598        return sum(1 for d in short._get_cursor()) 
 599
 600    @property
 601    def shape(self):
 602        """
 603        return shape of dataframe
 604        """
 605        return len(self), len(self.columns)
 606
 607    @property
 608    def ndim(self):
 609        return len(self.shape)
 610
 611    @property
 612    def value(self):
 613        """
 614        resolve the query and return a Pandas DataFrame
 615
 616        :return: the result of the query as a pandas DataFrame
 617        """
 618        cursor = self._get_cursor()
 619        df = self._get_dataframe_from_cursor(cursor)
 620        if self.auto_inspect:
 621            self._inspect_cache.append(self.inspect(explain=True, cursor=cursor, raw=True))
 622        # this ensures the equiv. of pandas df.loc[n] is a Series
 623        if self.from_loc_indexer:
 624            if len(df) == 1 and not self.from_loc_range:
 625                idx = df.index
 626                df = df.T
 627                df = df[df.columns[0]]
 628                if df.ndim == 1 and len(df) == 1 and not isinstance(idx, pd.MultiIndex):
 629                    # single row single dimension, numeric index only
 630                    df = df.iloc[0]
 631            elif (df.ndim == 1 or df.shape[1] == 1) and not self.from_loc_range:
 632                df = df[df.columns[0]]
 633        if self._preparefn:
 634            df = self._preparefn(df)
 635        return df
 636
 637    def reset(self):
 638        # TODO if head(), tail(), query(), .loc/iloc should return a new MDataFrame instance to avoid having a reset need
 639        self.head_limit = None
 640        self.skip_topn = None
 641        self.filter_criteria = {}
 642        self.force_columns = []
 643        self.sort_order = None
 644        self.from_loc_indexer = False
 645        return self
 646
 647    def _get_dataframe_from_cursor(self, cursor):
 648        """
 649        from the given cursor return a DataFrame
 650        """
 651        df = cursor_to_dataframe(cursor, parser=self._parser)
 652        df = self._restore_dataframe_proper(df)
 653        return df
 654
 655    @property
 656    def _index_meta(self):
 657        return self.metadata.get('idx_meta') or dict()
 658
 659    def _restore_dataframe_proper(self, df):
 660        df = restore_index(df, self._index_meta)
 661        if '_id' in df.columns and not self._raw:
 662            df.drop('_id', axis=1, inplace=True)
 663        if self.force_columns:
 664            missing = set(self.force_columns) - set(self.columns)
 665            for col in missing:
 666                df[col] = np.NaN
 667        return df
 668
 669    def _get_cursor(self):
 670        projection = make_tuple(self.columns)
 671        projection += make_tuple(self._get_frame_index())
 672        if not self.sort_order:
 673            # implicit sort
 674            projection += make_tuple(self._get_frame_om_fields())
 675        cursor = self.collection.find(projection=projection)
 676        if self.sort_order:
 677            cursor.sort(qops.make_sortkey(make_tuple(self.sort_order)))
 678        if self.head_limit:
 679            cursor.limit(self.head_limit)
 680        if self.skip_topn:
 681            cursor.skip(self.skip_topn)
 682        return cursor
 683
[docs]
 684    def sort(self, columns):
 685        """
 686        sort by specified columns
 687
 688        :param columns: str of single column or a list of columns. Sort order
 689                        is specified as the + (ascending) or - (descending)
 690                        prefix to the column name. Default sort order is
 691                        ascending.
 692        :return: the MDataFrame
 693        """
 694        self._evaluated = None
 695        self.sort_order = make_tuple(columns)
 696        return self 
 697
[docs]
 698    def head(self, limit=10):
 699        """
 700        return up to limit numbers of rows
 701
 702        :param limit: the number of rows to return. Defaults to 10
 703        :return: the MDataFrame
 704        """
 705        return self._clone(limit=limit) 
 706
 707    def tail(self, limit=10):
 708        """
 709        return up to limit number of rows from last inserted values
 710
 711        :param limit:
 712        :return:
 713        """
 714        tail_n = self.skip(len(self) - limit)
 715        return self._clone(skip=tail_n)
 716
[docs]
 717    def skip(self, topn):
 718        """
 719        skip the topn number of rows
 720
 721        :param topn: the number of rows to skip.
 722        :return: the MDataFrame
 723        """
 724        return self._clone(skip=topn) 
 725
[docs]
 726    def merge(self, right, on=None, left_on=None, right_on=None,
 727              how='inner', target=None, suffixes=('_x', '_y'),
 728              sort=False, inspect=False, filter=None):
 729        """
 730        merge this dataframe with another dataframe. only left outer joins
 731        are currently supported. the output is saved as a new collection,
 732        target name (defaults to a generated name if not specified).
 733
 734        :param right: the other MDataFrame
 735        :param on: the list of key columns to merge by
 736        :param left_on: the list of the key columns to merge on this dataframe
 737        :param right_on: the list of the key columns to merge on the other
 738           dataframe
 739        :param how: the method to merge. supported are left, inner, right.
 740           Defaults to inner
 741        :param target: the name of the collection to store the merge results
 742           in. If not provided a temporary name will be created.
 743        :param suffixes: the suffixes to apply to identical left and right
 744           columns
 745        :param sort: if True the merge results will be sorted. If False the
 746           MongoDB natural order is implied.
 747        :returns: the MDataFrame to the target MDataFrame
 748        """
 749        # validate input
 750        supported_how = ["left", 'inner', 'right']
 751        assert how in supported_how, "only %s merges are currently supported" % supported_how
 752        for key in [on, left_on, right_on]:
 753            if key:
 754                assert isinstance(
 755                    key, str), "only single column merge keys are supported (%s)" % key
 756        if isinstance(right, (Collection, PickableCollection, FilteredCollection)):
 757            right = MDataFrame(right)
 758        assert isinstance(
 759            right, MDataFrame), "both must be MDataFrames, got right=%" % type(right)
 760        if how == 'right':
 761            # A right B == B left A
 762            return right.merge(self, on=on, left_on=right_on, right_on=left_on,
 763                               how='left', target=target, suffixes=suffixes)
 764        # generate lookup parameters
 765        on = on or '_id'
 766        right_name = self._get_collection_name_of(right, right)
 767        target_name = self._get_collection_name_of(
 768            target, '_temp.merge.%s' % uuid4().hex)
 769        target_field = (
 770                "%s_%s" % (right_name.replace('.', '_'), right_on or on))
 771        """
 772        TODO enable filter criteria on right dataframe. requires changing LOOKUP syntax from 
 773             equitly to arbitray match 
 774        
 775        if right.filter_criteria:
 776            right_filter = [qops.MATCH(self._get_filter_criteria(**right.filter_criteria))]
 777        else:
 778            right_filter = None
 779        """
 780        right_filter = None
 781        lookup = qops.LOOKUP(right_name,
 782                             key=on,
 783                             left_key=left_on,
 784                             right_key=right_on,
 785                             target=target_field)
 786        # unwind merged documents from arrays to top-level document fields
 787        unwind = qops.UNWIND(target_field, preserve=how != 'inner')
 788        # get all fields from left, right
 789        project = {}
 790        for left_col in self.columns:
 791            source_left_col = left_col
 792            if left_col == '_id':
 793                project[left_col] = 1
 794                continue
 795            if left_col.startswith('_idx'):
 796                continue
 797            if left_col.startswith('_om#'):
 798                continue
 799            if left_col != (on or left_on) and left_col in right.columns:
 800                left_col = '%s%s' % (left_col, suffixes[0])
 801            project[left_col] = "$%s" % source_left_col
 802        for right_col in right.columns:
 803            if right_col == '_id':
 804                continue
 805            if right_col.startswith('_idx'):
 806                continue
 807            if right_col.startswith('_om#'):
 808                continue
 809            if right_col == (on or right_on) and right_col == (on or left_on):
 810                # if the merge field is the same in both frames, we already
 811                # have it from left
 812                continue
 813            if right_col in self.columns:
 814                left_col = '%s%s' % (right_col, suffixes[1])
 815            else:
 816                left_col = '%s' % right_col
 817            project[left_col] = '$%s.%s' % (target_field, right_col)
 818        expected_columns = list(project.keys())
 819        if '_id' not in project:
 820            project['_id'] = 0  # never copy objectids to avoid duplicate keys, unless requested
 821        project = {"$project": project}
 822        # store merged documents and return an MDataFrame to it
 823        out = qops.OUT(target_name)
 824        pipeline = [lookup, unwind, project]
 825        if filter:
 826            query = qops.MATCH(self._get_filter_criteria(**filter))
 827            pipeline.append(query)
 828        if sort:
 829            sort_cols = make_list(on or [left_on, right_on])
 830            sort_key = qops.make_sortkey(sort_cols)
 831            sort = qops.SORT(**dict(sort_key))
 832            pipeline.append(sort)
 833        pipeline.append(out)
 834        if inspect:
 835            result = pipeline
 836        else:
 837            result = self.collection.aggregate(pipeline, allowDiskUse=True)
 838            result = MDataFrame(self.collection.database[target_name],
 839                                force_columns=expected_columns)
 840        return result 
 841
 842    def append(self, other):
 843        if isinstance(other, Collection):
 844            other = MDataFrame(other)
 845        assert isinstance(
 846            other, MDataFrame), "both must be MDataFrames, got other={}".format(type(other))
 847        outname = self.collection.name
 848        mrout = {
 849            'merge': outname,
 850            'nonAtomic': True,
 851        }
 852        mapfn = Code("""
 853        function() {
 854           this._id = ObjectId();
 855           if(this['_om#rowid']) {
 856              this['_om#rowid'] += %s;
 857           }
 858           emit(this._id, this);
 859        }
 860        """ % len(self))
 861        reducefn = Code("""
 862        function(key, value) {
 863           return value;
 864        }
 865        """)
 866        finfn = Code("""
 867        function(key, value) {
 868           return value;
 869        }
 870        """)
 871        other.collection.map_reduce(mapfn, reducefn, mrout, finalize=finfn, jsMode=True)
 872        unwind = {
 873            "$replaceRoot": {
 874                "newRoot": {
 875                    "$ifNull": ["$value", "$$CURRENT"],
 876                }
 877            }
 878        }
 879        output = qops.OUT(outname)
 880        pipeline = [unwind, output]
 881        self.collection.aggregate(pipeline, allowDiskUse=True)
 882        return self
 883
 884    def _get_collection_name_of(self, some, default=None):
 885        """
 886        determine the collection name of the given parameter
 887
 888        returns the collection name if some is a MDataFrame, a Collection
 889        or a string_type. Otherwise returns default
 890        """
 891        if isinstance(some, MDataFrame):
 892            name = some.collection.name
 893        elif isinstance(some, Collection):
 894            name = some.name
 895        else:
 896            name = default
 897        return name
 898
 899    def _get_filter_criteria(self, *args, **kwargs):
 900        """
 901        return mongo query from filter specs
 902
 903        this uses a Filter to produce the query from the kwargs.
 904
 905        :param args: a Q object or logical combination of Q objects
 906           (optional)
 907        :param kwargs: all AND filter criteria
 908        """
 909        if len(args) > 0:
 910            q = args[0]
 911            if isinstance(q, MongoQ):
 912                filter_criteria = Filter(self.collection, q).query
 913            elif isinstance(q, Filter):
 914                filter_criteria = Filter(self.collection, q.q).query
 915        else:
 916            filter_criteria = Filter(self.collection, **kwargs).query
 917        return filter_criteria
 918
[docs]
 919    def query_inplace(self, *args, **kwargs):
 920        """
 921        filters this MDataFrame and returns it.
 922
 923        Any subsequent operation on the dataframe will have the filter
 924        applied. To reset the filter call .reset() without arguments.
 925
 926        :param args: a Q object or logical combination of Q objects
 927           (optional)
 928        :param kwargs: all AND filter criteria
 929        :return: self
 930        """
 931        self._evaluated = None
 932        self.filter_criteria = self._get_filter_criteria(*args, **kwargs)
 933        self.collection = FilteredCollection(
 934            self.collection, query=self.filter_criteria)
 935        return self 
 936
[docs]
 937    def query(self, *args, **kwargs):
 938        """
 939        return a new MDataFrame with a filter criteria
 940
 941        Any subsequent operation on the new dataframe will have the filter
 942        applied. To reset the filter call .reset() without arguments.
 943
 944        Note: Unlike pandas DataFrames, a filtered MDataFrame operates
 945        on the same collection as the original DataFrame
 946
 947        :param args: a Q object or logical combination of Q objects
 948           (optional)
 949        :param kwargs: all AND filter criteria
 950        :return: a new MDataFrame with the filter applied
 951        """
 952        effective_filter = dict(self.filter_criteria)
 953        filter_criteria = self._get_filter_criteria(*args, **kwargs)
 954        if '$and' in effective_filter:
 955            effective_filter['$and'].extend(filter_criteria.get('$and'))
 956        else:
 957            effective_filter.update(filter_criteria)
 958        coll = FilteredCollection(self.collection, query=effective_filter)
 959        return self._clone(collection=coll, query=effective_filter) 
 960
[docs]
 961    def create_index(self, keys, **kwargs):
 962        """
 963        create and index the easy way
 964        """
 965        keys, kwargs = MongoQueryOps().make_index(keys)
 966        result = ensure_index(self.collection, keys, **kwargs)
 967        return result 
 968
 969    def list_indexes(self):
 970        """
 971        list all indices in database
 972        """
 973        return cursor_to_dataframe(self.collection.list_indexes())
 974
 975    def iterchunks(self, chunksize=100):
 976        """
 977        return an iterator
 978
 979        Args:
 980            chunksize (int): number of rows in each chunk
 981
 982        Returns:
 983            a dataframe of max. length chunksize
 984        """
 985        chunksize = int(chunksize)
 986        i = 0
 987        while True:
 988            chunkdf = self.skip(i).head(chunksize).value
 989            if len(chunkdf) == 0:
 990                break
 991            i += chunksize
 992            yield chunkdf
 993
 994    def itertuples(self, chunksize=1000):
 995        chunksize = int(chunksize)
 996        __doc__ = pd.DataFrame.itertuples.__doc__
 997
 998        for chunkdf in self.iterchunks(chunksize=chunksize):
 999            for row in chunkdf.iterrows():
1000                yield row
1001
1002    def iterrows(self, chunksize=1000):
1003        chunksize = int(chunksize)
1004        __doc__ = pd.DataFrame.iterrows.__doc__
1005
1006        for chunkdf in self.iterchunks(chunksize=chunksize):
1007            if isinstance(chunkdf, pd.DataFrame):
1008                for row in chunkdf.iterrows():
1009                    yield row
1010            else:
1011                # Series does not have iterrows
1012                for i in range(0, len(chunkdf), chunksize):
1013                    yield chunkdf.iloc[i:i + chunksize]
1014
1015    def iteritems(self):
1016        if not hasattr(pd.DataFrame, 'iteritems'):
1017            raise NotImplementedError('MDataFrame.iteritems has been removed since Pandas 2.0. Use .items instead.')
1018        __doc__ = pd.DataFrame.iteritems.__doc__
1019        return self.items()
1020
1021    def items(self):
1022        __doc__ = pd.DataFrame.items.__doc__
1023
1024        for col in self.columns:
1025            yield col, self[col].value
1026
1027    @property
1028    def loc(self):
1029        """
1030        Access by index
1031
1032        Use as mdf.loc[index_value]
1033
1034        :return: MLocIndexer
1035        """
1036        self._evaluated = None
1037        indexer = MLocIndexer(self)
1038        return indexer
1039
1040    @property
1041    def iloc(self):
1042        self._evaluated = None
1043        indexer = MPosIndexer(self)
1044        return indexer
1045
1046    def rows(self, start=None, end=None, chunksize=1000):
1047        # equivalent to .iloc[start:end].iteritems(),
1048        start, end, chunksize = (int(v) for v in (start, end, chunksize))
1049        return self.iloc[slice(start, end)].iterchunks(chunksize)
1050
1051    def __repr__(self):
1052        kwargs = ', '.join('{}={}'.format(k, v) for k, v in self._getcopy_kwargs().items())
1053        return "MDataFrame(collection={collection.name}, {kwargs})".format(collection=self.collection,
1054                                                                           kwargs=kwargs) 
1055
1056
[docs]
1057class MSeries(MDataFrame):
1058    """
1059    Series implementation for MDataFrames
1060
1061    behaves like a DataFrame but limited to one column.
1062    """
1063
1064    def __init__(self, *args, **kwargs):
1065        super(MSeries, self).__init__(*args, **kwargs)
1066        # true if only unique values apply
1067        self.is_unique = False
1068        # apply mixins
1069        self._applyto = str(self.__class__)
1070        self._apply_mixins(*args, **kwargs)
1071
1072    def __getitem__(self, cols_or_slice):
1073        if isinstance(cols_or_slice, Filter):
1074            return MSeries(self.collection, columns=self.columns,
1075                           query=cols_or_slice.query)
1076        return super(MSeries, self).__getitem__(cols_or_slice)
1077
1078    @property
1079    def name(self):
1080        return self.columns[0]
1081
[docs]
1082    def unique(self):
1083        """
1084        return the unique set of values for the series
1085
1086        :return: MSeries
1087        """
1088        self.is_unique = True
1089        return self 
1090
1091    def _get_cursor(self):
1092        if self.is_unique:
1093            # this way indexes get applied
1094            cursor = self.collection.distinct(make_tuple(self.columns)[0])
1095        else:
1096            cursor = super(MSeries, self)._get_cursor()
1097        return cursor
1098
1099    @property
1100    def value(self):
1101        """
1102        return the value of the series
1103
1104        this is a Series unless unique() was called. If unique()
1105        only distinct values are returned as an array, matching
1106        the behavior of a Series
1107
1108        :return: pandas.Series
1109        """
1110        cursor = self._get_cursor()
1111        column = make_tuple(self.columns)[0]
1112        if self.is_unique:
1113            # the .distinct() cursor returns a list of values
1114            # this is to make sure we return the same thing as pandas
1115            val = [v for v in cursor]
1116        else:
1117            val = self._get_dataframe_from_cursor(cursor)
1118            val = val[column]
1119            val.name = self.name
1120            if len(val) == 1 and self.from_loc_indexer:
1121                val = val.iloc[0]
1122        if self.auto_inspect:
1123            self._inspect_cache.append(self.inspect(explain=True, cursor=cursor, raw=True))
1124        if self._preparefn:
1125            df = self._preparefn(val)
1126        return val
1127
1128    def __repr__(self):
1129        kwargs = ', '.join('{}={}'.format(k, v) for k, v in self._getcopy_kwargs().items())
1130        return "MSeries(collection={collection.name}, {kwargs})".format(collection=self.collection,
1131                                                                        kwargs=kwargs)
1132
1133    @property
1134    def shape(self):
1135        return len(self), 
1136
1137
1138def _mdf_remake(collection):
1139    # recreate a pickled MDF
1140    mdf = MDataFrame(collection)
1141    return mdf