Source code for omegaml.runtimes.proxies.modelproxy

  1from __future__ import absolute_import
  2
  3import logging
  4
  5from omegaml.runtimes.proxies.baseproxy import RuntimeProxyBase
  6
  7logger = logging.getLogger(__name__)
  8
  9
[docs] 10class OmegaModelProxy(RuntimeProxyBase): 11 """ 12 proxy to a remote model in a celery worker 13 14 The proxy provides the same methods as the model but will 15 execute the methods using celery tasks and return celery 16 AsyncResult objects 17 18 Usage:: 19 20 om = Omega() 21 # train a model 22 # result is AsyncResult, use .get() to return it's result 23 result = om.runtime.model('foo').fit('datax', 'datay') 24 result.get() 25 26 # predict 27 result = om.runtime.model('foo').predict('datax') 28 # result is AsyncResult, use .get() to return it's result 29 print result.get() 30 31 Notes: 32 The actual methods of ModelProxy are defined in its mixins 33 34 See Also: 35 * ModelMixin 36 * GridSearchMixin 37 """ 38 39 # Implementation note: 40 # 41 # We decided to implement each method call explicitely in both 42 # this class (mixins) and the celery tasks. While it would be possible to 43 # implement a generic method and task that passes the method and 44 # arguments to be called, maintainability would suffer and the 45 # documentation would become very unspecific. We think it is much 46 # cleaner to have an explicit interface at the chance of missing 47 # features. If need should arise we can still implement a generic 48 # method call. 49 50 def __init__(self, modelname, runtime=None): 51 super().__init__(modelname, runtime=runtime, store=runtime.omega.models) 52 self.modelname = modelname 53
[docs] 54 def task(self, name): 55 """ 56 return the task from the runtime with requirements applied 57 """ 58 return self.runtime.task(name)
59
[docs] 60 def experiment(self, experiment=None, label=None, provider=None, **tracker_kwargs): 61 """ return the experiment for this model 62 63 If an experiment does not exist yet, it will be created. The 64 experiment is automatically set to track this model, unless another 65 experiment has already been set to track this model for the same label. 66 If a previous model has been set to track this model it will be 67 returned. If an experiment name is passed it will be used. 68 69 Args: 70 experiment (str): the experiment name, defaults to the modelname 71 label (str): the runtime label, defaults to 'default' 72 provider (str): the provider to use, defaults to 'default' 73 tracker_kwargs (dict): additional kwargs to pass to the tracker 74 75 Returns: 76 OmegaTrackingProxy() instance 77 """ 78 label = label or self.runtime._default_label or 'default' 79 label = self.runtime._default_label if label in ('default', 'local') else label 80 exps = self.experiments(label=label) if experiment is None else None 81 exp = exps.get(label) if exps else None 82 experiment = experiment or self.modelname 83 if exp is None: 84 exp = self.runtime.experiment(experiment, provider=provider, **tracker_kwargs) 85 if not label in self.experiments(): 86 exp.track(self.modelname, label=label) 87 return exp
88
[docs] 89 def experiments(self, label=None, raw=False): 90 """ return list of experiments tracking this model 91 92 Args: 93 label (None|str): the label for which to return the experiments, or None for all 94 raw (bool): if True return the metadata for the experiment, else return the OmegaTrackingProxy 95 96 Returns: 97 instances (dict): mapping of label => instance of OmegaTrackingProxy if not raw, else Metadata, 98 includes a dummy label '_all_', listing all experiments that track this model. 99 100 .. versionchanged:: 0.17 101 returns a dict instead of a list 102 """ 103 store = self.store 104 tracking = (store.metadata(self.modelname).attributes.get('tracking', {})) 105 by_label = { 106 label: self.runtime.experiment(name) if not raw else store.metadata(f'experiments/{name}') 107 for label, name in tracking.items() if label not in ['experiments', 'monitors'] 108 } 109 unlabeled = { 110 '_all_': [self.runtime.experiment(name) if not raw else store.metadata(f'experiments/{name}') 111 for name in tracking.get('experiments', [])] 112 } 113 all_exps = dict(**by_label, **unlabeled) 114 return {k: v for k, v in all_exps.items() if not label or k == label}
115
[docs] 116 def monitor(self, label=None, provider=None, **tracker_kwargs): 117 """ return the monitor for this model 118 119 Returns: 120 OmegaMonitorProxy() instance 121 """ 122 return self.experiment(label=label, provider=provider, **tracker_kwargs).as_monitor(self.modelname)