From 4d6e21fb9934354cf4234e1f00d91a3618eb7961 Mon Sep 17 00:00:00 2001 From: graepaul_amdeng Date: Thu, 25 Jun 2026 10:48:18 -0700 Subject: [PATCH 1/4] Quick self-review and adjustments --- nodescraper/models/analyzerargs.py | 15 +- nodescraper/pluginrecipe/__init__.py | 16 -- nodescraper/pluginrecipe/all_plugins.py | 4 +- nodescraper/pluginrecipe/discovery.py | 228 ++++++++++++--------- nodescraper/pluginrecipe/node_status.py | 4 +- nodescraper/pluginrecipe/pluginrecipe.py | 38 ++-- nodescraper/pluginregistry.py | 244 +++++++++++++++++------ 7 files changed, 350 insertions(+), 199 deletions(-) diff --git a/nodescraper/models/analyzerargs.py b/nodescraper/models/analyzerargs.py index f1782801..b73ee7b8 100644 --- a/nodescraper/models/analyzerargs.py +++ b/nodescraper/models/analyzerargs.py @@ -25,7 +25,7 @@ ############################################################################### from typing import Any -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, ConfigDict, model_serializer, model_validator class AnalyzerArgs(BaseModel): @@ -37,7 +37,16 @@ class AnalyzerArgs(BaseModel): """ - model_config = {"extra": "forbid", "exclude_none": True} + model_config = ConfigDict(extra="forbid") + + @model_serializer + def serialize_exclude_none(self) -> dict: + """Serialize the model to a dictionary, excluding None values. + + Returns: + A dictionary representation of the model with None values excluded. + """ + return self.model_dump(exclude_none=True) @model_validator(mode="before") @classmethod @@ -89,5 +98,5 @@ def build_from_model(cls, datamodel): NotImplementedError: Not implemented error """ raise NotImplementedError( - "Setting analyzer args from datamodel is not implemented for class: %s", cls.__name__ + f"Setting analyzer args from datamodel is not implemented for class: {cls.__name__}", ) diff --git a/nodescraper/pluginrecipe/__init__.py b/nodescraper/pluginrecipe/__init__.py index b37e91fd..b28ca993 100644 --- a/nodescraper/pluginrecipe/__init__.py +++ b/nodescraper/pluginrecipe/__init__.py @@ -8,15 +8,6 @@ from nodescraper.models import PluginConfig from .all_plugins import AllPlugins -from .discovery import ( - load_plugin_class, - plugin_has_analyzer, - plugin_has_collector, - plugin_names_matching, - plugins_with_analyzer, - plugins_with_collector, - registered_plugin_names, -) from .node_status import NodeStatus from .pluginrecipe import ( ANALYZE_ONLY, @@ -40,12 +31,5 @@ "PluginConfig", "PluginRecipe", "PluginRunFlags", - "load_plugin_class", "merge_plugin_configs", - "plugin_has_analyzer", - "plugin_has_collector", - "plugin_names_matching", - "plugins_with_analyzer", - "plugins_with_collector", - "registered_plugin_names", ] diff --git a/nodescraper/pluginrecipe/all_plugins.py b/nodescraper/pluginrecipe/all_plugins.py index 491ced5d..948c12e7 100644 --- a/nodescraper/pluginrecipe/all_plugins.py +++ b/nodescraper/pluginrecipe/all_plugins.py @@ -7,7 +7,7 @@ ############################################################################### from __future__ import annotations -from .discovery import registered_plugin_names +from .discovery import PluginDiscovery from .pluginrecipe import PluginRecipe @@ -21,4 +21,4 @@ def plugin_names(cls) -> tuple[str, ...]: Returns: tuple[str, ...]: Sorted names of all plugins in the plugin registry. """ - return registered_plugin_names() + return PluginDiscovery().registered_plugin_names() diff --git a/nodescraper/pluginrecipe/discovery.py b/nodescraper/pluginrecipe/discovery.py index be0eaa03..0abcd1f6 100644 --- a/nodescraper/pluginrecipe/discovery.py +++ b/nodescraper/pluginrecipe/discovery.py @@ -4,105 +4,153 @@ # # Copyright (c) 2025 Advanced Micro Devices, Inc. # +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ############################################################################### -from __future__ import annotations - -from typing import Iterable - - -def registered_plugin_names() -> tuple[str, ...]: - """Return all plugin names known to :class:`~nodescraper.pluginregistry.PluginRegistry`. - Returns: - tuple[str, ...]: Sorted registered plugin names. - """ - from nodescraper.pluginregistry import PluginRegistry - - return tuple(sorted(PluginRegistry().plugins)) - - -def plugin_names_matching(names: Iterable[str]) -> tuple[str, ...]: - """Return plugin names from ``names`` that are registered at runtime. - - Args: - names (Iterable[str]): Candidate plugin names to filter. - - Returns: - tuple[str, ...]: Sorted subset of ``names`` present in the plugin registry. - """ - available = set(registered_plugin_names()) - return tuple(sorted(name for name in names if name in available)) - - -def load_plugin_class(plugin_name: str) -> type | None: - """Return a registered plugin class by name. - - Args: - plugin_name (str): Registered plugin name. - - Returns: - type | None: Plugin class, or ``None`` if the name is not registered. - """ - from nodescraper.pluginregistry import PluginRegistry - - return PluginRegistry().plugins.get(plugin_name) - - -def plugin_has_collector(plugin_name: str) -> bool: - """Return whether the plugin exposes a collector task. +from __future__ import annotations - Args: - plugin_name (str): Registered plugin name. +import threading +from typing import TYPE_CHECKING, Iterable - Returns: - bool: ``True`` when the plugin class defines ``COLLECTOR``. - """ - plugin_class = load_plugin_class(plugin_name) - if plugin_class is None: - return False - collectors = getattr(plugin_class, "get_collector_classes", None) - if callable(collectors): - return bool(collectors()) - collector = getattr(plugin_class, "COLLECTOR", None) - if collector is None: - return False - if isinstance(collector, (tuple, list)): - return len(collector) > 0 - return True - - -def plugin_has_analyzer(plugin_name: str) -> bool: - """Return whether the plugin exposes an analyzer task. - - Args: - plugin_name (str): Registered plugin name. - - Returns: - bool: ``True`` when the plugin class defines ``ANALYZER``. - """ - plugin_class = load_plugin_class(plugin_name) - return plugin_class is not None and getattr(plugin_class, "ANALYZER", None) is not None +from nodescraper.pluginregistry import PluginRegistry +if TYPE_CHECKING: + from nodescraper.interfaces import PluginInterface -def plugins_with_collector(plugin_names: Iterable[str]) -> tuple[str, ...]: - """Filter plugin names to those that define a collector. - Args: - plugin_names (Iterable[str]): Candidate plugin names. +class PluginDiscovery: + """Allows for the discovery of plugins and their capabilities. These external plugins must be + registered with the :class:`~nodescraper.pluginregistry.PluginRegistry` before they can be discovered. - Returns: - tuple[str, ...]: Sorted plugin names with a ``COLLECTOR`` implementation. + This class can use a cache to avoid repeated PluginRegistry lookups, which can be expensive. + If use_cache is False, it will query the PluginRegistry each time. """ - return tuple(sorted(name for name in plugin_names if plugin_has_collector(name))) - -def plugins_with_analyzer(plugin_names: Iterable[str]) -> tuple[str, ...]: - """Filter plugin names to those that define an analyzer. + _plugin_cache: dict[str, type[PluginInterface]] | None = None + _cache_lock = threading.Lock() + COLLECTOR_ATTRIBUTE = "COLLECTOR" + ANALYZER_ATTRIBUTE = "ANALYZER" - Args: - plugin_names (Iterable[str]): Candidate plugin names. + def __init__(self, use_cache: bool = True) -> None: + """Initialize the PluginDiscovery instance. - Returns: - tuple[str, ...]: Sorted plugin names with an ``ANALYZER`` implementation. - """ - return tuple(sorted(name for name in plugin_names if plugin_has_analyzer(name))) + Args: + use_cache: If True, cache plugin lookups to improve performance. Defaults to True. + """ + self._use_cache = use_cache + + def load_plugin_class(self, plugin_name: str) -> type | None: + if not self._use_cache: + return PluginRegistry().plugins.get(plugin_name) + + if self._plugin_cache is None: + with self._cache_lock: + if self._plugin_cache is None: + self._plugin_cache = PluginRegistry().plugins + + return self._plugin_cache.get(plugin_name) + + def plugin_has_collector(self, plugin_name: str) -> bool: + """Check if a plugin has a COLLECTOR attribute. + + Args: + plugin_name: The name of the plugin to check. + + Returns: + True if the plugin exists and has a COLLECTOR attribute, False otherwise. + """ + plugin_class = self.load_plugin_class(plugin_name) + return ( + plugin_class is not None + and getattr(plugin_class, self.COLLECTOR_ATTRIBUTE, None) is not None + ) + + def plugin_has_analyzer(self, plugin_name: str) -> bool: + """Check if a plugin has an ANALYZER attribute. + + Args: + plugin_name: The name of the plugin to check. + + Returns: + True if the plugin exists and has an ANALYZER attribute, False otherwise. + """ + plugin_class = self.load_plugin_class(plugin_name) + return ( + plugin_class is not None + and getattr(plugin_class, self.ANALYZER_ATTRIBUTE, None) is not None + ) + + def plugins_with_collector(self, plugin_names: Iterable[str]) -> tuple[str, ...]: + """Filter a list of plugin names to those that have a COLLECTOR attribute. + + Args: + plugin_names: An iterable of plugin names to filter. + + Returns: + A sorted tuple of plugin names that have a COLLECTOR attribute. + """ + return tuple(sorted(name for name in plugin_names if self.plugin_has_collector(name))) + + def plugins_with_analyzer(self, plugin_names: Iterable[str]) -> tuple[str, ...]: + """Filter a list of plugin names to those that have an ANALYZER attribute. + + Args: + plugin_names: An iterable of plugin names to filter. + + Returns: + A sorted tuple of plugin names that have an ANALYZER attribute. + """ + return tuple(sorted(name for name in plugin_names if self.plugin_has_analyzer(name))) + + def clear_cache(self) -> None: + """Clears the plugin cache, forcing future lookups to query the PluginRegistry again. + + Thread-safe: Acquires the cache lock to ensure no other thread is accessing the cache. + """ + with self._cache_lock: + self._plugin_cache = None + + def registered_plugin_names(self) -> tuple[str, ...]: + """Return all plugin names known to :class:`~nodescraper.pluginregistry.PluginRegistry`. + + Returns: + tuple[str, ...]: Sorted registered plugin names. + """ + if not self._use_cache: + return tuple(sorted(PluginRegistry().plugins.keys())) + + if self._plugin_cache is None: + with self._cache_lock: + if self._plugin_cache is None: + self._plugin_cache = PluginRegistry().plugins + + return tuple(sorted(self._plugin_cache.keys())) + + def plugin_names_matching(self, names: Iterable[str]) -> tuple[str, ...]: + """Return plugin names from ``names`` that are registered at runtime. + + Args: + names (Iterable[str]): Candidate plugin names to filter. + + Returns: + tuple[str, ...]: Sorted subset of ``names`` present in the plugin registry. + """ + available = set(self.registered_plugin_names()) + return tuple(sorted(name for name in names if name in available)) diff --git a/nodescraper/pluginrecipe/node_status.py b/nodescraper/pluginrecipe/node_status.py index 42be8d01..d0237758 100644 --- a/nodescraper/pluginrecipe/node_status.py +++ b/nodescraper/pluginrecipe/node_status.py @@ -7,7 +7,7 @@ ############################################################################### from __future__ import annotations -from .discovery import plugin_names_matching +from .discovery import PluginDiscovery from .pluginrecipe import PluginRecipe _NODE_STATUS_PLUGINS = ( @@ -35,4 +35,4 @@ def plugin_names(cls) -> tuple[str, ...]: Returns: tuple[str, ...]: Sorted node-status plugin names registered in the plugin registry. """ - return plugin_names_matching(_NODE_STATUS_PLUGINS) + return PluginDiscovery().plugin_names_matching(_NODE_STATUS_PLUGINS) diff --git a/nodescraper/pluginrecipe/pluginrecipe.py b/nodescraper/pluginrecipe/pluginrecipe.py index e9b5a3ba..5c9bc34e 100644 --- a/nodescraper/pluginrecipe/pluginrecipe.py +++ b/nodescraper/pluginrecipe/pluginrecipe.py @@ -5,27 +5,29 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. # ############################################################################### + from __future__ import annotations import abc -from dataclasses import dataclass from typing import Any, Iterable +from pydantic import BaseModel, ConfigDict + from nodescraper.models import PluginConfig +from nodescraper.pluginrecipe.discovery import PluginDiscovery -@dataclass(frozen=True) -class PluginRunFlags: - """Collection/analysis toggles passed to nodescraper ``DataPlugin.run``.""" +class PluginRunFlags(BaseModel): + model_config = ConfigDict(frozen=True) collection: bool = True analysis: bool = True def as_config(self) -> dict[str, bool]: - """Return nodescraper per-plugin config fields. + """Return collection and analysis fields for one plugin entry. Returns: - dict[str, bool]: ``collection`` and ``analysis`` entries for a plugin config. + dict[str, bool]: ``collection`` and ``analysis`` entries for a plugin entry. """ return {"collection": self.collection, "analysis": self.analysis} @@ -71,7 +73,7 @@ def description(cls) -> str: @classmethod def flags_for_plugin(cls, plugin_name: str) -> PluginRunFlags: - """Return collection/analysis flags for one plugin. + """Return collection and analysis flags for the given plugin. Args: plugin_name (str): Registered plugin name. @@ -91,18 +93,17 @@ def extra_plugin_args(cls, plugin_name: str) -> dict[str, Any]: Returns: dict[str, Any]: Extra plugin config kwargs merged into the entry dict. """ - del plugin_name + _plugin_name = plugin_name # Avoid unused variable warning return {} @classmethod def plugin_entry(cls, plugin_name: str) -> dict[str, Any]: - """Build the nodescraper plugin config entry for one plugin. - + """Build the per-plugin config entry for one plugin.plugins: type[PluginInterface] | None Args: plugin_name (str): Registered plugin name. Returns: - dict[str, Any]: Per-plugin config passed to node-scraper. + dict[str, Any]: Per-plugin config for a single plugin. """ entry: dict[str, Any] = dict(cls.flags_for_plugin(plugin_name).as_config()) entry.update(cls.extra_plugin_args(plugin_name)) @@ -110,11 +111,10 @@ def plugin_entry(cls, plugin_name: str) -> dict[str, Any]: @classmethod def plugin_config(cls) -> PluginConfig: - """Build a node-scraper plugin config at runtime. + """Build the full plugin config for this recipe. Returns: - PluginConfig: Plugin config with ``name``, ``desc``, ``global_args``, - ``plugins``, and ``result_collators`` fields. + PluginConfig: Config with recipe name, description, and per-plugin entries. """ return PluginConfig( name=cls.name(), @@ -138,11 +138,10 @@ def filter_plugin_names(cls, names: Iterable[str]) -> tuple[str, ...]: names (Iterable[str]): Candidate plugin names. Returns: - tuple[str, ...]: Sorted names with a ``COLLECTOR`` implementation. + tuple[str, ...]: Sorted names that implement collection. """ - from .discovery import plugins_with_collector - return plugins_with_collector(names) + return PluginDiscovery().plugins_with_collector(names) class AnalyzerOnlyPluginRecipe(PluginRecipe): @@ -158,11 +157,10 @@ def filter_plugin_names(cls, names: Iterable[str]) -> tuple[str, ...]: names (Iterable[str]): Candidate plugin names. Returns: - tuple[str, ...]: Sorted names with an ``ANALYZER`` implementation. + tuple[str, ...]: Sorted names that implement analysis. """ - from .discovery import plugins_with_analyzer - return plugins_with_analyzer(names) + return PluginDiscovery().plugins_with_analyzer(names) def merge_plugin_configs(*configs: PluginConfig | dict[str, Any]) -> PluginConfig: diff --git a/nodescraper/pluginregistry.py b/nodescraper/pluginregistry.py index 5abc2f84..32b1c5a3 100644 --- a/nodescraper/pluginregistry.py +++ b/nodescraper/pluginregistry.py @@ -27,7 +27,9 @@ import importlib.metadata import inspect import pkgutil +import threading import types +from importlib.metadata import EntryPoints from typing import Iterable, Optional import nodescraper.connection as internal_connections @@ -39,8 +41,32 @@ PluginResultCollator, ) +# Entry point group names +ENTRY_POINT_PLUGINS = "nodescraper.plugins" +ENTRY_POINT_CONNECTION_MANAGERS = "nodescraper.connection_managers" + class PluginRegistry: + """This class dynamically loads plugins. Internal plugins are loaded by default using + the ``nodescraper.plugins``, ``nodescraper.connection``, and ``nodescraper.resultcollators`` packages. + A caller of node-scraper can also specify entry points for plugins and connection managers. The + user could also define entrypoints which ``nodescraper.connection_managers`` or ``nodescraper.plugins`` + entry point groups. The PluginRegistry will load these plugins and connection managers as well. + """ + + # Class-level caches for entry points (shared across all instances) + _entry_point_plugins_cache: Optional[dict[str, type]] = None + _entry_point_connection_managers_cache: Optional[dict[str, type]] = None + # Cache for loaded modules to avoid re-importing + _module_cache: dict[str, types.ModuleType] = {} + # Cache for entry points by group name + _entry_points_cache: dict[str, EntryPoints] = {} + + # Global cache control switch + _use_cache: bool = True + + # Single lock for all cache operations to ensure atomicity + _cache_lock = threading.RLock() def __init__( self, @@ -59,7 +85,11 @@ def __init__( ``nodescraper.connection_managers`` entry-point group. Defaults to True. """ if load_internal_plugins: - self.plugin_pkg = [internal_plugins, internal_connections, internal_collators] + self.plugin_pkg = [ + internal_plugins, + internal_connections, + internal_collators, + ] else: self.plugin_pkg = [] @@ -105,12 +135,22 @@ def load_plugins( def _recurse_pkg(pkg: types.ModuleType, base_class: type) -> None: for _, module_name, ispkg in pkgutil.iter_modules(pkg.__path__, pkg.__name__ + "."): - module = importlib.import_module(module_name) + # Check module cache first with thread safety (if caching enabled) + if PluginRegistry._use_cache: + with PluginRegistry._cache_lock: + if module_name in PluginRegistry._module_cache: + module = PluginRegistry._module_cache[module_name] + else: + module = importlib.import_module(module_name) + PluginRegistry._module_cache[module_name] = module + else: + module = importlib.import_module(module_name) + for _, plugin in inspect.getmembers( module, - lambda x: inspect.isclass(x) - and issubclass(x, base_class) - and not inspect.isabstract(x), + lambda x: PluginRegistry._valid_sub_class_check( + in_cls=x, base_class=base_class + ), ): if hasattr(plugin, "is_valid") and not plugin.is_valid(): continue @@ -122,6 +162,44 @@ def _recurse_pkg(pkg: types.ModuleType, base_class: type) -> None: _recurse_pkg(pkg, base_class) return registry + @staticmethod + def _valid_sub_class_check(in_cls: type, base_class: type) -> bool: + """Check if a class is a subclass of the specified base class. + + Args: + cls (type): The class to check. + base_class (type): The base class to check against. + + Returns: + bool: True if cls is a subclass of base_class, False otherwise. + """ + return ( + inspect.isclass(in_cls) + and issubclass(in_cls, base_class) + and not inspect.isabstract(in_cls) + ) + + @staticmethod + def _load_connection_managers_uncached() -> dict[str, type]: + """Internal: Load connection managers without caching logic.""" + managers: dict[str, type] = {} + eps: Iterable = PluginRegistry.load_entry_points(ENTRY_POINT_CONNECTION_MANAGERS) + + for entry_point in eps: + loaded = entry_point.load() # type: ignore[attr-defined, union-attr] + if not PluginRegistry._valid_sub_class_check( + in_cls=loaded, base_class=ConnectionManager + ): + continue + if hasattr(loaded, "is_valid") and not loaded.is_valid(): + continue + cls = loaded + managers[cls.__name__] = cls + ep_name = getattr(entry_point, "name", None) + if ep_name and ep_name != cls.__name__: + managers[ep_name] = cls + return managers + @staticmethod def load_connection_managers_from_entry_points() -> dict[str, type]: """Load ConnectionManager subclasses from ``nodescraper.connection_managers`` entry points. @@ -132,41 +210,77 @@ def load_connection_managers_from_entry_points() -> dict[str, type]: Returns: dict[str, type]: Map of lookup key to connection manager class. """ - managers: dict[str, type] = {} + # Return cached result if caching is enabled and cache exists + if ( + PluginRegistry._use_cache + and PluginRegistry._entry_point_connection_managers_cache is not None + ): + return PluginRegistry._entry_point_connection_managers_cache + + # If caching disabled, skip lock and always reload + if not PluginRegistry._use_cache: + return PluginRegistry._load_connection_managers_uncached() + + with PluginRegistry._cache_lock: + # Check again inside the lock to prevent duplicate work + if PluginRegistry._entry_point_connection_managers_cache is not None: + return PluginRegistry._entry_point_connection_managers_cache + + managers = PluginRegistry._load_connection_managers_uncached() + + # Cache the result + PluginRegistry._entry_point_connection_managers_cache = managers + return managers + @staticmethod + def _load_entry_points_uncached(entry_point: str) -> EntryPoints: + """Internal: Load entry points without caching logic.""" try: - eps: Iterable - try: - eps = importlib.metadata.entry_points( # type: ignore[call-arg] - group="nodescraper.connection_managers" - ) - except TypeError: - all_eps = importlib.metadata.entry_points() # type: ignore[assignment] - eps = all_eps.get("nodescraper.connection_managers", []) # type: ignore[assignment, attr-defined, arg-type] - - for entry_point in eps: - try: - loaded = entry_point.load() # type: ignore[attr-defined, union-attr] - if not ( - inspect.isclass(loaded) - and issubclass(loaded, ConnectionManager) - and not inspect.isabstract(loaded) - ): - continue - if hasattr(loaded, "is_valid") and not loaded.is_valid(): - continue - cls = loaded - managers[cls.__name__] = cls - ep_name = getattr(entry_point, "name", None) - if ep_name and ep_name != cls.__name__: - managers[ep_name] = cls - except Exception: - pass + eps: EntryPoints = importlib.metadata.entry_points(group=entry_point) # type: ignore[call-arg] + except TypeError: + all_eps: EntryPoints = importlib.metadata.entry_points() # type: ignore[assignment] + eps = all_eps.get(entry_point, []) # type: ignore[assignment, attr-defined, arg-type] + return eps - except Exception: - pass + @staticmethod + def load_entry_points(entry_point: str) -> EntryPoints: + # Return cached result if caching is enabled and cache exists + if PluginRegistry._use_cache and entry_point in PluginRegistry._entry_points_cache: + return PluginRegistry._entry_points_cache[entry_point] - return managers + # If caching disabled, skip lock and always reload + if not PluginRegistry._use_cache: + return PluginRegistry._load_entry_points_uncached(entry_point) + + with PluginRegistry._cache_lock: + # Check again inside the lock to prevent duplicate work + if entry_point in PluginRegistry._entry_points_cache: + return PluginRegistry._entry_points_cache[entry_point] + + eps = PluginRegistry._load_entry_points_uncached(entry_point) + + # Cache the result + PluginRegistry._entry_points_cache[entry_point] = eps + return eps + + @staticmethod + def _load_plugins_uncached() -> dict[str, type]: + """Internal: Load plugins without caching logic.""" + plugins = {} + eps: Iterable = PluginRegistry.load_entry_points(ENTRY_POINT_PLUGINS) + + for entry_point in eps: + plugin_class = entry_point.load() # type: ignore[attr-defined, union-attr] + + if not PluginRegistry._valid_sub_class_check( + in_cls=plugin_class, base_class=PluginInterface + ): + continue + if hasattr(plugin_class, "is_valid") and not plugin_class.is_valid(): + continue + + plugins[plugin_class.__name__] = plugin_class + return plugins @staticmethod def load_plugins_from_entry_points() -> dict[str, type]: @@ -175,35 +289,33 @@ def load_plugins_from_entry_points() -> dict[str, type]: Returns: dict[str, type]: A dictionary mapping plugin names to their classes. """ - plugins = {} + # Return cached result if caching is enabled and cache exists + if PluginRegistry._use_cache and PluginRegistry._entry_point_plugins_cache is not None: + return PluginRegistry._entry_point_plugins_cache.copy() - try: - eps: Iterable - # Python 3.10+ supports group parameter - try: - eps = importlib.metadata.entry_points(group="nodescraper.plugins") # type: ignore[call-arg] - except TypeError: - # Python 3.9 - entry_points() returns dict-like object - all_eps = importlib.metadata.entry_points() # type: ignore[assignment] - eps = all_eps.get("nodescraper.plugins", []) # type: ignore[assignment, attr-defined, arg-type] - - for entry_point in eps: - try: - plugin_class = entry_point.load() # type: ignore[attr-defined, union-attr] - - if ( - inspect.isclass(plugin_class) - and issubclass(plugin_class, PluginInterface) - and not inspect.isabstract(plugin_class) - ): - if hasattr(plugin_class, "is_valid") and not plugin_class.is_valid(): - continue - - plugins[plugin_class.__name__] = plugin_class - except Exception: - pass - - except Exception: - pass + # If caching disabled, skip lock and always reload + if not PluginRegistry._use_cache: + return PluginRegistry._load_plugins_uncached() - return plugins + with PluginRegistry._cache_lock: + # Check again inside the lock to prevent duplicate work + if PluginRegistry._entry_point_plugins_cache is not None: + return PluginRegistry._entry_point_plugins_cache.copy() + + plugins = PluginRegistry._load_plugins_uncached() + + # Cache the result - no need to copy before caching + PluginRegistry._entry_point_plugins_cache = plugins + return plugins + + @classmethod + def clear_caches(cls) -> None: + """Clear all caches. Useful for testing or when plugins are dynamically installed. + + Thread-safe: Acquires all locks to ensure no other thread is accessing caches. + """ + with cls._cache_lock: + cls._entry_point_plugins_cache = None + cls._entry_point_connection_managers_cache = None + cls._module_cache.clear() + cls._entry_points_cache.clear() From 890cfb01103db90627020890b59e9d1f22df4e40 Mon Sep 17 00:00:00 2001 From: graepaul_amdeng Date: Thu, 25 Jun 2026 11:44:44 -0700 Subject: [PATCH 2/4] Adding tests --- nodescraper/models/analyzerargs.py | 26 +- nodescraper/pluginrecipe/discovery.py | 28 +- test/functional/test_discovery.py | 406 ++++++++++++++++++++++++ test/functional/test_plugin_registry.py | 100 ++++++ 4 files changed, 537 insertions(+), 23 deletions(-) create mode 100644 test/functional/test_discovery.py diff --git a/nodescraper/models/analyzerargs.py b/nodescraper/models/analyzerargs.py index b73ee7b8..11744e66 100644 --- a/nodescraper/models/analyzerargs.py +++ b/nodescraper/models/analyzerargs.py @@ -25,7 +25,13 @@ ############################################################################### from typing import Any -from pydantic import BaseModel, ConfigDict, model_serializer, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + SerializerFunctionWrapHandler, + model_serializer, + model_validator, +) class AnalyzerArgs(BaseModel): @@ -39,14 +45,16 @@ class AnalyzerArgs(BaseModel): model_config = ConfigDict(extra="forbid") - @model_serializer - def serialize_exclude_none(self) -> dict: - """Serialize the model to a dictionary, excluding None values. - - Returns: - A dictionary representation of the model with None values excluded. - """ - return self.model_dump(exclude_none=True) + @model_serializer(mode="wrap") + def serialize_model(self, handler: SerializerFunctionWrapHandler) -> dict[str, object]: + serialized = handler(self) + remove_keys = [] + for key, value in serialized.items(): + if value is None: + remove_keys.append(key) + for key in remove_keys: + del serialized[key] + return serialized @model_validator(mode="before") @classmethod diff --git a/nodescraper/pluginrecipe/discovery.py b/nodescraper/pluginrecipe/discovery.py index 0abcd1f6..337a0720 100644 --- a/nodescraper/pluginrecipe/discovery.py +++ b/nodescraper/pluginrecipe/discovery.py @@ -60,12 +60,12 @@ def load_plugin_class(self, plugin_name: str) -> type | None: if not self._use_cache: return PluginRegistry().plugins.get(plugin_name) - if self._plugin_cache is None: - with self._cache_lock: - if self._plugin_cache is None: - self._plugin_cache = PluginRegistry().plugins + if PluginDiscovery._plugin_cache is None: + with PluginDiscovery._cache_lock: + if PluginDiscovery._plugin_cache is None: + PluginDiscovery._plugin_cache = PluginRegistry().plugins - return self._plugin_cache.get(plugin_name) + return PluginDiscovery._plugin_cache.get(plugin_name) def plugin_has_collector(self, plugin_name: str) -> bool: """Check if a plugin has a COLLECTOR attribute. @@ -119,13 +119,14 @@ def plugins_with_analyzer(self, plugin_names: Iterable[str]) -> tuple[str, ...]: """ return tuple(sorted(name for name in plugin_names if self.plugin_has_analyzer(name))) - def clear_cache(self) -> None: + @staticmethod + def clear_cache() -> None: """Clears the plugin cache, forcing future lookups to query the PluginRegistry again. Thread-safe: Acquires the cache lock to ensure no other thread is accessing the cache. """ - with self._cache_lock: - self._plugin_cache = None + with PluginDiscovery._cache_lock: + PluginDiscovery._plugin_cache = None def registered_plugin_names(self) -> tuple[str, ...]: """Return all plugin names known to :class:`~nodescraper.pluginregistry.PluginRegistry`. @@ -136,12 +137,11 @@ def registered_plugin_names(self) -> tuple[str, ...]: if not self._use_cache: return tuple(sorted(PluginRegistry().plugins.keys())) - if self._plugin_cache is None: - with self._cache_lock: - if self._plugin_cache is None: - self._plugin_cache = PluginRegistry().plugins - - return tuple(sorted(self._plugin_cache.keys())) + if PluginDiscovery._plugin_cache is None: + with PluginDiscovery._cache_lock: + if PluginDiscovery._plugin_cache is None: + PluginDiscovery._plugin_cache = PluginRegistry().plugins + return tuple(sorted(PluginDiscovery._plugin_cache.keys())) def plugin_names_matching(self, names: Iterable[str]) -> tuple[str, ...]: """Return plugin names from ``names`` that are registered at runtime. diff --git a/test/functional/test_discovery.py b/test/functional/test_discovery.py new file mode 100644 index 00000000..33f58278 --- /dev/null +++ b/test/functional/test_discovery.py @@ -0,0 +1,406 @@ +############################################################################### +# +# MIT License +# +# Copyright (c) 2025 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +############################################################################### +"""Functional tests for plugin discovery functionality.""" + +import threading + +from nodescraper.pluginrecipe.discovery import PluginDiscovery +from nodescraper.pluginregistry import PluginRegistry + + +def test_plugin_discovery_initialization(): + """Test that PluginDiscovery can be initialized with and without cache.""" + # With cache (default) + discovery_cached = PluginDiscovery() + assert discovery_cached._use_cache is True + + # Without cache + discovery_no_cache = PluginDiscovery(use_cache=False) + assert discovery_no_cache._use_cache is False + + +def test_registered_plugin_names(): + """Test that registered_plugin_names returns plugin names.""" + discovery = PluginDiscovery() + plugin_names = discovery.registered_plugin_names() + + assert isinstance(plugin_names, tuple) + assert len(plugin_names) > 0 + # Should be sorted + assert plugin_names == tuple(sorted(plugin_names)) + # All should be strings + assert all(isinstance(name, str) for name in plugin_names) + + +def test_registered_plugin_names_matches_registry(): + """Test that registered_plugin_names matches the PluginRegistry.""" + discovery = PluginDiscovery() + registry = PluginRegistry() + + discovery_names = set(discovery.registered_plugin_names()) + registry_names = set(registry.plugins.keys()) + + assert discovery_names == registry_names + + +def test_load_plugin_class_existing(): + """Test loading an existing plugin class.""" + discovery = PluginDiscovery() + plugin_names = discovery.registered_plugin_names() + + if len(plugin_names) == 0: + return # Skip if no plugins + + # Load first plugin + plugin_name = plugin_names[0] + plugin_class = discovery.load_plugin_class(plugin_name) + + assert plugin_class is not None + assert hasattr(plugin_class, "run") + + +def test_load_plugin_class_nonexistent(): + """Test loading a non-existent plugin returns None.""" + discovery = PluginDiscovery() + plugin_class = discovery.load_plugin_class("NonExistentPlugin12345") + + assert plugin_class is None + + +def test_plugin_has_collector(): + """Test checking if plugins have COLLECTOR attribute.""" + discovery = PluginDiscovery() + plugin_names = discovery.registered_plugin_names() + + # Test with all plugins + for plugin_name in plugin_names: + has_collector = discovery.plugin_has_collector(plugin_name) + assert isinstance(has_collector, bool) + + # Verify against actual plugin class + plugin_class = discovery.load_plugin_class(plugin_name) + expected = getattr(plugin_class, "COLLECTOR", None) is not None + assert has_collector == expected + + +def test_plugin_has_collector_nonexistent(): + """Test that non-existent plugins return False for has_collector.""" + discovery = PluginDiscovery() + assert discovery.plugin_has_collector("NonExistentPlugin12345") is False + + +def test_plugin_has_analyzer(): + """Test checking if plugins have ANALYZER attribute.""" + discovery = PluginDiscovery() + plugin_names = discovery.registered_plugin_names() + + # Test with all plugins + for plugin_name in plugin_names: + has_analyzer = discovery.plugin_has_analyzer(plugin_name) + assert isinstance(has_analyzer, bool) + + # Verify against actual plugin class + plugin_class = discovery.load_plugin_class(plugin_name) + expected = getattr(plugin_class, "ANALYZER", None) is not None + assert has_analyzer == expected + + +def test_plugin_has_analyzer_nonexistent(): + """Test that non-existent plugins return False for has_analyzer.""" + discovery = PluginDiscovery() + assert discovery.plugin_has_analyzer("NonExistentPlugin12345") is False + + +def test_plugins_with_collector(): + """Test filtering plugins that have COLLECTOR attribute.""" + discovery = PluginDiscovery() + all_plugins = discovery.registered_plugin_names() + + plugins_with_collector = discovery.plugins_with_collector(all_plugins) + + assert isinstance(plugins_with_collector, tuple) + # Should be sorted + assert plugins_with_collector == tuple(sorted(plugins_with_collector)) + + # Verify each one actually has COLLECTOR + for plugin_name in plugins_with_collector: + assert discovery.plugin_has_collector(plugin_name) + + +def test_plugins_with_collector_empty(): + """Test plugins_with_collector with empty input.""" + discovery = PluginDiscovery() + result = discovery.plugins_with_collector([]) + assert result == () + + +def test_plugins_with_collector_mixed(): + """Test plugins_with_collector with mix of valid and invalid names.""" + discovery = PluginDiscovery() + all_plugins = discovery.registered_plugin_names() + + if len(all_plugins) == 0: + return + + # Mix real and fake plugin names + test_names = list(all_plugins[:3]) + ["FakePlugin1", "FakePlugin2"] + result = discovery.plugins_with_collector(test_names) + + # Should only contain valid plugins with collector + for name in result: + assert name in all_plugins + assert discovery.plugin_has_collector(name) + + +def test_plugins_with_analyzer(): + """Test filtering plugins that have ANALYZER attribute.""" + discovery = PluginDiscovery() + all_plugins = discovery.registered_plugin_names() + + plugins_with_analyzer = discovery.plugins_with_analyzer(all_plugins) + + assert isinstance(plugins_with_analyzer, tuple) + # Should be sorted + assert plugins_with_analyzer == tuple(sorted(plugins_with_analyzer)) + + # Verify each one actually has ANALYZER + for plugin_name in plugins_with_analyzer: + assert discovery.plugin_has_analyzer(plugin_name) + + +def test_plugins_with_analyzer_empty(): + """Test plugins_with_analyzer with empty input.""" + discovery = PluginDiscovery() + result = discovery.plugins_with_analyzer([]) + assert result == () + + +def test_plugins_with_analyzer_mixed(): + """Test plugins_with_analyzer with mix of valid and invalid names.""" + discovery = PluginDiscovery() + all_plugins = discovery.registered_plugin_names() + + if len(all_plugins) == 0: + return + + # Mix real and fake plugin names + test_names = list(all_plugins[:3]) + ["FakePlugin1", "FakePlugin2"] + result = discovery.plugins_with_analyzer(test_names) + + # Should only contain valid plugins with analyzer + for name in result: + assert name in all_plugins + assert discovery.plugin_has_analyzer(name) + + +def test_plugin_names_matching(): + """Test plugin_names_matching returns only registered plugins.""" + discovery = PluginDiscovery() + all_plugins = discovery.registered_plugin_names() + + if len(all_plugins) == 0: + return + + # Test with mix of valid and invalid names + test_names = [ + all_plugins[0] if len(all_plugins) > 0 else "ValidPlugin", + "NonExistentPlugin1", + all_plugins[1] if len(all_plugins) > 1 else "AnotherPlugin", + "NonExistentPlugin2", + ] + + matched = discovery.plugin_names_matching(test_names) + + assert isinstance(matched, tuple) + # Should be sorted + assert matched == tuple(sorted(matched)) + # Should only contain plugins that exist + for name in matched: + assert name in all_plugins + + +def test_plugin_names_matching_empty(): + """Test plugin_names_matching with empty input.""" + discovery = PluginDiscovery() + matched = discovery.plugin_names_matching([]) + assert matched == () + + +def test_plugin_names_matching_all_invalid(): + """Test plugin_names_matching with all non-existent plugins.""" + discovery = PluginDiscovery() + matched = discovery.plugin_names_matching(["Fake1", "Fake2", "NotReal"]) + assert matched == () + + +def test_plugin_names_matching_all_valid(): + """Test plugin_names_matching with all valid plugins.""" + discovery = PluginDiscovery() + all_plugins = discovery.registered_plugin_names() + + if len(all_plugins) == 0: + return + + # Take first 3 plugins + test_plugins = all_plugins[: min(3, len(all_plugins))] + matched = discovery.plugin_names_matching(test_plugins) + + # Should return all of them + assert set(matched) == set(test_plugins) + # Should be sorted + assert matched == tuple(sorted(matched)) + + +def test_cache_behavior(): + """Test that caching works correctly.""" + # Clear any existing cache + PluginDiscovery._plugin_cache = None + + discovery = PluginDiscovery(use_cache=True) + + # First call should populate class-level cache + assert PluginDiscovery._plugin_cache is None + plugins1 = discovery.registered_plugin_names() + assert PluginDiscovery._plugin_cache is not None + + # Second call should use cache and return same results + plugins2 = discovery.registered_plugin_names() + assert plugins1 == plugins2 + + +def test_no_cache_behavior(): + """Test that use_cache=False bypasses cache.""" + # Clear cache + PluginDiscovery._plugin_cache = None + + discovery = PluginDiscovery(use_cache=False) + + # Get plugins without caching + plugins = discovery.registered_plugin_names() + # Should still get valid results + assert len(plugins) > 0 + # Cache should remain None + assert PluginDiscovery._plugin_cache is None + + +def test_clear_cache(): + """Test that clear_cache properly clears the cache.""" + discovery = PluginDiscovery(use_cache=True) + + # Populate cache by calling registered_plugin_names + discovery.registered_plugin_names() + assert PluginDiscovery._plugin_cache is not None + + # Clear cache + discovery.clear_cache() + assert PluginDiscovery._plugin_cache is None + + +def test_concurrent_access_thread_safe(): + """Test that concurrent cache access is thread-safe.""" + # Clear cache before test + PluginDiscovery._plugin_cache = None + + results = [] + errors = [] + + def load_plugins_worker(): + try: + discovery = PluginDiscovery(use_cache=True) + plugins = discovery.registered_plugin_names() + results.append(plugins) + except Exception as e: + errors.append(e) + + # Create 10 threads + threads = [threading.Thread(target=load_plugins_worker) for _ in range(10)] + + # Start all threads + for thread in threads: + thread.start() + + # Wait for completion + for thread in threads: + thread.join() + + # No errors should occur + assert len(errors) == 0, f"Thread safety errors: {errors}" + + # All threads should get consistent results + assert len(results) == 10 + first_result = set(results[0]) + for result in results[1:]: + assert set(result) == first_result + + +def test_class_attributes(): + """Test that class attributes have expected values.""" + assert PluginDiscovery.COLLECTOR_ATTRIBUTE == "COLLECTOR" + assert PluginDiscovery.ANALYZER_ATTRIBUTE == "ANALYZER" + + +def test_load_plugin_class_with_cache(): + """Test load_plugin_class uses cache when enabled.""" + # Clear cache + PluginDiscovery._plugin_cache = None + + discovery = PluginDiscovery(use_cache=True) + plugin_names = discovery.registered_plugin_names() + + if len(plugin_names) == 0: + return + + plugin_name = plugin_names[0] + + # First load should populate cache + assert PluginDiscovery._plugin_cache is None or len(PluginDiscovery._plugin_cache) > 0 + plugin1 = discovery.load_plugin_class(plugin_name) + assert plugin1 is not None + assert PluginDiscovery._plugin_cache is not None + + # Second load should use cache and return same class + plugin2 = discovery.load_plugin_class(plugin_name) + assert plugin1 is plugin2 + + +def test_load_plugin_class_without_cache(): + """Test load_plugin_class bypasses cache when disabled.""" + # Clear cache + PluginDiscovery._plugin_cache = None + + discovery = PluginDiscovery(use_cache=False) + plugin_names = discovery.registered_plugin_names() + + if len(plugin_names) == 0: + return + + plugin_name = plugin_names[0] + + # Load without cache + plugin = discovery.load_plugin_class(plugin_name) + assert plugin is not None + # Cache should remain None + assert PluginDiscovery._plugin_cache is None diff --git a/test/functional/test_plugin_registry.py b/test/functional/test_plugin_registry.py index 77d352f7..a47c92bf 100644 --- a/test/functional/test_plugin_registry.py +++ b/test/functional/test_plugin_registry.py @@ -26,6 +26,7 @@ """Functional tests for plugin registry and plugin loading.""" import inspect +import threading from nodescraper.pluginregistry import PluginRegistry @@ -73,3 +74,102 @@ def test_plugin_registry_get_plugin(): assert plugin is not None assert hasattr(plugin, "run") + + +# ============================================================================ +# CACHING TESTS +# ============================================================================ + + +def test_entry_point_plugins_are_cached(): + """Test that entry point plugins are cached and subsequent calls use the cache.""" + # Clear cache to start fresh + PluginRegistry.clear_caches() + assert PluginRegistry._entry_point_plugins_cache is None + + # First call - should populate cache + plugins1 = PluginRegistry.load_plugins_from_entry_points() + assert PluginRegistry._entry_point_plugins_cache is not None + + # Second call - should return from cache (but as a copy) + plugins2 = PluginRegistry.load_plugins_from_entry_points() + + # Verify it's a copy (different object but same content) + assert plugins1 is not plugins2, "Should return copy, not same reference" + assert plugins1 == plugins2, "Content should be identical" + + +def test_cache_returns_copy_prevents_corruption(): + """Test that cache returns a copy to prevent caller modifications from corrupting cache.""" + PluginRegistry.clear_caches() + + # Get plugins from cache + plugins1 = PluginRegistry.load_plugins_from_entry_points() + plugins2 = PluginRegistry.load_plugins_from_entry_points() + + # Modify first copy + if plugins1: + test_key = list(plugins1.keys())[0] + plugins1.pop(test_key) + assert test_key not in plugins1 + + # Second copy should be unaffected + if plugins2: + test_key = list(plugins2.keys())[0] + assert test_key in plugins2, "Cache was corrupted by caller modification" + + +def test_concurrent_cache_access_thread_safe(): + """Test that concurrent cache access is thread-safe with no race conditions.""" + PluginRegistry.clear_caches() + results = [] + errors = [] + + def load_plugins_worker(): + try: + plugins = PluginRegistry.load_plugins_from_entry_points() + results.append(plugins) + except Exception as e: + errors.append(e) + + # Create 10 threads that simultaneously try to load plugins + threads = [threading.Thread(target=load_plugins_worker) for _ in range(10)] + + # Start all threads at once + for thread in threads: + thread.start() + + # Wait for completion + for thread in threads: + thread.join() + + # No errors should occur + assert len(errors) == 0, f"Thread safety errors: {errors}" + + # All threads should get consistent results + assert len(results) == 10 + first_keys = set(results[0].keys()) + for result in results[1:]: + assert set(result.keys()) == first_keys, "Inconsistent results across threads" + + +def test_clear_caches_resets_all_caches(): + """Test that clear_caches properly clears all cache storage.""" + # Populate all caches + PluginRegistry.load_plugins_from_entry_points() + PluginRegistry.load_connection_managers_from_entry_points() + PluginRegistry.load_entry_points("nodescraper.plugins") + + # Verify caches are populated + assert PluginRegistry._entry_point_plugins_cache is not None + assert PluginRegistry._entry_point_connection_managers_cache is not None + assert len(PluginRegistry._entry_points_cache) > 0 + + # Clear all caches + PluginRegistry.clear_caches() + + # Verify all caches are cleared + assert PluginRegistry._entry_point_plugins_cache is None + assert PluginRegistry._entry_point_connection_managers_cache is None + assert len(PluginRegistry._entry_points_cache) == 0 + assert len(PluginRegistry._module_cache) == 0 From 95662cd580631923c73083ff0625bb3b6701e70e Mon Sep 17 00:00:00 2001 From: graepaul_amdeng Date: Thu, 25 Jun 2026 15:41:25 -0700 Subject: [PATCH 3/4] Updating to work with 3.9 --- nodescraper/pluginregistry.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/nodescraper/pluginregistry.py b/nodescraper/pluginregistry.py index 32b1c5a3..8bc6dc18 100644 --- a/nodescraper/pluginregistry.py +++ b/nodescraper/pluginregistry.py @@ -29,9 +29,14 @@ import pkgutil import threading import types -from importlib.metadata import EntryPoints from typing import Iterable, Optional +# Python 3.9 compatibility: EntryPoints type was added in 3.10 +try: + from importlib.metadata import EntryPoints +except ImportError: + EntryPoints = Iterable # type: ignore[misc, assignment] + import nodescraper.connection as internal_connections import nodescraper.plugins as internal_plugins import nodescraper.resultcollators as internal_collators From 290f842859d30ec49af16d9fb8c48c0dfbd6c071 Mon Sep 17 00:00:00 2001 From: graepaul_amdeng Date: Thu, 25 Jun 2026 16:13:21 -0700 Subject: [PATCH 4/4] Update Tests --- nodescraper/pluginregistry.py | 2 +- .../test_connection_manager_entrypoints.py | 15 ++++++++++++--- test/unit/framework/test_pluginrecipe.py | 13 ++++++++++++- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/nodescraper/pluginregistry.py b/nodescraper/pluginregistry.py index 8bc6dc18..c5bf1269 100644 --- a/nodescraper/pluginregistry.py +++ b/nodescraper/pluginregistry.py @@ -33,7 +33,7 @@ # Python 3.9 compatibility: EntryPoints type was added in 3.10 try: - from importlib.metadata import EntryPoints + from importlib.metadata import EntryPoints # type: ignore[attr-defined] except ImportError: EntryPoints = Iterable # type: ignore[misc, assignment] diff --git a/test/unit/framework/test_connection_manager_entrypoints.py b/test/unit/framework/test_connection_manager_entrypoints.py index 16721196..a11d5b50 100644 --- a/test/unit/framework/test_connection_manager_entrypoints.py +++ b/test/unit/framework/test_connection_manager_entrypoints.py @@ -8,7 +8,7 @@ # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is +# copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all @@ -25,10 +25,20 @@ ############################################################################### from unittest.mock import MagicMock, patch +import pytest + from nodescraper.connection.inband.inbandmanager import InBandConnectionManager +from nodescraper.pluginrecipe.discovery import PluginDiscovery from nodescraper.pluginregistry import PluginRegistry +@pytest.fixture(autouse=True) +def clear_cache(): + yield + PluginDiscovery.clear_cache() + PluginRegistry().clear_caches() + + def _entry_points_side_effect_cm_only(mock_ep, *args, **kwargs): group = kwargs.get("group") if group == "nodescraper.connection_managers": @@ -40,11 +50,10 @@ def test_load_connection_managers_from_entry_points_registers_class_and_alias(): mock_ep = MagicMock() mock_ep.name = "AliasInBand" mock_ep.load.return_value = InBandConnectionManager - + PluginRegistry.clear_caches() with patch("nodescraper.pluginregistry.importlib.metadata.entry_points") as mock_eps: mock_eps.side_effect = lambda *a, **k: _entry_points_side_effect_cm_only(mock_ep, *a, **k) found = PluginRegistry.load_connection_managers_from_entry_points() - assert found["InBandConnectionManager"] is InBandConnectionManager assert found["AliasInBand"] is InBandConnectionManager diff --git a/test/unit/framework/test_pluginrecipe.py b/test/unit/framework/test_pluginrecipe.py index 396fbaf1..55b65450 100644 --- a/test/unit/framework/test_pluginrecipe.py +++ b/test/unit/framework/test_pluginrecipe.py @@ -10,8 +10,11 @@ from unittest.mock import patch +import pytest + from nodescraper.models import PluginConfig from nodescraper.pluginrecipe.all_plugins import AllPlugins +from nodescraper.pluginrecipe.discovery import PluginDiscovery from nodescraper.pluginrecipe.node_status import NodeStatus from nodescraper.pluginrecipe.pluginrecipe import ( ANALYZE_ONLY, @@ -24,6 +27,13 @@ from nodescraper.pluginregistry import PluginRegistry +@pytest.fixture(autouse=True) +def clear_cache(): + yield + PluginDiscovery.clear_cache() + PluginRegistry().clear_caches() + + class _CollectorOnlyPlugin: COLLECTOR = object() @@ -97,8 +107,9 @@ def test_analyzer_only_recipe_sets_collection_false() -> None: assert config.plugins["DmesgPlugin"] == ANALYZE_ONLY.as_config() -@patch("nodescraper.pluginrecipe.discovery.load_plugin_class") +@patch("nodescraper.pluginrecipe.discovery.PluginDiscovery.load_plugin_class") def test_filter_plugin_names_by_task_type(mock_load_plugin_class) -> None: + mock_load_plugin_class.side_effect = lambda name: { "CollectorPlugin": _CollectorOnlyPlugin, "AnalyzerPlugin": _AnalyzerOnlyPlugin,