What Python's asyncio primitives get wrong about shared state

6 min read Original article ↗

from __future__ import annotations

import asyncio

import threading

import typing

T = typing.TypeVar("T")

# Used by `wait_for_not_none` to narrow `ValueWatcher[X | None]` to `X`.

S = typing.TypeVar("S")

class ValueWatcher(typing.Generic[T]):

"""

Thread-safe observable value with async watchers.

Watchers can await value changes via methods like `wait_for` and

`wait_for_change`. Alternatively, they can add callbacks via `on_change` and

`on_value`.

Any thread can set `.value`, and the watcher will react accordingly.

"""

def __init__(

self,

initial_value: T,

*,

on_change: typing.Callable[[T, T], None] | None = None,

) -> None:

"""

Args:

initial_value: The initial value.

on_change: Called when the value changes. Good for debug logging.

"""

self._lock = threading.Lock()

self._on_changes: list[typing.Callable[[T, T], None]] = []

if on_change:

self._on_changes.append(on_change)

# Every watcher gets its own (loop, queue) pair. Storing the loop lets

# the setter use `call_soon_threadsafe` for cross-thread notification.

# Queue items are (old, new) tuples.

self._watch_queues: list[

tuple[asyncio.AbstractEventLoop, asyncio.Queue[tuple[T, T]]]

] = []

# Hold references to fire-and-forget tasks to prevent GC.

self._background_tasks: set[asyncio.Task[T]] = set()

self._value = initial_value

@property

def value(self) -> T:

with self._lock:

return self._value

@value.setter

def value(self, new_value: T) -> None:

with self._lock:

if new_value == self._value:

return

old_value = self._value

self._value = new_value

# Snapshot lists under lock to avoid iteration issues

queues = list(self._watch_queues)

callbacks = list(self._on_changes)

# Notify all watchers outside the lock to avoid deadlock.

for loop, queue in queues:

try:

# `call_soon_threadsafe` wakes the target loop's selector

# immediately. A plain `put_nowait` wouldn't poke the self-pipe,

# so a cross-thread watcher could stall until something else

# wakes its loop.

#

# In other words, without `call_soon_threadsafe`, a watcher

# could get the changed value notification long after the value

# actually changed.

loop.call_soon_threadsafe(

queue.put_nowait, (old_value, new_value)

)

except RuntimeError:

# Target event loop is closed.

pass

for on_change in callbacks:

try:

on_change(old_value, new_value)

except Exception:

# Suppress exceptions from callbacks so one failure doesn't skip

# the rest.

pass

def set_if(

self,

new_value: T,

condition: typing.Callable[[T], bool],

) -> bool:

"""

Atomically set the value only if the current value satisfies the

condition. Returns True if the value was set.

"""

with self._lock:

if not condition(self._value):

return False

if new_value == self._value:

return True

old_value = self._value

self._value = new_value

queues = list(self._watch_queues)

callbacks = list(self._on_changes)

for loop, queue in queues:

try:

loop.call_soon_threadsafe(

queue.put_nowait, (old_value, new_value)

)

except RuntimeError:

pass

for on_change in callbacks:

try:

on_change(old_value, new_value)

except Exception:

pass

return True

def on_change(self, callback: typing.Callable[[T, T], None]) -> None:

"""

Add a callback that's called when the value changes.

Args:

callback: Called with (old_value, new_value) on each change.

"""

with self._lock:

self._on_changes.append(callback)

def on_value(self, value: T, callback: typing.Callable[[], None]) -> None:

"""

One-shot callback for when the value equals `value`. Requires a

running event loop (internally spawns a background task).

Args:

value: The value to wait for.

callback: Called when the internal value equals `value`.

"""

task = asyncio.create_task(self.wait_for(value))

self._background_tasks.add(task)

def _done(t: asyncio.Task[T]) -> None:

self._background_tasks.discard(t)

if not t.cancelled() and t.exception() is None:

callback()

task.add_done_callback(_done)

async def wait_for(

self,

value: T,

*,

immediate: bool = True,

timeout: float | None = None,

) -> T:

"""

Wait for the internal value to equal the given value.

Args:

value: Return when the internal value is equal to this.

immediate: If True and the internal value is already equal to the given value, return immediately. Defaults to True.

timeout: Seconds to wait before raising `asyncio.TimeoutError`. None means wait forever.

"""

return await self._wait_for_condition(

lambda v: v == value,

immediate=immediate,

timeout=timeout,

)

async def wait_for_not(

self,

value: T,

*,

immediate: bool = True,

timeout: float | None = None,

) -> T:

"""

Wait for the internal value to not equal the given value.

Args:

value: Return when the internal value is not equal to this.

immediate: If True and the internal value is already not equal to the given value, return immediately. Defaults to True.

timeout: Seconds to wait before raising `asyncio.TimeoutError`. None means wait forever.

"""

return await self._wait_for_condition(

lambda v: v != value,

immediate=immediate,

timeout=timeout,

)

async def wait_for_not_none(

self: ValueWatcher[S | None],

*,

immediate: bool = True,

timeout: float | None = None,

) -> S:

"""

Wait for the internal value to be not None.

Args:

immediate: If True and the internal value is already not None, return immediately. Defaults to True.

timeout: Seconds to wait before raising `asyncio.TimeoutError`. None means wait forever.

"""

result = await self._wait_for_condition(

lambda v: v is not None,

immediate=immediate,

timeout=timeout,

)

if result is None:

raise AssertionError("unreachable")

return result

async def _wait_for_condition(

self,

condition: typing.Callable[[T], bool],

*,

immediate: bool = True,

timeout: float | None = None,

) -> T:

"""

Wait until `condition(current_value)` is true, then return the

matching value. Handles the TOCTOU gap between checking the current

value and subscribing to the change queue.

"""

# Fast path: no task needed if the value already matches.

if immediate:

# Read once to avoid a TOCTOU race between check and return.

current = self.value

if condition(current):

return current

async def _wait() -> T:

with self._watch() as queue:

# Re-check after queue registration to close the gap

# between the fast path above and the queue being live.

if immediate:

# Read once to avoid a TOCTOU race between check and return.

current = self.value

if condition(current):

return current

while True:

_, new = await queue.get()

if condition(new):

return new

return await asyncio.wait_for(_wait(), timeout=timeout)

async def wait_for_change(

self,

*,

timeout: float | None = None,

) -> T:

"""

Wait for the internal value to change.

Args:

timeout: Seconds to wait before raising `asyncio.TimeoutError`. None means wait forever.

"""

async def _wait() -> T:

with self._watch() as queue:

_, new = await queue.get()

return new

return await asyncio.wait_for(_wait(), timeout=timeout)

def _watch(self) -> _WatchContextManager[T]:

"""

Watch for all changes to the value. This method returns a context

manager so it must be used in a `with` statement.

Its return value is a queue that yields tuples of the old and new

values.

"""

loop = asyncio.get_running_loop()

queue = asyncio.Queue[tuple[T, T]]()

with self._lock:

self._watch_queues.append((loop, queue))

return _WatchContextManager(

on_exit=lambda: self._remove_queue(queue),

queue=queue,

)

def _remove_queue(self, queue: asyncio.Queue[tuple[T, T]]) -> None:

"""

Remove a queue from the watch list in a thread-safe manner.

"""

with self._lock:

self._watch_queues = [

entry for entry in self._watch_queues if entry[1] is not queue

]

class _WatchContextManager(typing.Generic[T]):

"""

Context manager that's used to automatically delete a queue when it's no

longer being watched.

Returns a queue that yields tuples of the old and new values.

"""

def __init__(

self,

on_exit: typing.Callable[[], None],

queue: asyncio.Queue[tuple[T, T]],

) -> None:

self._on_exit = on_exit

self._queue = queue

def __enter__(self) -> asyncio.Queue[tuple[T, T]]:

# IMPORTANT: Do not return an async generator. That can lead to "Task

# was destroyed but it is pending!" warnings when the event loop closes.

return self._queue

def __exit__(

self,

exc_type: type[BaseException] | None,

exc_value: BaseException | None,

traceback: object,

) -> None:

self._on_exit()