Source code for omegaml.runtimes.runtime

  1from __future__ import absolute_import
  2
  3import logging
  4from copy import deepcopy
  5from socket import gethostname
  6
  7from celery import Celery
  8from celery.events import EventReceiver
  9
 10from omegaml.mongoshim import mongo_url
 11from omegaml.util import dict_merge
 12
 13logger = logging.getLogger(__name__)
 14
 15
 16class CeleryTask(object):
 17    """
 18    A thin wrapper for a Celery.Task object
 19
 20    This is so that we can collect common delay arguments on the
 21    .task() call
 22    """
 23
 24    def __init__(self, task, kwargs):
 25        """
 26
 27        Args:
 28            task (Celery.Task): the celery task object
 29            kwargs (dict): optional, the kwargs to pass to apply_async
 30        """
 31        self.task = task
 32        self.kwargs = dict(kwargs)
 33
 34    def _apply_kwargs(self, task_kwargs, celery_kwargs):
 35        # update task_kwargs from runtime's passed on kwargs
 36        # update celery_kwargs to match celery routing semantics
 37        task_kwargs.update(self.kwargs.get('task', {}))
 38        celery_kwargs.update(self.kwargs.get('routing', {}))
 39        if 'label' in celery_kwargs:
 40            celery_kwargs['queue'] = celery_kwargs['label']
 41            del celery_kwargs['label']
 42
 43    def _apply_auth(self, args, kwargs, celery_kwargs):
 44        from omegaml.client.auth import AuthenticationEnv
 45        AuthenticationEnv.active().taskauth(args, kwargs, celery_kwargs)
 46
 47    def apply_async(self, args=None, kwargs=None, **celery_kwargs):
 48        """
 49
 50        Args:
 51            args (tuple): the task args
 52            kwargs (dict): the task kwargs, passed as task.apply_async(kwargs=kwargs)
 53            celery_kwargs (dict): apply_async kwargs, passed as task.apply_async(..., **celery_kwargs)
 54
 55        Returns:
 56            AsyncResult
 57        """
 58        args = args or tuple()
 59        kwargs = kwargs or {}
 60        self._apply_kwargs(kwargs, celery_kwargs)
 61        self._apply_auth(args, kwargs, celery_kwargs)
 62        return self.task.apply_async(args=args, kwargs=kwargs, **celery_kwargs)
 63
 64    def delay(self, *args, **kwargs):
 65        """
 66        submit the task with args and kwargs to pass on
 67
 68        This calls task.apply_async and passes on the self.kwargs.
 69        """
 70        return self.apply_async(args=args, kwargs=kwargs)
 71
 72    def signature(self, args=None, kwargs=None, immutable=False, **celery_kwargs):
 73        """ return the task signature with all kwargs and celery_kwargs applied
 74        """
 75        self._apply_kwargs(kwargs, celery_kwargs)
 76        sig = self.task.signature(args=args, kwargs=kwargs, **celery_kwargs, immutable=immutable)
 77        return sig
 78
 79    def run(self, *args, **kwargs):
 80        return self.delay(*args, **kwargs)
 81
 82
[docs] 83class OmegaRuntime(object): 84 """ 85 omegaml compute cluster gateway 86 """ 87 88 def __init__(self, omega, bucket=None, defaults=None, celeryconf=None): 89 from omegaml.util import settings 90 91 self.omega = omega 92 defaults = defaults or settings() 93 self.bucket = bucket 94 self.pure_python = getattr(defaults, 'OMEGA_FORCE_PYTHON_CLIENT', False) 95 self.pure_python = self.pure_python or self._client_is_pure_python() 96 self._create_celery_app(defaults, celeryconf=celeryconf) 97 # temporary requirements, use .require() to set 98 self._require_kwargs = dict(task={}, routing={}) 99 # fixed default arguments, use .require(always=True) to set 100 self._task_default_kwargs = dict(task={}, routing={}) 101 # default routing label 102 self._default_label = self.celeryapp.conf.get('CELERY_DEFAULT_QUEUE') 103 104 def __repr__(self): 105 return 'OmegaRuntime({})'.format(self.omega.__repr__()) 106 107 @property 108 def auth(self): 109 return None 110 111 @property 112 def _common_kwargs(self): 113 common = deepcopy(self._task_default_kwargs) 114 common['task'].update(pure_python=self.pure_python, __bucket=self.bucket) 115 common['task'].update(self._require_kwargs['task']) 116 common['routing'].update(self._require_kwargs['routing']) 117 return common 118 119 @property 120 def _inspect(self): 121 return self.celeryapp.control.inspect() 122 123 @property 124 def is_local(self): 125 return self.celeryapp.conf['CELERY_ALWAYS_EAGER'] 126 127 def mode(self, local=None, logging=None): 128 """ specify runtime modes 129 130 Args: 131 local (bool): if True, all execution will run locally, else on 132 the configured remote cluster 133 logging (bool|str|tuple): if True, will set the root logger output 134 at INFO level; a single string is the name of the logger, 135 typically a module name; a tuple (logger, level) will select 136 logger and the level. Valid levels are INFO, WARNING, ERROR, 137 CRITICAL, DEBUG 138 139 Usage:: 140 141 # run all runtime tasks locally 142 om.runtime.mode(local=True) 143 144 # enable logging both in local and remote mode 145 om.runtime.mode(logging=True) 146 147 # select a specific module and level) 148 om.runtime.mode(logging=('sklearn', 'DEBUG')) 149 150 # disable logging 151 om.runtime.mode(logging=False) 152 """ 153 if isinstance(local, bool): 154 self.celeryapp.conf['CELERY_ALWAYS_EAGER'] = local 155 self._task_default_kwargs['task']['__logging'] = logging 156 return self 157 158 def _create_celery_app(self, defaults, celeryconf=None): 159 # initialize celery as a runtimes 160 taskpkgs = defaults.OMEGA_CELERY_IMPORTS 161 celeryconf = dict(celeryconf or defaults.OMEGA_CELERY_CONFIG) 162 # ensure we use current value 163 celeryconf['CELERY_ALWAYS_EAGER'] = bool(defaults.OMEGA_LOCAL_RUNTIME) 164 if celeryconf['CELERY_RESULT_BACKEND'].startswith('mongodb://'): 165 celeryconf['CELERY_RESULT_BACKEND'] = mongo_url(self.omega, drop_kwargs=['uuidRepresentation']) 166 # initialize ssl configuration 167 if celeryconf.get('BROKER_USE_SSL'): 168 # celery > 5 requires ssl options to be specific 169 # https://docs.celeryq.dev/en/stable/userguide/configuration.html#std-setting-broker_use_ssl 170 # https://github.com/celery/kombu/issues/1493 171 # https://docs.python.org/dev/library/ssl.html#ssl.wrap_socket 172 # https://www.openssl.org/docs/man3.0/man3/SSL_CTX_set_default_verify_paths.html 173 # env variables: 174 # SSL_CERT_FILE, CA_CERTS_PATH 175 self._apply_broker_ssl(celeryconf) 176 self.celeryapp = Celery('omegaml') 177 self.celeryapp.config_from_object(celeryconf) 178 # needed to get it to actually load the tasks 179 # https://stackoverflow.com/a/35735471 180 self.celeryapp.autodiscover_tasks(taskpkgs, force=True) 181 self.celeryapp.finalize() 182 183 def _apply_broker_ssl(self, celeryconf): 184 # hook to apply broker ssl options 185 pass 186 187 def _client_is_pure_python(self): 188 try: 189 import pandas as pd 190 import numpy as np 191 import sklearn 192 except Exception as e: 193 logging.getLogger().info(e) 194 return True 195 else: 196 return False 197 198 def _sanitize_require(self, value): 199 # convert value into dict(label=value) 200 if isinstance(value, str): 201 return dict(label=value) 202 if isinstance(value, (list, tuple)): 203 return dict(*value) 204 return value 205 206 def require(self, label=None, always=False, routing=None, task=None, 207 logging=None, override=True, **kwargs): 208 """ 209 specify requirements for the task execution 210 211 Use this to specify resource or routing requirements on the next task 212 call sent to the runtime. Any requirements will be reset after the 213 call has been submitted. 214 215 Args: 216 always (bool): if True requirements will persist across task calls. defaults to False 217 label (str): the label required by the worker to have a runtime task dispatched to it. 218 'local' is equivalent to calling self.mode(local=True). 219 task (dict): if specified applied to the task's kwargs, passed as task.apply_async(..., kwargs=task) 220 routing (dict): if specified applied to the task's routing, passed as task.apply_async(..., **routing) 221 logging (str|tuple): if specified, same as runtime.mode(logging=...) 222 override (bool): if True overrides previously set .require(), defaults to True 223 kwargs: requirements specification that the runtime understands 224 225 Usage: 226 om.runtime.require(label='gpu').model('foo').fit(...) 227 228 See Also: 229 - CeleryTask.apply_async 230 - celery.app.task.Task.apply_async, specifically the kwargs= and **options 231 232 Returns: 233 self 234 """ 235 if label: 236 # avoid overriding a previous local() call by an erronous label 237 assert isinstance(label, str), "label must be valid, run om.runtime.labels() to list of active workers" 238 if label == 'local': 239 self.mode(local=True) 240 elif override: 241 self.mode(local=False) 242 # update routing, don't replace (#416) 243 routing = routing or {} 244 routing.update({'label': label or self._default_label}) 245 task = task or {} 246 routing = routing or {} 247 if task or routing: 248 if not override: 249 # override not allowed, remove previously existing 250 ex_task = dict(**self._task_default_kwargs['task'], 251 **self._require_kwargs['task']) 252 ex_routing = dict(**self._task_default_kwargs['routing'], 253 **self._require_kwargs['routing']) 254 exists_or_none = lambda k, d: k not in d or d.get(k, False) is None 255 task = {k: v for k, v in task.items() if exists_or_none(k, ex_task)} 256 routing = {k: v for k, v in routing.items() if exists_or_none(k, ex_routing)} 257 if always: 258 self._task_default_kwargs['routing'].update(routing) 259 self._task_default_kwargs['task'].update(task) 260 else: 261 self._require_kwargs['routing'].update(routing) 262 self._require_kwargs['task'].update(task) 263 else: 264 # FIXME this does not work as expected (will only reset if both task and routing are False) 265 if not task: 266 self._require_kwargs['task'] = {} 267 if not routing: 268 self._require_kwargs['routing'] = {} 269 if logging is not None: 270 self.mode(logging=logging) 271 return self 272
[docs] 273 def model(self, modelname, require=None): 274 """ 275 return a model for remote execution 276 277 Args: 278 modelname (str): the name of the object in om.models 279 require (dict): routing requirements for this job 280 281 Returns: 282 OmegaModelProxy 283 """ 284 from omegaml.runtimes.proxies.modelproxy import OmegaModelProxy 285 self.require(**self._sanitize_require(require)) if require else None 286 return OmegaModelProxy(modelname, runtime=self)
287 288 def job(self, jobname, require=None): 289 """ 290 return a job for remote execution 291 292 Args: 293 jobname (str): the name of the object in om.jobs 294 require (dict): routing requirements for this job 295 296 Returns: 297 OmegaJobProxy 298 """ 299 from omegaml.runtimes.proxies.jobproxy import OmegaJobProxy 300 301 self.require(**self._sanitize_require(require)) if require else None 302 return OmegaJobProxy(jobname, runtime=self) 303 304 def script(self, scriptname, require=None): 305 """ 306 return a script for remote execution 307 308 Args: 309 scriptname (str): the name of object in om.scripts 310 require (dict): routing requirements for this job 311 312 Returns: 313 OmegaScriptProxy 314 """ 315 from omegaml.runtimes.proxies.scriptproxy import OmegaScriptProxy 316 317 self.require(**self._sanitize_require(require)) if require else None 318 return OmegaScriptProxy(scriptname, runtime=self) 319 320 def experiment(self, experiment, provider=None, implied_run=True, recreate=False, **tracker_kwargs): 321 """ set the tracking backend and experiment 322 323 Args: 324 experiment (str): the name of the experiment 325 provider (str): the name of the provider 326 tracker_kwargs (dict): additional kwargs for the tracker 327 recreate (bool): if True, recreate the experiment (i.e. drop and recreate, 328 this is useful to change the provider or other settings. All previous data will 329 be kept) 330 331 Returns: 332 OmegaTrackingProxy 333 """ 334 from omegaml.runtimes.proxies.trackingproxy import OmegaTrackingProxy 335 # tracker implied_run means we are using the currently active run, i.e. with block will call exp.start() 336 tracker = OmegaTrackingProxy(experiment, provider=provider, runtime=self, implied_run=implied_run, 337 recreate=recreate, **tracker_kwargs) 338 return tracker 339 340 def task(self, name, **kwargs): 341 """ 342 retrieve the task function from the celery instance 343 344 Args: 345 name (str): a registered celery task as ``module.tasks.task_name`` 346 kwargs (dict): routing keywords to CeleryTask.apply_async 347 348 Returns: 349 CeleryTask 350 """ 351 taskfn = self.celeryapp.tasks.get(name) 352 assert taskfn is not None, "cannot find task {name} in Celery runtime".format(**locals()) 353 kwargs = dict_merge(self._common_kwargs, dict(routing=kwargs)) 354 task = CeleryTask(taskfn, kwargs) 355 self._require_kwargs = dict(routing={}, task={}) 356 return task 357 358 def result(self, task_id, wait=True): 359 from celery.result import AsyncResult 360 promise = AsyncResult(task_id, app=self.celeryapp) 361 return promise.get() if wait else promise 362 363 def settings(self, require=None): 364 """ return the runtimes's cluster settings 365 """ 366 self.require(**require) if require else None 367 return self.task('omegaml.tasks.omega_settings').delay().get() 368 369 def ping(self, *args, require=None, wait=True, timeout=10, **kwargs): 370 """ 371 ping the runtime 372 373 Args: 374 args (tuple): task args 375 require (dict): routing requirements for this job 376 wait (bool): if True, wait for the task to return, else return 377 AsyncResult 378 timeout (int): if wait is True, the timeout in seconds, defaults to 10 379 kwargs (dict): task kwargs, as accepted by CeleryTask.apply_async 380 381 Returns: 382 * response (dict) for wait=True 383 * AsyncResult for wait=False 384 """ 385 self.require(**require) if require else None 386 promise = self.task('omegaml.tasks.omega_ping').delay(*args, **kwargs) 387 return promise.get(timeout=timeout) if wait else promise 388 389 def enable_hostqueues(self): 390 """ enable a worker-specific queue on every worker host 391 392 Returns: 393 list of labels (one entry for each hostname) 394 """ 395 control = self.celeryapp.control 396 inspect = control.inspect() 397 active = inspect.active() 398 queues = [] 399 for worker in active.keys(): 400 hostname = worker.split('@')[-1] 401 control.cancel_consumer(hostname) 402 control.add_consumer(hostname, destination=[worker]) 403 queues.append(hostname) 404 return queues 405 406 def workers(self): 407 """ list of workers 408 409 Returns: 410 dict of workers => list of active tasks 411 412 See Also: 413 celery Inspect.active() 414 """ 415 local_worker = {gethostname(): [{'name': 'local', 'is_local': True}]} 416 celery_workers = self._inspect.active() or {} 417 return dict_merge(local_worker, celery_workers) 418 419 def queues(self): 420 """ list queues 421 422 Returns: 423 dict of workers => list of queues 424 425 See Also: 426 celery Inspect.active_queues() 427 """ 428 local_q = {gethostname(): [{'name': 'local', 'is_local': True}]} 429 celery_qs = self._inspect.active_queues() or {} 430 return dict_merge(local_q, celery_qs) 431 432 def labels(self): 433 """ list available labels 434 435 Returns: 436 dict of workers => list of lables 437 """ 438 return {worker: [q.get('name') for q in queues] 439 for worker, queues in self.queues().items()} 440 441 def stats(self): 442 """ worker statistics 443 444 Returns: 445 dict of workers => dict of stats 446 447 See Also: 448 celery Inspect.stats() 449 """ 450 return self._inspect.stats() 451 452 def status(self): 453 """ current cluster status 454 455 This collects key information from .labels(), .stats() and the latest 456 worker heartbeat. Note that loadavg is only available if the worker has 457 recently sent a heartbeat and may not be accurate across the cluster. 458 459 Returns: 460 snapshot (dict): a snapshot of the cluster status 461 '<worker>': { 462 'loadavg': [0.0, 0.0, 0.0], # load average in % seen by the worker (1, 5, 15 min) 463 'processes': 1, # number of active worker processes 464 'concurrency': 1, # max concurrency 465 'uptime': 0, # uptime in seconds 466 'processed': Counter(task=n), # number of tasks processed 467 'queues': ['default'], # list of queues (labels) the worker is listening on 468 } 469 """ 470 labels = self.labels() 471 stats = self.stats() 472 heartbeat = self.events.latest() 473 snapshot = { 474 worker: { 475 'loadavg': heartbeat.get('loadavg', []), 476 'processes': stats[worker]['pool']['processes'], 477 'concurrency': stats[worker]['pool']['max-concurrency'], 478 'uptime': stats[worker]['uptime'], 479 'processed': stats[worker]['total'], 480 'queues': labels[worker], 481 } for worker in labels if worker in stats 482 } 483 return snapshot 484 485 @property 486 def events(self): 487 return CeleryEventStream(self.celeryapp) 488 489 def callback(self, script_name, always=False, **kwargs): 490 """ Add a callback to a registered script 491 492 The callback will be triggered upon successful or failed 493 execution of the runtime tasks. The script syntax is:: 494 495 # script.py 496 def run(om, state=None, result=None, **kwargs): 497 # state (str): 'SUCCESS'|'ERROR' 498 # result (obj): the task's serialized result 499 500 Args: 501 script_name (str): the name of the script (in om.scripts) 502 always (bool): if True always apply this callback, defaults to False 503 **kwargs: and other kwargs to pass on to the script 504 505 Returns: 506 self 507 """ 508 success_sig = (self.script(script_name) 509 .task(as_callback=True) 510 .signature(args=['SUCCESS', script_name], 511 kwargs=kwargs, 512 immutable=False)) 513 error_sig = (self.script(script_name) 514 .task(as_callback=True) 515 .signature(args=['ERROR', script_name], 516 kwargs=kwargs, 517 immutable=False)) 518 519 if always: 520 self._task_default_kwargs['routing']['link'] = success_sig 521 self._task_default_kwargs['routing']['link_error'] = error_sig 522 else: 523 self._require_kwargs['routing']['link'] = success_sig 524 self._require_kwargs['routing']['link_error'] = error_sig 525 return self
526 527 528class CeleryEventStream: 529 def __init__(self, app, limit=None, timeout=5, wakeup=False): 530 self.app = app 531 self.limit = limit 532 self.timeout = timeout 533 self.wakeup = wakeup 534 self.max_size = 100 535 self.buffer = [] 536 537 def handle(self, event): 538 self.buffer.append(event) 539 if len(self.buffer) > self.max_size: 540 self.buffer = self.buffer[-1 * self.max_size:] 541 542 def listen(self, handlers=None, limit=None, timeout=None): 543 # Connect to the broker using Kombu (Celery's underlying messaging system) 544 handlers = handlers or {'worker-heartbeat': self.handle} 545 limit = limit or self.limit 546 timeout = timeout or self.timeout 547 with self.app.connection() as conn: 548 # Create the EventReceiver to listen to all events 549 recv = EventReceiver(conn, handlers=handlers, app=self.app) 550 recv.capture(limit=limit, timeout=timeout, wakeup=self.wakeup) 551 552 def latest(self, timeout=None): 553 while len(self.buffer) == 0: 554 self.listen(limit=1, timeout=timeout) 555 return self.buffer[-1] 556 557 558# apply mixins 559from omegaml.runtimes.mixins.taskcanvas import canvas_chain, canvas_group, canvas_chord 560from omegaml.runtimes.mixins.swagger import SwaggerGenerator 561 562OmegaRuntime.sequence = canvas_chain 563OmegaRuntime.parallel = canvas_group 564OmegaRuntime.mapreduce = canvas_chord 565OmegaRuntime.swagger = SwaggerGenerator.build_swagger