from __future__ import annotations
import asyncio
import threading
import typing
T = typing.TypeVar("T")
# Used by `wait_for_not_none` to narrow `ValueWatcher[X | None]` to `X`.
S = typing.TypeVar("S")
class ValueWatcher(typing.Generic[T]):
"""
Thread-safe observable value with async watchers.
Watchers can await value changes via methods like `wait_for` and
`wait_for_change`. Alternatively, they can add callbacks via `on_change` and
`on_value`.
Any thread can set `.value`, and the watcher will react accordingly.
"""
def __init__(
self,
initial_value: T,
*,
on_change: typing.Callable[[T, T], None] | None = None,
) -> None:
"""
Args:
initial_value: The initial value.
on_change: Called when the value changes. Good for debug logging.
"""
self._lock = threading.Lock()
self._on_changes: list[typing.Callable[[T, T], None]] = []
if on_change:
self._on_changes.append(on_change)
# Every watcher gets its own (loop, queue) pair. Storing the loop lets
# the setter use `call_soon_threadsafe` for cross-thread notification.
# Queue items are (old, new) tuples.
self._watch_queues: list[
tuple[asyncio.AbstractEventLoop, asyncio.Queue[tuple[T, T]]]
] = []
# Hold references to fire-and-forget tasks to prevent GC.
self._background_tasks: set[asyncio.Task[T]] = set()
self._value = initial_value
@property
def value(self) -> T:
with self._lock:
return self._value
@value.setter
def value(self, new_value: T) -> None:
with self._lock:
if new_value == self._value:
return
old_value = self._value
self._value = new_value
# Snapshot lists under lock to avoid iteration issues
queues = list(self._watch_queues)
callbacks = list(self._on_changes)
# Notify all watchers outside the lock to avoid deadlock.
for loop, queue in queues:
try:
# `call_soon_threadsafe` wakes the target loop's selector
# immediately. A plain `put_nowait` wouldn't poke the self-pipe,
# so a cross-thread watcher could stall until something else
# wakes its loop.
#
# In other words, without `call_soon_threadsafe`, a watcher
# could get the changed value notification long after the value
# actually changed.
loop.call_soon_threadsafe(
queue.put_nowait, (old_value, new_value)
)
except RuntimeError:
# Target event loop is closed.
pass
for on_change in callbacks:
try:
on_change(old_value, new_value)
except Exception:
# Suppress exceptions from callbacks so one failure doesn't skip
# the rest.
pass
def set_if(
self,
new_value: T,
condition: typing.Callable[[T], bool],
) -> bool:
"""
Atomically set the value only if the current value satisfies the
condition. Returns True if the value was set.
"""
with self._lock:
if not condition(self._value):
return False
if new_value == self._value:
return True
old_value = self._value
self._value = new_value
queues = list(self._watch_queues)
callbacks = list(self._on_changes)
for loop, queue in queues:
try:
loop.call_soon_threadsafe(
queue.put_nowait, (old_value, new_value)
)
except RuntimeError:
pass
for on_change in callbacks:
try:
on_change(old_value, new_value)
except Exception:
pass
return True
def on_change(self, callback: typing.Callable[[T, T], None]) -> None:
"""
Add a callback that's called when the value changes.
Args:
callback: Called with (old_value, new_value) on each change.
"""
with self._lock:
self._on_changes.append(callback)
def on_value(self, value: T, callback: typing.Callable[[], None]) -> None:
"""
One-shot callback for when the value equals `value`. Requires a
running event loop (internally spawns a background task).
Args:
value: The value to wait for.
callback: Called when the internal value equals `value`.
"""
task = asyncio.create_task(self.wait_for(value))
self._background_tasks.add(task)
def _done(t: asyncio.Task[T]) -> None:
self._background_tasks.discard(t)
if not t.cancelled() and t.exception() is None:
callback()
task.add_done_callback(_done)
async def wait_for(
self,
value: T,
*,
immediate: bool = True,
timeout: float | None = None,
) -> T:
"""
Wait for the internal value to equal the given value.
Args:
value: Return when the internal value is equal to this.
immediate: If True and the internal value is already equal to the given value, return immediately. Defaults to True.
timeout: Seconds to wait before raising `asyncio.TimeoutError`. None means wait forever.
"""
return await self._wait_for_condition(
lambda v: v == value,
immediate=immediate,
timeout=timeout,
)
async def wait_for_not(
self,
value: T,
*,
immediate: bool = True,
timeout: float | None = None,
) -> T:
"""
Wait for the internal value to not equal the given value.
Args:
value: Return when the internal value is not equal to this.
immediate: If True and the internal value is already not equal to the given value, return immediately. Defaults to True.
timeout: Seconds to wait before raising `asyncio.TimeoutError`. None means wait forever.
"""
return await self._wait_for_condition(
lambda v: v != value,
immediate=immediate,
timeout=timeout,
)
async def wait_for_not_none(
self: ValueWatcher[S | None],
*,
immediate: bool = True,
timeout: float | None = None,
) -> S:
"""
Wait for the internal value to be not None.
Args:
immediate: If True and the internal value is already not None, return immediately. Defaults to True.
timeout: Seconds to wait before raising `asyncio.TimeoutError`. None means wait forever.
"""
result = await self._wait_for_condition(
lambda v: v is not None,
immediate=immediate,
timeout=timeout,
)
if result is None:
raise AssertionError("unreachable")
return result
async def _wait_for_condition(
self,
condition: typing.Callable[[T], bool],
*,
immediate: bool = True,
timeout: float | None = None,
) -> T:
"""
Wait until `condition(current_value)` is true, then return the
matching value. Handles the TOCTOU gap between checking the current
value and subscribing to the change queue.
"""
# Fast path: no task needed if the value already matches.
if immediate:
# Read once to avoid a TOCTOU race between check and return.
current = self.value
if condition(current):
return current
async def _wait() -> T:
with self._watch() as queue:
# Re-check after queue registration to close the gap
# between the fast path above and the queue being live.
if immediate:
# Read once to avoid a TOCTOU race between check and return.
current = self.value
if condition(current):
return current
while True:
_, new = await queue.get()
if condition(new):
return new
return await asyncio.wait_for(_wait(), timeout=timeout)
async def wait_for_change(
self,
*,
timeout: float | None = None,
) -> T:
"""
Wait for the internal value to change.
Args:
timeout: Seconds to wait before raising `asyncio.TimeoutError`. None means wait forever.
"""
async def _wait() -> T:
with self._watch() as queue:
_, new = await queue.get()
return new
return await asyncio.wait_for(_wait(), timeout=timeout)
def _watch(self) -> _WatchContextManager[T]:
"""
Watch for all changes to the value. This method returns a context
manager so it must be used in a `with` statement.
Its return value is a queue that yields tuples of the old and new
values.
"""
loop = asyncio.get_running_loop()
queue = asyncio.Queue[tuple[T, T]]()
with self._lock:
self._watch_queues.append((loop, queue))
return _WatchContextManager(
on_exit=lambda: self._remove_queue(queue),
queue=queue,
)
def _remove_queue(self, queue: asyncio.Queue[tuple[T, T]]) -> None:
"""
Remove a queue from the watch list in a thread-safe manner.
"""
with self._lock:
self._watch_queues = [
entry for entry in self._watch_queues if entry[1] is not queue
]
class _WatchContextManager(typing.Generic[T]):
"""
Context manager that's used to automatically delete a queue when it's no
longer being watched.
Returns a queue that yields tuples of the old and new values.
"""
def __init__(
self,
on_exit: typing.Callable[[], None],
queue: asyncio.Queue[tuple[T, T]],
) -> None:
self._on_exit = on_exit
self._queue = queue
def __enter__(self) -> asyncio.Queue[tuple[T, T]]:
# IMPORTANT: Do not return an async generator. That can lead to "Task
# was destroyed but it is pending!" warnings when the event loop closes.
return self._queue
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: object,
) -> None:
self._on_exit()