# Copyright cocotb contributors
# Licensed under the Revised BSD License, see LICENSE for details.
# SPDX-License-Identifier: BSD-3-Clause
"""TaskManager and related code."""
from __future__ import annotations
import sys
from asyncio import CancelledError
from bdb import BdbQuit
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar, overload
from cocotb.task import Task, current_task
from cocotb.triggers import Event, NullTrigger
if sys.version_info >= (3, 11):
from typing import Self
else:
from exceptiongroup import BaseExceptionGroup
if sys.version_info >= (3, 10):
from typing import ParamSpec
P = ParamSpec("P")
T = TypeVar("T")
class TaskManager:
r"""An :term:`asynchronous context manager` which runs :term:`coroutine function`\ s or :term:`awaitable`\ s concurrently until all finish.
See :ref:`task_manager_tutorial` for detailed usage information.
Args:
default_continue_on_error: Default value for *continue_on_error* for child Tasks started by this TaskManager.
context_continue_on_error: Value for *continue_on_error* for the context block itself.
If not specified, defaults to the value of *default_continue_on_error*.
.. versionadded:: 2.1
"""
def __init__(
self,
*,
default_continue_on_error: bool = False,
context_continue_on_error: bool | None = None,
) -> None:
self._default_continue_on_error = default_continue_on_error
self._context_continue_on_error: bool = (
context_continue_on_error
if context_continue_on_error is not None
else default_continue_on_error
)
# We don't keep all children Tasks around, just the ones that haven't finished yet,
# so we have to save the exceptions for the ExceptionGroup we raise at the end of the context block.
self._exceptions: set[BaseException] = set()
# dict value is per-Task continue_on_error setting
self._remaining_tasks: dict[Task[Any], bool] = {}
self._none_remaining = Event()
# Children were cancelled due to a child Task/block failure.
self._cancelled: bool = False
# We started __aexit__. Ensures we dont add more Tasks after the context block has started finishing.
self._finishing: bool = False
# For protecting against adding Tasks before entering the context block
self._entered: bool = False
# The parent Task which entered the context block.
# Used to ensure only the parent Task can add child Tasks,
# and to cancel the parent Task if a child Task fails while the block is still running.
self._parent_task: Task[Any]
# Start with no remaining tasks
self._none_remaining.set()
[docs]
def start_soon(
self,
aw: Awaitable[T],
*,
name: str | None = None,
continue_on_error: bool | None = None,
) -> Task[T]:
r"""Await the *aw* argument concurrently.
Args:
aw: A :class:`~collections.abc.Awaitable` to :keyword:`await` concurrently.
name: A name to associate with the :class:`!Task` awaiting *aw*.
continue_on_error: Value of *continue_on_error* for this Task only.
If not specified, defaults to the value of the :class:`!TaskManager`'s *default_continue_on_error* argument.
Returns:
A :class:`~cocotb.task.Task` which is awaiting *aw* concurrently.
"""
# Ensure that we can add a new Task to this TaskManager before creating the Task.
# We would have to close it if it failed.
if self._cancelled:
raise RuntimeError("Cannot add new Tasks to TaskManager after error")
elif self._finishing:
raise RuntimeError("Cannot add new Tasks to TaskManager after finishing")
elif not self._entered:
raise RuntimeError(
"Cannot add new Tasks to TaskManager before entering context"
)
if current_task() is not self._parent_task:
raise RuntimeError("Cannot add new Tasks to TaskManager from another Task")
# Create the Task and tie it to this TaskManager via the done callback.
task = Task[T](aw, name=name)
task._add_done_callback(self._done_callback)
# Track the Task and store per-Task continue_on_error setting
self._remaining_tasks[task] = (
continue_on_error
if continue_on_error is not None
else self._default_continue_on_error
)
self._none_remaining.clear()
# Start the Task to running.
task.start_soon()
return task
@overload
def fork(
self,
coro_func: Callable[[], Awaitable[T]],
/,
) -> Task[T]: ...
@overload
def fork(
self, *, continue_on_error: bool
) -> Callable[[Callable[[], Awaitable[T]]], Task[T]]: ...
[docs]
def fork(
self,
coro_func: Callable[..., Awaitable[T]] | None = None,
*,
continue_on_error: bool | None = None,
) -> Task[T] | Callable[[Callable[[], Awaitable[T]]], Task[T]]:
r"""Decorate a coroutine function to run it concurrently.
.. note::
This does not necessarily have to be a coroutine function.
Any callable which returns a :class:`~collections.abc.Awaitable` can be used.
Args:
coro_func: A :term:`coroutine function` to run concurrently. Typically this is the decorated function.
continue_on_error: Value of *continue_on_error* for this Task only.
If not specified, defaults to the value of the :class:`!TaskManager`'s *default_continue_on_error* argument.
Passing this requires calling the :meth:`fork` method before decorating the coroutine function.
Returns:
A :class:`~cocotb.task.Task` which is running *coro_func* concurrently.
.. code-block:: python
async with TaskManager() as tm:
@tm.fork
async def my_func():
# Do stuff
...
@tm.fork(continue_on_error=True)
async def other_func():
# Do other stuff in parallel to my_func
...
"""
# Handle the case where fork is called as a function and returns the decorator.
if coro_func is None:
if continue_on_error is None:
raise TypeError(
"Missing required keyword-only argument: 'continue_on_error'"
)
def deco(
coro: Callable[[], Awaitable[T]],
) -> Task[T]:
return self.fork( # type: ignore[call-overload]
coro,
continue_on_error=continue_on_error,
)
return deco
return self.start_soon(
coro_func(), name=coro_func.__name__, continue_on_error=continue_on_error
)
def _done_callback(self, task: Task[Any]) -> None:
"""Callback run when a child Task finishes."""
continue_on_error = self._remaining_tasks.pop(task)
if not self._remaining_tasks:
self._none_remaining.set()
# If a child Task failed, cancel all other child Tasks.
if not task.cancelled() and (exc := task.exception()) is not None:
self._exceptions.add(exc)
if not continue_on_error:
self._cancel()
def _cancel(self) -> None:
"""Cancel all unfinished child Tasks."""
if self._cancelled:
return
self._cancelled = True
# If a child Task fails while we are in the middle of a TaskManager block,
# cancel the parent Task to force the block to end.
if not self._finishing:
self._parent_task.cancel()
# Cancel all child Tasks.
for task in self._remaining_tasks:
task.cancel()
async def __aenter__(self) -> Self:
if self._finishing:
raise RuntimeError("Cannot re-enter finished TaskManager context")
self._entered = True
self._parent_task = current_task()
return self
async def __aexit__(
self, exc_type: object, exc: BaseException | None, traceback: object
) -> None:
self._finishing = True
# Propagate special exceptions immediately.
# The GeneratorExit case is there so if the block is cancelled due to a child Task failure,
# and the block squashes the CancelledError and does an await, a GeneratorExit is thrown at the await, which may end up here.
if isinstance(exc, (KeyboardInterrupt, SystemExit, BdbQuit, GeneratorExit)):
self._cancel()
return None # re-raise exception
if self._cancelled:
# The context block was cancelled due to a child Task failure.
if isinstance(exc, CancelledError):
# Suppress CancelledError in this case to allow child Tasks to finish cancelling.
exc = None
self._parent_task._uncancel()
else:
# The context block ignored the cancellation and either threw some other exception or finished successfully.
# We await a token NullTrigger to allow the parent Task to kill the Task due to ignored CancelledError.
await NullTrigger()
# This will never run as a GeneratorExit will be thrown at the above await.
raise RuntimeError("Reached unreachable code") # pragma: no cover
elif exc is not None:
if isinstance(exc, CancelledError):
# Something else cancelled the parent task. Propagate CancelledError immediately.
self._cancel()
return None # re-raise CancelledError
elif not self._context_continue_on_error:
# Block finished with an exception and we are not continuing on error.
self._cancel()
# Wait for all Tasks to finish / finish cancelling.
try:
await self._none_remaining.wait()
except CancelledError:
# Cancel all child Tasks if the current Task is cancelled by the user while
# waiting for all child Tasks to finish. If the TaskManager is already
# cancelling due to a child Task failure, this will no-op.
self._cancel()
raise
except BaseException:
# The current Task failed while waiting for child Tasks to finish.
# This is likely because we ignored a CancelledError since there is no other
# way to fail this await AFAICT. Cancel children and let it pass up as there's
# nothing we can do.
self._cancel()
raise
# Build BaseExceptionGroup if there were any errors. Ignore CancelledError.
if exc is not None and not isinstance(exc, CancelledError):
self._exceptions.add(exc)
if self._exceptions:
# BaseExceptionGroup constructor will automatically return an ExceptionGroup if all elements are Exceptions.
raise BaseExceptionGroup(
"TaskManager finished with errors", tuple(self._exceptions)
)
# Return True to handle suppressing CancelledError if there were no exceptions
return True # type: ignore[return-value] # __aexit__ can return True