1from __future__ import absolute_import
2
3import logging
4from celery import Celery
5from celery.events import EventReceiver
6from copy import deepcopy
7from socket import gethostname
8
9from omegaml.mongoshim import mongo_url
10from omegaml.util import dict_merge
11
12logger = logging.getLogger(__name__)
13
14
15class CeleryTask(object):
16 """
17 A thin wrapper for a Celery.Task object
18
19 This is so that we can collect common delay arguments on the
20 .task() call
21 """
22
23 def __init__(self, task, kwargs):
24 """
25
26 Args:
27 task (Celery.Task): the celery task object
28 kwargs (dict): optional, the kwargs to pass to apply_async
29 """
30 self.task = task
31 self.kwargs = dict(kwargs)
32
33 def _apply_kwargs(self, task_kwargs, celery_kwargs):
34 # update task_kwargs from runtime's passed on kwargs
35 # update celery_kwargs to match celery routing semantics
36 task_kwargs.update(self.kwargs.get('task', {}))
37 celery_kwargs.update(self.kwargs.get('routing', {}))
38 if 'label' in celery_kwargs:
39 celery_kwargs['queue'] = celery_kwargs['label']
40 del celery_kwargs['label']
41
42 def _apply_auth(self, args, kwargs, celery_kwargs):
43 from omegaml.client.auth import AuthenticationEnv
44 AuthenticationEnv.active().taskauth(args, kwargs, celery_kwargs)
45
46 def apply_async(self, args=None, kwargs=None, **celery_kwargs):
47 """
48
49 Args:
50 args (tuple): the task args
51 kwargs (dict): the task kwargs
52 celery_kwargs (dict): apply_async kwargs, e.g. routing
53
54 Returns:
55 AsyncResult
56 """
57 args = args or tuple()
58 kwargs = kwargs or {}
59 self._apply_kwargs(kwargs, celery_kwargs)
60 self._apply_auth(args, kwargs, celery_kwargs)
61 return self.task.apply_async(args=args, kwargs=kwargs, **celery_kwargs)
62
63 def delay(self, *args, **kwargs):
64 """
65 submit the task with args and kwargs to pass on
66
67 This calls task.apply_async and passes on the self.kwargs.
68 """
69 return self.apply_async(args=args, kwargs=kwargs)
70
71 def signature(self, args=None, kwargs=None, immutable=False, **celery_kwargs):
72 """ return the task signature with all kwargs and celery_kwargs applied
73 """
74 self._apply_kwargs(kwargs, celery_kwargs)
75 sig = self.task.signature(args=args, kwargs=kwargs, **celery_kwargs, immutable=immutable)
76 return sig
77
78 def run(self, *args, **kwargs):
79 return self.delay(*args, **kwargs)
80
81
[docs]
82class OmegaRuntime(object):
83 """
84 omegaml compute cluster gateway
85 """
86
87 def __init__(self, omega, bucket=None, defaults=None, celeryconf=None):
88 from omegaml.util import settings
89
90 self.omega = omega
91 defaults = defaults or settings()
92 self.bucket = bucket
93 self.pure_python = getattr(defaults, 'OMEGA_FORCE_PYTHON_CLIENT', False)
94 self.pure_python = self.pure_python or self._client_is_pure_python()
95 self._create_celery_app(defaults, celeryconf=celeryconf)
96 # temporary requirements, use .require() to set
97 self._require_kwargs = dict(task={}, routing={})
98 # fixed default arguments, use .require(always=True) to set
99 self._task_default_kwargs = dict(task={}, routing={})
100 # default routing label
101 self._default_label = self.celeryapp.conf.get('CELERY_DEFAULT_QUEUE')
102
103 def __repr__(self):
104 return 'OmegaRuntime({})'.format(self.omega.__repr__())
105
106 @property
107 def auth(self):
108 return None
109
110 @property
111 def _common_kwargs(self):
112 common = deepcopy(self._task_default_kwargs)
113 common['task'].update(pure_python=self.pure_python, __bucket=self.bucket)
114 common['task'].update(self._require_kwargs['task'])
115 common['routing'].update(self._require_kwargs['routing'])
116 return common
117
118 @property
119 def _inspect(self):
120 return self.celeryapp.control.inspect()
121
122 @property
123 def is_local(self):
124 return self.celeryapp.conf['CELERY_ALWAYS_EAGER']
125
[docs]
126 def mode(self, local=None, logging=None):
127 """ specify runtime modes
128
129 Args:
130 local (bool): if True, all execution will run locally, else on
131 the configured remote cluster
132 logging (bool|str|tuple): if True, will set the root logger output
133 at INFO level; a single string is the name of the logger,
134 typically a module name; a tuple (logger, level) will select
135 logger and the level. Valid levels are INFO, WARNING, ERROR,
136 CRITICAL, DEBUG
137
138 Usage::
139
140 # run all runtime tasks locally
141 om.runtime.mode(local=True)
142
143 # enable logging both in local and remote mode
144 om.runtime.mode(logging=True)
145
146 # select a specific module and level)
147 om.runtime.mode(logging=('sklearn', 'DEBUG'))
148
149 # disable logging
150 om.runtime.mode(logging=False)
151 """
152 if isinstance(local, bool):
153 self.celeryapp.conf['CELERY_ALWAYS_EAGER'] = local
154 self._task_default_kwargs['task']['__logging'] = logging
155 return self
156
157 def _create_celery_app(self, defaults, celeryconf=None):
158 # initialize celery as a runtimes
159 taskpkgs = defaults.OMEGA_CELERY_IMPORTS
160 celeryconf = dict(celeryconf or defaults.OMEGA_CELERY_CONFIG)
161 # ensure we use current value
162 celeryconf['CELERY_ALWAYS_EAGER'] = bool(defaults.OMEGA_LOCAL_RUNTIME)
163 if celeryconf['CELERY_RESULT_BACKEND'].startswith('mongodb://'):
164 celeryconf['CELERY_RESULT_BACKEND'] = mongo_url(self.omega, drop_kwargs=['uuidRepresentation'])
165 # initialize ssl configuration
166 if celeryconf.get('BROKER_USE_SSL'):
167 # celery > 5 requires ssl options to be specific
168 # https://docs.celeryq.dev/en/stable/userguide/configuration.html#std-setting-broker_use_ssl
169 # https://github.com/celery/kombu/issues/1493
170 # https://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
171 # https://www.openssl.org/docs/man3.0/man3/SSL_CTX_set_default_verify_paths.html
172 # env variables:
173 # SSL_CERT_FILE, CA_CERTS_PATH
174 self._apply_broker_ssl(celeryconf)
175 self.celeryapp = Celery('omegaml')
176 self.celeryapp.config_from_object(celeryconf)
177 # needed to get it to actually load the tasks
178 # https://stackoverflow.com/a/35735471
179 self.celeryapp.autodiscover_tasks(taskpkgs, force=True)
180 self.celeryapp.finalize()
181
182 def _apply_broker_ssl(self, celeryconf):
183 # hook to apply broker ssl options
184 pass
185
186 def _client_is_pure_python(self):
187 try:
188 import pandas as pd
189 import numpy as np
190 import sklearn
191 except Exception as e:
192 logging.getLogger().info(e)
193 return True
194 else:
195 return False
196
197 def _sanitize_require(self, value):
198 # convert value into dict(label=value)
199 if isinstance(value, str):
200 return dict(label=value)
201 if isinstance(value, (list, tuple)):
202 return dict(*value)
203 return value
204
[docs]
205 def require(self, label=None, always=False, routing=None, task=None,
206 logging=None, override=True, **kwargs):
207 """
208 specify requirements for the task execution
209
210 Use this to specify resource or routing requirements on the next task
211 call sent to the runtime. Any requirements will be reset after the
212 call has been submitted.
213
214 Args:
215 always (bool): if True requirements will persist across task calls. defaults to False
216 label (str): the label required by the worker to have a runtime task dispatched to it.
217 'local' is equivalent to calling self.mode(local=True).
218 task (dict): if specified applied to the task kwargs
219 logging (str|tuple): if specified, same as runtime.mode(logging=...)
220 override (bool): if True overrides previously set .require(), defaults to True
221 kwargs: requirements specification that the runtime understands
222
223 Usage:
224 om.runtime.require(label='gpu').model('foo').fit(...)
225
226 Returns:
227 self
228 """
229 if label is not None:
230 if label == 'local':
231 self.mode(local=True)
232 elif override:
233 self.mode(local=False)
234 # update routing, don't replace (#416)
235 routing = routing or {}
236 routing.update({'label': label or self._default_label})
237 task = task or {}
238 routing = routing or {}
239 if task or routing:
240 if not override:
241 # override not allowed, remove previously existing
242 ex_task = dict(**self._task_default_kwargs['task'],
243 **self._require_kwargs['task'])
244 ex_routing = dict(**self._task_default_kwargs['routing'],
245 **self._require_kwargs['routing'])
246 exists_or_none = lambda k, d: k not in d or d.get(k, False) is None
247 task = {k: v for k, v in task.items() if exists_or_none(k, ex_task)}
248 routing = {k: v for k, v in routing.items() if exists_or_none(k, ex_routing)}
249 if always:
250 self._task_default_kwargs['routing'].update(routing)
251 self._task_default_kwargs['task'].update(task)
252 else:
253 self._require_kwargs['routing'].update(routing)
254 self._require_kwargs['task'].update(task)
255 else:
256 # FIXME this does not work as expected (will only reset if both task and routing are False)
257 if not task:
258 self._require_kwargs['task'] = {}
259 if not routing:
260 self._require_kwargs['routing'] = {}
261 if logging is not None:
262 self.mode(logging=logging)
263 return self
264
[docs]
265 def model(self, modelname, require=None):
266 """
267 return a model for remote execution
268
269 Args:
270 modelname (str): the name of the object in om.models
271 require (dict): routing requirements for this job
272
273 Returns:
274 OmegaModelProxy
275 """
276 from omegaml.runtimes.proxies.modelproxy import OmegaModelProxy
277 self.require(**self._sanitize_require(require)) if require else None
278 return OmegaModelProxy(modelname, runtime=self)
279
[docs]
280 def job(self, jobname, require=None):
281 """
282 return a job for remote execution
283
284 Args:
285 jobname (str): the name of the object in om.jobs
286 require (dict): routing requirements for this job
287
288 Returns:
289 OmegaJobProxy
290 """
291 from omegaml.runtimes.proxies.jobproxy import OmegaJobProxy
292
293 self.require(**self._sanitize_require(require)) if require else None
294 return OmegaJobProxy(jobname, runtime=self)
295
[docs]
296 def script(self, scriptname, require=None):
297 """
298 return a script for remote execution
299
300 Args:
301 scriptname (str): the name of object in om.scripts
302 require (dict): routing requirements for this job
303
304 Returns:
305 OmegaScriptProxy
306 """
307 from omegaml.runtimes.proxies.scriptproxy import OmegaScriptProxy
308
309 self.require(**self._sanitize_require(require)) if require else None
310 return OmegaScriptProxy(scriptname, runtime=self)
311
[docs]
312 def experiment(self, experiment, provider=None, implied_run=True, recreate=False, **tracker_kwargs):
313 """ set the tracking backend and experiment
314
315 Args:
316 experiment (str): the name of the experiment
317 provider (str): the name of the provider
318 tracker_kwargs (dict): additional kwargs for the tracker
319 recreate (bool): if True, recreate the experiment (i.e. drop and recreate,
320 this is useful to change the provider or other settings. All previous data will
321 be kept)
322
323 Returns:
324 OmegaTrackingProxy
325 """
326 from omegaml.runtimes.proxies.trackingproxy import OmegaTrackingProxy
327 # tracker implied_run means we are using the currently active run, i.e. with block will call exp.start()
328 tracker = OmegaTrackingProxy(experiment, provider=provider, runtime=self, implied_run=implied_run,
329 recreate=recreate, **tracker_kwargs)
330 return tracker
331
[docs]
332 def task(self, name, **kwargs):
333 """
334 retrieve the task function from the celery instance
335
336 Args:
337 name (str): a registered celery task as ``module.tasks.task_name``
338 kwargs (dict): routing keywords to CeleryTask.apply_async
339
340 Returns:
341 CeleryTask
342 """
343 taskfn = self.celeryapp.tasks.get(name)
344 assert taskfn is not None, "cannot find task {name} in Celery runtime".format(**locals())
345 kwargs = dict_merge(self._common_kwargs, dict(routing=kwargs))
346 task = CeleryTask(taskfn, kwargs)
347 self._require_kwargs = dict(routing={}, task={})
348 return task
349
350 def result(self, task_id, wait=True):
351 from celery.result import AsyncResult
352 promise = AsyncResult(task_id, app=self.celeryapp)
353 return promise.get() if wait else promise
354
[docs]
355 def settings(self, require=None):
356 """ return the runtimes's cluster settings
357 """
358 self.require(**require) if require else None
359 return self.task('omegaml.tasks.omega_settings').delay().get()
360
[docs]
361 def ping(self, *args, require=None, wait=True, timeout=10, **kwargs):
362 """
363 ping the runtime
364
365 Args:
366 args (tuple): task args
367 require (dict): routing requirements for this job
368 wait (bool): if True, wait for the task to return, else return
369 AsyncResult
370 timeout (int): if wait is True, the timeout in seconds, defaults to 10
371 kwargs (dict): task kwargs, as accepted by CeleryTask.apply_async
372
373 Returns:
374 * response (dict) for wait=True
375 * AsyncResult for wait=False
376 """
377 self.require(**require) if require else None
378 promise = self.task('omegaml.tasks.omega_ping').delay(*args, **kwargs)
379 return promise.get(timeout=timeout) if wait else promise
380
[docs]
381 def enable_hostqueues(self):
382 """ enable a worker-specific queue on every worker host
383
384 Returns:
385 list of labels (one entry for each hostname)
386 """
387 control = self.celeryapp.control
388 inspect = control.inspect()
389 active = inspect.active()
390 queues = []
391 for worker in active.keys():
392 hostname = worker.split('@')[-1]
393 control.cancel_consumer(hostname)
394 control.add_consumer(hostname, destination=[worker])
395 queues.append(hostname)
396 return queues
397
[docs]
398 def workers(self):
399 """ list of workers
400
401 Returns:
402 dict of workers => list of active tasks
403
404 See Also:
405 celery Inspect.active()
406 """
407 local_worker = {gethostname(): [{'name': 'local', 'is_local': True}]}
408 celery_workers = self._inspect.active() or {}
409 return dict_merge(local_worker, celery_workers)
410
[docs]
411 def queues(self):
412 """ list queues
413
414 Returns:
415 dict of workers => list of queues
416
417 See Also:
418 celery Inspect.active_queues()
419 """
420 local_q = {gethostname(): [{'name': 'local', 'is_local': True}]}
421 celery_qs = self._inspect.active_queues() or {}
422 return dict_merge(local_q, celery_qs)
423
[docs]
424 def labels(self):
425 """ list available labels
426
427 Returns:
428 dict of workers => list of lables
429 """
430 return {worker: [q.get('name') for q in queues]
431 for worker, queues in self.queues().items()}
432
[docs]
433 def stats(self):
434 """ worker statistics
435
436 Returns:
437 dict of workers => dict of stats
438
439 See Also:
440 celery Inspect.stats()
441 """
442 return self._inspect.stats()
443
[docs]
444 def status(self):
445 """ current cluster status
446
447 This collects key information from .labels(), .stats() and the latest
448 worker heartbeat. Note that loadavg is only available if the worker has
449 recently sent a heartbeat and may not be accurate across the cluster.
450
451 Returns:
452 snapshot (dict): a snapshot of the cluster status
453 '<worker>': {
454 'loadavg': [0.0, 0.0, 0.0], # load average in % seen by the worker (1, 5, 15 min)
455 'processes': 1, # number of active worker processes
456 'concurrency': 1, # max concurrency
457 'uptime': 0, # uptime in seconds
458 'processed': Counter(task=n), # number of tasks processed
459 'queues': ['default'], # list of queues (labels) the worker is listening on
460 }
461 """
462 labels = self.labels()
463 stats = self.stats()
464 heartbeat = self.events.latest()
465 snapshot = {
466 worker: {
467 'loadavg': heartbeat.get('loadavg', []),
468 'processes': stats[worker]['pool']['processes'],
469 'concurrency': stats[worker]['pool']['max-concurrency'],
470 'uptime': stats[worker]['uptime'],
471 'processed': stats[worker]['total'],
472 'queues': labels[worker],
473 } for worker in labels if worker in stats
474 }
475 return snapshot
476
477 @property
478 def events(self):
479 return CeleryEventStream(self.celeryapp)
480
[docs]
481 def callback(self, script_name, always=False, **kwargs):
482 """ Add a callback to a registered script
483
484 The callback will be triggered upon successful or failed
485 execution of the runtime tasks. The script syntax is::
486
487 # script.py
488 def run(om, state=None, result=None, **kwargs):
489 # state (str): 'SUCCESS'|'ERROR'
490 # result (obj): the task's serialized result
491
492 Args:
493 script_name (str): the name of the script (in om.scripts)
494 always (bool): if True always apply this callback, defaults to False
495 **kwargs: and other kwargs to pass on to the script
496
497 Returns:
498 self
499 """
500 success_sig = (self.script(script_name)
501 .task(as_callback=True)
502 .signature(args=['SUCCESS', script_name],
503 kwargs=kwargs,
504 immutable=False))
505 error_sig = (self.script(script_name)
506 .task(as_callback=True)
507 .signature(args=['ERROR', script_name],
508 kwargs=kwargs,
509 immutable=False))
510
511 if always:
512 self._task_default_kwargs['routing']['link'] = success_sig
513 self._task_default_kwargs['routing']['link_error'] = error_sig
514 else:
515 self._require_kwargs['routing']['link'] = success_sig
516 self._require_kwargs['routing']['link_error'] = error_sig
517 return self
518
519
520class CeleryEventStream:
521 def __init__(self, app, limit=None, timeout=5, wakeup=False):
522 self.app = app
523 self.limit = limit
524 self.timeout = timeout
525 self.wakeup = wakeup
526 self.max_size = 100
527 self.buffer = []
528
529 def handle(self, event):
530 self.buffer.append(event)
531 if len(self.buffer) > self.max_size:
532 self.buffer = self.buffer[-1 * self.max_size:]
533
534 def listen(self, handlers=None, limit=None, timeout=None):
535 # Connect to the broker using Kombu (Celery's underlying messaging system)
536 handlers = handlers or {'worker-heartbeat': self.handle}
537 limit = limit or self.limit
538 timeout = timeout or self.timeout
539 with self.app.connection() as conn:
540 # Create the EventReceiver to listen to all events
541 recv = EventReceiver(conn, handlers=handlers, app=self.app)
542 recv.capture(limit=limit, timeout=timeout, wakeup=self.wakeup)
543
544 def latest(self, timeout=None):
545 while len(self.buffer) == 0:
546 self.listen(limit=1, timeout=timeout)
547 return self.buffer[-1]
548
549
550# apply mixins
551from omegaml.runtimes.mixins.taskcanvas import canvas_chain, canvas_group, canvas_chord
552from omegaml.runtimes.mixins.swagger import SwaggerGenerator
553
554OmegaRuntime.sequence = canvas_chain
555OmegaRuntime.parallel = canvas_group
556OmegaRuntime.mapreduce = canvas_chord
557OmegaRuntime.swagger = SwaggerGenerator.build_swagger