1from contextlib import contextmanager
  2
  3from celery import chain, group, chord
  4
  5
[docs]
  6class CanvasTask:
  7    """
  8    support for canvas tasks
  9
 10    See Also:
 11        * om.runtime.sequence
 12        * om.runtime.parallel
 13        * om.runtime.mapreduce
 14    """
 15
 16    def __init__(self, canvasfn):
 17        self.sigs = []
 18        self.canvasfn = canvasfn
 19        self.runtime = None
 20
 21    def add(self, task):
 22        self.sigs.append(task)
 23
 24    def delay(self, *args, **kwargs):
 25        return self.apply_async(args=args, kwargs=kwargs)
 26
 27    def apply_async(self, args=None, kwargs=None, **celery_kwargs):
 28        task = self.sigs[-1]
 29        if self.canvasfn is chord:
 30            sig = task.signature(args=args, kwargs=kwargs, **celery_kwargs, immutable=False)
 31        else:
 32            # immutable means results are not passed on from task to task
 33            sig = task.signature(args=args, kwargs=kwargs, **celery_kwargs, immutable=True)
 34        self.sigs[-1] = sig
 35        return sig
 36
 37    def run(self):
 38        if self.canvasfn is chord:
 39            result = self.canvasfn(self.sigs[:-1])(self.sigs[-1])
 40        else:
 41            result = self.canvasfn(*self.sigs).apply_async()
 42        # add easy result collection
 43        self._easy_collect(result)
 44        return result
 45
 46    def _easy_collect(self, result):
 47        # traverse graph of results and return a single list
 48        # see https://docs.celeryproject.org/en/stable/userguide/canvas.html
 49        collect = lambda r: set({r} | collect(r.parent) if r.parent else {r})
 50        flatten = lambda l: l[0] if isinstance(l[0], list) else l
 51        result.collect = lambda: collect(result)
 52        result.getall = lambda: flatten([r.get() for r in collect(result)]) 
 53
 54
 55
 56def make_canvased(canvasfn):
 57    @contextmanager
 58    def canvased(self):
 59        """
 60        context manager to support sequenced, parallel and mapreduce tasks
 61
 62        Usage::
 63
 64            # run tasks async, in sequence
 65            with om.runtime.sequence() as crt:
 66                crt.model('mymodel').fit(...)
 67                crt.model('mymodel').predict(...)
 68                result = crt.run()
 69
 70            # run tasks async, in parallel
 71            with om.runtime.parallel() as crt:
 72                crt.model('mymodel').predict(...)
 73                crt.model('mymodel').predict(...)
 74                result = crt.run()
 75
 76            # run tasks async, in parallel with a final step
 77            with om.runtime.mapreduce() as crt:
 78                # map tasks
 79                crt.model('mymodel').predict(...)
 80                crt.model('mymodel').predict(...)
 81                # reduce results - combined is a virtualobj function
 82                crt.model('combined').reduce(...)
 83                result = crt.run()
 84
 85            # combined is a virtual obj function, e.g.
 86            @virtualobj
 87            def combined(data=None, **kwargs):
 88                # data is the list of results from each map step
 89                return data
 90
 91            Note that the statements inside the context are
 92            executed in sequence, as any normal python code. However,
 93            the actual tasks are only executed on calling crt.run()
 94
 95        Args:
 96            self: the runtime
 97
 98        Returns:
 99            OmegaRuntime, for use within context
100        """
101        canvas = CanvasTask(canvasfn)
102        _orig_task = self.task
103
104        def canvas_task(*args, **kwargs):
105            task = _orig_task(*args, **kwargs)
106            canvas.add(task)
107            return canvas
108
109        self.task = canvas_task
110        canvas.runtime = self
111        canvas.runtime.run = canvas.run
112        try:
113            yield canvas.runtime
114        finally:
115            canvas.runtime.task = _orig_task
116            canvas.runtime.run = None
117
118    return canvased
119
120
121canvas_chain = make_canvased(chain)
122canvas_group = make_canvased(group)
123canvas_chord = make_canvased(chord)