mirror of
https://github.com/samjage/metro-warden.git
synced 2026-06-06 01:00:41 +00:00
224 lines
7.5 KiB
Python
224 lines
7.5 KiB
Python
"""
|
|
Metro Warden Event Bus — asyncio pub/sub with wildcard topic support.
|
|
|
|
Topics follow a dot-separated hierarchy: "network.interfaces", "system.cpu", etc.
|
|
Wildcard "*" matches a single segment; "**" matches any number of segments.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import fnmatch
|
|
import logging
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
|
import uuid
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
Handler = Callable[[str, Any], Awaitable[None] | None]
|
|
|
|
|
|
@dataclass
|
|
class Subscription:
|
|
"""Represents a single topic subscription."""
|
|
|
|
id: str
|
|
topic_pattern: str
|
|
handler: Handler
|
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
|
|
|
|
@dataclass
|
|
class Event:
|
|
"""An event published to the bus."""
|
|
|
|
id: str
|
|
topic: str
|
|
data: Any
|
|
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"id": self.id,
|
|
"topic": self.topic,
|
|
"data": self.data,
|
|
"timestamp": self.timestamp.isoformat(),
|
|
}
|
|
|
|
|
|
class EventBus:
|
|
"""
|
|
Asyncio-based pub/sub event bus supporting wildcard topics.
|
|
|
|
Usage::
|
|
|
|
bus = EventBus()
|
|
|
|
async def on_network(topic, data):
|
|
print(f"{topic}: {data}")
|
|
|
|
sub_id = bus.subscribe("network.*", on_network)
|
|
await bus.publish("network.interfaces", {"eth0": "up"})
|
|
bus.unsubscribe(sub_id)
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._subscriptions: Dict[str, Subscription] = {}
|
|
# index from pattern to set of subscription ids for fast lookup
|
|
self._pattern_index: Dict[str, Set[str]] = defaultdict(set)
|
|
self._history: List[Event] = []
|
|
self._history_limit: int = 1000
|
|
self._lock = asyncio.Lock()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public API
|
|
# ------------------------------------------------------------------
|
|
|
|
def subscribe(self, topic_pattern: str, handler: Handler) -> str:
|
|
"""
|
|
Subscribe *handler* to all events whose topic matches *topic_pattern*.
|
|
|
|
Patterns support fnmatch-style wildcards:
|
|
- ``network.*`` matches ``network.interfaces`` but not ``network.dns.query``
|
|
- ``network.**`` matches any subtopic under ``network``
|
|
- ``*`` matches any single-segment topic
|
|
|
|
Returns a subscription ID that can be passed to :meth:`unsubscribe`.
|
|
"""
|
|
sub_id = str(uuid.uuid4())
|
|
sub = Subscription(id=sub_id, topic_pattern=topic_pattern, handler=handler)
|
|
self._subscriptions[sub_id] = sub
|
|
self._pattern_index[topic_pattern].add(sub_id)
|
|
log.debug("subscribed %s -> pattern=%r", sub_id[:8], topic_pattern)
|
|
return sub_id
|
|
|
|
def unsubscribe(self, subscription_id: str) -> bool:
|
|
"""
|
|
Remove a subscription by its ID.
|
|
|
|
Returns ``True`` if the subscription existed and was removed.
|
|
"""
|
|
sub = self._subscriptions.pop(subscription_id, None)
|
|
if sub is None:
|
|
return False
|
|
self._pattern_index[sub.topic_pattern].discard(subscription_id)
|
|
if not self._pattern_index[sub.topic_pattern]:
|
|
del self._pattern_index[sub.topic_pattern]
|
|
log.debug("unsubscribed %s", subscription_id[:8])
|
|
return True
|
|
|
|
def unsubscribe_all(self, handler: Handler) -> int:
|
|
"""Remove all subscriptions registered for *handler*. Returns count removed."""
|
|
to_remove = [
|
|
sid for sid, sub in self._subscriptions.items() if sub.handler is handler
|
|
]
|
|
for sid in to_remove:
|
|
self.unsubscribe(sid)
|
|
return len(to_remove)
|
|
|
|
async def publish(self, topic: str, data: Any = None) -> int:
|
|
"""
|
|
Publish an event to *topic*.
|
|
|
|
All matching handlers are dispatched concurrently via asyncio.gather.
|
|
Returns the number of handlers notified.
|
|
"""
|
|
event = Event(id=str(uuid.uuid4()), topic=topic, data=data)
|
|
self._record(event)
|
|
|
|
matching = self._find_matching_subs(topic)
|
|
if not matching:
|
|
log.debug("publish %r — no subscribers", topic)
|
|
return 0
|
|
|
|
tasks = []
|
|
for sub in matching:
|
|
tasks.append(self._dispatch(sub, event))
|
|
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
errors = [r for r in results if isinstance(r, Exception)]
|
|
for err in errors:
|
|
log.error("handler error on topic %r: %s", topic, err)
|
|
|
|
log.debug("publish %r — notified %d handlers", topic, len(matching))
|
|
return len(matching)
|
|
|
|
def publish_sync(self, topic: str, data: Any = None) -> None:
|
|
"""
|
|
Fire-and-forget publish that schedules an async publish on the running loop.
|
|
Safe to call from synchronous code when a loop is running.
|
|
"""
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
loop.create_task(self.publish(topic, data))
|
|
except RuntimeError:
|
|
# No running loop — run synchronously in a new one
|
|
asyncio.run(self.publish(topic, data))
|
|
|
|
def get_history(
|
|
self,
|
|
topic_filter: Optional[str] = None,
|
|
limit: int = 100,
|
|
) -> List[Event]:
|
|
"""Return recent events, optionally filtered by topic pattern."""
|
|
events = self._history
|
|
if topic_filter:
|
|
events = [e for e in events if self._topic_matches(e.topic, topic_filter)]
|
|
return events[-limit:]
|
|
|
|
@property
|
|
def subscription_count(self) -> int:
|
|
return len(self._subscriptions)
|
|
|
|
@property
|
|
def patterns(self) -> List[str]:
|
|
return list(self._pattern_index.keys())
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _record(self, event: Event) -> None:
|
|
self._history.append(event)
|
|
if len(self._history) > self._history_limit:
|
|
self._history = self._history[-self._history_limit :]
|
|
|
|
def _find_matching_subs(self, topic: str) -> List[Subscription]:
|
|
matched: List[Subscription] = []
|
|
seen: Set[str] = set()
|
|
for pattern, ids in self._pattern_index.items():
|
|
if self._topic_matches(topic, pattern):
|
|
for sid in ids:
|
|
if sid not in seen and sid in self._subscriptions:
|
|
seen.add(sid)
|
|
matched.append(self._subscriptions[sid])
|
|
return matched
|
|
|
|
@staticmethod
|
|
def _topic_matches(topic: str, pattern: str) -> bool:
|
|
"""
|
|
Match *topic* against *pattern*.
|
|
|
|
``**`` is expanded to ``*`` repeated across segments so that
|
|
``network.**`` matches ``network.interfaces.eth0``.
|
|
"""
|
|
if pattern == topic:
|
|
return True
|
|
# Convert "**" to a greedy glob that matches path separators too
|
|
if "**" in pattern:
|
|
glob_pattern = pattern.replace("**", "*")
|
|
return fnmatch.fnmatch(topic, glob_pattern)
|
|
return fnmatch.fnmatch(topic, pattern)
|
|
|
|
@staticmethod
|
|
async def _dispatch(sub: Subscription, event: Event) -> None:
|
|
try:
|
|
result = sub.handler(event.topic, event.data)
|
|
if asyncio.iscoroutine(result):
|
|
await result
|
|
except Exception as exc:
|
|
raise exc
|