Source code for omegaml.runtimes.mixins.taskcanvas

  1from contextlib import contextmanager
  2
  3from celery import chain, group, chord
  4
  5
[docs] 6class CanvasTask: 7 """ 8 support for canvas tasks 9 10 See Also: 11 * om.runtime.sequence 12 * om.runtime.parallel 13 * om.runtime.mapreduce 14 """ 15 16 def __init__(self, canvasfn): 17 self.sigs = [] 18 self.canvasfn = canvasfn 19 self.runtime = None 20 21 def add(self, task): 22 self.sigs.append(task) 23 24 def delay(self, *args, **kwargs): 25 return self.apply_async(args=args, kwargs=kwargs) 26 27 def apply_async(self, args=None, kwargs=None, **celery_kwargs): 28 task = self.sigs[-1] 29 if self.canvasfn is chord: 30 sig = task.signature(args=args, kwargs=kwargs, **celery_kwargs, immutable=False) 31 else: 32 # immutable means results are not passed on from task to task 33 sig = task.signature(args=args, kwargs=kwargs, **celery_kwargs, immutable=True) 34 self.sigs[-1] = sig 35 return sig 36 37 def run(self): 38 if self.canvasfn is chord: 39 result = self.canvasfn(self.sigs[:-1])(self.sigs[-1]) 40 else: 41 result = self.canvasfn(*self.sigs).apply_async() 42 # add easy result collection 43 self._easy_collect(result) 44 return result 45 46 def _easy_collect(self, result): 47 # traverse graph of results and return a single list 48 # see https://docs.celeryproject.org/en/stable/userguide/canvas.html 49 collect = lambda r: set({r} | collect(r.parent) if r.parent else {r}) 50 flatten = lambda l: l[0] if isinstance(l[0], list) else l 51 result.collect = lambda: collect(result) 52 result.getall = lambda: flatten([r.get() for r in collect(result)])
53 54 55 56def make_canvased(canvasfn): 57 @contextmanager 58 def canvased(self): 59 """ 60 context manager to support sequenced, parallel and mapreduce tasks 61 62 Usage:: 63 64 # run tasks async, in sequence 65 with om.runtime.sequence() as crt: 66 crt.model('mymodel').fit(...) 67 crt.model('mymodel').predict(...) 68 result = crt.run() 69 70 # run tasks async, in parallel 71 with om.runtime.parallel() as crt: 72 crt.model('mymodel').predict(...) 73 crt.model('mymodel').predict(...) 74 result = crt.run() 75 76 # run tasks async, in parallel with a final step 77 with om.runtime.mapreduce() as crt: 78 # map tasks 79 crt.model('mymodel').predict(...) 80 crt.model('mymodel').predict(...) 81 # reduce results - combined is a virtualobj function 82 crt.model('combined').reduce(...) 83 result = crt.run() 84 85 # combined is a virtual obj function, e.g. 86 @virtualobj 87 def combined(data=None, **kwargs): 88 # data is the list of results from each map step 89 return data 90 91 Note that the statements inside the context are 92 executed in sequence, as any normal python code. However, 93 the actual tasks are only executed on calling crt.run() 94 95 Args: 96 self: the runtime 97 98 Returns: 99 OmegaRuntime, for use within context 100 """ 101 canvas = CanvasTask(canvasfn) 102 _orig_task = self.task 103 104 def canvas_task(*args, **kwargs): 105 task = _orig_task(*args, **kwargs) 106 canvas.add(task) 107 return canvas 108 109 self.task = canvas_task 110 canvas.runtime = self 111 canvas.runtime.run = canvas.run 112 try: 113 yield canvas.runtime 114 finally: 115 canvas.runtime.task = _orig_task 116 canvas.runtime.run = None 117 118 return canvased 119 120 121canvas_chain = make_canvased(chain) 122canvas_group = make_canvased(group) 123canvas_chord = make_canvased(chord)