|
5 | 5 | import io |
6 | 6 | import json |
7 | 7 | import os |
| 8 | + |
| 9 | +from rich.console import Console |
| 10 | +from rich.live import Live |
| 11 | +from rich.markdown import Markdown |
| 12 | + |
8 | 13 | from llm import ( |
9 | 14 | Attachment, |
10 | 15 | AsyncConversation, |
|
41 | 46 | remove_alias, |
42 | 47 | ) |
43 | 48 | from llm.models import _BaseConversation, ChainResponse |
44 | | -from .cli_utils import prompt_output |
45 | 49 |
|
46 | 50 | from .migrations import migrate |
47 | 51 | from .plugins import pm, load_plugins |
@@ -833,41 +837,76 @@ def read_prompt(): |
833 | 837 | # Merge in options for the .prompt() methods |
834 | 838 | kwargs.update(validated_options) |
835 | 839 |
|
| 840 | + console = Console() |
| 841 | + |
836 | 842 | try: |
837 | 843 | if async_: |
838 | 844 |
|
839 | 845 | async def inner(): |
840 | | - return prompt_output( |
841 | | - prompt_method, |
| 846 | + response = prompt_method( |
842 | 847 | prompt, |
843 | | - should_stream, |
844 | | - render, |
845 | | - extract, |
846 | | - extract_last, |
847 | | - resolved_fragments, |
848 | | - resolved_attachments, |
849 | | - system, |
850 | | - schema, |
851 | | - resolved_system_fragments, |
852 | | - kwargs, |
| 848 | + attachments=resolved_attachments, |
| 849 | + system=system, |
| 850 | + schema=schema, |
| 851 | + fragments=resolved_fragments, |
| 852 | + system_fragments=resolved_system_fragments, |
| 853 | + **kwargs, |
853 | 854 | ) |
854 | 855 |
|
| 856 | + if should_stream: |
| 857 | + accumulated_text = "" |
| 858 | + with Live(accumulated_text, console=console, refresh_per_second=10) as live: |
| 859 | + async for chunk in response: |
| 860 | + accumulated_text += chunk |
| 861 | + |
| 862 | + if render: |
| 863 | + display_content = Markdown(accumulated_text) |
| 864 | + else: |
| 865 | + display_content = accumulated_text |
| 866 | + |
| 867 | + live.update(display_content) |
| 868 | + else: |
| 869 | + text = await response.text() |
| 870 | + if extract or extract_last: |
| 871 | + text = extract_fenced_code_block(text, last=extract_last) or text |
| 872 | + if render: |
| 873 | + text = Markdown(text) |
| 874 | + console.print(text) |
| 875 | + |
| 876 | + |
855 | 877 | response = asyncio.run(inner()) |
856 | 878 | else: |
857 | | - response = prompt_output( |
858 | | - prompt_method, |
| 879 | + response = prompt_method( |
859 | 880 | prompt, |
860 | | - should_stream, |
861 | | - render, |
862 | | - extract, |
863 | | - extract_last, |
864 | | - resolved_fragments, |
865 | | - resolved_attachments, |
866 | | - system, |
867 | | - schema, |
868 | | - resolved_system_fragments, |
869 | | - kwargs, |
| 881 | + fragments=resolved_fragments, |
| 882 | + attachments=resolved_attachments, |
| 883 | + system=system, |
| 884 | + schema=schema, |
| 885 | + system_fragments=resolved_system_fragments, |
| 886 | + **kwargs, |
870 | 887 | ) |
| 888 | + |
| 889 | + |
| 890 | + |
| 891 | + if should_stream: |
| 892 | + accumulated_text = "" |
| 893 | + with Live(accumulated_text, console=console, refresh_per_second=10) as live: |
| 894 | + for chunk in response: |
| 895 | + accumulated_text += chunk |
| 896 | + |
| 897 | + if render: |
| 898 | + display_content = Markdown(accumulated_text) |
| 899 | + else: |
| 900 | + display_content = accumulated_text |
| 901 | + |
| 902 | + live.update(display_content) |
| 903 | + else: |
| 904 | + text = response.text() |
| 905 | + if extract or extract_last: |
| 906 | + text = extract_fenced_code_block(text, last=extract_last) or text |
| 907 | + if render: |
| 908 | + text = Markdown(text) |
| 909 | + console.print(text) |
871 | 910 | # List of exceptions that should never be raised in pytest: |
872 | 911 | except (ValueError, NotImplementedError) as ex: |
873 | 912 | raise click.ClickException(str(ex)) |
|
0 commit comments