Skip to content

Commit bb23f00

Browse files
committed
frontend/jupyter/llm: generalize cell content context for llm-tool and ai-cell-generator
1 parent 4ebd714 commit bb23f00

File tree

4 files changed

+317
-316
lines changed

4 files changed

+317
-316
lines changed

src/packages/frontend/jupyter/insert-cell/ai-cell-generator.tsx

Lines changed: 72 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import {
77
Dropdown,
88
Flex,
99
Input,
10-
InputNumber,
1110
Popover,
1211
Space,
1312
Switch,
@@ -39,6 +38,7 @@ import LLMSelector, {
3938
} from "@cocalc/frontend/frame-editors/llm/llm-selector";
4039
import { labels } from "@cocalc/frontend/i18n";
4140
import { JupyterActions } from "@cocalc/frontend/jupyter/browser-actions";
41+
import { LLMCellContextSelector } from "@cocalc/frontend/jupyter/llm/cell-context-selector";
4242
import { splitCells } from "@cocalc/frontend/jupyter/llm/split-cells";
4343
import { LLMCostEstimation } from "@cocalc/frontend/misc/llm-cost-estimation";
4444
import { useProjectContext } from "@cocalc/frontend/project/context";
@@ -60,12 +60,13 @@ import {
6060
} from "@cocalc/util/misc";
6161
import { COLORS } from "@cocalc/util/theme";
6262
import NBViewer from "../nbviewer/nbviewer";
63-
import { getPreviousNonemptyCellContents } from "../util/cell-content";
63+
import {
64+
CellContextContent,
65+
getNonemptyCellContents,
66+
} from "../util/cell-content";
6467
import { Position } from "./types";
6568
import { insertCell } from "./util";
6669

67-
type PrevCells = "none" | number | "all above";
68-
6970
type Cell = { cell_type: "markdown" | "code"; source: string[] };
7071
type Cells = Cell[];
7172

@@ -127,8 +128,8 @@ export function AIGenerateCodeCell({
127128
const [model, setModel] = useLanguageModelSetting(project_id);
128129
const [prompt, setPrompt] = useState<string>("");
129130
const [cellTypes, setCellTypes] = useState<"code" | "all">("code");
130-
const [includePreviousCells, setIncludePreviousCells] =
131-
useState<PrevCells>(2);
131+
// Context for the new selector component - default to 2 previous cells, 0 after
132+
const [contextRange, setContextRange] = useState<[number, number]>([-2, 0]);
132133
const [error, setError] = useState<string>();
133134
const [preview, setPreview] = useState<Cells | null>(null);
134135
const [attribute, setAttribute] = useState<boolean>(false);
@@ -141,7 +142,7 @@ export function AIGenerateCodeCell({
141142

142143
const open = showAICellGen != null;
143144

144-
const prevCodeContents = getPrevCodeContents();
145+
const contextContent = getContextContents();
145146

146147
const inputPrompt = getInput({
147148
frameActions,
@@ -150,7 +151,8 @@ export function AIGenerateCodeCell({
150151
kernel_name,
151152
position: showAICellGen,
152153
model,
153-
prevCodeContents,
154+
contextContent,
155+
contextRange,
154156
});
155157

156158
const { input } = inputPrompt;
@@ -193,16 +195,22 @@ export function AIGenerateCodeCell({
193195
}
194196
}, [preview, open]);
195197

196-
function getPrevCodeContents(): string {
197-
if (includePreviousCells === 0 || showAICellGen == null) return "";
198-
return getPreviousNonemptyCellContents(
199-
frameActions.current,
198+
function getContextContents(): CellContextContent {
199+
const prevCount = -contextRange[0]; // contextRange[0] is negative, so -(-2) = 2
200+
const nextCount = contextRange[1]; // contextRange[1] is positive for cells after
201+
202+
if (prevCount === 0 && nextCount === 0) return {};
203+
204+
return getNonemptyCellContents({
205+
actions: frameActions.current,
200206
id,
201-
showAICellGen,
202-
includePreviousCells,
207+
direction: "around",
208+
cellCount: "all", // Use "all" for around direction
203209
cellTypes,
204210
lang,
205-
);
211+
aboveCount: prevCount,
212+
belowCount: nextCount,
213+
});
206214
}
207215

208216
function insertCells() {
@@ -273,8 +281,9 @@ export function AIGenerateCodeCell({
273281
}
274282

275283
async function queryLanguageModel({
276-
prevCodeContents,
277-
includePreviousCells,
284+
contextContent,
285+
}: {
286+
contextContent: CellContextContent;
278287
}) {
279288
if (!prompt.trim()) return;
280289

@@ -285,7 +294,8 @@ export function AIGenerateCodeCell({
285294
model,
286295
position: showAICellGen,
287296
prompt,
288-
prevCodeContents,
297+
contextContent,
298+
contextRange,
289299
});
290300

291301
if (!input) {
@@ -300,7 +310,7 @@ export function AIGenerateCodeCell({
300310
tag,
301311
type: "generate",
302312
model,
303-
prev: includePreviousCells,
313+
contextRange,
304314
});
305315

306316
const stream = await webapp_client.openai_client.queryStream({
@@ -364,16 +374,15 @@ export function AIGenerateCodeCell({
364374
}
365375
}
366376

367-
function doQuery(prevCodeContents: string) {
377+
function doQuery(contextContent: CellContextContent) {
368378
cancel.current = false;
369379
setError("");
370380
setQuerying(true);
371381

372382
if (showAICellGen == null) return;
373383

374384
queryLanguageModel({
375-
prevCodeContents,
376-
includePreviousCells,
385+
contextContent,
377386
});
378387

379388
// we also log this
@@ -427,74 +436,20 @@ export function AIGenerateCodeCell({
427436
}
428437

429438
function renderContext() {
430-
const cellStr = `${cellTypes === "code" ? "code " : ""} cell`;
431439
return (
432440
<>
433441
<Divider orientation="left">
434442
<Text>Context</Text>
435443
</Divider>
436-
<Paragraph>
437-
<Flex dir="horizontal" gap="10px" align="center" justify="center">
438-
<Flex flex={1}>
439-
<div>
440-
Include{" "}
441-
{typeof includePreviousCells === "number" ? (
442-
<>
443-
previous{" "}
444-
<InputNumber
445-
min={0}
446-
max={10}
447-
size={"small"}
448-
value={includePreviousCells}
449-
onChange={(value) => setIncludePreviousCells(value ?? 1)}
450-
/>{" "}
451-
{plural(
452-
includePreviousCells,
453-
`${cellStr}.`,
454-
`${cellStr}s.`,
455-
)}
456-
</>
457-
) : includePreviousCells === "all above" ? (
458-
`all previous ${cellStr}s`
459-
) : (
460-
`no ${cellStr}s`
461-
)}
462-
</div>
463-
</Flex>
464-
<Flex flex={0}>
465-
{["none", 1, 2, 3, 5, 10, "all above"].map((i: PrevCells) => {
466-
const c = getRandomColor(`${i}`);
467-
return (
468-
<Tag
469-
key={i}
470-
color={c}
471-
style={{ cursor: "pointer" }}
472-
onClick={() => setIncludePreviousCells(i)}
473-
>
474-
{i}
475-
</Tag>
476-
);
477-
})}
478-
</Flex>
479-
</Flex>
480-
</Paragraph>
481-
<Paragraph>
482-
<Flex align="center" gap="10px">
483-
<Flex flex={0}>
484-
<Switch
485-
defaultChecked={cellTypes === "all"}
486-
onChange={(val) => setCellTypes(val ? "all" : "code")}
487-
unCheckedChildren={"Code cells"}
488-
checkedChildren={"All Cells"}
489-
/>
490-
</Flex>
491-
<Flex flex={1}>
492-
<Text type="secondary">
493-
Include only code cells, or all types of cells.
494-
</Text>
495-
</Flex>
496-
</Flex>
497-
</Paragraph>
444+
<LLMCellContextSelector
445+
contextRange={contextRange}
446+
onContextRangeChange={setContextRange}
447+
cellTypes={cellTypes}
448+
onCellTypesChange={setCellTypes}
449+
currentCellId={id}
450+
frameActions={frameActions.current}
451+
mode="insert-position"
452+
/>
498453
</>
499454
);
500455
}
@@ -643,7 +598,7 @@ export function AIGenerateCodeCell({
643598
if (!e.shiftKey) return;
644599
e.preventDefault(); // prevent the default action
645600
e.stopPropagation(); // stop event propagation
646-
doQuery(prevCodeContents);
601+
doQuery(contextContent);
647602
}}
648603
autoSize={{ minRows: 2, maxRows: 6 }}
649604
/>
@@ -657,7 +612,7 @@ export function AIGenerateCodeCell({
657612
<LLMQueryDropdownButton
658613
disabled={!prompt.trim()}
659614
loading={querying}
660-
onClick={() => doQuery(prevCodeContents)}
615+
onClick={() => doQuery(contextContent)}
661616
llmTools={llmTools}
662617
task="Generate using"
663618
/>
@@ -724,9 +679,10 @@ interface GetInputProps {
724679
model: LanguageModel;
725680
position: Position;
726681
prompt: string;
727-
prevCodeContents: string;
682+
contextContent: CellContextContent;
728683
lang: string;
729684
kernel_name: string;
685+
contextRange: [number, number];
730686
}
731687

732688
function getInputPrompt(prompt: string): string {
@@ -736,9 +692,10 @@ function getInputPrompt(prompt: string): string {
736692
function getInput({
737693
frameActions,
738694
prompt,
739-
prevCodeContents,
695+
contextContent,
740696
lang,
741697
kernel_name,
698+
contextRange,
742699
}: GetInputProps): {
743700
input: string;
744701
system: string;
@@ -753,9 +710,30 @@ function getInput({
753710
);
754711
return { input: "", system: "", history: [] };
755712
}
756-
const prevCode = prevCodeContents
757-
? `The context after which to insert the cells is:\n\n<context>\n${prevCodeContents}\n\</context>\n\n`
758-
: "";
713+
714+
const prevCount = -contextRange[0]; // cells before insertion point
715+
const afterCount = contextRange[1]; // cells after insertion point
716+
717+
let contextInfo = "";
718+
719+
if (contextContent.before || contextContent.after) {
720+
const beforeCells =
721+
prevCount > 0 ? `${prevCount} cells before` : "no cells before";
722+
const afterCells =
723+
afterCount > 0 ? `${afterCount} cells after` : "no cells after";
724+
contextInfo = `Context: The new cell will be inserted with ${beforeCells} and ${afterCells} the insertion point.\n\n`;
725+
726+
if (contextContent.before) {
727+
contextInfo += `Cells BEFORE insertion point:\n<before>\n${contextContent.before}\n</before>\n\n`;
728+
}
729+
730+
if (contextContent.after) {
731+
contextInfo += `Cells AFTER insertion point:\n<after>\n${contextContent.after}\n</after>\n\n`;
732+
}
733+
} else {
734+
contextInfo =
735+
"Context: The new cell will be inserted at the beginning or end of the notebook.\n\n";
736+
}
759737

760738
const history: Message[] = [
761739
{ role: "user", content: getInputPrompt("Show the value of foo.") },
@@ -766,8 +744,8 @@ function getInput({
766744
];
767745

768746
return {
769-
input: `${prevCode}${getInputPrompt(prompt)}`,
747+
input: `${contextInfo}${getInputPrompt(prompt)}`,
770748
history,
771-
system: `Create one or more code cells in a Jupyter Notebook.\n\nKernel: "${kernel_name}".\n\nProgramming language: "${lang}".\n\nEach code cell must be wrapped in triple backticks. Do not say what the output will be. Be brief.`,
749+
system: `Create one or more code cells in a Jupyter Notebook.\n\nKernel: "${kernel_name}".\n\nProgramming language: "${lang}".\n\nThe new cell(s) will be inserted at a specific position in the notebook. Pay attention to the context provided - cells marked as BEFORE come before the insertion point, cells marked as AFTER come after the insertion point.\n\nEach code cell must be wrapped in triple backticks. Do not say what the output will be. Be brief.`,
772750
};
773751
}

0 commit comments

Comments
 (0)