""" Metro Warden Plugin Registry — discovers, loads, and manages plugins. Plugins are discovered from the ``plugins/`` package hierarchy. Each plugin module must expose a class that inherits from :class:`plugins.base.BasePlugin`. The registry stores plugin metadata and lifecycle state. """ from __future__ import annotations import importlib import inspect import logging from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum, auto from typing import Dict, List, Optional, Type log = logging.getLogger(__name__) class PluginStatus(Enum): DISCOVERED = auto() LOADED = auto() ACTIVE = auto() STOPPED = auto() ERROR = auto() @dataclass class PluginRecord: """Metadata + runtime state for a single plugin.""" name: str version: str description: str plugin_class: Type module_path: str status: PluginStatus = PluginStatus.DISCOVERED instance: Optional[object] = None error: Optional[str] = None loaded_at: Optional[datetime] = None tags: List[str] = field(default_factory=list) def to_dict(self) -> dict: return { "name": self.name, "version": self.version, "description": self.description, "module_path": self.module_path, "status": self.status.name, "error": self.error, "loaded_at": self.loaded_at.isoformat() if self.loaded_at else None, "tags": self.tags, } # Built-in plugin module paths to scan on startup BUILTIN_PLUGIN_MODULES = [ "plugins.network.plugin", "plugins.dns.plugin", "plugins.firewall.plugin", "plugins.system.plugin", ] class PluginRegistry: """ Discovers, instantiates, and manages the lifecycle of Metro Warden plugins. Usage:: registry = PluginRegistry(bus=bus, state=state) registry.discover() registry.load_all() """ def __init__(self, bus=None, state=None) -> None: self._bus = bus self._state = state self._records: Dict[str, PluginRecord] = {} # ------------------------------------------------------------------ # Discovery # ------------------------------------------------------------------ def discover(self, extra_modules: Optional[List[str]] = None) -> int: """ Scan built-in plugin modules (plus any *extra_modules*) and register discovered plugin classes. Returns the number of newly discovered plugins. """ modules = list(BUILTIN_PLUGIN_MODULES) if extra_modules: modules.extend(extra_modules) found = 0 for module_path in modules: count = self._scan_module(module_path) found += count log.info("discovery complete — %d plugins found", found) return found def _scan_module(self, module_path: str) -> int: """Import *module_path* and register any BasePlugin subclasses found.""" from plugins.base import BasePlugin # local import to avoid circular deps try: module = importlib.import_module(module_path) except ImportError as exc: log.warning("could not import plugin module %r: %s", module_path, exc) return 0 found = 0 for _name, obj in inspect.getmembers(module, inspect.isclass): if ( issubclass(obj, BasePlugin) and obj is not BasePlugin and not inspect.isabstract(obj) ): record = PluginRecord( name=obj.name, version=obj.version, description=obj.description, plugin_class=obj, module_path=module_path, tags=getattr(obj, "tags", []), ) if record.name in self._records: log.debug("plugin %r already registered, skipping", record.name) continue self._records[record.name] = record log.debug("discovered plugin %r v%s", record.name, record.version) found += 1 return found # ------------------------------------------------------------------ # Loading / unloading # ------------------------------------------------------------------ def load(self, name: str) -> bool: """Instantiate and call on_load() for the named plugin.""" record = self._records.get(name) if record is None: log.error("plugin %r not found in registry", name) return False if record.status in (PluginStatus.ACTIVE, PluginStatus.LOADED): log.debug("plugin %r already loaded", name) return True try: instance = record.plugin_class(bus=self._bus, state=self._state) instance.on_load() record.instance = instance record.status = PluginStatus.ACTIVE record.loaded_at = datetime.now(timezone.utc) record.error = None log.info("loaded plugin %r v%s", name, record.version) self._notify_bus("registry.plugin.loaded", record.to_dict()) return True except Exception as exc: record.status = PluginStatus.ERROR record.error = str(exc) log.error("failed to load plugin %r: %s", name, exc) return False def unload(self, name: str) -> bool: """Call on_unload() and deactivate the named plugin.""" record = self._records.get(name) if record is None or record.instance is None: return False try: record.instance.on_unload() except Exception as exc: log.error("error during unload of %r: %s", name, exc) record.instance = None record.status = PluginStatus.STOPPED log.info("unloaded plugin %r", name) self._notify_bus("registry.plugin.unloaded", record.to_dict()) return True def load_all(self) -> Dict[str, bool]: """Load all discovered plugins. Returns {name: success} mapping.""" results = {} for name in list(self._records.keys()): results[name] = self.load(name) return results def unload_all(self) -> None: """Unload every active plugin.""" for name in list(self._records.keys()): self.unload(name) # ------------------------------------------------------------------ # Query # ------------------------------------------------------------------ def get(self, name: str) -> Optional[PluginRecord]: return self._records.get(name) def all_records(self) -> List[PluginRecord]: return list(self._records.values()) def active_plugins(self) -> List[PluginRecord]: return [r for r in self._records.values() if r.status == PluginStatus.ACTIVE] def plugin_instance(self, name: str): record = self._records.get(name) return record.instance if record else None # ------------------------------------------------------------------ # Internal # ------------------------------------------------------------------ def _notify_bus(self, topic: str, data: dict) -> None: if self._bus: self._bus.publish_sync(topic, data)