Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion llm/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,49 @@
_loaded = False


def _is_truthy(value):
if value is None:
return False
return value.lower() not in ("", "0", "false", "no", "off")


def _load_entrypoint_plugins(plugin_manager, entry_points=None):
strict_plugin_loading = _is_truthy(os.environ.get("LLM_STRICT_PLUGIN_LOADING"))
if entry_points is None:
entry_points = metadata.entry_points(group="llm")
for entry_point in entry_points:
if plugin_manager.get_plugin(entry_point.name) is not None:
continue
if plugin_manager.is_blocked(entry_point.name):
continue
try:
plugin = entry_point.load()
try:
plugin_manager.register(plugin, name=entry_point.name)
except Exception:
# Clean up if our plugin was partially registered
if plugin_manager.get_plugin(entry_point.name) is plugin:
plugin_manager.unregister(name=entry_point.name)
raise
dist = getattr(entry_point, "dist", None)
if dist is not None:
plugin_manager._plugin_distinfo.append((plugin, dist)) # type: ignore
except Exception as ex:
if strict_plugin_loading:
raise
sys.stderr.write(
"Plugin {} failed to load: {}\n".format(entry_point.name, ex)
)


def load_plugins():
global _loaded
if _loaded:
return
_loaded = True
if not hasattr(sys, "_called_from_test") and LLM_LOAD_PLUGINS is None:
# Only load plugins if not running tests
pm.load_setuptools_entrypoints("llm")
_load_entrypoint_plugins(pm)

# Load any plugins specified in LLM_LOAD_PLUGINS")
if LLM_LOAD_PLUGINS is not None:
Expand Down
110 changes: 110 additions & 0 deletions tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import importlib
import json
import llm
import pluggy
from llm.tools import llm_version, llm_time
from llm import cli, hookimpl, plugins, get_template_loaders, get_fragment_loaders
import pathlib
Expand Down Expand Up @@ -45,6 +46,115 @@ def hello_world():
assert "HelloWorldPlugin" not in plugin_names()


class FakeEntryPoint:
def __init__(self, name, plugin=None, error=None, dist=None):
self.name = name
self._plugin = plugin
self._error = error
self.dist = dist

def load(self):
if self._error:
raise self._error
return self._plugin


def test_load_entrypoint_plugins_continues_on_failure(monkeypatch, capsys):
plugin_manager = pluggy.PluginManager("llm")
good_plugin = object()
entry_points = [
FakeEntryPoint("broken", error=RuntimeError("boom")),
FakeEntryPoint("good", plugin=good_plugin),
]
monkeypatch.delenv("LLM_STRICT_PLUGIN_LOADING", raising=False)
plugins._load_entrypoint_plugins(plugin_manager, entry_points=entry_points)
captured = capsys.readouterr()
assert "Plugin broken failed to load: boom" in captured.err
assert plugin_manager.get_plugin("broken") is None
assert plugin_manager.get_plugin("good") is good_plugin


def test_load_entrypoint_plugins_cleans_up_partial_registration(monkeypatch, capsys):
"""register() that raises after adding plugin to _name2plugin gets rolled back."""
plugin_manager = pluggy.PluginManager("llm")
good_plugin = object()
partial_plugin = object()

original_register = plugin_manager.register

def failing_register(plugin, name=None):
# Simulate pluggy's non-transactional register: the plugin gets added
# to _name2plugin, then an error occurs during hook processing.
original_register(plugin, name=name)
if name == "partial":
raise RuntimeError("mid-registration failure")

monkeypatch.setattr(plugin_manager, "register", failing_register)

entry_points = [
FakeEntryPoint("partial", plugin=partial_plugin),
FakeEntryPoint("good", plugin=good_plugin),
]
monkeypatch.delenv("LLM_STRICT_PLUGIN_LOADING", raising=False)
plugins._load_entrypoint_plugins(plugin_manager, entry_points=entry_points)
captured = capsys.readouterr()
assert "Plugin partial failed to load" in captured.err
assert plugin_manager.get_plugin("partial") is None
assert plugin_manager.get_plugin("good") is good_plugin


def test_load_entrypoint_plugins_skips_already_registered(monkeypatch):
"""Entry points whose name is already registered are silently skipped."""
plugin_manager = pluggy.PluginManager("llm")
existing_plugin = object()
plugin_manager.register(existing_plugin, name="existing")

loaded = []
original_load = FakeEntryPoint.load

class TrackingEntryPoint(FakeEntryPoint):
def load(self):
loaded.append(self.name)
return original_load(self)

entry_points = [
TrackingEntryPoint("existing", plugin=object()),
]
monkeypatch.delenv("LLM_STRICT_PLUGIN_LOADING", raising=False)
plugins._load_entrypoint_plugins(plugin_manager, entry_points=entry_points)
assert loaded == []
assert plugin_manager.get_plugin("existing") is existing_plugin


def test_load_entrypoint_plugins_skips_blocked(monkeypatch):
"""Entry points whose name is blocked are not imported."""
plugin_manager = pluggy.PluginManager("llm")
plugin_manager.set_blocked("blocked")

loaded = []

class TrackingEntryPoint(FakeEntryPoint):
def load(self):
loaded.append(self.name)
return super().load()

entry_points = [
TrackingEntryPoint("blocked", plugin=object()),
]
monkeypatch.delenv("LLM_STRICT_PLUGIN_LOADING", raising=False)
plugins._load_entrypoint_plugins(plugin_manager, entry_points=entry_points)
assert loaded == []
assert plugin_manager.get_plugin("blocked") is None


def test_load_entrypoint_plugins_strict_mode(monkeypatch):
plugin_manager = pluggy.PluginManager("llm")
entry_points = [FakeEntryPoint("broken", error=RuntimeError("boom"))]
monkeypatch.setenv("LLM_STRICT_PLUGIN_LOADING", "1")
with pytest.raises(RuntimeError, match="boom"):
plugins._load_entrypoint_plugins(plugin_manager, entry_points=entry_points)


def test_register_template_loaders():
assert get_template_loaders() == {}

Expand Down