1from __future__ import absolute_import
2
3import logging
4from omegaml.runtimes.proxies.baseproxy import RuntimeProxyBase
5
6logger = logging.getLogger(__name__)
7
8
[docs]
9class OmegaModelProxy(RuntimeProxyBase):
10 """
11 proxy to a remote model in a celery worker
12
13 The proxy provides the same methods as the model but will
14 execute the methods using celery tasks and return celery
15 AsyncResult objects
16
17 Usage::
18
19 om = Omega()
20 # train a model
21 # result is AsyncResult, use .get() to return it's result
22 result = om.runtime.model('foo').fit('datax', 'datay')
23 result.get()
24
25 # predict
26 result = om.runtime.model('foo').predict('datax')
27 # result is AsyncResult, use .get() to return it's result
28 print result.get()
29
30 Notes:
31 The actual methods of ModelProxy are defined in its mixins
32
33 See Also:
34 * ModelMixin
35 * GridSearchMixin
36 """
37
38 # Implementation note:
39 #
40 # We decided to implement each method call explicitely in both
41 # this class (mixins) and the celery tasks. While it would be possible to
42 # implement a generic method and task that passes the method and
43 # arguments to be called, maintainability would suffer and the
44 # documentation would become very unspecific. We think it is much
45 # cleaner to have an explicit interface at the chance of missing
46 # features. If need should arise we can still implement a generic
47 # method call.
48
49 def __init__(self, modelname, runtime=None):
50 super().__init__(modelname, runtime=runtime, store=runtime.omega.models)
51 self.modelname = modelname
52
[docs]
53 def task(self, name):
54 """
55 return the task from the runtime with requirements applied
56 """
57 return self.runtime.task(name)
58
[docs]
59 def experiment(self, experiment=None, label=None, provider=None, **tracker_kwargs):
60 """ return the experiment for this model
61
62 If an experiment does not exist yet, it will be created. The
63 experiment is automatically set to track this model, unless another
64 experiment has already been set to track this model for the same label.
65 If a previous model has been set to track this model it will be
66 returned. If an experiment name is passed it will be used.
67
68 Args:
69 experiment (str): the experiment name, defaults to the modelname
70 label (str): the runtime label, defaults to 'default'
71 provider (str): the provider to use, defaults to 'default'
72 tracker_kwargs (dict): additional kwargs to pass to the tracker
73
74 Returns:
75 OmegaTrackingProxy() instance
76 """
77 label = label or self.runtime._default_label or 'default'
78 exps = self.experiments(label=label) if experiment is None else None
79 exp = exps.get(label) if exps else None
80 experiment = experiment or self.modelname
81 if exp is None:
82 exp = self.runtime.experiment(experiment, provider=provider, **tracker_kwargs)
83 if not label in self.experiments():
84 exp.track(self.modelname, label=label)
85 return exp
86
[docs]
87 def experiments(self, label=None, raw=False):
88 """ return list of experiments tracking this model
89
90 Args:
91 label (None|str): the label for which to return the experiments, or None for all
92 raw (bool): if True return the metadata for the experiment, else return the OmegaTrackingProxy
93
94 Returns:
95 instances (dict): mapping of label => instance of OmegaTrackingProxy if not raw, else Metadata,
96 includes a dummy label '_all_', listing all experiments that track this model.
97
98 .. versionchanged:: 0.17
99 returns a dict instead of a list
100 """
101 store = self.store
102 tracking = (store.metadata(self.modelname).attributes.get('tracking', {}))
103 by_label = {
104 label: self.runtime.experiment(name) if not raw else store.metadata(f'experiments/{name}')
105 for label, name in tracking.items() if label not in ['experiments', 'monitors']
106 }
107 unlabeled = {
108 '_all_': [self.runtime.experiment(name) if not raw else store.metadata(f'experiments/{name}')
109 for name in tracking.get('experiments', [])]
110 }
111 all_exps = dict(**by_label, **unlabeled)
112 return {k: v for k, v in all_exps.items() if not label or k == label}
113
[docs]
114 def monitor(self, label=None, provider=None, **tracker_kwargs):
115 """ return the monitor for this model
116
117 Returns:
118 OmegaMonitorProxy() instance
119 """
120 return self.experiment(label=label, provider=provider, **tracker_kwargs).as_monitor(self.modelname)