1from __future__ import absolute_import
  2
  3import logging
  4import sys
  5import warnings
  6from uuid import uuid4
  7
  8from omegaml.util import is_dataframe, is_ndarray, is_series
  9
 10logger = logging.getLogger(__file__)
 11
 12
[docs]
 13class ModelMixin(object):
 14    """ mixin methods to OmegaModelProxy
 15    """
 16
[docs]
 17    def fit(self, Xname, Yname=None, **kwargs):
 18        """
 19        fit the model
 20
 21        Calls :code:`.fit(X, Y, **kwargs)`. If instead of dataset names actual data
 22        is given, the data is stored using _fitX/fitY prefixes and a unique
 23        name.
 24
 25        After fitting, a new model version is stored with its attributes
 26        fitX and fitY pointing to the datasets, as well as the sklearn
 27        version used.
 28
 29        :param Xname: name of X dataset or data
 30        :param Yname: name of Y dataset or data
 31        :return: the model (self) or the string representation (python clients)
 32        """
 33        omega_fit = self.task('omegaml.tasks.omega_fit')
 34        Xname = self._ensure_data_is_stored(Xname, prefix='_fitX')
 35        if Yname is not None:
 36            Yname = self._ensure_data_is_stored(Yname, prefix='_fitY')
 37        return omega_fit.delay(self.modelname, Xname, Yname=Yname, **kwargs) 
 38
[docs]
 39    def partial_fit(self, Xname, Yname=None, **kwargs):
 40        """
 41        update the model
 42
 43        Calls :code:`.partial_fit(X, Y, **kwargs)`. If instead of dataset names actual
 44        data  is given, the data is stored using _fitX/fitY prefixes and
 45        a unique name.
 46
 47        After fitting, a new model version is stored with its attributes
 48        fitX and fitY pointing to the datasets, as well as the sklearn
 49        version used.
 50
 51        :param Xname: name of X dataset or data
 52        :param Yname: name of Y dataset or data
 53        :return: the model (self) or the string representation (python clients)
 54        """
 55        omega_fit = self.task('omegaml.tasks.omega_partial_fit')
 56        Xname = self._ensure_data_is_stored(Xname, prefix='_fitX')
 57        if Yname is not None:
 58            Yname = self._ensure_data_is_stored(Yname, prefix='_fitY')
 59        return omega_fit.delay(self.modelname, Xname, Yname=Yname, **kwargs) 
 60
 77
 98
[docs]
 99    def predict(self, Xpath_or_data, rName=None, **kwargs):
100        """
101        predict
102
103        Calls :code:`.predict(X)`. If rName is given the result is
104        stored as object rName
105
106        :param Xname: name of the X dataset
107        :param rName: name of the resulting dataset (optional)
108        :return: the data returned by .predict, or the metadata of the rName
109            dataset if rName was given
110        """
111        omega_predict = self.task('omegaml.tasks.omega_predict')
112        Xname = self._ensure_data_is_stored(Xpath_or_data)
113        return omega_predict.delay(self.modelname, Xname, rName=rName, **kwargs) 
114
[docs]
115    def predict_proba(self, Xpath_or_data, rName=None, **kwargs):
116        """
117        predict probabilities
118
119        Calls :code:`.predict_proba(X)`. If rName is given the result is
120        stored as object rName
121
122        :param Xname: name of the X dataset
123        :param rName: name of the resulting dataset (optional)
124        :return: the data returned by .predict_proba, or the metadata of the rName
125           dataset if rName was given
126        """
127        omega_predict_proba = self.task(
128            'omegaml.tasks.omega_predict_proba')
129        Xname = self._ensure_data_is_stored(Xpath_or_data)
130        return omega_predict_proba.delay(self.modelname, Xname, rName=rName, **kwargs) 
131
[docs]
132    def complete(self, Xname, rName=None, **kwargs):
133        """
134        complete
135
136        Calls :code:`.complete(X)`. If rName is given the result is
137        stored as object rName
138
139        :param Xname: name of the X dataset
140        :param rName: name of the resulting dataset (optional)
141        :return: the data returned by .complete, or the metadata of the rName
142            dataset if rName was given
143        """
144        omega_complete = self.task('omegaml.tasks.omega_complete')
145        Xname = self._ensure_data_is_stored(Xname)
146        return omega_complete.delay(self.modelname, Xname, rName=rName, **kwargs) 
147
[docs]
148    def embed(self, Xname, rName=None, **kwargs):
149        """
150        embed
151
152        Calls :code:`.embed(X)`. If rName is given the result is
153        stored as object rName
154
155        :param Xname: name of the X dataset
156        :param rName: name of the resulting dataset (optional)
157        :return: the data returned by .embed, or the metadata of the rName
158            dataset if rName was given
159        """
160        omega_embed = self.task('omegaml.tasks.omega_embed')
161        Xname = self._ensure_data_is_stored(Xname)
162        return omega_embed.delay(self.modelname, Xname, rName=rName, **kwargs) 
163
[docs]
164    def score(self, Xname, Yname=None, rName=None, **kwargs):
165        """
166        calculate score
167
168        Calls :code:`.score(X, y, **kwargs)`. If rName is given the result is
169        stored as object rName
170
171        :param Xname: name of the X dataset
172        :param yName: name of the y dataset
173        :param rName: name of the resulting dataset (optional)
174        :return: the data returned by .score, or the metadata of the rName
175           dataset if rName was given
176        """
177        omega_score = self.task('omegaml.tasks.omega_score')
178        Xname = self._ensure_data_is_stored(Xname)
179        YName = self._ensure_data_is_stored(Yname)
180        return omega_score.delay(self.modelname, Xname, Yname=YName, rName=rName, **kwargs) 
181
[docs]
182    def decision_function(self, Xname, rName=None, **kwargs):
183        """
184        calculate score
185
186        Calls :code:`.decision_function(X, y, **kwargs)`. If rName is given the result is
187        stored as object rName
188
189        :param Xname: name of the X dataset
190        :param rName: name of the resulting dataset (optional)
191        :return: the data returned by .score, or the metadata of the rName
192           dataset if rName was given
193        """
194        omega_decision_function = self.task('omegaml.tasks.omega_decision_function')
195        Xname = self._ensure_data_is_stored(Xname)
196        return omega_decision_function.delay(self.modelname, Xname, rName=rName, **kwargs) 
197
198    def reduce(self, rName=None, **kwargs):
199        omega_reduce = self.task('omegaml.tasks.omega_reduce')
200        return omega_reduce.delay(modelName=self.modelname, rName=rName, **kwargs)
201
202    def _ensure_data_is_stored(self, name_or_data, prefix='_temp', as_payload=False):
203        from omegaml.mixins.store.passthrough import PassthroughDataset
204
205        if as_payload:
206            return PassthroughDataset(name_or_data)
207        elif isinstance(name_or_data, str):
208            name = name_or_data
209        elif isinstance(name_or_data, PassthroughDataset):
210            return name_or_data
211        elif isinstance(name_or_data, (list, tuple, dict)):
212            if sys.getsizeof(name_or_data) <= PassthroughDataset.MAX_SIZE:
213                return PassthroughDataset(name_or_data)
214            else:
215                warnings.warn(
216                    f'size of dataset is larger than {PassthroughDataset.MAX_SIZE} bytes, storing in om.datasets')
217                name = '%s_%s' % (prefix, uuid4().hex)
218                self.runtime.omega.datasets.put(name_or_data, name)
219        elif is_dataframe(name_or_data) or is_series(name_or_data):
220            name = '%s_%s' % (prefix, uuid4().hex)
221            self.runtime.omega.datasets.put(name_or_data, name)
222        elif is_ndarray(name_or_data):
223            name = '%s_%s' % (prefix, uuid4().hex)
224            self.runtime.omega.datasets.put(name_or_data, name)
225        else:
226            raise TypeError(
227                'invalid type for Xpath_or_data', type(name_or_data))
228        return name