mirror of
https://github.com/samjage/metro-warden.git
synced 2026-06-06 01:00:41 +00:00
224 lines
7.2 KiB
Python
224 lines
7.2 KiB
Python
"""
|
|
Metro Warden Plugin Registry — discovers, loads, and manages plugins.
|
|
|
|
Plugins are discovered from the ``plugins/`` package hierarchy. Each plugin
|
|
module must expose a class that inherits from :class:`plugins.base.BasePlugin`.
|
|
The registry stores plugin metadata and lifecycle state.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import importlib
|
|
import inspect
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from enum import Enum, auto
|
|
from typing import Dict, List, Optional, Type
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class PluginStatus(Enum):
|
|
DISCOVERED = auto()
|
|
LOADED = auto()
|
|
ACTIVE = auto()
|
|
STOPPED = auto()
|
|
ERROR = auto()
|
|
|
|
|
|
@dataclass
|
|
class PluginRecord:
|
|
"""Metadata + runtime state for a single plugin."""
|
|
|
|
name: str
|
|
version: str
|
|
description: str
|
|
plugin_class: Type
|
|
module_path: str
|
|
status: PluginStatus = PluginStatus.DISCOVERED
|
|
instance: Optional[object] = None
|
|
error: Optional[str] = None
|
|
loaded_at: Optional[datetime] = None
|
|
tags: List[str] = field(default_factory=list)
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"name": self.name,
|
|
"version": self.version,
|
|
"description": self.description,
|
|
"module_path": self.module_path,
|
|
"status": self.status.name,
|
|
"error": self.error,
|
|
"loaded_at": self.loaded_at.isoformat() if self.loaded_at else None,
|
|
"tags": self.tags,
|
|
}
|
|
|
|
|
|
# Built-in plugin module paths to scan on startup
|
|
BUILTIN_PLUGIN_MODULES = [
|
|
"plugins.network.plugin",
|
|
"plugins.dns.plugin",
|
|
"plugins.firewall.plugin",
|
|
"plugins.system.plugin",
|
|
]
|
|
|
|
|
|
class PluginRegistry:
|
|
"""
|
|
Discovers, instantiates, and manages the lifecycle of Metro Warden plugins.
|
|
|
|
Usage::
|
|
|
|
registry = PluginRegistry(bus=bus, state=state)
|
|
registry.discover()
|
|
registry.load_all()
|
|
"""
|
|
|
|
def __init__(self, bus=None, state=None) -> None:
|
|
self._bus = bus
|
|
self._state = state
|
|
self._records: Dict[str, PluginRecord] = {}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Discovery
|
|
# ------------------------------------------------------------------
|
|
|
|
def discover(self, extra_modules: Optional[List[str]] = None) -> int:
|
|
"""
|
|
Scan built-in plugin modules (plus any *extra_modules*) and register
|
|
discovered plugin classes.
|
|
|
|
Returns the number of newly discovered plugins.
|
|
"""
|
|
modules = list(BUILTIN_PLUGIN_MODULES)
|
|
if extra_modules:
|
|
modules.extend(extra_modules)
|
|
|
|
found = 0
|
|
for module_path in modules:
|
|
count = self._scan_module(module_path)
|
|
found += count
|
|
|
|
log.info("discovery complete — %d plugins found", found)
|
|
return found
|
|
|
|
def _scan_module(self, module_path: str) -> int:
|
|
"""Import *module_path* and register any BasePlugin subclasses found."""
|
|
from plugins.base import BasePlugin # local import to avoid circular deps
|
|
|
|
try:
|
|
module = importlib.import_module(module_path)
|
|
except ImportError as exc:
|
|
log.warning("could not import plugin module %r: %s", module_path, exc)
|
|
return 0
|
|
|
|
found = 0
|
|
for _name, obj in inspect.getmembers(module, inspect.isclass):
|
|
if (
|
|
issubclass(obj, BasePlugin)
|
|
and obj is not BasePlugin
|
|
and not inspect.isabstract(obj)
|
|
):
|
|
record = PluginRecord(
|
|
name=obj.name,
|
|
version=obj.version,
|
|
description=obj.description,
|
|
plugin_class=obj,
|
|
module_path=module_path,
|
|
tags=getattr(obj, "tags", []),
|
|
)
|
|
if record.name in self._records:
|
|
log.debug("plugin %r already registered, skipping", record.name)
|
|
continue
|
|
self._records[record.name] = record
|
|
log.debug("discovered plugin %r v%s", record.name, record.version)
|
|
found += 1
|
|
|
|
return found
|
|
|
|
# ------------------------------------------------------------------
|
|
# Loading / unloading
|
|
# ------------------------------------------------------------------
|
|
|
|
def load(self, name: str) -> bool:
|
|
"""Instantiate and call on_load() for the named plugin."""
|
|
record = self._records.get(name)
|
|
if record is None:
|
|
log.error("plugin %r not found in registry", name)
|
|
return False
|
|
|
|
if record.status in (PluginStatus.ACTIVE, PluginStatus.LOADED):
|
|
log.debug("plugin %r already loaded", name)
|
|
return True
|
|
|
|
try:
|
|
instance = record.plugin_class(bus=self._bus, state=self._state)
|
|
instance.on_load()
|
|
record.instance = instance
|
|
record.status = PluginStatus.ACTIVE
|
|
record.loaded_at = datetime.now(timezone.utc)
|
|
record.error = None
|
|
log.info("loaded plugin %r v%s", name, record.version)
|
|
self._notify_bus("registry.plugin.loaded", record.to_dict())
|
|
return True
|
|
except Exception as exc:
|
|
record.status = PluginStatus.ERROR
|
|
record.error = str(exc)
|
|
log.error("failed to load plugin %r: %s", name, exc)
|
|
return False
|
|
|
|
def unload(self, name: str) -> bool:
|
|
"""Call on_unload() and deactivate the named plugin."""
|
|
record = self._records.get(name)
|
|
if record is None or record.instance is None:
|
|
return False
|
|
|
|
try:
|
|
record.instance.on_unload()
|
|
except Exception as exc:
|
|
log.error("error during unload of %r: %s", name, exc)
|
|
|
|
record.instance = None
|
|
record.status = PluginStatus.STOPPED
|
|
log.info("unloaded plugin %r", name)
|
|
self._notify_bus("registry.plugin.unloaded", record.to_dict())
|
|
return True
|
|
|
|
def load_all(self) -> Dict[str, bool]:
|
|
"""Load all discovered plugins. Returns {name: success} mapping."""
|
|
results = {}
|
|
for name in list(self._records.keys()):
|
|
results[name] = self.load(name)
|
|
return results
|
|
|
|
def unload_all(self) -> None:
|
|
"""Unload every active plugin."""
|
|
for name in list(self._records.keys()):
|
|
self.unload(name)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Query
|
|
# ------------------------------------------------------------------
|
|
|
|
def get(self, name: str) -> Optional[PluginRecord]:
|
|
return self._records.get(name)
|
|
|
|
def all_records(self) -> List[PluginRecord]:
|
|
return list(self._records.values())
|
|
|
|
def active_plugins(self) -> List[PluginRecord]:
|
|
return [r for r in self._records.values() if r.status == PluginStatus.ACTIVE]
|
|
|
|
def plugin_instance(self, name: str):
|
|
record = self._records.get(name)
|
|
return record.instance if record else None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal
|
|
# ------------------------------------------------------------------
|
|
|
|
def _notify_bus(self, topic: str, data: dict) -> None:
|
|
if self._bus:
|
|
self._bus.publish_sync(topic, data)
|