1from contextlib import contextmanager
2
3from celery import chain, group, chord
4
5
[docs]
6class CanvasTask:
7 """
8 support for canvas tasks
9
10 See Also:
11 * om.runtime.sequence
12 * om.runtime.parallel
13 * om.runtime.mapreduce
14 """
15
16 def __init__(self, canvasfn):
17 self.sigs = []
18 self.canvasfn = canvasfn
19 self.runtime = None
20
21 def add(self, task):
22 self.sigs.append(task)
23
24 def delay(self, *args, **kwargs):
25 return self.apply_async(args=args, kwargs=kwargs)
26
27 def apply_async(self, args=None, kwargs=None, **celery_kwargs):
28 task = self.sigs[-1]
29 if self.canvasfn is chord:
30 sig = task.signature(args=args, kwargs=kwargs, **celery_kwargs, immutable=False)
31 else:
32 # immutable means results are not passed on from task to task
33 sig = task.signature(args=args, kwargs=kwargs, **celery_kwargs, immutable=True)
34 self.sigs[-1] = sig
35 return sig
36
37 def run(self):
38 if self.canvasfn is chord:
39 result = self.canvasfn(self.sigs[:-1])(self.sigs[-1])
40 else:
41 result = self.canvasfn(*self.sigs).apply_async()
42 # add easy result collection
43 self._easy_collect(result)
44 return result
45
46 def _easy_collect(self, result):
47 # traverse graph of results and return a single list
48 # see https://docs.celeryproject.org/en/stable/userguide/canvas.html
49 collect = lambda r: set({r} | collect(r.parent) if r.parent else {r})
50 flatten = lambda l: l[0] if isinstance(l[0], list) else l
51 result.collect = lambda: collect(result)
52 result.getall = lambda: flatten([r.get() for r in collect(result)])
53
54
55
56def make_canvased(canvasfn):
57 @contextmanager
58 def canvased(self):
59 """
60 context manager to support sequenced, parallel and mapreduce tasks
61
62 Usage::
63
64 # run tasks async, in sequence
65 with om.runtime.sequence() as crt:
66 crt.model('mymodel').fit(...)
67 crt.model('mymodel').predict(...)
68 result = crt.run()
69
70 # run tasks async, in parallel
71 with om.runtime.parallel() as crt:
72 crt.model('mymodel').predict(...)
73 crt.model('mymodel').predict(...)
74 result = crt.run()
75
76 # run tasks async, in parallel with a final step
77 with om.runtime.mapreduce() as crt:
78 # map tasks
79 crt.model('mymodel').predict(...)
80 crt.model('mymodel').predict(...)
81 # reduce results - combined is a virtualobj function
82 crt.model('combined').reduce(...)
83 result = crt.run()
84
85 # combined is a virtual obj function, e.g.
86 @virtualobj
87 def combined(data=None, **kwargs):
88 # data is the list of results from each map step
89 return data
90
91 Note that the statements inside the context are
92 executed in sequence, as any normal python code. However,
93 the actual tasks are only executed on calling crt.run()
94
95 Args:
96 self: the runtime
97
98 Returns:
99 OmegaRuntime, for use within context
100 """
101 canvas = CanvasTask(canvasfn)
102 _orig_task = self.task
103
104 def canvas_task(*args, **kwargs):
105 task = _orig_task(*args, **kwargs)
106 canvas.add(task)
107 return canvas
108
109 self.task = canvas_task
110 canvas.runtime = self
111 canvas.runtime.run = canvas.run
112 try:
113 yield canvas.runtime
114 finally:
115 canvas.runtime.task = _orig_task
116 canvas.runtime.run = None
117
118 return canvased
119
120
121canvas_chain = make_canvased(chain)
122canvas_group = make_canvased(group)
123canvas_chord = make_canvased(chord)