Files

224 lines
7.2 KiB
Python

"""
Metro Warden Plugin Registry — discovers, loads, and manages plugins.
Plugins are discovered from the ``plugins/`` package hierarchy. Each plugin
module must expose a class that inherits from :class:`plugins.base.BasePlugin`.
The registry stores plugin metadata and lifecycle state.
"""
from __future__ import annotations
import importlib
import inspect
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum, auto
from typing import Dict, List, Optional, Type
log = logging.getLogger(__name__)
class PluginStatus(Enum):
DISCOVERED = auto()
LOADED = auto()
ACTIVE = auto()
STOPPED = auto()
ERROR = auto()
@dataclass
class PluginRecord:
"""Metadata + runtime state for a single plugin."""
name: str
version: str
description: str
plugin_class: Type
module_path: str
status: PluginStatus = PluginStatus.DISCOVERED
instance: Optional[object] = None
error: Optional[str] = None
loaded_at: Optional[datetime] = None
tags: List[str] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"name": self.name,
"version": self.version,
"description": self.description,
"module_path": self.module_path,
"status": self.status.name,
"error": self.error,
"loaded_at": self.loaded_at.isoformat() if self.loaded_at else None,
"tags": self.tags,
}
# Built-in plugin module paths to scan on startup
BUILTIN_PLUGIN_MODULES = [
"plugins.network.plugin",
"plugins.dns.plugin",
"plugins.firewall.plugin",
"plugins.system.plugin",
]
class PluginRegistry:
"""
Discovers, instantiates, and manages the lifecycle of Metro Warden plugins.
Usage::
registry = PluginRegistry(bus=bus, state=state)
registry.discover()
registry.load_all()
"""
def __init__(self, bus=None, state=None) -> None:
self._bus = bus
self._state = state
self._records: Dict[str, PluginRecord] = {}
# ------------------------------------------------------------------
# Discovery
# ------------------------------------------------------------------
def discover(self, extra_modules: Optional[List[str]] = None) -> int:
"""
Scan built-in plugin modules (plus any *extra_modules*) and register
discovered plugin classes.
Returns the number of newly discovered plugins.
"""
modules = list(BUILTIN_PLUGIN_MODULES)
if extra_modules:
modules.extend(extra_modules)
found = 0
for module_path in modules:
count = self._scan_module(module_path)
found += count
log.info("discovery complete — %d plugins found", found)
return found
def _scan_module(self, module_path: str) -> int:
"""Import *module_path* and register any BasePlugin subclasses found."""
from plugins.base import BasePlugin # local import to avoid circular deps
try:
module = importlib.import_module(module_path)
except ImportError as exc:
log.warning("could not import plugin module %r: %s", module_path, exc)
return 0
found = 0
for _name, obj in inspect.getmembers(module, inspect.isclass):
if (
issubclass(obj, BasePlugin)
and obj is not BasePlugin
and not inspect.isabstract(obj)
):
record = PluginRecord(
name=obj.name,
version=obj.version,
description=obj.description,
plugin_class=obj,
module_path=module_path,
tags=getattr(obj, "tags", []),
)
if record.name in self._records:
log.debug("plugin %r already registered, skipping", record.name)
continue
self._records[record.name] = record
log.debug("discovered plugin %r v%s", record.name, record.version)
found += 1
return found
# ------------------------------------------------------------------
# Loading / unloading
# ------------------------------------------------------------------
def load(self, name: str) -> bool:
"""Instantiate and call on_load() for the named plugin."""
record = self._records.get(name)
if record is None:
log.error("plugin %r not found in registry", name)
return False
if record.status in (PluginStatus.ACTIVE, PluginStatus.LOADED):
log.debug("plugin %r already loaded", name)
return True
try:
instance = record.plugin_class(bus=self._bus, state=self._state)
instance.on_load()
record.instance = instance
record.status = PluginStatus.ACTIVE
record.loaded_at = datetime.now(timezone.utc)
record.error = None
log.info("loaded plugin %r v%s", name, record.version)
self._notify_bus("registry.plugin.loaded", record.to_dict())
return True
except Exception as exc:
record.status = PluginStatus.ERROR
record.error = str(exc)
log.error("failed to load plugin %r: %s", name, exc)
return False
def unload(self, name: str) -> bool:
"""Call on_unload() and deactivate the named plugin."""
record = self._records.get(name)
if record is None or record.instance is None:
return False
try:
record.instance.on_unload()
except Exception as exc:
log.error("error during unload of %r: %s", name, exc)
record.instance = None
record.status = PluginStatus.STOPPED
log.info("unloaded plugin %r", name)
self._notify_bus("registry.plugin.unloaded", record.to_dict())
return True
def load_all(self) -> Dict[str, bool]:
"""Load all discovered plugins. Returns {name: success} mapping."""
results = {}
for name in list(self._records.keys()):
results[name] = self.load(name)
return results
def unload_all(self) -> None:
"""Unload every active plugin."""
for name in list(self._records.keys()):
self.unload(name)
# ------------------------------------------------------------------
# Query
# ------------------------------------------------------------------
def get(self, name: str) -> Optional[PluginRecord]:
return self._records.get(name)
def all_records(self) -> List[PluginRecord]:
return list(self._records.values())
def active_plugins(self) -> List[PluginRecord]:
return [r for r in self._records.values() if r.status == PluginStatus.ACTIVE]
def plugin_instance(self, name: str):
record = self._records.get(name)
return record.instance if record else None
# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------
def _notify_bus(self, topic: str, data: dict) -> None:
if self._bus:
self._bus.publish_sync(topic, data)