How to limit concurrency with Python asyncio?
If I'm not mistaken you're searching for asyncio.Semaphore. Example of usage:
import asynciofrom random import randintasync def download(code): wait_time = randint(1, 3) print('downloading {} will take {} second(s)'.format(code, wait_time)) await asyncio.sleep(wait_time) # I/O, context will switch to main function print('downloaded {}'.format(code))sem = asyncio.Semaphore(3)async def safe_download(i): async with sem: # semaphore limits num of simultaneous downloads return await download(i)async def main(): tasks = [ asyncio.ensure_future(safe_download(i)) # creating task starts coroutine for i in range(9) ] await asyncio.gather(*tasks) # await moment all downloads doneif __name__ == '__main__': loop = asyncio.get_event_loop() try: loop.run_until_complete(main()) finally: loop.run_until_complete(loop.shutdown_asyncgens()) loop.close()
Output:
downloading 0 will take 3 second(s)downloading 1 will take 3 second(s)downloading 2 will take 1 second(s)downloaded 2downloading 3 will take 3 second(s)downloaded 1downloaded 0downloading 4 will take 2 second(s)downloading 5 will take 1 second(s)downloaded 5downloaded 3downloading 6 will take 3 second(s)downloading 7 will take 1 second(s)downloaded 4downloading 8 will take 2 second(s)downloaded 7downloaded 8downloaded 6
An example of async downloading with aiohttp
can be found here. Note that aiohttp
has a Semaphore equivalent built in, which you can see an example of here. It has a default limit of 100 connections.
I used Mikhails answer and ended up with this little gem
async def gather_with_concurrency(n, *tasks): semaphore = asyncio.Semaphore(n) async def sem_task(task): async with semaphore: return await task return await asyncio.gather(*(sem_task(task) for task in tasks))
Which you would run instead of normal gather
await gather_with_concurrency(100, *my_coroutines)
Before reading the rest of this answer, please note that the idiomatic way of limiting the number of parallel tasks this with asyncio is using asyncio.Semaphore
, as shown in Mikhail's answer and elegantly abstracted in Andrei's answer. This answer contains working, but a bit more complicated ways of achieving the same. I am leaving the answer because in some cases this approach can have advantages over a semaphore, specifically when the work to be done is very large or unbounded, and you cannot create all the coroutines in advance. In that case the second (queue-based) solution is this answer is what you want. But in most regular situations, such as parallel download through aiohttp, you should use a semaphore instead.
You basically need a fixed-size pool of download tasks. asyncio
doesn't come with a pre-made task pool, but it is easy to create one: simply keep a set of tasks and don't allow it to grow past the limit. Although the question states your reluctance to go down that route, the code ends up much more elegant:
import asyncioimport randomasync def download(code): wait_time = random.randint(1, 3) print('downloading {} will take {} second(s)'.format(code, wait_time)) await asyncio.sleep(wait_time) # I/O, context will switch to main function print('downloaded {}'.format(code))async def main(loop): no_concurrent = 3 dltasks = set() i = 0 while i < 9: if len(dltasks) >= no_concurrent: # Wait for some download to finish before adding a new one _done, dltasks = await asyncio.wait( dltasks, return_when=asyncio.FIRST_COMPLETED) dltasks.add(loop.create_task(download(i))) i += 1 # Wait for the remaining downloads to finish await asyncio.wait(dltasks)
An alternative is to create a fixed number of coroutines doing the downloading, much like a fixed-size thread pool, and feed them work using an asyncio.Queue
. This removes the need to manually limit the number of downloads, which will be automatically limited by the number of coroutines invoking download()
:
# download() defined as aboveasync def download_worker(q): while True: code = await q.get() await download(code) q.task_done()async def main(loop): q = asyncio.Queue() workers = [loop.create_task(download_worker(q)) for _ in range(3)] i = 0 while i < 9: await q.put(i) i += 1 await q.join() # wait for all tasks to be processed for worker in workers: worker.cancel() await asyncio.gather(*workers, return_exceptions=True)
As for your other question, the obvious choice would be aiohttp
.