Skip to content

Commit b021cb1

Browse files
committed
stamp: creating result types and common usage interface
- Improving typescript with conditional output types based on the selected provider - Defining common properties for LLM providers like `usage` metrics and simplified `value`
1 parent 8f64021 commit b021cb1

File tree

4 files changed

+88
-29
lines changed

4 files changed

+88
-29
lines changed

ext/ai/js/ai.ts

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,23 @@ export type EmbeddingInputOptions = {
3434
normalize?: boolean;
3535
};
3636

37-
export type SessionInputOptions<T extends SessionType> = T extends
38-
LLMProviderName ? LLMInputOptions
39-
: EmbeddingInputOptions;
37+
export type SessionInputOptions<T extends SessionType> = T extends "gte-small"
38+
? EmbeddingInputOptions
39+
: T extends LLMProviderName ? LLMInputOptions
40+
: never;
41+
42+
export type SessionOutput<T extends SessionType, O> = T extends "gte-small"
43+
? number[]
44+
: T extends LLMProviderName
45+
? O extends { stream: true }
46+
? AsyncGenerator<LLMProviderInstance<T>["output"]>
47+
: LLMProviderInstance<T>["output"]
48+
: never;
4049

4150
export class Session<T extends SessionType> {
4251
#model?: string;
4352
#init?: Promise<void>;
4453

45-
// TODO:(kallebysantos) get 'provider' type here and use type checking to suggest Inputs when run
4654
constructor(
4755
public readonly type: T,
4856
public readonly options?: SessionOptions<T>,
@@ -77,7 +85,10 @@ export class Session<T extends SessionType> {
7785
}
7886

7987
// /** @param {string | object} prompt Either a String (ollama) or an OpenAI chat completion body object (openaicompatible): https://platform.openai.com/docs/api-reference/chat/create */
80-
async run(input: SessionInput<T>, options: SessionInputOptions<T>) {
88+
async run<O extends SessionInputOptions<T>>(
89+
input: SessionInput<T>,
90+
options: O,
91+
): Promise<SessionOutput<T, O>> {
8192
if (this.isLLMType()) {
8293
const opts = options as LLMInputOptions;
8394
const stream = opts.stream ?? false;
@@ -93,7 +104,7 @@ export class Session<T extends SessionType> {
93104
stream,
94105
signal: opts.signal,
95106
timeout: opts.timeout,
96-
});
107+
}) as SessionOutput<T, typeof options>;
97108
}
98109

99110
if (this.#init) {

ext/ai/js/llm/llm_session.ts

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,27 @@ export interface ILLMProviderOptions {
3434
export type ILLMProviderInput<T = string | object> = T extends string ? string
3535
: T;
3636

37+
export interface ILLMProviderOutput<T = object> {
38+
value?: string;
39+
usage: {
40+
inputTokens: number;
41+
outputTokens: number;
42+
totalTokens: number;
43+
};
44+
inner: T;
45+
}
46+
3747
export interface ILLMProvider {
3848
// TODO:(kallebysantos) remove 'any'
3949
// TODO: (kallebysantos) standardised output format
4050
getStream(
4151
input: ILLMProviderInput,
4252
signal: AbortSignal,
43-
): Promise<AsyncIterable<any>>;
44-
getText(input: ILLMProviderInput, signal: AbortSignal): Promise<any>;
53+
): Promise<AsyncIterable<ILLMProviderOutput>>;
54+
getText(
55+
input: ILLMProviderInput,
56+
signal: AbortSignal,
57+
): Promise<ILLMProviderOutput>;
4558
}
4659

4760
export const providers = {
@@ -92,7 +105,7 @@ export class LLMSession {
92105
run(
93106
input: ILLMProviderInput,
94107
opts: LLMSessionRunInputOptions,
95-
): Promise<AsyncIterable<any>> | Promise<any> {
108+
): Promise<AsyncIterable<ILLMProviderOutput>> | Promise<ILLMProviderOutput> {
96109
const isStream = opts.stream ?? false;
97110

98111
const timeoutSeconds = typeof opts.timeout === "number" ? opts.timeout : 60;

ext/ai/js/llm/providers/ollama.ts

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ import {
33
ILLMProviderInput,
44
ILLMProviderMeta,
55
ILLMProviderOptions,
6-
} from '../llm_session.ts';
7-
import { parseJSON } from '../utils/json_parser.ts';
6+
ILLMProviderOutput,
7+
} from "../llm_session.ts";
8+
import { parseJSON } from "../utils/json_parser.ts";
89

910
export type OllamaProviderOptions = ILLMProviderOptions;
1011
export type OllamaProviderInput = ILLMProviderInput<string>;
12+
export type OllamaProviderOutput = ILLMProviderOutput<OllamaMessage>;
1113

1214
export type OllamaMessage = {
1315
model: string;
@@ -25,7 +27,7 @@ export type OllamaMessage = {
2527

2628
export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta {
2729
input!: OllamaProviderInput;
28-
output!: unknown;
30+
output!: OllamaProviderOutput;
2931
options: OllamaProviderOptions;
3032

3133
constructor(opts: OllamaProviderOptions) {
@@ -36,31 +38,34 @@ export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta {
3638
async getStream(
3739
prompt: OllamaProviderInput,
3840
signal: AbortSignal,
39-
): Promise<AsyncIterable<OllamaMessage>> {
41+
): Promise<AsyncIterable<OllamaProviderOutput>> {
4042
const generator = await this.generate(
4143
prompt,
4244
signal,
4345
true,
4446
) as AsyncGenerator<OllamaMessage>;
4547

48+
const parser = this.parse;
49+
4650
const stream = async function* () {
4751
for await (const message of generator) {
48-
if ('error' in message) {
52+
if ("error" in message) {
4953
if (message.error instanceof Error) {
5054
throw message.error;
5155
} else {
5256
throw new Error(message.error as string);
5357
}
5458
}
5559

56-
yield message;
60+
yield parser(message);
61+
5762
if (message.done) {
5863
return;
5964
}
6065
}
6166

6267
throw new Error(
63-
'Did not receive done or success response in stream.',
68+
"Did not receive done or success response in stream.",
6469
);
6570
};
6671

@@ -70,14 +75,28 @@ export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta {
7075
async getText(
7176
prompt: OllamaProviderInput,
7277
signal: AbortSignal,
73-
): Promise<OllamaMessage> {
78+
): Promise<OllamaProviderOutput> {
7479
const response = await this.generate(prompt, signal) as OllamaMessage;
7580

7681
if (!response?.done) {
77-
throw new Error('Expected a completed response.');
82+
throw new Error("Expected a completed response.");
7883
}
7984

80-
return response;
85+
return this.parse(response);
86+
}
87+
88+
private parse(message: OllamaMessage): OllamaProviderOutput {
89+
const { response, prompt_eval_count, eval_count } = message;
90+
91+
return {
92+
value: response,
93+
inner: message,
94+
usage: {
95+
inputTokens: prompt_eval_count,
96+
outputTokens: eval_count,
97+
totalTokens: prompt_eval_count + eval_count,
98+
},
99+
};
81100
}
82101

83102
private async generate(
@@ -86,11 +105,11 @@ export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta {
86105
stream: boolean = false,
87106
) {
88107
const res = await fetch(
89-
new URL('/api/generate', this.options.baseURL),
108+
new URL("/api/generate", this.options.baseURL),
90109
{
91-
method: 'POST',
110+
method: "POST",
92111
headers: {
93-
'Content-Type': 'application/json',
112+
"Content-Type": "application/json",
94113
},
95114
body: JSON.stringify({
96115
model: this.options.model,
@@ -108,7 +127,7 @@ export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta {
108127
}
109128

110129
if (!res.body) {
111-
throw new Error('Missing body');
130+
throw new Error("Missing body");
112131
}
113132

114133
if (stream) {

ext/ai/js/llm/providers/openai.ts

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import {
33
ILLMProviderInput,
44
ILLMProviderMeta,
55
ILLMProviderOptions,
6+
ILLMProviderOutput,
67
} from "../llm_session.ts";
78
import { parseJSONOverEventStream } from "../utils/json_parser.ts";
89

@@ -97,11 +98,11 @@ export type OpenAIResponse = {
9798
export type OpenAICompatibleInput = Omit<OpenAIRequest, "stream" | "model">;
9899

99100
export type OpenAIProviderInput = ILLMProviderInput<OpenAICompatibleInput>;
101+
export type OpenAIProviderOutput = ILLMProviderOutput<OpenAIResponse>;
100102

101103
export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta {
102104
input!: OpenAIProviderInput;
103-
// TODO:(kallebysantos) add output types
104-
output: unknown;
105+
output!: OpenAIProviderOutput;
105106
options: OpenAIProviderOptions;
106107

107108
constructor(opts: OpenAIProviderOptions) {
@@ -111,13 +112,14 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta {
111112
async getStream(
112113
prompt: OpenAIProviderInput,
113114
signal: AbortSignal,
114-
): Promise<AsyncIterable<OpenAIResponse>> {
115+
): Promise<AsyncIterable<OpenAIProviderOutput>> {
115116
const generator = await this.generate(
116117
prompt,
117118
signal,
118119
true,
119120
) as AsyncGenerator<any>; // TODO:(kallebysantos) remove any
120121

122+
const parser = this.parse;
121123
const stream = async function* () {
122124
for await (const message of generator) {
123125
// TODO:(kallebysantos) Simplify duplicated code for stream error checking
@@ -129,7 +131,7 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta {
129131
}
130132
}
131133

132-
yield message;
134+
yield parser(message);
133135
const finishReason = message.choices[0].finish_reason;
134136

135137
if (finishReason) {
@@ -152,7 +154,7 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta {
152154
async getText(
153155
prompt: OpenAIProviderInput,
154156
signal: AbortSignal,
155-
): Promise<OpenAIResponse> {
157+
): Promise<OpenAIProviderOutput> {
156158
const response = await this.generate(
157159
prompt,
158160
signal,
@@ -164,9 +166,23 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta {
164166
throw new Error("Expected a completed response.");
165167
}
166168

167-
return response;
169+
return this.parse(response);
168170
}
169171

172+
private parse(message: OpenAIResponse): OpenAIProviderOutput {
173+
const { usage } = message;
174+
175+
return {
176+
value: message.choices.at(0)?.message.content ?? undefined,
177+
inner: message,
178+
usage: {
179+
// Usage maybe 'null' while streaming, but the final message will include it
180+
inputTokens: usage?.prompt_tokens ?? 0,
181+
outputTokens: usage?.completion_tokens ?? 0,
182+
totalTokens: usage?.total_tokens ?? 0,
183+
},
184+
};
185+
}
170186
private async generate(
171187
input: OpenAICompatibleInput,
172188
signal: AbortSignal,

0 commit comments

Comments
 (0)