Flask Celery task locking Flask Celery task locking flask flask

Flask Celery task locking


In your question, you point out this warning from the Celery example you used:

In order for this to work correctly you need to be using a cache backend where the .add operation is atomic. memcached is known to work well for this purpose.

And you mention that you don't really understand what this means. Indeed, the code you show demonstrates that you've not heeded that warning, because your code uses an inappropriate backend.

Consider this code:

with memcache_lock(lock_id, self.app.oid) as acquired:    if acquired:        # do some work

What you want here is for acquired to be true only for one thread at a time. If two threads enter the with block at the same time, only one should "win" and have acquired be true. This thread that has acquired true can then proceed with its work, and the other thread has to skip doing the work and try again later to acquire the lock. In order to ensure that only one thread can have acquired true, .add must be atomic.

Here's some pseudo code of what .add(key, value) does:

1. if <key> is already in the cache:2.   return False    3. else:4.   set the cache so that <key> has the value <value>5.   return True

If the execution of .add is not atomic, this could happen if two threads A and B execute .add("foo", "bar"). Assume an empty cache at the start.

  1. Thread A executes 1. if "foo" is already in the cache and finds that "foo" is not in the cache, and jumps to line 3 but the thread scheduler switches control to thread B.
  2. Thread B also executes 1. if "foo" is already in the cache, and also finds that "foo" is not in the cache. So it jumps to line 3 and then executes line 4 and 5 which sets the key "foo" to the value "bar" and the call returns True.
  3. Eventually, the scheduler gives control back to Thread A, which continues executing 3, 4, 5 and also sets the key "foo" to the value "bar" and also returns True.

What you have here is two .add calls that return True, if these .add calls are made within memcache_lock this entails that two threads can have acquired be true. So two threads could do work at the same time, and your memcache_lock is not doing what it should be doing, which is only allow one thread to work at a time.

You are not using a cache that ensures that .add is atomic. You initialize it like this:

cache = Cache(app, config={'CACHE_TYPE': 'simple'})

The simple backend is scoped to a single process, has no thread-safety, and has an .add operation which is not atomic. (This does not involve Mongo at all by the way. If you wanted your cache to be backed by Mongo, you'd have to specify a backed specifically made to send data to a Mongo database.)

So you have to switch to another backend, one that guarantees that .add is atomic. You could follow the lead of the Celery example and use the memcached backend, which does have an atomic .add operation. I don't use Flask, but I've does essentially what you are doing with Django and Celery, and used the Redis backend successfully to provide the kind of locking you're using here.


With this setup, you should still expect to see workers receiving the task, since the lock is checked inside of the task itself. The only difference will be that the work won't be performed if the lock is acquired by another worker.
In the example given in the docs, this is the desired behavior; if a lock already exists, the task will simply do nothing and finish as successful. What you want is slightly different; you want the work to be queued up instead of ignored.

In order to get the desired effect, you would need to make sure that the task will be picked up by a worker and performed some time in the future. One way to accomplish this would be with retrying.

@task(bind=True, name='my-task')def my_task(self):    lock_id = self.name    with memcache_lock(lock_id, self.app.oid) as acquired:        if acquired:            # do work if we got the lock            print('acquired is {}'.format(acquired))            return 'result'    # otherwise, the lock was already in use    raise self.retry(countdown=60)  # redeliver message to the queue, so the work can be done later


I also found this to be a surprisingly hard problem. Inspired mainly by Sebastian's work on implementing a distributed locking algorithm in redis I wrote up a decorator function.

A key point to bear in mind about this approach is that we lock tasks at the level of the task's argument space, e.g. we allow multiple game update/process order tasks to run concurrently, but only one per game. That's what argument_signature achieves in the code below. You can see documentation on how we use this in our stack at this gist:

import base64from contextlib import contextmanagerimport jsonimport pickle as pklimport uuidfrom backend.config import Configfrom redis import StrictRedisfrom redis_cache import RedisCachefrom redlock import Redlockrds = StrictRedis(Config.REDIS_HOST, decode_responses=True, charset="utf-8")rds_cache = StrictRedis(Config.REDIS_HOST, decode_responses=False, charset="utf-8")redis_cache = RedisCache(redis_client=rds_cache, prefix="rc", serializer=pkl.dumps, deserializer=pkl.loads)dlm = Redlock([{"host": Config.REDIS_HOST}])TASK_LOCK_MSG = "Task execution skipped -- another task already has the lock"DEFAULT_ASSET_EXPIRATION = 8 * 24 * 60 * 60  # by default keep cached values around for 8 daysDEFAULT_CACHE_EXPIRATION = 1 * 24 * 60 * 60  # we can keep cached values around for a shorter period of timeREMOVE_ONLY_IF_OWNER_SCRIPT = """if redis.call("get",KEYS[1]) == ARGV[1] then    return redis.call("del",KEYS[1])else    return 0end"""@contextmanagerdef redis_lock(lock_name, expires=60):    # https://breadcrumbscollector.tech/what-is-celery-beat-and-how-to-use-it-part-2-patterns-and-caveats/    random_value = str(uuid.uuid4())    lock_acquired = bool(        rds.set(lock_name, random_value, ex=expires, nx=True)    )    yield lock_acquired    if lock_acquired:        rds.eval(REMOVE_ONLY_IF_OWNER_SCRIPT, 1, lock_name, random_value)def argument_signature(*args, **kwargs):    arg_list = [str(x) for x in args]    kwarg_list = [f"{str(k)}:{str(v)}" for k, v in kwargs.items()]    return base64.b64encode(f"{'_'.join(arg_list)}-{'_'.join(kwarg_list)}".encode()).decode()def task_lock(func=None, main_key="", timeout=None):    def _dec(run_func):        def _caller(*args, **kwargs):            with redis_lock(f"{main_key}_{argument_signature(*args, **kwargs)}", timeout) as acquired:                if not acquired:                    return TASK_LOCK_MSG                return run_func(*args, **kwargs)        return _caller    return _dec(func) if func is not None else _dec

Implementation in our task definitions file:

@celery.task(name="async_test_task_lock")@task_lock(main_key="async_test_task_lock", timeout=UPDATE_GAME_DATA_TIMEOUT)def async_test_task_lock(game_id):    print(f"processing game_id {game_id}")    time.sleep(TASK_LOCK_TEST_SLEEP)

How we test against a local celery cluster:

from backend.tasks.definitions import async_test_task_lock, TASK_LOCK_TEST_SLEEPfrom backend.tasks.redis_handlers import rds, TASK_LOCK_MSGclass TestTaskLocking(TestCase):    def test_task_locking(self):        rds.flushall()        res1 = async_test_task_lock.delay(3)        res2 = async_test_task_lock.delay(5)        self.assertFalse(res1.ready())        self.assertFalse(res2.ready())        res3 = async_test_task_lock.delay(5)        res4 = async_test_task_lock.delay(5)        self.assertEqual(res3.get(), TASK_LOCK_MSG)        self.assertEqual(res4.get(), TASK_LOCK_MSG)        time.sleep(TASK_LOCK_TEST_SLEEP)        res5 = async_test_task_lock.delay(3)        self.assertFalse(res5.ready())

(as a goodie there's also a quick example of how to setup a redis_cache)