Source code for omegaml.runtimes.runtime

  1from __future__ import absolute_import
  2
  3import logging
  4from celery import Celery
  5from celery.events import EventReceiver
  6from copy import deepcopy
  7from socket import gethostname
  8
  9from omegaml.mongoshim import mongo_url
 10from omegaml.util import dict_merge
 11
 12logger = logging.getLogger(__name__)
 13
 14
 15class CeleryTask(object):
 16    """
 17    A thin wrapper for a Celery.Task object
 18
 19    This is so that we can collect common delay arguments on the
 20    .task() call
 21    """
 22
 23    def __init__(self, task, kwargs):
 24        """
 25
 26        Args:
 27            task (Celery.Task): the celery task object
 28            kwargs (dict): optional, the kwargs to pass to apply_async
 29        """
 30        self.task = task
 31        self.kwargs = dict(kwargs)
 32
 33    def _apply_kwargs(self, task_kwargs, celery_kwargs):
 34        # update task_kwargs from runtime's passed on kwargs
 35        # update celery_kwargs to match celery routing semantics
 36        task_kwargs.update(self.kwargs.get('task', {}))
 37        celery_kwargs.update(self.kwargs.get('routing', {}))
 38        if 'label' in celery_kwargs:
 39            celery_kwargs['queue'] = celery_kwargs['label']
 40            del celery_kwargs['label']
 41
 42    def _apply_auth(self, args, kwargs, celery_kwargs):
 43        from omegaml.client.auth import AuthenticationEnv
 44        AuthenticationEnv.active().taskauth(args, kwargs, celery_kwargs)
 45
 46    def apply_async(self, args=None, kwargs=None, **celery_kwargs):
 47        """
 48
 49        Args:
 50            args (tuple): the task args
 51            kwargs (dict): the task kwargs
 52            celery_kwargs (dict): apply_async kwargs, e.g. routing
 53
 54        Returns:
 55            AsyncResult
 56        """
 57        args = args or tuple()
 58        kwargs = kwargs or {}
 59        self._apply_kwargs(kwargs, celery_kwargs)
 60        self._apply_auth(args, kwargs, celery_kwargs)
 61        return self.task.apply_async(args=args, kwargs=kwargs, **celery_kwargs)
 62
 63    def delay(self, *args, **kwargs):
 64        """
 65        submit the task with args and kwargs to pass on
 66
 67        This calls task.apply_async and passes on the self.kwargs.
 68        """
 69        return self.apply_async(args=args, kwargs=kwargs)
 70
 71    def signature(self, args=None, kwargs=None, immutable=False, **celery_kwargs):
 72        """ return the task signature with all kwargs and celery_kwargs applied
 73        """
 74        self._apply_kwargs(kwargs, celery_kwargs)
 75        sig = self.task.signature(args=args, kwargs=kwargs, **celery_kwargs, immutable=immutable)
 76        return sig
 77
 78    def run(self, *args, **kwargs):
 79        return self.delay(*args, **kwargs)
 80
 81
[docs] 82class OmegaRuntime(object): 83 """ 84 omegaml compute cluster gateway 85 """ 86 87 def __init__(self, omega, bucket=None, defaults=None, celeryconf=None): 88 from omegaml.util import settings 89 90 self.omega = omega 91 defaults = defaults or settings() 92 self.bucket = bucket 93 self.pure_python = getattr(defaults, 'OMEGA_FORCE_PYTHON_CLIENT', False) 94 self.pure_python = self.pure_python or self._client_is_pure_python() 95 self._create_celery_app(defaults, celeryconf=celeryconf) 96 # temporary requirements, use .require() to set 97 self._require_kwargs = dict(task={}, routing={}) 98 # fixed default arguments, use .require(always=True) to set 99 self._task_default_kwargs = dict(task={}, routing={}) 100 # default routing label 101 self._default_label = self.celeryapp.conf.get('CELERY_DEFAULT_QUEUE') 102 103 def __repr__(self): 104 return 'OmegaRuntime({})'.format(self.omega.__repr__()) 105 106 @property 107 def auth(self): 108 return None 109 110 @property 111 def _common_kwargs(self): 112 common = deepcopy(self._task_default_kwargs) 113 common['task'].update(pure_python=self.pure_python, __bucket=self.bucket) 114 common['task'].update(self._require_kwargs['task']) 115 common['routing'].update(self._require_kwargs['routing']) 116 return common 117 118 @property 119 def _inspect(self): 120 return self.celeryapp.control.inspect() 121 122 @property 123 def is_local(self): 124 return self.celeryapp.conf['CELERY_ALWAYS_EAGER'] 125
[docs] 126 def mode(self, local=None, logging=None): 127 """ specify runtime modes 128 129 Args: 130 local (bool): if True, all execution will run locally, else on 131 the configured remote cluster 132 logging (bool|str|tuple): if True, will set the root logger output 133 at INFO level; a single string is the name of the logger, 134 typically a module name; a tuple (logger, level) will select 135 logger and the level. Valid levels are INFO, WARNING, ERROR, 136 CRITICAL, DEBUG 137 138 Usage:: 139 140 # run all runtime tasks locally 141 om.runtime.mode(local=True) 142 143 # enable logging both in local and remote mode 144 om.runtime.mode(logging=True) 145 146 # select a specific module and level) 147 om.runtime.mode(logging=('sklearn', 'DEBUG')) 148 149 # disable logging 150 om.runtime.mode(logging=False) 151 """ 152 if isinstance(local, bool): 153 self.celeryapp.conf['CELERY_ALWAYS_EAGER'] = local 154 self._task_default_kwargs['task']['__logging'] = logging 155 return self
156 157 def _create_celery_app(self, defaults, celeryconf=None): 158 # initialize celery as a runtimes 159 taskpkgs = defaults.OMEGA_CELERY_IMPORTS 160 celeryconf = dict(celeryconf or defaults.OMEGA_CELERY_CONFIG) 161 # ensure we use current value 162 celeryconf['CELERY_ALWAYS_EAGER'] = bool(defaults.OMEGA_LOCAL_RUNTIME) 163 if celeryconf['CELERY_RESULT_BACKEND'].startswith('mongodb://'): 164 celeryconf['CELERY_RESULT_BACKEND'] = mongo_url(self.omega, drop_kwargs=['uuidRepresentation']) 165 # initialize ssl configuration 166 if celeryconf.get('BROKER_USE_SSL'): 167 # celery > 5 requires ssl options to be specific 168 # https://docs.celeryq.dev/en/stable/userguide/configuration.html#std-setting-broker_use_ssl 169 # https://github.com/celery/kombu/issues/1493 170 # https://docs.python.org/dev/library/ssl.html#ssl.wrap_socket 171 # https://www.openssl.org/docs/man3.0/man3/SSL_CTX_set_default_verify_paths.html 172 # env variables: 173 # SSL_CERT_FILE, CA_CERTS_PATH 174 self._apply_broker_ssl(celeryconf) 175 self.celeryapp = Celery('omegaml') 176 self.celeryapp.config_from_object(celeryconf) 177 # needed to get it to actually load the tasks 178 # https://stackoverflow.com/a/35735471 179 self.celeryapp.autodiscover_tasks(taskpkgs, force=True) 180 self.celeryapp.finalize() 181 182 def _apply_broker_ssl(self, celeryconf): 183 # hook to apply broker ssl options 184 pass 185 186 def _client_is_pure_python(self): 187 try: 188 import pandas as pd 189 import numpy as np 190 import sklearn 191 except Exception as e: 192 logging.getLogger().info(e) 193 return True 194 else: 195 return False 196 197 def _sanitize_require(self, value): 198 # convert value into dict(label=value) 199 if isinstance(value, str): 200 return dict(label=value) 201 if isinstance(value, (list, tuple)): 202 return dict(*value) 203 return value 204
[docs] 205 def require(self, label=None, always=False, routing=None, task=None, 206 logging=None, override=True, **kwargs): 207 """ 208 specify requirements for the task execution 209 210 Use this to specify resource or routing requirements on the next task 211 call sent to the runtime. Any requirements will be reset after the 212 call has been submitted. 213 214 Args: 215 always (bool): if True requirements will persist across task calls. defaults to False 216 label (str): the label required by the worker to have a runtime task dispatched to it. 217 'local' is equivalent to calling self.mode(local=True). 218 task (dict): if specified applied to the task kwargs 219 logging (str|tuple): if specified, same as runtime.mode(logging=...) 220 override (bool): if True overrides previously set .require(), defaults to True 221 kwargs: requirements specification that the runtime understands 222 223 Usage: 224 om.runtime.require(label='gpu').model('foo').fit(...) 225 226 Returns: 227 self 228 """ 229 if label is not None: 230 if label == 'local': 231 self.mode(local=True) 232 elif override: 233 self.mode(local=False) 234 # update routing, don't replace (#416) 235 routing = routing or {} 236 routing.update({'label': label or self._default_label}) 237 task = task or {} 238 routing = routing or {} 239 if task or routing: 240 if not override: 241 # override not allowed, remove previously existing 242 ex_task = dict(**self._task_default_kwargs['task'], 243 **self._require_kwargs['task']) 244 ex_routing = dict(**self._task_default_kwargs['routing'], 245 **self._require_kwargs['routing']) 246 exists_or_none = lambda k, d: k not in d or d.get(k, False) is None 247 task = {k: v for k, v in task.items() if exists_or_none(k, ex_task)} 248 routing = {k: v for k, v in routing.items() if exists_or_none(k, ex_routing)} 249 if always: 250 self._task_default_kwargs['routing'].update(routing) 251 self._task_default_kwargs['task'].update(task) 252 else: 253 self._require_kwargs['routing'].update(routing) 254 self._require_kwargs['task'].update(task) 255 else: 256 # FIXME this does not work as expected (will only reset if both task and routing are False) 257 if not task: 258 self._require_kwargs['task'] = {} 259 if not routing: 260 self._require_kwargs['routing'] = {} 261 if logging is not None: 262 self.mode(logging=logging) 263 return self
264
[docs] 265 def model(self, modelname, require=None): 266 """ 267 return a model for remote execution 268 269 Args: 270 modelname (str): the name of the object in om.models 271 require (dict): routing requirements for this job 272 273 Returns: 274 OmegaModelProxy 275 """ 276 from omegaml.runtimes.proxies.modelproxy import OmegaModelProxy 277 self.require(**self._sanitize_require(require)) if require else None 278 return OmegaModelProxy(modelname, runtime=self)
279
[docs] 280 def job(self, jobname, require=None): 281 """ 282 return a job for remote execution 283 284 Args: 285 jobname (str): the name of the object in om.jobs 286 require (dict): routing requirements for this job 287 288 Returns: 289 OmegaJobProxy 290 """ 291 from omegaml.runtimes.proxies.jobproxy import OmegaJobProxy 292 293 self.require(**self._sanitize_require(require)) if require else None 294 return OmegaJobProxy(jobname, runtime=self)
295
[docs] 296 def script(self, scriptname, require=None): 297 """ 298 return a script for remote execution 299 300 Args: 301 scriptname (str): the name of object in om.scripts 302 require (dict): routing requirements for this job 303 304 Returns: 305 OmegaScriptProxy 306 """ 307 from omegaml.runtimes.proxies.scriptproxy import OmegaScriptProxy 308 309 self.require(**self._sanitize_require(require)) if require else None 310 return OmegaScriptProxy(scriptname, runtime=self)
311
[docs] 312 def experiment(self, experiment, provider=None, implied_run=True, recreate=False, **tracker_kwargs): 313 """ set the tracking backend and experiment 314 315 Args: 316 experiment (str): the name of the experiment 317 provider (str): the name of the provider 318 tracker_kwargs (dict): additional kwargs for the tracker 319 recreate (bool): if True, recreate the experiment (i.e. drop and recreate, 320 this is useful to change the provider or other settings. All previous data will 321 be kept) 322 323 Returns: 324 OmegaTrackingProxy 325 """ 326 from omegaml.runtimes.proxies.trackingproxy import OmegaTrackingProxy 327 # tracker implied_run means we are using the currently active run, i.e. with block will call exp.start() 328 tracker = OmegaTrackingProxy(experiment, provider=provider, runtime=self, implied_run=implied_run, 329 recreate=recreate, **tracker_kwargs) 330 return tracker
331
[docs] 332 def task(self, name, **kwargs): 333 """ 334 retrieve the task function from the celery instance 335 336 Args: 337 name (str): a registered celery task as ``module.tasks.task_name`` 338 kwargs (dict): routing keywords to CeleryTask.apply_async 339 340 Returns: 341 CeleryTask 342 """ 343 taskfn = self.celeryapp.tasks.get(name) 344 assert taskfn is not None, "cannot find task {name} in Celery runtime".format(**locals()) 345 kwargs = dict_merge(self._common_kwargs, dict(routing=kwargs)) 346 task = CeleryTask(taskfn, kwargs) 347 self._require_kwargs = dict(routing={}, task={}) 348 return task
349 350 def result(self, task_id, wait=True): 351 from celery.result import AsyncResult 352 promise = AsyncResult(task_id, app=self.celeryapp) 353 return promise.get() if wait else promise 354
[docs] 355 def settings(self, require=None): 356 """ return the runtimes's cluster settings 357 """ 358 self.require(**require) if require else None 359 return self.task('omegaml.tasks.omega_settings').delay().get()
360
[docs] 361 def ping(self, *args, require=None, wait=True, timeout=10, **kwargs): 362 """ 363 ping the runtime 364 365 Args: 366 args (tuple): task args 367 require (dict): routing requirements for this job 368 wait (bool): if True, wait for the task to return, else return 369 AsyncResult 370 timeout (int): if wait is True, the timeout in seconds, defaults to 10 371 kwargs (dict): task kwargs, as accepted by CeleryTask.apply_async 372 373 Returns: 374 * response (dict) for wait=True 375 * AsyncResult for wait=False 376 """ 377 self.require(**require) if require else None 378 promise = self.task('omegaml.tasks.omega_ping').delay(*args, **kwargs) 379 return promise.get(timeout=timeout) if wait else promise
380
[docs] 381 def enable_hostqueues(self): 382 """ enable a worker-specific queue on every worker host 383 384 Returns: 385 list of labels (one entry for each hostname) 386 """ 387 control = self.celeryapp.control 388 inspect = control.inspect() 389 active = inspect.active() 390 queues = [] 391 for worker in active.keys(): 392 hostname = worker.split('@')[-1] 393 control.cancel_consumer(hostname) 394 control.add_consumer(hostname, destination=[worker]) 395 queues.append(hostname) 396 return queues
397
[docs] 398 def workers(self): 399 """ list of workers 400 401 Returns: 402 dict of workers => list of active tasks 403 404 See Also: 405 celery Inspect.active() 406 """ 407 local_worker = {gethostname(): [{'name': 'local', 'is_local': True}]} 408 celery_workers = self._inspect.active() or {} 409 return dict_merge(local_worker, celery_workers)
410
[docs] 411 def queues(self): 412 """ list queues 413 414 Returns: 415 dict of workers => list of queues 416 417 See Also: 418 celery Inspect.active_queues() 419 """ 420 local_q = {gethostname(): [{'name': 'local', 'is_local': True}]} 421 celery_qs = self._inspect.active_queues() or {} 422 return dict_merge(local_q, celery_qs)
423
[docs] 424 def labels(self): 425 """ list available labels 426 427 Returns: 428 dict of workers => list of lables 429 """ 430 return {worker: [q.get('name') for q in queues] 431 for worker, queues in self.queues().items()}
432
[docs] 433 def stats(self): 434 """ worker statistics 435 436 Returns: 437 dict of workers => dict of stats 438 439 See Also: 440 celery Inspect.stats() 441 """ 442 return self._inspect.stats()
443
[docs] 444 def status(self): 445 """ current cluster status 446 447 This collects key information from .labels(), .stats() and the latest 448 worker heartbeat. Note that loadavg is only available if the worker has 449 recently sent a heartbeat and may not be accurate across the cluster. 450 451 Returns: 452 snapshot (dict): a snapshot of the cluster status 453 '<worker>': { 454 'loadavg': [0.0, 0.0, 0.0], # load average in % seen by the worker (1, 5, 15 min) 455 'processes': 1, # number of active worker processes 456 'concurrency': 1, # max concurrency 457 'uptime': 0, # uptime in seconds 458 'processed': Counter(task=n), # number of tasks processed 459 'queues': ['default'], # list of queues (labels) the worker is listening on 460 } 461 """ 462 labels = self.labels() 463 stats = self.stats() 464 heartbeat = self.events.latest() 465 snapshot = { 466 worker: { 467 'loadavg': heartbeat.get('loadavg', []), 468 'processes': stats[worker]['pool']['processes'], 469 'concurrency': stats[worker]['pool']['max-concurrency'], 470 'uptime': stats[worker]['uptime'], 471 'processed': stats[worker]['total'], 472 'queues': labels[worker], 473 } for worker in labels if worker in stats 474 } 475 return snapshot
476 477 @property 478 def events(self): 479 return CeleryEventStream(self.celeryapp) 480
[docs] 481 def callback(self, script_name, always=False, **kwargs): 482 """ Add a callback to a registered script 483 484 The callback will be triggered upon successful or failed 485 execution of the runtime tasks. The script syntax is:: 486 487 # script.py 488 def run(om, state=None, result=None, **kwargs): 489 # state (str): 'SUCCESS'|'ERROR' 490 # result (obj): the task's serialized result 491 492 Args: 493 script_name (str): the name of the script (in om.scripts) 494 always (bool): if True always apply this callback, defaults to False 495 **kwargs: and other kwargs to pass on to the script 496 497 Returns: 498 self 499 """ 500 success_sig = (self.script(script_name) 501 .task(as_callback=True) 502 .signature(args=['SUCCESS', script_name], 503 kwargs=kwargs, 504 immutable=False)) 505 error_sig = (self.script(script_name) 506 .task(as_callback=True) 507 .signature(args=['ERROR', script_name], 508 kwargs=kwargs, 509 immutable=False)) 510 511 if always: 512 self._task_default_kwargs['routing']['link'] = success_sig 513 self._task_default_kwargs['routing']['link_error'] = error_sig 514 else: 515 self._require_kwargs['routing']['link'] = success_sig 516 self._require_kwargs['routing']['link_error'] = error_sig 517 return self
518 519 520class CeleryEventStream: 521 def __init__(self, app, limit=None, timeout=5, wakeup=False): 522 self.app = app 523 self.limit = limit 524 self.timeout = timeout 525 self.wakeup = wakeup 526 self.max_size = 100 527 self.buffer = [] 528 529 def handle(self, event): 530 self.buffer.append(event) 531 if len(self.buffer) > self.max_size: 532 self.buffer = self.buffer[-1 * self.max_size:] 533 534 def listen(self, handlers=None, limit=None, timeout=None): 535 # Connect to the broker using Kombu (Celery's underlying messaging system) 536 handlers = handlers or {'worker-heartbeat': self.handle} 537 limit = limit or self.limit 538 timeout = timeout or self.timeout 539 with self.app.connection() as conn: 540 # Create the EventReceiver to listen to all events 541 recv = EventReceiver(conn, handlers=handlers, app=self.app) 542 recv.capture(limit=limit, timeout=timeout, wakeup=self.wakeup) 543 544 def latest(self, timeout=None): 545 while len(self.buffer) == 0: 546 self.listen(limit=1, timeout=timeout) 547 return self.buffer[-1] 548 549 550# apply mixins 551from omegaml.runtimes.mixins.taskcanvas import canvas_chain, canvas_group, canvas_chord 552from omegaml.runtimes.mixins.swagger import SwaggerGenerator 553 554OmegaRuntime.sequence = canvas_chain 555OmegaRuntime.parallel = canvas_group 556OmegaRuntime.mapreduce = canvas_chord 557OmegaRuntime.swagger = SwaggerGenerator.build_swagger