Source code for cocotb._task_manager

# Copyright cocotb contributors
# Licensed under the Revised BSD License, see LICENSE for details.
# SPDX-License-Identifier: BSD-3-Clause

"""TaskManager and related code."""

from __future__ import annotations

import inspect
import sys
from asyncio import CancelledError
from bdb import BdbQuit
from collections.abc import Awaitable, Callable, Coroutine
from typing import Any, TypeVar, overload

from cocotb._base_triggers import NullTrigger
from cocotb.task import Task, current_task
from cocotb.triggers import Event, Trigger

if sys.version_info >= (3, 11):
    from typing import Self
else:
    from exceptiongroup import BaseExceptionGroup

if sys.version_info >= (3, 10):
    from typing import ParamSpec

    P = ParamSpec("P")


T = TypeVar("T")


async def _waiter(aw: Awaitable[T]) -> T:
    return await aw


class TaskManager:
    r"""An :term:`asynchronous context manager` which runs :term:`coroutine function`\ s or :term:`awaitable`\ s concurrently until all finish.

    See :ref:`task_manager_tutorial` for detailed usage information.

    Args:
        default_continue_on_error: Default value for *continue_on_error* for child Tasks started by this TaskManager.
        context_continue_on_error: Value for *continue_on_error* for the context block itself.

            If not specified, defaults to the value of *default_continue_on_error*.

    .. versionadded:: 2.1
    """

    def __init__(
        self,
        *,
        default_continue_on_error: bool = False,
        context_continue_on_error: bool | None = None,
    ) -> None:
        self._default_continue_on_error = default_continue_on_error
        self._context_continue_on_error: bool = (
            context_continue_on_error
            if context_continue_on_error is not None
            else default_continue_on_error
        )

        self._exceptions: set[BaseException] = set()
        # dict value is per-Task continue_on_error setting
        self._remaining_tasks: dict[Task[Any], bool] = {}
        self._none_remaining = Event()
        self._cancelled: bool = False
        self._finishing: bool = False
        self._entered: bool = False
        # parent task will not exist if we aren't using this as a context manager
        self._parent_task: Task[Any] | None = None

        # Start with no remaining tasks
        self._none_remaining.set()

    def _ensure_can_add(self) -> None:
        if self._cancelled:
            raise RuntimeError("Cannot add new Tasks to TaskManager after error")
        elif self._finishing:
            raise RuntimeError("Cannot add new Tasks to TaskManager after finishing")
        elif not self._entered:
            raise RuntimeError(
                "Cannot add new Tasks to TaskManager before entering context"
            )
        if current_task() is not self._parent_task:
            raise RuntimeError("Cannot add new Tasks to TaskManager from another Task")

[docs] def start_soon( self, aw: Awaitable[T], *, name: str | None = None, continue_on_error: bool | None = None, ) -> Task[T]: r"""Await the *aw* argument concurrently. Args: aw: A :class:`~collections.abc.Awaitable` to :keyword:`await` concurrently. name: A name to associate with the :class:`!Task` awaiting *aw*. continue_on_error: Value of *continue_on_error* for this Task only. If not specified, defaults to the value of the :class:`!TaskManager`'s *default_continue_on_error* argument. Returns: A :class:`~cocotb.task.Task` which is awaiting *aw* concurrently. """ self._ensure_can_add() if isinstance(aw, Coroutine): # This must come before Awaitable since Coroutine is a subclass of Awaitable pass elif isinstance(aw, Awaitable): aw = _waiter(aw) else: raise TypeError( f"start_soon() expected an Awaitable, got {type(aw).__name__}" ) task = Task[T](aw, name=name) self._add_task(task, continue_on_error=continue_on_error) return task
@overload def fork( self, coro_func: Callable[[], Coroutine[Trigger, None, T]], /, ) -> Task[T]: ... @overload def fork( self, *, continue_on_error: bool ) -> Callable[[Callable[[], Coroutine[Trigger, None, T]]], Task[T]]: ...
[docs] def fork( self, coro_func: Callable[..., Coroutine[Trigger, None, T]] | None = None, *, continue_on_error: bool | None = None, ) -> Task[T] | Callable[[Callable[[], Coroutine[Trigger, None, T]]], Task[T]]: r"""Decorate a coroutine function to run it concurrently. Args: coro_func: A :term:`coroutine function` to run concurrently. Typically only passed as a decorator. continue_on_error: Value of *continue_on_error* for this Task only. If not specified, defaults to the value of the :class:`!TaskManager`'s *default_continue_on_error* argument. Passing this requires calling the :meth:`fork` method before decorating the coroutine function. Returns: A :class:`~cocotb.task.Task` which is running *coro_func* concurrently. .. code-block:: python async with TaskManager() as tm: @tm.fork async def my_func(): # Do stuff ... @tm.fork(continue_on_error=True) async def other_func(): # Do other stuff in parallel to my_func ... """ self._ensure_can_add() if coro_func is None: if continue_on_error is None: raise TypeError( "Missing required keyword-only argument: 'continue_on_error'" ) def deco( coro: Callable[[], Coroutine[Trigger, None, T]], ) -> Task[T]: return self.fork( # type: ignore[call-overload] coro, continue_on_error=continue_on_error, ) return deco if not inspect.iscoroutinefunction(coro_func): raise TypeError( f"fork() expected a coroutine function, got {type(coro_func).__name__}" ) task = Task[T](coro_func(), name=coro_func.__name__) self._add_task(task, continue_on_error=continue_on_error) return task
def _add_task( self, task: Task[Any], *, continue_on_error: bool | None = None, ) -> None: # Track the Task and store per-Task continue_on_error setting task._add_done_callback(self._done_callback) if continue_on_error is None: continue_on_error = self._default_continue_on_error self._remaining_tasks[task] = continue_on_error self._none_remaining.clear() # Schedule the Task to run soon task._ensure_started() def _done_callback(self, task: Task[Any]) -> None: """Callback run when a child Task finishes.""" continue_on_error = self._remaining_tasks.pop(task) if not self._remaining_tasks: self._none_remaining.set() # If a child Task failed, cancel all other child Tasks. if not task.cancelled() and (exc := task.exception()) is not None: self._exceptions.add(exc) if not continue_on_error: self._cancel() def _cancel(self) -> None: """Cancel all unfinished child Tasks.""" if self._cancelled: return self._cancelled = True # If a child Task fails while we are in the middle of a TaskManager block, # cancel the parent Task to force the block to end. if not self._finishing and self._parent_task is not None: self._parent_task.cancel() # Cancel all child Tasks. for task in self._remaining_tasks: task.cancel() async def __aenter__(self) -> Self: if self._finishing: raise RuntimeError("Cannot re-enter finished TaskManager context") self._entered = True self._parent_task = current_task() return self async def __aexit__( self, exc_type: object, exc: BaseException | None, traceback: object ) -> None: self._finishing = True if self._parent_task is not None and self._cancelled: # The context block was cancelled due to a child Task failure. if isinstance(exc, CancelledError): # Suppress CancelledError in this case to allow child Tasks to finish cancelling. exc = None assert self._parent_task is not None self._parent_task._uncancel() else: # The context block ignored the cancellation. Hard fail the test. # There is a special case in Task if a cancelled Task finishes without # raising CancelledError, it fails the test. So we just await *something* # here to hit that code path. # TODO Make this force a test failure. self._cancel() await NullTrigger() elif exc is not None: if isinstance(exc, CancelledError): # Something else cancelled the parent task. Propagate CancelledError. self._cancel() return None # re-raise CancelledError elif isinstance(exc, (KeyboardInterrupt, SystemExit, BdbQuit)): # Certain BaseExceptions should be immediately propagated like they are in Task. self._cancel() return None # re-raise exception elif not self._context_continue_on_error: # Block finished with an exception and we are not continuing on error. self._cancel() # Wait for all Tasks to finish / finish cancelling. try: await self._none_remaining.wait() except CancelledError: # Cancel all child Tasks if the current Task is cancelled by the user while # waiting for all child Tasks to finish. If the TaskManager is already # cancelling due to a child Task failure, this will no-op. self._cancel() raise except BaseException: # The current Task failed while waiting for child Tasks to finish. # This is likely because we ignored a CancelledError since there is no other # way to fail this await AFAICT. Cancel children and let it pass up as there's # nothing we can do. # TODO Make this force a test failure. Special case KeyboardInterrupt/SystemExit/BdbQuit self._cancel() raise self._finished = True # Build BaseExceptionGroup if there were any errors. Ignore CancelledError. if exc is not None and not isinstance(exc, CancelledError): self._exceptions.add(exc) if self._exceptions: # BaseExceptionGroup constructor will automatically return an ExceptionGroup if all elements are Exceptions. raise BaseExceptionGroup( "TaskManager finished with errors", tuple(self._exceptions) ) # Return True to handle suppressing CancelledError if there were no exceptions return True # type: ignore[return-value] # __aexit__ can return True