Skip to content

Commit 4dddad3

Browse files
committed
fix
1 parent 6a830c0 commit 4dddad3

File tree

2 files changed

+64
-81
lines changed

2 files changed

+64
-81
lines changed

llm/cli.py

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
import io
66
import json
77
import os
8+
9+
from rich.console import Console
10+
from rich.live import Live
11+
from rich.markdown import Markdown
12+
813
from llm import (
914
Attachment,
1015
AsyncConversation,
@@ -41,7 +46,6 @@
4146
remove_alias,
4247
)
4348
from llm.models import _BaseConversation, ChainResponse
44-
from .cli_utils import prompt_output
4549

4650
from .migrations import migrate
4751
from .plugins import pm, load_plugins
@@ -833,41 +837,76 @@ def read_prompt():
833837
# Merge in options for the .prompt() methods
834838
kwargs.update(validated_options)
835839

840+
console = Console()
841+
836842
try:
837843
if async_:
838844

839845
async def inner():
840-
return prompt_output(
841-
prompt_method,
846+
response = prompt_method(
842847
prompt,
843-
should_stream,
844-
render,
845-
extract,
846-
extract_last,
847-
resolved_fragments,
848-
resolved_attachments,
849-
system,
850-
schema,
851-
resolved_system_fragments,
852-
kwargs,
848+
attachments=resolved_attachments,
849+
system=system,
850+
schema=schema,
851+
fragments=resolved_fragments,
852+
system_fragments=resolved_system_fragments,
853+
**kwargs,
853854
)
854855

856+
if should_stream:
857+
accumulated_text = ""
858+
with Live(accumulated_text, console=console, refresh_per_second=10) as live:
859+
async for chunk in response:
860+
accumulated_text += chunk
861+
862+
if render:
863+
display_content = Markdown(accumulated_text)
864+
else:
865+
display_content = accumulated_text
866+
867+
live.update(display_content)
868+
else:
869+
text = await response.text()
870+
if extract or extract_last:
871+
text = extract_fenced_code_block(text, last=extract_last) or text
872+
if render:
873+
text = Markdown(text)
874+
console.print(text)
875+
876+
855877
response = asyncio.run(inner())
856878
else:
857-
response = prompt_output(
858-
prompt_method,
879+
response = prompt_method(
859880
prompt,
860-
should_stream,
861-
render,
862-
extract,
863-
extract_last,
864-
resolved_fragments,
865-
resolved_attachments,
866-
system,
867-
schema,
868-
resolved_system_fragments,
869-
kwargs,
881+
fragments=resolved_fragments,
882+
attachments=resolved_attachments,
883+
system=system,
884+
schema=schema,
885+
system_fragments=resolved_system_fragments,
886+
**kwargs,
870887
)
888+
889+
890+
891+
if should_stream:
892+
accumulated_text = ""
893+
with Live(accumulated_text, console=console, refresh_per_second=10) as live:
894+
for chunk in response:
895+
accumulated_text += chunk
896+
897+
if render:
898+
display_content = Markdown(accumulated_text)
899+
else:
900+
display_content = accumulated_text
901+
902+
live.update(display_content)
903+
else:
904+
text = response.text()
905+
if extract or extract_last:
906+
text = extract_fenced_code_block(text, last=extract_last) or text
907+
if render:
908+
text = Markdown(text)
909+
console.print(text)
871910
# List of exceptions that should never be raised in pytest:
872911
except (ValueError, NotImplementedError) as ex:
873912
raise click.ClickException(str(ex))

llm/cli_utils.py

Lines changed: 0 additions & 56 deletions
This file was deleted.

0 commit comments

Comments
 (0)