1from __future__ import absolute_import
2
3import logging
4from copy import deepcopy
5from socket import gethostname
6
7from celery import Celery
8from celery.events import EventReceiver
9
10from omegaml.mongoshim import mongo_url
11from omegaml.util import dict_merge
12
13logger = logging.getLogger(__name__)
14
15
16class CeleryTask(object):
17 """
18 A thin wrapper for a Celery.Task object
19
20 This is so that we can collect common delay arguments on the
21 .task() call
22 """
23
24 def __init__(self, task, kwargs):
25 """
26
27 Args:
28 task (Celery.Task): the celery task object
29 kwargs (dict): optional, the kwargs to pass to apply_async
30 """
31 self.task = task
32 self.kwargs = dict(kwargs)
33
34 def _apply_kwargs(self, task_kwargs, celery_kwargs):
35 # update task_kwargs from runtime's passed on kwargs
36 # update celery_kwargs to match celery routing semantics
37 task_kwargs.update(self.kwargs.get('task', {}))
38 celery_kwargs.update(self.kwargs.get('routing', {}))
39 if 'label' in celery_kwargs:
40 celery_kwargs['queue'] = celery_kwargs['label']
41 del celery_kwargs['label']
42
43 def _apply_auth(self, args, kwargs, celery_kwargs):
44 from omegaml.client.auth import AuthenticationEnv
45 AuthenticationEnv.active().taskauth(args, kwargs, celery_kwargs)
46
47 def apply_async(self, args=None, kwargs=None, **celery_kwargs):
48 """
49
50 Args:
51 args (tuple): the task args
52 kwargs (dict): the task kwargs, passed as task.apply_async(kwargs=kwargs)
53 celery_kwargs (dict): apply_async kwargs, passed as task.apply_async(..., **celery_kwargs)
54
55 Returns:
56 AsyncResult
57 """
58 args = args or tuple()
59 kwargs = kwargs or {}
60 self._apply_kwargs(kwargs, celery_kwargs)
61 self._apply_auth(args, kwargs, celery_kwargs)
62 return self.task.apply_async(args=args, kwargs=kwargs, **celery_kwargs)
63
64 def delay(self, *args, **kwargs):
65 """
66 submit the task with args and kwargs to pass on
67
68 This calls task.apply_async and passes on the self.kwargs.
69 """
70 return self.apply_async(args=args, kwargs=kwargs)
71
72 def signature(self, args=None, kwargs=None, immutable=False, **celery_kwargs):
73 """ return the task signature with all kwargs and celery_kwargs applied
74 """
75 self._apply_kwargs(kwargs, celery_kwargs)
76 sig = self.task.signature(args=args, kwargs=kwargs, **celery_kwargs, immutable=immutable)
77 return sig
78
79 def run(self, *args, **kwargs):
80 return self.delay(*args, **kwargs)
81
82
[docs]
83class OmegaRuntime(object):
84 """
85 omegaml compute cluster gateway
86 """
87
88 def __init__(self, omega, bucket=None, defaults=None, celeryconf=None):
89 from omegaml.util import settings
90
91 self.omega = omega
92 defaults = defaults or settings()
93 self.bucket = bucket
94 self.pure_python = getattr(defaults, 'OMEGA_FORCE_PYTHON_CLIENT', False)
95 self.pure_python = self.pure_python or self._client_is_pure_python()
96 self._create_celery_app(defaults, celeryconf=celeryconf)
97 # temporary requirements, use .require() to set
98 self._require_kwargs = dict(task={}, routing={})
99 # fixed default arguments, use .require(always=True) to set
100 self._task_default_kwargs = dict(task={}, routing={})
101 # default routing label
102 self._default_label = self.celeryapp.conf.get('CELERY_DEFAULT_QUEUE')
103
104 def __repr__(self):
105 return 'OmegaRuntime({})'.format(self.omega.__repr__())
106
107 @property
108 def auth(self):
109 return None
110
111 @property
112 def _common_kwargs(self):
113 common = deepcopy(self._task_default_kwargs)
114 common['task'].update(pure_python=self.pure_python, __bucket=self.bucket)
115 common['task'].update(self._require_kwargs['task'])
116 common['routing'].update(self._require_kwargs['routing'])
117 return common
118
119 @property
120 def _inspect(self):
121 return self.celeryapp.control.inspect()
122
123 @property
124 def is_local(self):
125 return self.celeryapp.conf['CELERY_ALWAYS_EAGER']
126
127 def mode(self, local=None, logging=None):
128 """ specify runtime modes
129
130 Args:
131 local (bool): if True, all execution will run locally, else on
132 the configured remote cluster
133 logging (bool|str|tuple): if True, will set the root logger output
134 at INFO level; a single string is the name of the logger,
135 typically a module name; a tuple (logger, level) will select
136 logger and the level. Valid levels are INFO, WARNING, ERROR,
137 CRITICAL, DEBUG
138
139 Usage::
140
141 # run all runtime tasks locally
142 om.runtime.mode(local=True)
143
144 # enable logging both in local and remote mode
145 om.runtime.mode(logging=True)
146
147 # select a specific module and level)
148 om.runtime.mode(logging=('sklearn', 'DEBUG'))
149
150 # disable logging
151 om.runtime.mode(logging=False)
152 """
153 if isinstance(local, bool):
154 self.celeryapp.conf['CELERY_ALWAYS_EAGER'] = local
155 self._task_default_kwargs['task']['__logging'] = logging
156 return self
157
158 def _create_celery_app(self, defaults, celeryconf=None):
159 # initialize celery as a runtimes
160 taskpkgs = defaults.OMEGA_CELERY_IMPORTS
161 celeryconf = dict(celeryconf or defaults.OMEGA_CELERY_CONFIG)
162 # ensure we use current value
163 celeryconf['CELERY_ALWAYS_EAGER'] = bool(defaults.OMEGA_LOCAL_RUNTIME)
164 if celeryconf['CELERY_RESULT_BACKEND'].startswith('mongodb://'):
165 celeryconf['CELERY_RESULT_BACKEND'] = mongo_url(self.omega, drop_kwargs=['uuidRepresentation'])
166 # initialize ssl configuration
167 if celeryconf.get('BROKER_USE_SSL'):
168 # celery > 5 requires ssl options to be specific
169 # https://docs.celeryq.dev/en/stable/userguide/configuration.html#std-setting-broker_use_ssl
170 # https://github.com/celery/kombu/issues/1493
171 # https://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
172 # https://www.openssl.org/docs/man3.0/man3/SSL_CTX_set_default_verify_paths.html
173 # env variables:
174 # SSL_CERT_FILE, CA_CERTS_PATH
175 self._apply_broker_ssl(celeryconf)
176 self.celeryapp = Celery('omegaml')
177 self.celeryapp.config_from_object(celeryconf)
178 # needed to get it to actually load the tasks
179 # https://stackoverflow.com/a/35735471
180 self.celeryapp.autodiscover_tasks(taskpkgs, force=True)
181 self.celeryapp.finalize()
182
183 def _apply_broker_ssl(self, celeryconf):
184 # hook to apply broker ssl options
185 pass
186
187 def _client_is_pure_python(self):
188 try:
189 import pandas as pd
190 import numpy as np
191 import sklearn
192 except Exception as e:
193 logging.getLogger().info(e)
194 return True
195 else:
196 return False
197
198 def _sanitize_require(self, value):
199 # convert value into dict(label=value)
200 if isinstance(value, str):
201 return dict(label=value)
202 if isinstance(value, (list, tuple)):
203 return dict(*value)
204 return value
205
206 def require(self, label=None, always=False, routing=None, task=None,
207 logging=None, override=True, **kwargs):
208 """
209 specify requirements for the task execution
210
211 Use this to specify resource or routing requirements on the next task
212 call sent to the runtime. Any requirements will be reset after the
213 call has been submitted.
214
215 Args:
216 always (bool): if True requirements will persist across task calls. defaults to False
217 label (str): the label required by the worker to have a runtime task dispatched to it.
218 'local' is equivalent to calling self.mode(local=True).
219 task (dict): if specified applied to the task's kwargs, passed as task.apply_async(..., kwargs=task)
220 routing (dict): if specified applied to the task's routing, passed as task.apply_async(..., **routing)
221 logging (str|tuple): if specified, same as runtime.mode(logging=...)
222 override (bool): if True overrides previously set .require(), defaults to True
223 kwargs: requirements specification that the runtime understands
224
225 Usage:
226 om.runtime.require(label='gpu').model('foo').fit(...)
227
228 See Also:
229 - CeleryTask.apply_async
230 - celery.app.task.Task.apply_async, specifically the kwargs= and **options
231
232 Returns:
233 self
234 """
235 if label:
236 # avoid overriding a previous local() call by an erronous label
237 assert isinstance(label, str), "label must be valid, run om.runtime.labels() to list of active workers"
238 if label == 'local':
239 self.mode(local=True)
240 elif override:
241 self.mode(local=False)
242 # update routing, don't replace (#416)
243 routing = routing or {}
244 routing.update({'label': label or self._default_label})
245 task = task or {}
246 routing = routing or {}
247 if task or routing:
248 if not override:
249 # override not allowed, remove previously existing
250 ex_task = dict(**self._task_default_kwargs['task'],
251 **self._require_kwargs['task'])
252 ex_routing = dict(**self._task_default_kwargs['routing'],
253 **self._require_kwargs['routing'])
254 exists_or_none = lambda k, d: k not in d or d.get(k, False) is None
255 task = {k: v for k, v in task.items() if exists_or_none(k, ex_task)}
256 routing = {k: v for k, v in routing.items() if exists_or_none(k, ex_routing)}
257 if always:
258 self._task_default_kwargs['routing'].update(routing)
259 self._task_default_kwargs['task'].update(task)
260 else:
261 self._require_kwargs['routing'].update(routing)
262 self._require_kwargs['task'].update(task)
263 else:
264 # FIXME this does not work as expected (will only reset if both task and routing are False)
265 if not task:
266 self._require_kwargs['task'] = {}
267 if not routing:
268 self._require_kwargs['routing'] = {}
269 if logging is not None:
270 self.mode(logging=logging)
271 return self
272
[docs]
273 def model(self, modelname, require=None):
274 """
275 return a model for remote execution
276
277 Args:
278 modelname (str): the name of the object in om.models
279 require (dict): routing requirements for this job
280
281 Returns:
282 OmegaModelProxy
283 """
284 from omegaml.runtimes.proxies.modelproxy import OmegaModelProxy
285 self.require(**self._sanitize_require(require)) if require else None
286 return OmegaModelProxy(modelname, runtime=self)
287
288 def job(self, jobname, require=None):
289 """
290 return a job for remote execution
291
292 Args:
293 jobname (str): the name of the object in om.jobs
294 require (dict): routing requirements for this job
295
296 Returns:
297 OmegaJobProxy
298 """
299 from omegaml.runtimes.proxies.jobproxy import OmegaJobProxy
300
301 self.require(**self._sanitize_require(require)) if require else None
302 return OmegaJobProxy(jobname, runtime=self)
303
304 def script(self, scriptname, require=None):
305 """
306 return a script for remote execution
307
308 Args:
309 scriptname (str): the name of object in om.scripts
310 require (dict): routing requirements for this job
311
312 Returns:
313 OmegaScriptProxy
314 """
315 from omegaml.runtimes.proxies.scriptproxy import OmegaScriptProxy
316
317 self.require(**self._sanitize_require(require)) if require else None
318 return OmegaScriptProxy(scriptname, runtime=self)
319
320 def experiment(self, experiment, provider=None, implied_run=True, recreate=False, **tracker_kwargs):
321 """ set the tracking backend and experiment
322
323 Args:
324 experiment (str): the name of the experiment
325 provider (str): the name of the provider
326 tracker_kwargs (dict): additional kwargs for the tracker
327 recreate (bool): if True, recreate the experiment (i.e. drop and recreate,
328 this is useful to change the provider or other settings. All previous data will
329 be kept)
330
331 Returns:
332 OmegaTrackingProxy
333 """
334 from omegaml.runtimes.proxies.trackingproxy import OmegaTrackingProxy
335 # tracker implied_run means we are using the currently active run, i.e. with block will call exp.start()
336 tracker = OmegaTrackingProxy(experiment, provider=provider, runtime=self, implied_run=implied_run,
337 recreate=recreate, **tracker_kwargs)
338 return tracker
339
340 def task(self, name, **kwargs):
341 """
342 retrieve the task function from the celery instance
343
344 Args:
345 name (str): a registered celery task as ``module.tasks.task_name``
346 kwargs (dict): routing keywords to CeleryTask.apply_async
347
348 Returns:
349 CeleryTask
350 """
351 taskfn = self.celeryapp.tasks.get(name)
352 assert taskfn is not None, "cannot find task {name} in Celery runtime".format(**locals())
353 kwargs = dict_merge(self._common_kwargs, dict(routing=kwargs))
354 task = CeleryTask(taskfn, kwargs)
355 self._require_kwargs = dict(routing={}, task={})
356 return task
357
358 def result(self, task_id, wait=True):
359 from celery.result import AsyncResult
360 promise = AsyncResult(task_id, app=self.celeryapp)
361 return promise.get() if wait else promise
362
363 def settings(self, require=None):
364 """ return the runtimes's cluster settings
365 """
366 self.require(**require) if require else None
367 return self.task('omegaml.tasks.omega_settings').delay().get()
368
369 def ping(self, *args, require=None, wait=True, timeout=10, **kwargs):
370 """
371 ping the runtime
372
373 Args:
374 args (tuple): task args
375 require (dict): routing requirements for this job
376 wait (bool): if True, wait for the task to return, else return
377 AsyncResult
378 timeout (int): if wait is True, the timeout in seconds, defaults to 10
379 kwargs (dict): task kwargs, as accepted by CeleryTask.apply_async
380
381 Returns:
382 * response (dict) for wait=True
383 * AsyncResult for wait=False
384 """
385 self.require(**require) if require else None
386 promise = self.task('omegaml.tasks.omega_ping').delay(*args, **kwargs)
387 return promise.get(timeout=timeout) if wait else promise
388
389 def enable_hostqueues(self):
390 """ enable a worker-specific queue on every worker host
391
392 Returns:
393 list of labels (one entry for each hostname)
394 """
395 control = self.celeryapp.control
396 inspect = control.inspect()
397 active = inspect.active()
398 queues = []
399 for worker in active.keys():
400 hostname = worker.split('@')[-1]
401 control.cancel_consumer(hostname)
402 control.add_consumer(hostname, destination=[worker])
403 queues.append(hostname)
404 return queues
405
406 def workers(self):
407 """ list of workers
408
409 Returns:
410 dict of workers => list of active tasks
411
412 See Also:
413 celery Inspect.active()
414 """
415 local_worker = {gethostname(): [{'name': 'local', 'is_local': True}]}
416 celery_workers = self._inspect.active() or {}
417 return dict_merge(local_worker, celery_workers)
418
419 def queues(self):
420 """ list queues
421
422 Returns:
423 dict of workers => list of queues
424
425 See Also:
426 celery Inspect.active_queues()
427 """
428 local_q = {gethostname(): [{'name': 'local', 'is_local': True}]}
429 celery_qs = self._inspect.active_queues() or {}
430 return dict_merge(local_q, celery_qs)
431
432 def labels(self):
433 """ list available labels
434
435 Returns:
436 dict of workers => list of lables
437 """
438 return {worker: [q.get('name') for q in queues]
439 for worker, queues in self.queues().items()}
440
441 def stats(self):
442 """ worker statistics
443
444 Returns:
445 dict of workers => dict of stats
446
447 See Also:
448 celery Inspect.stats()
449 """
450 return self._inspect.stats()
451
452 def status(self):
453 """ current cluster status
454
455 This collects key information from .labels(), .stats() and the latest
456 worker heartbeat. Note that loadavg is only available if the worker has
457 recently sent a heartbeat and may not be accurate across the cluster.
458
459 Returns:
460 snapshot (dict): a snapshot of the cluster status
461 '<worker>': {
462 'loadavg': [0.0, 0.0, 0.0], # load average in % seen by the worker (1, 5, 15 min)
463 'processes': 1, # number of active worker processes
464 'concurrency': 1, # max concurrency
465 'uptime': 0, # uptime in seconds
466 'processed': Counter(task=n), # number of tasks processed
467 'queues': ['default'], # list of queues (labels) the worker is listening on
468 }
469 """
470 labels = self.labels()
471 stats = self.stats()
472 heartbeat = self.events.latest()
473 snapshot = {
474 worker: {
475 'loadavg': heartbeat.get('loadavg', []),
476 'processes': stats[worker]['pool']['processes'],
477 'concurrency': stats[worker]['pool']['max-concurrency'],
478 'uptime': stats[worker]['uptime'],
479 'processed': stats[worker]['total'],
480 'queues': labels[worker],
481 } for worker in labels if worker in stats
482 }
483 return snapshot
484
485 @property
486 def events(self):
487 return CeleryEventStream(self.celeryapp)
488
489 def callback(self, script_name, always=False, **kwargs):
490 """ Add a callback to a registered script
491
492 The callback will be triggered upon successful or failed
493 execution of the runtime tasks. The script syntax is::
494
495 # script.py
496 def run(om, state=None, result=None, **kwargs):
497 # state (str): 'SUCCESS'|'ERROR'
498 # result (obj): the task's serialized result
499
500 Args:
501 script_name (str): the name of the script (in om.scripts)
502 always (bool): if True always apply this callback, defaults to False
503 **kwargs: and other kwargs to pass on to the script
504
505 Returns:
506 self
507 """
508 success_sig = (self.script(script_name)
509 .task(as_callback=True)
510 .signature(args=['SUCCESS', script_name],
511 kwargs=kwargs,
512 immutable=False))
513 error_sig = (self.script(script_name)
514 .task(as_callback=True)
515 .signature(args=['ERROR', script_name],
516 kwargs=kwargs,
517 immutable=False))
518
519 if always:
520 self._task_default_kwargs['routing']['link'] = success_sig
521 self._task_default_kwargs['routing']['link_error'] = error_sig
522 else:
523 self._require_kwargs['routing']['link'] = success_sig
524 self._require_kwargs['routing']['link_error'] = error_sig
525 return self
526
527
528class CeleryEventStream:
529 def __init__(self, app, limit=None, timeout=5, wakeup=False):
530 self.app = app
531 self.limit = limit
532 self.timeout = timeout
533 self.wakeup = wakeup
534 self.max_size = 100
535 self.buffer = []
536
537 def handle(self, event):
538 self.buffer.append(event)
539 if len(self.buffer) > self.max_size:
540 self.buffer = self.buffer[-1 * self.max_size:]
541
542 def listen(self, handlers=None, limit=None, timeout=None):
543 # Connect to the broker using Kombu (Celery's underlying messaging system)
544 handlers = handlers or {'worker-heartbeat': self.handle}
545 limit = limit or self.limit
546 timeout = timeout or self.timeout
547 with self.app.connection() as conn:
548 # Create the EventReceiver to listen to all events
549 recv = EventReceiver(conn, handlers=handlers, app=self.app)
550 recv.capture(limit=limit, timeout=timeout, wakeup=self.wakeup)
551
552 def latest(self, timeout=None):
553 while len(self.buffer) == 0:
554 self.listen(limit=1, timeout=timeout)
555 return self.buffer[-1]
556
557
558# apply mixins
559from omegaml.runtimes.mixins.taskcanvas import canvas_chain, canvas_group, canvas_chord
560from omegaml.runtimes.mixins.swagger import SwaggerGenerator
561
562OmegaRuntime.sequence = canvas_chain
563OmegaRuntime.parallel = canvas_group
564OmegaRuntime.mapreduce = canvas_chord
565OmegaRuntime.swagger = SwaggerGenerator.build_swagger