1from __future__ import absolute_import
2
3import logging
4
5from omegaml.runtimes.proxies.baseproxy import RuntimeProxyBase
6
7logger = logging.getLogger(__name__)
8
9
[docs]
10class OmegaModelProxy(RuntimeProxyBase):
11 """
12 proxy to a remote model in a celery worker
13
14 The proxy provides the same methods as the model but will
15 execute the methods using celery tasks and return celery
16 AsyncResult objects
17
18 Usage::
19
20 om = Omega()
21 # train a model
22 # result is AsyncResult, use .get() to return it's result
23 result = om.runtime.model('foo').fit('datax', 'datay')
24 result.get()
25
26 # predict
27 result = om.runtime.model('foo').predict('datax')
28 # result is AsyncResult, use .get() to return it's result
29 print result.get()
30
31 Notes:
32 The actual methods of ModelProxy are defined in its mixins
33
34 See Also:
35 * ModelMixin
36 * GridSearchMixin
37 """
38
39 # Implementation note:
40 #
41 # We decided to implement each method call explicitely in both
42 # this class (mixins) and the celery tasks. While it would be possible to
43 # implement a generic method and task that passes the method and
44 # arguments to be called, maintainability would suffer and the
45 # documentation would become very unspecific. We think it is much
46 # cleaner to have an explicit interface at the chance of missing
47 # features. If need should arise we can still implement a generic
48 # method call.
49
50 def __init__(self, modelname, runtime=None):
51 super().__init__(modelname, runtime=runtime, store=runtime.omega.models)
52 self.modelname = modelname
53
[docs]
54 def task(self, name):
55 """
56 return the task from the runtime with requirements applied
57 """
58 return self.runtime.task(name)
59
[docs]
60 def experiment(self, experiment=None, label=None, provider=None, **tracker_kwargs):
61 """ return the experiment for this model
62
63 If an experiment does not exist yet, it will be created. The
64 experiment is automatically set to track this model, unless another
65 experiment has already been set to track this model for the same label.
66 If a previous model has been set to track this model it will be
67 returned. If an experiment name is passed it will be used.
68
69 Args:
70 experiment (str): the experiment name, defaults to the modelname
71 label (str): the runtime label, defaults to 'default'
72 provider (str): the provider to use, defaults to 'default'
73 tracker_kwargs (dict): additional kwargs to pass to the tracker
74
75 Returns:
76 OmegaTrackingProxy() instance
77 """
78 label = label or self.runtime._default_label or 'default'
79 label = self.runtime._default_label if label in ('default', 'local') else label
80 exps = self.experiments(label=label) if experiment is None else None
81 exp = exps.get(label) if exps else None
82 experiment = experiment or self.modelname
83 if exp is None:
84 exp = self.runtime.experiment(experiment, provider=provider, **tracker_kwargs)
85 if not label in self.experiments():
86 exp.track(self.modelname, label=label)
87 return exp
88
[docs]
89 def experiments(self, label=None, raw=False):
90 """ return list of experiments tracking this model
91
92 Args:
93 label (None|str): the label for which to return the experiments, or None for all
94 raw (bool): if True return the metadata for the experiment, else return the OmegaTrackingProxy
95
96 Returns:
97 instances (dict): mapping of label => instance of OmegaTrackingProxy if not raw, else Metadata,
98 includes a dummy label '_all_', listing all experiments that track this model.
99
100 .. versionchanged:: 0.17
101 returns a dict instead of a list
102 """
103 store = self.store
104 tracking = (store.metadata(self.modelname).attributes.get('tracking', {}))
105 by_label = {
106 label: self.runtime.experiment(name) if not raw else store.metadata(f'experiments/{name}')
107 for label, name in tracking.items() if label not in ['experiments', 'monitors']
108 }
109 unlabeled = {
110 '_all_': [self.runtime.experiment(name) if not raw else store.metadata(f'experiments/{name}')
111 for name in tracking.get('experiments', [])]
112 }
113 all_exps = dict(**by_label, **unlabeled)
114 return {k: v for k, v in all_exps.items() if not label or k == label}
115
[docs]
116 def monitor(self, label=None, provider=None, **tracker_kwargs):
117 """ return the monitor for this model
118
119 Returns:
120 OmegaMonitorProxy() instance
121 """
122 return self.experiment(label=label, provider=provider, **tracker_kwargs).as_monitor(self.modelname)