diff --git a/packages/components/src/modelLoader.test.ts b/packages/components/src/modelLoader.test.ts new file mode 100644 index 00000000000..ed54d15a20a --- /dev/null +++ b/packages/components/src/modelLoader.test.ts @@ -0,0 +1,60 @@ +import axios from 'axios' +import { getModelConfigByModelName, MODEL_TYPE } from './modelLoader' + +jest.mock('axios') + +const mockedAxios = axios as jest.Mocked + +describe('modelLoader', () => { + const originalModelListConfigJson = process.env.MODEL_LIST_CONFIG_JSON + + afterEach(() => { + jest.resetAllMocks() + if (originalModelListConfigJson === undefined) { + delete process.env.MODEL_LIST_CONFIG_JSON + } else { + process.env.MODEL_LIST_CONFIG_JSON = originalModelListConfigJson + } + delete process.env.MODEL_LIST_FETCH_TIMEOUT_MS + }) + + it('uses a bounded timeout when loading remote model config before falling back locally', async () => { + process.env.MODEL_LIST_CONFIG_JSON = 'https://example.com/models.json' + mockedAxios.get.mockRejectedValueOnce(new Error('timeout')) + + const modelConfig = await getModelConfigByModelName(MODEL_TYPE.CHAT, 'awsChatBedrock', 'ai21.jamba-1-5-large-v1:0') + + expect(mockedAxios.get).toHaveBeenCalledWith('https://example.com/models.json', { timeout: 5000 }) + expect(modelConfig?.name).toBe('ai21.jamba-1-5-large-v1:0') + }) + + it('allows configuring the remote model list timeout', async () => { + process.env.MODEL_LIST_CONFIG_JSON = 'https://example.com/custom-timeout-models.json' + process.env.MODEL_LIST_FETCH_TIMEOUT_MS = '1500' + mockedAxios.get.mockRejectedValueOnce(new Error('timeout')) + + await getModelConfigByModelName(MODEL_TYPE.CHAT, 'awsChatBedrock', 'ai21.jamba-1-5-large-v1:0') + + expect(mockedAxios.get).toHaveBeenCalledWith('https://example.com/custom-timeout-models.json', { timeout: 1500 }) + }) + + it('caches model config after the first load', async () => { + process.env.MODEL_LIST_CONFIG_JSON = 'https://example.com/cached-models.json' + mockedAxios.get.mockResolvedValueOnce({ + status: 200, + data: { + [MODEL_TYPE.CHAT]: [ + { + name: 'Test Provider', + models: [{ name: 'test-model' }] + } + ] + } + }) + + await getModelConfigByModelName(MODEL_TYPE.CHAT, 'Test Provider', 'test-model') + await getModelConfigByModelName(MODEL_TYPE.CHAT, 'Test Provider', 'test-model') + + expect(mockedAxios.get).toHaveBeenCalledTimes(1) + }) +}) diff --git a/packages/components/src/modelLoader.ts b/packages/components/src/modelLoader.ts index dc728634b74..b4144078c0f 100644 --- a/packages/components/src/modelLoader.ts +++ b/packages/components/src/modelLoader.ts @@ -3,6 +3,9 @@ import * as fs from 'fs' import * as path from 'path' import { INodeOptionsValue } from './Interface' +const DEFAULT_MODEL_LIST_FETCH_TIMEOUT_MS = 5000 +let rawModelFileCache: { modelFile: string; data: any } | undefined + export enum MODEL_TYPE { CHAT = 'chat', LLM = 'llm', @@ -29,6 +32,11 @@ const isValidUrl = (urlString: string) => { return url.protocol === 'http:' || url.protocol === 'https:' } +const getModelListFetchTimeoutMs = () => { + const timeout = Number(process.env.MODEL_LIST_FETCH_TIMEOUT_MS) + return Number.isFinite(timeout) && timeout > 0 ? timeout : DEFAULT_MODEL_LIST_FETCH_TIMEOUT_MS +} + /** * Load the raw model file from either a URL or a local file * If any of the loading fails, fallback to the default models.json file on disk @@ -36,28 +44,38 @@ const isValidUrl = (urlString: string) => { const getRawModelFile = async () => { const modelFile = process.env.MODEL_LIST_CONFIG_JSON ?? 'https://raw.githubusercontent.com/FlowiseAI/Flowise/main/packages/components/models.json' + if (rawModelFileCache?.modelFile === modelFile) { + return rawModelFileCache.data + } + + let rawModelFile try { if (isValidUrl(modelFile)) { - const resp = await axios.get(modelFile) + const resp = await axios.get(modelFile, { timeout: getModelListFetchTimeoutMs() }) if (resp.status === 200 && resp.data) { - return resp.data + rawModelFile = resp.data } else { throw new Error('Error fetching model list') } } else if (fs.existsSync(modelFile)) { const models = await fs.promises.readFile(modelFile, 'utf8') if (models) { - return JSON.parse(models) + rawModelFile = JSON.parse(models) } } - throw new Error('Model file does not exist or is empty') + if (!rawModelFile) { + throw new Error('Model file does not exist or is empty') + } } catch (e) { const models = await fs.promises.readFile(getModelsJSONPath(), 'utf8') if (models) { - return JSON.parse(models) + rawModelFile = JSON.parse(models) + } else { + rawModelFile = {} } - return {} } + rawModelFileCache = { modelFile, data: rawModelFile } + return rawModelFile } const getModelConfig = async (category: MODEL_TYPE, name: string) => {