diff --git a/.changeset/custom-shiki-themes.md b/.changeset/custom-shiki-themes.md new file mode 100644 index 00000000..971a4624 --- /dev/null +++ b/.changeset/custom-shiki-themes.md @@ -0,0 +1,6 @@ +--- +"streamdown": minor +"streamdown-code": minor +--- + +Support custom Shiki themes in code blocks diff --git a/packages/streamdown-code/__tests__/index.test.ts b/packages/streamdown-code/__tests__/index.test.ts index b786d5c8..e5dff7dc 100644 --- a/packages/streamdown-code/__tests__/index.test.ts +++ b/packages/streamdown-code/__tests__/index.test.ts @@ -1,4 +1,4 @@ -import type { BundledLanguage } from "shiki"; +import type { BundledLanguage, ThemeRegistrationAny } from "shiki"; import { describe, expect, it, vi } from "vitest"; import { code, createCodePlugin } from "../index"; @@ -303,4 +303,87 @@ describe("createCodePlugin", () => { expect(typeof plugin.getSupportedLanguages).toBe("function"); expect(typeof plugin.getThemes).toBe("function"); }); + + it("should create plugin with custom theme objects", () => { + const customLight: ThemeRegistrationAny = { + name: "my-light-theme", + type: "light", + colors: { "editor.background": "#ffffff" }, + tokenColors: [], + }; + const customDark: ThemeRegistrationAny = { + name: "my-dark-theme", + type: "dark", + colors: { "editor.background": "#1e1e1e" }, + tokenColors: [], + }; + const plugin = createCodePlugin({ + themes: [customLight, customDark], + }); + const themes = plugin.getThemes(); + expect(themes[0]).toBe(customLight); + expect(themes[1]).toBe(customDark); + }); + + it("should create plugin with mixed built-in and custom themes", () => { + const customDark: ThemeRegistrationAny = { + name: "my-dark-theme", + type: "dark", + colors: { "editor.background": "#1e1e1e" }, + tokenColors: [], + }; + const plugin = createCodePlugin({ + themes: ["github-light", customDark], + }); + const themes = plugin.getThemes(); + expect(themes[0]).toBe("github-light"); + expect(themes[1]).toBe(customDark); + }); + + it("should highlight code with custom theme objects", async () => { + const customLight: ThemeRegistrationAny = { + name: "custom-light", + type: "light", + colors: { + "editor.background": "#ffffff", + "editor.foreground": "#000000", + }, + tokenColors: [], + }; + const customDark: ThemeRegistrationAny = { + name: "custom-dark", + type: "dark", + colors: { + "editor.background": "#1e1e1e", + "editor.foreground": "#d4d4d4", + }, + tokenColors: [], + }; + const plugin = createCodePlugin({ + themes: [customLight, customDark], + }); + + const callback = vi.fn(); + const result = plugin.highlight( + { + code: "const x = 1;", + language: "javascript", + themes: [customLight, customDark], + }, + callback + ); + + expect(result).toBeNull(); + + await vi.waitFor( + () => { + expect(callback).toHaveBeenCalled(); + }, + { timeout: 5000 } + ); + + const highlightResult = callback.mock.calls[0][0]; + expect(highlightResult.tokens).toBeDefined(); + expect(Array.isArray(highlightResult.tokens)).toBe(true); + }); }); diff --git a/packages/streamdown-code/index.ts b/packages/streamdown-code/index.ts index a459e283..78d7a8c9 100644 --- a/packages/streamdown-code/index.ts +++ b/packages/streamdown-code/index.ts @@ -3,6 +3,7 @@ import { type BundledLanguage, type BundledTheme, + type ThemeRegistrationAny, bundledLanguages, bundledLanguagesInfo, createHighlighter, @@ -14,6 +15,8 @@ import { createJavaScriptRegexEngine } from "shiki/engine/javascript"; const jsEngine = createJavaScriptRegexEngine({ forgiving: true }); +export type ThemeInput = BundledTheme | ThemeRegistrationAny; + /** * Result from code highlighting */ @@ -25,7 +28,7 @@ export type HighlightResult = TokensResult; export interface HighlightOptions { code: string; language: BundledLanguage; - themes: [string, string]; + themes: [ThemeInput, ThemeInput]; } /** @@ -39,7 +42,7 @@ export interface CodeHighlighterPlugin { /** * Get the configured themes */ - getThemes: () => [BundledTheme, BundledTheme]; + getThemes: () => [ThemeInput, ThemeInput]; /** * Highlight code and return tokens * Returns null if highlighting not ready yet (async loading) @@ -65,7 +68,7 @@ export interface CodePluginOptions { * Default themes for syntax highlighting [light, dark] * @default ["github-light", "github-dark"] */ - themes?: [BundledTheme, BundledTheme]; + themes?: [ThemeInput, ThemeInput]; } const languageAliases = Object.fromEntries( @@ -104,10 +107,13 @@ const tokensCache = new Map(); // Subscribers for async token updates const subscribers = new Map void>>(); +const getThemeName = (theme: ThemeInput): string => + typeof theme === "string" ? theme : (theme.name ?? "custom"); + const getHighlighterCacheKey = ( language: BundledLanguage, - themeNames: [string, string] -) => `${language}-${themeNames[0]}-${themeNames[1]}`; + themes: [ThemeInput, ThemeInput] +) => `${language}-${getThemeName(themes[0])}-${getThemeName(themes[1])}`; const getTokensCacheKey = ( code: string, @@ -121,9 +127,9 @@ const getTokensCacheKey = ( const getHighlighter = ( language: BundledLanguage, - themeNames: [string, string] + themes: [ThemeInput, ThemeInput] ): Promise> => { - const cacheKey = getHighlighterCacheKey(language, themeNames); + const cacheKey = getHighlighterCacheKey(language, themes); if (highlighterCache.has(cacheKey)) { return highlighterCache.get(cacheKey) as Promise< @@ -132,7 +138,7 @@ const getHighlighter = ( } const highlighterPromise = createHighlighter({ - themes: themeNames, + themes, langs: [language], engine: jsEngine, }); @@ -147,7 +153,7 @@ const getHighlighter = ( export function createCodePlugin( options: CodePluginOptions = {} ): CodeHighlighterPlugin { - const defaultThemes: [BundledTheme, BundledTheme] = options.themes ?? [ + const defaultThemes: [ThemeInput, ThemeInput] = options.themes ?? [ "github-light", "github-dark", ]; @@ -165,19 +171,23 @@ export function createCodePlugin( return Array.from(languageNames); }, - getThemes(): [BundledTheme, BundledTheme] { + getThemes(): [ThemeInput, ThemeInput] { return defaultThemes; }, highlight( - { code, language, themes: themeNames }: HighlightOptions, + { code, language, themes }: HighlightOptions, callback?: (result: HighlightResult) => void ): HighlightResult | null { const resolvedLanguage = normalizeLanguage(language); + const themeNames: [string, string] = [ + getThemeName(themes[0]), + getThemeName(themes[1]), + ]; const tokensCacheKey = getTokensCacheKey( code, resolvedLanguage, - themeNames as [string, string] + themeNames ); // Return cached result if available @@ -197,10 +207,7 @@ export function createCodePlugin( } // Start highlighting in background - getHighlighter( - resolvedLanguage as BundledLanguage, - themeNames as [string, string] - ) + getHighlighter(resolvedLanguage as BundledLanguage, themes) .then((highlighter) => { const availableLangs = highlighter.getLoadedLanguages(); const langToUse = ( diff --git a/packages/streamdown/index.tsx b/packages/streamdown/index.tsx index eaebfd4c..26e5bf44 100644 --- a/packages/streamdown/index.tsx +++ b/packages/streamdown/index.tsx @@ -17,7 +17,6 @@ import rehypeRaw from "rehype-raw"; import rehypeSanitize, { defaultSchema } from "rehype-sanitize"; import remarkGfm from "remark-gfm"; import remend, { type RemendOptions } from "remend"; -import type { BundledTheme } from "shiki"; import type { Pluggable } from "unified"; import { type AnimateOptions, createAnimatePlugin } from "./lib/animate"; import { BlockIncompleteContext } from "./lib/block-incomplete-context"; @@ -26,7 +25,7 @@ import { hasIncompleteCodeFence, hasTable } from "./lib/incomplete-code-utils"; import { Markdown, type Options } from "./lib/markdown"; import { parseMarkdownIntoBlocks } from "./lib/parse-blocks"; import { PluginContext } from "./lib/plugin-context"; -import type { PluginConfig } from "./lib/plugin-types"; +import type { PluginConfig, ThemeInput } from "./lib/plugin-types"; import { PrefixContext } from "./lib/prefix-context"; import { preprocessCustomTags } from "./lib/preprocess-custom-tags"; import { createCn } from "./lib/utils"; @@ -51,7 +50,9 @@ export type { HighlightOptions, MathPlugin, PluginConfig, + ThemeInput, } from "./lib/plugin-types"; +export type { ThemeRegistrationAny } from "shiki"; export { TableCopyDropdown, type TableCopyDropdownProps, @@ -148,7 +149,7 @@ export type StreamdownProps = Options & { /** Normalize HTML block indentation to prevent 4+ spaces being treated as code blocks. @default false */ normalizeHtmlIndentation?: boolean; className?: string; - shikiTheme?: [BundledTheme, BundledTheme]; + shikiTheme?: [ThemeInput, ThemeInput]; mermaid?: MermaidOptions; controls?: ControlsConfig; isAnimating?: boolean; @@ -210,7 +211,7 @@ export interface StreamdownContextType { linkSafety?: LinkSafetyConfig; mermaid?: MermaidOptions; mode: "static" | "streaming"; - shikiTheme: [BundledTheme, BundledTheme]; + shikiTheme: [ThemeInput, ThemeInput]; } const defaultStreamdownContext: StreamdownContextType = { @@ -311,7 +312,7 @@ export const Block = memo( Block.displayName = "Block"; -const defaultShikiTheme: [BundledTheme, BundledTheme] = [ +const defaultShikiTheme: [ThemeInput, ThemeInput] = [ "github-light", "github-dark", ]; diff --git a/packages/streamdown/lib/plugin-types.ts b/packages/streamdown/lib/plugin-types.ts index e4a293f7..ea124186 100644 --- a/packages/streamdown/lib/plugin-types.ts +++ b/packages/streamdown/lib/plugin-types.ts @@ -1,7 +1,13 @@ import type { MermaidConfig } from "mermaid"; -import type { BundledLanguage, BundledTheme } from "shiki"; +import type { + BundledLanguage, + BundledTheme, + ThemeRegistrationAny, +} from "shiki"; import type { Pluggable } from "unified"; +export type ThemeInput = BundledTheme | ThemeRegistrationAny; + /** * A single token in a highlighted line */ @@ -30,7 +36,7 @@ export interface HighlightResult { export interface HighlightOptions { code: string; language: BundledLanguage; - themes: [string, string]; + themes: [ThemeInput, ThemeInput]; } /** @@ -44,7 +50,7 @@ export interface CodeHighlighterPlugin { /** * Get the configured themes */ - getThemes: () => [BundledTheme, BundledTheme]; + getThemes: () => [ThemeInput, ThemeInput]; /** * Highlight code and return tokens * Returns null if highlighting not ready yet (async loading)