Skip to content

Commit 260df17

Browse files
authored
Merge pull request #1167 from transformerlab/fix/run-button-compatible-message
Suggest a compatible loader plugin for when one is not installed and a model is selected
2 parents 11d0af3 + 66cfc79 commit 260df17

File tree

3 files changed

+128
-10
lines changed

3 files changed

+128
-10
lines changed

api/transformerlab/routers/plugins.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,84 @@ async def plugin_gallery():
8080
return gallery
8181

8282

83+
@router.get("/suggest_loader", summary="Suggest a compatible loader plugin for a model architecture.")
84+
async def suggest_loader_plugin(model_architecture: str):
85+
"""
86+
Suggest a compatible loader plugin based on model architecture and platform.
87+
Returns the best matching loader plugin that:
88+
1. Supports the model architecture
89+
2. Is compatible with the current platform/hardware
90+
3. Is not already installed
91+
"""
92+
# Import here to avoid circular dependency
93+
import transformerlab.routers.serverinfo as serverinfo_module
94+
95+
device_type = serverinfo_module.system_info.get("device_type", "cpu")
96+
97+
# Map device_type to supported_hardware_architectures
98+
# device_type: nvidia -> cuda, apple_silicon -> mlx, amd -> amd, cpu -> cpu
99+
hardware_arch_map = {
100+
"nvidia": "cuda",
101+
"apple_silicon": "mlx",
102+
"amd": "amd",
103+
"cpu": "cpu",
104+
}
105+
required_hardware = hardware_arch_map.get(device_type, "cpu")
106+
107+
# Get all plugins from gallery
108+
gallery = await plugin_gallery()
109+
110+
# Filter for loader plugins that:
111+
# 1. Are of type "loader"
112+
# 2. Are not installed
113+
# 3. Support the model architecture
114+
# 4. Support the current hardware architecture
115+
compatible_plugins = []
116+
117+
for plugin in gallery:
118+
# Must be a loader plugin
119+
if plugin.get("type") != "loader":
120+
continue
121+
122+
# Must not be installed
123+
if plugin.get("installed", False):
124+
continue
125+
126+
# Must support the model architecture
127+
model_architectures = plugin.get("model_architectures", [])
128+
if not isinstance(model_architectures, list):
129+
continue
130+
131+
architecture_match = False
132+
for arch in model_architectures:
133+
if arch and arch.lower() == model_architecture.lower():
134+
architecture_match = True
135+
break
136+
137+
if not architecture_match:
138+
continue
139+
140+
# Must support the current hardware architecture
141+
supported_hardware = plugin.get("supported_hardware_architectures", [])
142+
if not isinstance(supported_hardware, list):
143+
continue
144+
145+
hardware_match = required_hardware in supported_hardware
146+
147+
if hardware_match:
148+
compatible_plugins.append(plugin)
149+
150+
# If no compatible plugins found, return None
151+
if not compatible_plugins:
152+
return None
153+
154+
# Sort alphabetically by name and return the first one
155+
compatible_plugins.sort(key=lambda p: p.get("name", ""))
156+
157+
# Return the first match
158+
return compatible_plugins[0]
159+
160+
83161
async def copy_plugin_files_to_workspace(plugin_id: str):
84162
plugin_id = secure_filename(plugin_id)
85163

src/renderer/components/Experiment/Foundation/RunModelButton.tsx

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ import {
2626
import InferenceEngineModal from './InferenceEngineModal';
2727
import * as chatAPI from 'renderer/lib/transformerlab-api-sdk';
2828
import OneTimePopup from 'renderer/components/Shared/OneTimePopup';
29-
import { useAPI } from 'renderer/lib/transformerlab-api-sdk';
29+
import { useAPI, fetcher } from 'renderer/lib/transformerlab-api-sdk';
30+
import { useSWRWithAuth as useSWR } from 'renderer/lib/authContext';
3031
import React from 'react';
3132

3233
import { Link } from 'react-router-dom';
@@ -83,6 +84,13 @@ export default function RunModelButton({
8384

8485
const archTag = experimentInfo?.config?.foundation_model_architecture ?? '';
8586

87+
// Fetch suggested compatible loader plugin from API (platform-aware)
88+
const { data: suggestedLoaderPlugin, isLoading: suggestedPluginLoading } =
89+
useSWR(
90+
archTag ? chatAPI.Endpoints.Plugins.SuggestLoader(archTag) : null,
91+
fetcher,
92+
);
93+
8694
const supportedEngines = React.useMemo(() => {
8795
if (!data || pipelineTagLoading) return [];
8896

@@ -472,12 +480,30 @@ export default function RunModelButton({
472480
<Alert startDecorator={<TriangleAlertIcon />} color="warning">
473481
<Typography level="body-sm">
474482
None of the installed Engines currently support this model
475-
architecture. You can try a different engine in{' '}
476-
<Link to="/plugins">
477-
<Plug2Icon size="15px" />
478-
Plugins
479-
</Link>{' '}
480-
, or you can try running it with an unsupported Engine by clicking{' '}
483+
architecture.
484+
{suggestedLoaderPlugin ? (
485+
<>
486+
{' '}
487+
<b>{suggestedLoaderPlugin.name}</b> is compatible with this
488+
model architecture. Install it in{' '}
489+
<Link to="/plugins">
490+
<Plug2Icon size="15px" />
491+
Plugins
492+
</Link>
493+
.
494+
</>
495+
) : (
496+
<>
497+
{' '}
498+
You can try a different engine in{' '}
499+
<Link to="/plugins">
500+
<Plug2Icon size="15px" />
501+
Plugins
502+
</Link>
503+
.
504+
</>
505+
)}{' '}
506+
Or you can try running it with an unsupported Engine by clicking{' '}
481507
<b>using Engine</b> below and check{' '}
482508
<b>Show unsupported engines</b>.
483509
</Typography>
@@ -495,9 +521,21 @@ export default function RunModelButton({
495521
<Plug2Icon size="15px" />
496522
Plugins
497523
</Link>{' '}
498-
and install an Inference Engine. <b>FastChat Server</b> is a good
499-
default for systems with a GPU. <b>Apple MLX Server</b> is the best
500-
default for MacOS with Apple Silicon.
524+
and install an Inference Engine.
525+
{suggestedLoaderPlugin ? (
526+
<>
527+
{' '}
528+
<b>{suggestedLoaderPlugin.name}</b> is compatible with this
529+
model architecture.
530+
</>
531+
) : (
532+
<>
533+
{' '}
534+
<b>FastChat Server</b> is a good default for systems with a GPU.{' '}
535+
<b>Apple MLX Server</b> is the best default for MacOS with Apple
536+
Silicon.
537+
</>
538+
)}
501539
</Typography>
502540
</Alert>
503541
)}

src/renderer/lib/api-client/endpoints.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,8 @@ Endpoints.Plugins = {
275275
List: () => `${API_URL()}plugins/list`,
276276
RunPluginInstallScript: (pluginId: string) =>
277277
`${API_URL()}plugins/${pluginId}/run_installer_script`,
278+
SuggestLoader: (modelArchitecture: string) =>
279+
`${API_URL()}plugins/suggest_loader?model_architecture=${encodeURIComponent(modelArchitecture)}`,
278280
};
279281

280282
// Following is no longer needed as it is replaced with useAPI

0 commit comments

Comments
 (0)