1from __future__ import absolute_import
2
3import logging
4import sys
5import warnings
6from uuid import uuid4
7
8from omegaml.util import is_dataframe, is_ndarray, is_series
9
10logger = logging.getLogger(__file__)
11
12
[docs]
13class ModelMixin(object):
14 """ mixin methods to OmegaModelProxy
15 """
16
[docs]
17 def fit(self, Xname, Yname=None, **kwargs):
18 """
19 fit the model
20
21 Calls :code:`.fit(X, Y, **kwargs)`. If instead of dataset names actual data
22 is given, the data is stored using _fitX/fitY prefixes and a unique
23 name.
24
25 After fitting, a new model version is stored with its attributes
26 fitX and fitY pointing to the datasets, as well as the sklearn
27 version used.
28
29 :param Xname: name of X dataset or data
30 :param Yname: name of Y dataset or data
31 :return: the model (self) or the string representation (python clients)
32 """
33 omega_fit = self.task('omegaml.tasks.omega_fit')
34 Xname = self._ensure_data_is_stored(Xname, prefix='_fitX')
35 if Yname is not None:
36 Yname = self._ensure_data_is_stored(Yname, prefix='_fitY')
37 return omega_fit.delay(self.modelname, Xname, Yname=Yname, **kwargs)
38
[docs]
39 def partial_fit(self, Xname, Yname=None, **kwargs):
40 """
41 update the model
42
43 Calls :code:`.partial_fit(X, Y, **kwargs)`. If instead of dataset names actual
44 data is given, the data is stored using _fitX/fitY prefixes and
45 a unique name.
46
47 After fitting, a new model version is stored with its attributes
48 fitX and fitY pointing to the datasets, as well as the sklearn
49 version used.
50
51 :param Xname: name of X dataset or data
52 :param Yname: name of Y dataset or data
53 :return: the model (self) or the string representation (python clients)
54 """
55 omega_fit = self.task('omegaml.tasks.omega_partial_fit')
56 Xname = self._ensure_data_is_stored(Xname, prefix='_fitX')
57 if Yname is not None:
58 Yname = self._ensure_data_is_stored(Yname, prefix='_fitY')
59 return omega_fit.delay(self.modelname, Xname, Yname=Yname, **kwargs)
60
77
98
[docs]
99 def predict(self, Xpath_or_data, rName=None, **kwargs):
100 """
101 predict
102
103 Calls :code:`.predict(X)`. If rName is given the result is
104 stored as object rName
105
106 :param Xname: name of the X dataset
107 :param rName: name of the resulting dataset (optional)
108 :return: the data returned by .predict, or the metadata of the rName
109 dataset if rName was given
110 """
111 omega_predict = self.task('omegaml.tasks.omega_predict')
112 Xname = self._ensure_data_is_stored(Xpath_or_data)
113 return omega_predict.delay(self.modelname, Xname, rName=rName, **kwargs)
114
115
[docs]
116 def predict_proba(self, Xpath_or_data, rName=None, **kwargs):
117 """
118 predict probabilities
119
120 Calls :code:`.predict_proba(X)`. If rName is given the result is
121 stored as object rName
122
123 :param Xname: name of the X dataset
124 :param rName: name of the resulting dataset (optional)
125 :return: the data returned by .predict_proba, or the metadata of the rName
126 dataset if rName was given
127 """
128 omega_predict_proba = self.task(
129 'omegaml.tasks.omega_predict_proba')
130 Xname = self._ensure_data_is_stored(Xpath_or_data)
131 return omega_predict_proba.delay(self.modelname, Xname, rName=rName, **kwargs)
132
[docs]
133 def complete(self, Xname, rName=None, **kwargs):
134 """
135 complete
136
137 Calls :code:`.complete(X)`. If rName is given the result is
138 stored as object rName
139
140 :param Xname: name of the X dataset
141 :param rName: name of the resulting dataset (optional)
142 :return: the data returned by .complete, or the metadata of the rName
143 dataset if rName was given
144 """
145 omega_complete = self.task('omegaml.tasks.omega_complete')
146 Xname = self._ensure_data_is_stored(Xname)
147 return omega_complete.delay(self.modelname, Xname, rName=rName, **kwargs)
148
[docs]
149 def score(self, Xname, Yname=None, rName=None, **kwargs):
150 """
151 calculate score
152
153 Calls :code:`.score(X, y, **kwargs)`. If rName is given the result is
154 stored as object rName
155
156 :param Xname: name of the X dataset
157 :param yName: name of the y dataset
158 :param rName: name of the resulting dataset (optional)
159 :return: the data returned by .score, or the metadata of the rName
160 dataset if rName was given
161 """
162 omega_score = self.task('omegaml.tasks.omega_score')
163 Xname = self._ensure_data_is_stored(Xname)
164 YName = self._ensure_data_is_stored(Yname)
165 return omega_score.delay(self.modelname, Xname, Yname=YName, rName=rName, **kwargs)
166
[docs]
167 def decision_function(self, Xname, rName=None, **kwargs):
168 """
169 calculate score
170
171 Calls :code:`.decision_function(X, y, **kwargs)`. If rName is given the result is
172 stored as object rName
173
174 :param Xname: name of the X dataset
175 :param rName: name of the resulting dataset (optional)
176 :return: the data returned by .score, or the metadata of the rName
177 dataset if rName was given
178 """
179 omega_decision_function = self.task('omegaml.tasks.omega_decision_function')
180 Xname = self._ensure_data_is_stored(Xname)
181 return omega_decision_function.delay(self.modelname, Xname, rName=rName, **kwargs)
182
183 def reduce(self, rName=None, **kwargs):
184 omega_reduce = self.task('omegaml.tasks.omega_reduce')
185 return omega_reduce.delay(modelName=self.modelname, rName=rName, **kwargs)
186
187 def _ensure_data_is_stored(self, name_or_data, prefix='_temp', as_payload=False):
188 from omegaml.mixins.store.passthrough import PassthroughDataset
189
190 if as_payload:
191 return PassthroughDataset(name_or_data)
192 elif isinstance(name_or_data, str):
193 name = name_or_data
194 elif isinstance(name_or_data, PassthroughDataset):
195 return name_or_data
196 elif isinstance(name_or_data, (list, tuple, dict)):
197 if sys.getsizeof(name_or_data) <= PassthroughDataset.MAX_SIZE:
198 return PassthroughDataset(name_or_data)
199 else:
200 warnings.warn(f'size of dataset is larger than {PassthroughDataset.MAX_SIZE} bytes, storing in om.datasets')
201 name = '%s_%s' % (prefix, uuid4().hex)
202 self.runtime.omega.datasets.put(name_or_data, name)
203 elif is_dataframe(name_or_data) or is_series(name_or_data):
204 name = '%s_%s' % (prefix, uuid4().hex)
205 self.runtime.omega.datasets.put(name_or_data, name)
206 elif is_ndarray(name_or_data):
207 name = '%s_%s' % (prefix, uuid4().hex)
208 self.runtime.omega.datasets.put(name_or_data, name)
209 else:
210 raise TypeError(
211 'invalid type for Xpath_or_data', type(name_or_data))
212 return name
213