diff --git a/llm/cli.py b/llm/cli.py index 2e11e2c8..b10c27c5 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -922,6 +922,189 @@ async def inner(): # At this point ALL forms should have a log_to_db() method that works: response.log_to_db(db) +@cli.command(name="prompt_multi") +@click.argument("prompt_file", required=True, type=click.Path(exists=True), ) +@click.option("-c", '--prompt-column', default='prompt', help="Column in the prompt file containing the prompt text, defaults to 'prompt'") +@click.option("-r", "--result-file", help="File to write results to, supports csv and parquet, defaults to stdout") +@click.option("-s", "--system", help="System prompt to use") +@click.option("model_id", "-m", "--model", help="Model to use", envvar="LLM_MODEL") +@click.option( + "-d", + "--database", + type=click.Path(readable=True, dir_okay=False), + help="Path to log database", +) +@click.option( + "queries", + "-q", + "--query", + multiple=True, + help="Use first model matching these strings", +) +@click.option( + "options", + "-o", + "--option", + type=(str, str), + multiple=True, + help="key/value options for the model", +) +@schema_option +@click.option( + "--schema-multi", + help="JSON schema to use for multiple results", +) +@click.option("-n", "--no-log", is_flag=True, help="Don't log to database") +@click.option("--log", is_flag=True, help="Log prompt and response to the database") +@click.option("--key", help="API key to use") +def prompt_multi( + prompt_file, + prompt_column, + result_file, + system, + model_id, + database, + queries, + options, + schema_input, + schema_multi, + no_log, + log, + key, +): + """ + Execute multiple prompts from prompt_file. + + prompt_file supports csv and parquet" + """ + try: + import bodo.pandas as pd + # ignore bodo warnings + warnings.filterwarnings("ignore", module="bodo.pandas") + except ImportError: + try: + import pandas as pd + warnings.warn("For better performance with large files, install bodo: pip install bodo") + except ImportError: + raise click.ClickException("This command requires a Pandas compatible dataframe library such as pandas or bodo to be installed: pip install pandas/bodo") + + + + if log and no_log: + raise click.ClickException("--log and --no-log are mutually exclusive") + + log_path = pathlib.Path(database) if database else logs_db_path() + (log_path.parent).mkdir(parents=True, exist_ok=True) + db = sqlite_utils.Database(log_path) + migrate(db) + + if queries and not model_id: + # Use -q options to find model with shortest model_id + matches = [] + for model_with_aliases in get_models_with_aliases(): + if all(model_with_aliases.matches(q) for q in queries): + matches.append(model_with_aliases.model.model_id) + if not matches: + raise click.ClickException( + "No model found matching queries {}".format(", ".join(queries)) + ) + model_id = min(matches, key=len) + + if schema_multi: + schema_input = schema_multi + + schema = resolve_schema_input(db, schema_input, load_template) + + if schema_multi: + # Convert that schema into multiple "items" of the same schema + schema = multi_schema(schema) + + model_id = model_id or get_default_model() + async_model = False + # Get the model, preferring async if available + try: + model = get_async_model(model_id) + async_model = True + except UnknownModelError as ex: + try: + model = get_model(model_id) + except UnknownModelError as ex: + raise click.ClickException(str(ex)) + + # Validate options + validated_options = {} + if options: + # Validate with pydantic + try: + validated_options = dict( + (key, value) + for key, value in model.Options(**dict(options)) + if value is not None + ) + except pydantic.ValidationError as ex: + raise click.ClickException(render_errors(ex.errors())) + + # Add on any default model options + default_options = get_model_options(model.model_id) + for key_, value in default_options.items(): + if key_ not in validated_options: + validated_options[key_] = value + + kwargs = {} + # We don't care about streaming for multi prompts + kwargs["stream"] = False + if isinstance(model, (KeyModel, AsyncKeyModel)): + kwargs["key"] = key + + # Validate file extensions + if not prompt_file.endswith(('.csv', '.parquet', '.pq')): + raise click.ClickException("Prompt file must be a CSV or Parquet file") + if result_file and not result_file.endswith(('.csv', '.parquet', '.pq')): + raise click.ClickException("Output file must be a CSV or Parquet file") + prompts_df = pd.read_csv(prompt_file) if prompt_file.endswith('.csv') else pd.read_parquet(prompt_file) + if prompt_column not in prompts_df.columns and len(prompts_df.columns) != 1: + raise click.ClickException(f"Prompt column '{prompt_column}' not found in prompt file") + prompts = prompts_df[prompts_df.columns[0]] if len(prompts_df.columns) == 1 else prompts_df[prompt_column] + + def process_row_prompt(p): + if async_model: + async def inner(): + response = model.prompt(system=system, prompt=p, schema=schema, **kwargs) + return response + response = asyncio.run(inner()) + else: + response = model.prompt(system=system, prompt=p, schema=schema, **kwargs) + # Log responses to the database + if (logs_on() or log) and not no_log: + db = sqlite_utils.Database(log_path) + # Could be Response, AsyncResponse, ChainResponse, AsyncChainResponse + if isinstance(response, AsyncResponse): + response = asyncio.run(response.to_sync_response()) + # At this point ALL forms should have a log_to_db() method that works: + response.log_to_db(db) + return response.text() + + responses = prompts.map(process_row_prompt) + + results = pd.DataFrame({ + "prompt": prompts, + "response": responses + }) + + if result_file: + if result_file.endswith('.csv'): + results.to_csv(result_file, index=False) + else: + results.to_parquet(result_file, index=False) + else: + print(results.to_string(index=False)) + + + + + + + @cli.command() @click.option("-s", "--system", help="System prompt to use") diff --git a/pyproject.toml b/pyproject.toml index a78b7c53..205c8c98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ test = [ "types-PyYAML", "types-setuptools", "llm-echo==0.3a3", + "bodo", ] [build-system] diff --git a/tests/test_llm.py b/tests/test_llm.py index e4a1c23c..0c67823a 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -9,6 +9,8 @@ import pytest import sqlite_utils from unittest import mock +import bodo +import bodo.pandas as pd def test_version(): @@ -851,3 +853,42 @@ def test_llm_prompt_continue_with_database( assert (user_path / "logs.db").exists() db_path = str(user_path / "logs.db") assert sqlite_utils.Database(db_path)["responses"].count == 2 + +@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"}) +def test_llm_prompt_multi(tmpdir, async_mock_model): + # Disable parallel processing in bodo for this test + # so we can use mock_model + bodo.dataframe_library_run_parallel = False + + test_csv_path = str(pathlib.Path(tmpdir, "test.csv")) + result_path = str(pathlib.Path(tmpdir, "results.pq")) + prompts = pd.DataFrame({ + "promptx": ["test1", "test2", "test3"], + }) + prompts.to_csv(test_csv_path, index=False) + # Needs an extra one for typing + async_mock_model.enqueue(["test1_resp"]) + + async_mock_model.enqueue(["test1_resp"]) + async_mock_model.enqueue(["test2_resp"]) + async_mock_model.enqueue(["test3_resp"]) + + runner = CliRunner() + result = runner.invoke( + cli, + [ + "prompt_multi", + test_csv_path, + "-m", + "mock", + "-s", + "You are a helpful assistant.", + "-c", "promptx", "-r", result_path + ], + catch_exceptions=True, + ) + assert result.exit_code == 0 + + results = pd.read_parquet(result_path) + assert list(results["prompt"]) == ["test1", "test2", "test3"] + assert list(results["response"]) == ["test1_resp", "test2_resp", "test3_resp"]