Skip to content

Commit aad1b56

Browse files
committed
fix(embedding): improve dimension detection performance and reliability
- Add dimension detection flags to prevent redundant API calls for custom models - Replace fragile magic number logic with explicit boolean flags for OAPI detection - Make environment variable check case-insensitive for better user experience - Fix setModel method to properly reset detection flags for unknown models - Update test imports to fix TypeScript compilation errors Performance improvements: - Dimension detection now happens only once per model/instance - Eliminates unnecessary API calls for known and already-detected models - Concurrent embed calls no longer trigger duplicate dimension detection Reliability improvements: - OAPI dimension detection no longer relies on magic number comparison - Case-insensitive environment variable parsing (true/TRUE/True all work) - Proper state management for model changes via setModel()
1 parent 5932aeb commit aad1b56

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

packages/core/src/embedding/openai-embedding.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { OpenAI } from 'openai';
1+
import OpenAI from 'openai';
22
import { OpenAIEmbedding } from './openai-embedding';
33
import type { EmbeddingVector } from './base-embedding';
44

@@ -12,7 +12,7 @@ jest.mock('openai', () => {
1212
}));
1313
});
1414

15-
const MockOpenAI = OpenAI as jest.Mock;
15+
const MockOpenAI = OpenAI as unknown as jest.Mock;
1616

1717
describe('OpenAIEmbedding OAPI Forwarding', () => {
1818
const originalEnv = process.env;

packages/core/src/embedding/openai-embedding.ts

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,18 @@ export class OpenAIEmbedding extends Embedding {
1212
private client: OpenAI;
1313
private config: OpenAIEmbeddingConfig;
1414
private dimension: number = 1536; // Default dimension for text-embedding-3-small
15+
private dimensionDetected: boolean = false; // Track if dimension has been detected
1516
protected maxTokens: number = 8192; // Maximum tokens for OpenAI embedding models
1617
private isOllamaViaOAPI: boolean = false; // Whether using Ollama model via OAPI
18+
private isOllamaDimensionDetected: boolean = false; // Track if OAPI dimension has been detected
1719

1820
constructor(config: OpenAIEmbeddingConfig) {
1921
super();
2022
this.config = config;
2123

2224
// Check environment variable for Ollama via OAPI
2325
this.isOllamaViaOAPI = config.useOllamaModel ||
24-
process.env.OPENAI_CUSTOM_BASE_USING_OLLAMA_MODEL === 'true' ||
25-
process.env.OPENAI_CUSTOM_BASE_USING_OLLAMA_MODEL === 'True';
26+
(process.env.OPENAI_CUSTOM_BASE_USING_OLLAMA_MODEL || '').toLowerCase() === 'true';
2627

2728
// Auto-correct baseURL if needed
2829
const correctedBaseURL = this.correctBaseURL(config.baseURL);
@@ -36,6 +37,13 @@ export class OpenAIEmbedding extends Embedding {
3637
console.log(`[OpenAI] Configured for Ollama model ${config.model} via OAPI forwarding`);
3738
// Reset dimension since Ollama models have different dimensions
3839
this.dimension = 768; // Common Ollama embedding dimension
40+
} else {
41+
// Set dimension detection flag for known models
42+
const knownModels = OpenAIEmbedding.getSupportedModels();
43+
if (knownModels[config.model]) {
44+
this.dimension = knownModels[config.model].dimension;
45+
this.dimensionDetected = true;
46+
}
3947
}
4048
}
4149

@@ -143,8 +151,10 @@ export class OpenAIEmbedding extends Embedding {
143151
const knownModels = OpenAIEmbedding.getSupportedModels();
144152
if (knownModels[model] && this.dimension !== knownModels[model].dimension) {
145153
this.dimension = knownModels[model].dimension;
146-
} else if (!knownModels[model]) {
154+
this.dimensionDetected = true;
155+
} else if (!knownModels[model] && !this.dimensionDetected) {
147156
this.dimension = await this.detectDimension();
157+
this.dimensionDetected = true;
148158
}
149159

150160
try {
@@ -179,9 +189,10 @@ export class OpenAIEmbedding extends Embedding {
179189
const processedText = this.preprocessText(text);
180190
const model = this.config.model;
181191

182-
// Detect dimension if using default OpenAI dimension
183-
if (this.dimension === 1536) {
192+
// Detect dimension if not already detected for Ollama
193+
if (!this.isOllamaDimensionDetected) {
184194
this.dimension = await this.detectOllamaDimensionViaOAPI('test', model);
195+
this.isOllamaDimensionDetected = true;
185196
}
186197

187198
try {
@@ -218,8 +229,10 @@ export class OpenAIEmbedding extends Embedding {
218229
const knownModels = OpenAIEmbedding.getSupportedModels();
219230
if (knownModels[model] && this.dimension !== knownModels[model].dimension) {
220231
this.dimension = knownModels[model].dimension;
221-
} else if (!knownModels[model]) {
232+
this.dimensionDetected = true;
233+
} else if (!knownModels[model] && !this.dimensionDetected) {
222234
this.dimension = await this.detectDimension();
235+
this.dimensionDetected = true;
223236
}
224237

225238
try {
@@ -255,9 +268,10 @@ export class OpenAIEmbedding extends Embedding {
255268
const processedTexts = this.preprocessTexts(texts);
256269
const model = this.config.model;
257270

258-
// Detect dimension if using default OpenAI dimension
259-
if (this.dimension === 1536) {
271+
// Detect dimension if not already detected for Ollama
272+
if (!this.isOllamaDimensionDetected) {
260273
this.dimension = await this.detectOllamaDimensionViaOAPI('test', model);
274+
this.isOllamaDimensionDetected = true;
261275
}
262276

263277
try {
@@ -299,8 +313,15 @@ export class OpenAIEmbedding extends Embedding {
299313
const knownModels = OpenAIEmbedding.getSupportedModels();
300314
if (knownModels[model]) {
301315
this.dimension = knownModels[model].dimension;
316+
this.dimensionDetected = true;
302317
} else {
318+
// Reset detection flags for unknown models
319+
this.dimensionDetected = false;
320+
if (this.isOllamaViaOAPI) {
321+
this.isOllamaDimensionDetected = false;
322+
}
303323
this.dimension = await this.detectDimension();
324+
this.dimensionDetected = true;
304325
}
305326
}
306327

0 commit comments

Comments
 (0)