Interact with celery ongoing task Interact with celery ongoing task python python

Interact with celery ongoing task


It look like the Control Bus pattern.

For a better scalability and in order to reduce the RPC call, I recommend to reverse the logic. The PAUSE, RESUME, CANCEL command are push to the Celery tasks through a control bus when the state change occurs. The Celery app will store the current state of the Celery app in a store (could be in memory, on the filesystem..). If task states must be kept even after a stop/start of the app, It will involve more work in order to keep both app synchronized (eg. synchronization at startup).


I'd like to demonstrate a general approach to implementing pause-able (and resume-able) ongoing celery tasks through the workflow pattern. Note: Original answer written here. Re-writing here due to this post being very relevant.

Concept

With celery workflows - you can design your entire operation to be divided into a chain of tasks. It doesn't necessarily have to be purely a chain, but it should follow the general concept of one task after another task (or task group) finishes.

Once you have a workflow like that, you can finally define points to pause at throughout your workflow. At each of these points, you can check whether or not the frontend user has requested the operation to pause and continue accordingly. The concept is this:-

A complex and time consuming operation O is split into 5 celery tasks - T1, T2, T3, T4, and T5 - each of these tasks (except the first one) depend on the return value of the previous task.

Let's assume we define points to pause after every single task, so the workflow looks like-

  • T1 executes
  • T1 completes, check if user has requested pause
    • If user has not requested pause - continue
    • If user has requested pause, serialize the remaining workflow chain and store it somewhere to continue later

... and so on. Since there's a pause point after each task, that check is performed after every one of them (except the last one of course).

But this is only theory, I struggled to find an implementation of this anywhere online so here's what I came up with-

Implementation

from typing import Any, Optionalfrom celery import shared_taskfrom celery.canvas import Signature, chain, signature@shared_task(bind=True)def pause_or_continue(    self, retval: Optional[Any] = None, clause: dict = None, callback: dict = None):    # Task to use for deciding whether to pause the operation chain    if signature(clause)(retval):        # Pause requested, call given callback with retval and remaining chain        # chain should be reversed as the order of execution follows from end to start        signature(callback)(retval, self.request.chain[::-1])        self.request.chain = None    else:        # Continue to the next task in chain        return retvaldef tappable(ch: chain, clause: Signature, callback: Signature, nth: Optional[int] = 1):    '''    Make a operation workflow chain pause-able/resume-able by inserting    the pause_or_continue task for every nth task in given chain    ch: chain        The workflow chain    clause: Signature        Signature of a task that takes one argument - return value of        last executed task in workflow (if any - othewise `None` is passsed)        - and returns a boolean, indicating whether or not the operation should continue        Should return True if operation should continue normally, or be paused    callback: Signature        Signature of a task that takes 2 arguments - return value of        last executed task in workflow (if any - othewise `None` is passsed) and        remaining chain of the operation workflow as a json dict object        No return value is expected        This task will be called when `clause` returns `True` (i.e task is pausing)        The return value and the remaining chain can be handled accordingly by this task    nth: Int        Check `clause` after every nth task in the chain        Default value is 1, i.e check `clause` after every task        Hence, by default, user given `clause` is called and checked        after every task    NOTE: The passed in chain is mutated in place    Returns the mutated chain    '''    newch = []    for n, sig in enumerate(ch.tasks):        if n != 0 and n % nth == nth - 1:            newch.append(pause_or_continue.s(clause=clause, callback=callback))        newch.append(sig)    ch.tasks = tuple(newch)    return ch

Explanation - pause_or_continue

Here pause_or_continue is the aforementioned pause point. It's a task that will be called at specific intervals (intervals as in task intervals, not as in time intervals). This task then calls a user provided function (actually a task) - clause - to check whether or not the task should continue.

If the clause function (actually a task) returns True, the user provided callback function is called, the latest return value (if any - None otherwise) is passed onto this callback, as well as the remaining chain of tasks. The callback does what it needs to do and pause_or_continue sets self.request.chain to None, which tells celery "The task chain is now empty - everything is finished".

If the clause function (actually a task) returns False, the return value from the previous task (if any - None otherwise) is returned back for the next task to receive - and the chain goes on. Hence the workflow continues.

Why are clause and callback task signatures and not regular functions?

Both clause and callback are being called directly - without delay or apply_async. It is executed in the current process, in the current context. So it behaves exactly like a normal function, then why use signatures?

The answer is serialization. You can't conveniently pass a regular function object to a celery task. But you can pass a task signature. That's exactly what I'm doing here. Both clause and callback should be a regular signature object of a celery task.

What is self.request.chain?

self.request.chain stores a list of dicts (representing jsons as the celery task serializer is json by default) - each of them representing a task signature. Each task from this list is executed in reverse order. Which is why, the list is reversed before passing to the user provided callback function (actually a task) - the user probably expects the order of tasks to be left to right.

Quick note: Irrelevant to this discussion, but if you're using the link parameter from apply_async to construct a chain instead of the chain primitive itself. self.request.callback is the property to be modified (i.e set to None to remove callback and stop chain) instead of self.request.chain

Explanation - tappable

tappable is just a basic function that takes a chain (which is the only workflow primitive covered here, for brevity) and inserts pause_or_continue after every nth task. You can insert them wherever you want really, it is upto you to define pause points in your operation. This is just an example!

For each chain object, the actual signatures of tasks (in order, going from left to right) is stored in the .tasks property. It's a tuple of task signatures. So all we have to do, is take this tuple, convert into a list, insert the pause points and convert back to a tuple to assign to the chain. Then return the modified chain object.

The clause and callback is also attached to the pause_or_continue signature. Normal celery stuff.

That covers the primary concept, but to showcase a real project using this pattern (and also to showcase the resuming part of a paused task), here's a small demo of all the necessary resources

Usage

This example usage assumes the concept of a basic web server with a database. Whenever an operation (i.e workflow chain) is started, it's assigned an id and stored into the database. The schema of that table looks like-

-- Create operations table-- Keeps track of operations and the users that started themCREATE TABLE operations (  id INTEGER PRIMARY KEY AUTOINCREMENT,  requester_id INTEGER NOT NULL,  completion TEXT NOT NULL,  workflow_store TEXT,  result TEXT,  FOREIGN KEY (requester_id) REFERENCES user (id));

The only field that needs to be known about right now, is completion. It just stores the status of the operation-

  • When the operation starts and a db entry is created, this is set to IN PROGRESS
  • When a user requests pause, the route controller (i.e view) modifies this to REQUESTING PAUSE
  • When the operation actually gets paused and callback (from tappable, inside pause_or_continue) is called, the callback should modify this to PAUSED
  • When the task is completed, this should be modified to COMPLETED

An example of clause

@celery.task()def should_pause(_, operation_id: int):    # This is the `clause` to be used for `tappable`    # i.e it lets celery know whether to pause or continue    db = get_db()    # Check the database to see if user has requested pause on the operation    operation = db.execute(        "SELECT * FROM operations WHERE id = ?", (operation_id,)    ).fetchone()    return operation["completion"] == "REQUESTING PAUSE"

This is the task to call at the pause points, to determine whether or not to pause. It's a function that takes 2 parameters.....well sort of. The first one is mandatory, tappable requires the clause to have one (and exactly one) argument - so it can pass the previous task's return value to it (even if that return value is None). In this example, the return value isn't required to be used - so we can just ignore it.

The second parameter is an operation id. See, all this clause does - is check a database for the operation (the workflow) entry and see if it has the status REQUESTING PAUSE. To do that, it needs to know the operation id. But clause should be a task with one argument, what gives?

Well, good thing signatures can be partial. When the task is first started and a tappable chain is created. The operation id is known and hence we can do should_pause.s(operation_id) to get the signature of a task that takes one parameter, that being the return value of the previous task. That qualifies as a clause!

An example of callback

import osimport jsonfrom typing import Any, List@celery.task()def save_state(retval: Any, chains: dict, operation_id: int):    # This is the `callback` to be used for `tappable`    # i.e this is called when an operation is pausing    db = get_db()    # Prepare directories to store the workflow    operation_dir = os.path.join(app.config["OPERATIONS"], f"{operation_id}")    workflow_file = os.path.join(operation_dir, "workflow.json")    if not os.path.isdir(operation_dir):        os.makedirs(operation_dir, exist_ok=True)        # Store the remaining workflow chain, serialized into json    with open(workflow_file, "w") as f:        json.dump(chains, f)    # Store the result from the last task and the workflow json path    db.execute(        """        UPDATE operations        SET completion = ?,            workflow_store = ?,            result = ?        WHERE id = ?        """,        ("PAUSED", workflow_file, f"{retval}", operation_id),    )    db.commit()

And here's the task to be called when the task is being paused. Remember, this should take the last executed task's return value and the remaining list of signatures (in order, from left to right). There's an extra param - operation_id - once again. The explanation for this is the same as the one for clause.

This function stores the remaining chain in a json file (since it's a list of dicts). Remember, you can use a different serializer - I'm using json since it's the default task serializer used by celery.

After storing the remaining chain, it updates the completion status to PAUSED and also logs the path to the json file into the db.

Now, let's see these in action-

An example of starting the workflow

def start_operation(user_id, *operation_args, **operation_kwargs):    db = get_db()    operation_id: int = db.execute(        "INSERT INTO operations (requester_id, completion) VALUES (?, ?)",        (user_id, "IN PROGRESS"),    ).lastrowid    # Convert a regular workflow chain to a tappable one    tappable_workflow = tappable(        (T1.s() | T2.s() | T3.s() | T4.s() | T5.s(operation_id)),        should_pause.s(operation_id),        save_state.s(operation_id),    )    # Start the chain (i.e send task to celery to run asynchronously)    tappable_workflow(*operation_args, **operation_kwargs)    db.commit()    return operation_id

A function that takes in a user id and starts an operation workflow. This is more or less an impractical dummy function modeled around a view/route controller. But I think it gets the general idea through.

Assume T[1-4] are all unit tasks of the operation, each taking the previous task's return as an argument. Just an example of a regular celery chain, feel free to go wild with your chains.

T5 is a task that saves the final result (result from T4) to the database. So along with the return value from T4 it needs the operation_id. Which is passed into the signature.

An example of pausing the workflow

def pause(operation_id):    db = get_db()    operation = db.execute(        "SELECT * FROM operations WHERE id = ?", (operation_id,)    ).fetchone()    if operation and operation["completion"] == "IN PROGRESS":        # Pause only if the operation is in progress        db.execute(            """            UPDATE operations            SET completion = ?            WHERE id = ?            """,            ("REQUESTING PAUSE", operation_id),        )        db.commit()        return 'success'    return 'invalid id'

This employs the previously mentioned concept of modifying the db entry to change completion to REQUESTING PAUSE. Once this is committed, the next time pause_or_continue calls should_pause, it'll know that the user has requested the operation to pause and it'll do so accordingly.

An example of resuming the workflow

def resume(operation_id):    db = get_db()    operation = db.execute(        "SELECT * FROM operations WHERE id = ?", (operation_id,)    ).fetchone()    if operation and operation["completion"] == "PAUSED":        # Resume only if the operation is paused        with open(operation["workflow_store"]) as f:            # Load the remaining workflow from the json            workflow_json = json.load(f)        # Load the chain from the json (i.e deserialize)        workflow_chain = chain(signature(x) for x in serialized_ch)        # Start the chain and feed in the last executed task result        workflow_chain(operation["result"])        db.execute(            """            UPDATE operations            SET completion = ?            WHERE id = ?            """,            ("IN PROGRESS", operation_id),        )        db.commit()        return 'success'    return 'invalid id'

Recall that, when the operation is paused - the remaining workflow is stored in a json. Since we are currently restricting the workflow to a chain object. We know this json is a list of signatures that should be turned into a chain. So, we deserialize it accordingly and send it to the celery worker.

Note that, this remaining workflow still has the pause_or_continue tasks as they were originally - so this workflow itself, is once again pause-able/resume-able. When it pauses, the workflow.json will simply be updated.