Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,189 @@ async def inner():
# At this point ALL forms should have a log_to_db() method that works:
response.log_to_db(db)

@cli.command(name="prompt_multi")
@click.argument("prompt_file", required=True, type=click.Path(exists=True), )
@click.option("-c", '--prompt-column', default='prompt', help="Column in the prompt file containing the prompt text, defaults to 'prompt'")
@click.option("-r", "--result-file", help="File to write results to, supports csv and parquet, defaults to stdout")
@click.option("-s", "--system", help="System prompt to use")
@click.option("model_id", "-m", "--model", help="Model to use", envvar="LLM_MODEL")
@click.option(
"-d",
"--database",
type=click.Path(readable=True, dir_okay=False),
help="Path to log database",
)
@click.option(
"queries",
"-q",
"--query",
multiple=True,
help="Use first model matching these strings",
)
@click.option(
"options",
"-o",
"--option",
type=(str, str),
multiple=True,
help="key/value options for the model",
)
@schema_option
@click.option(
"--schema-multi",
help="JSON schema to use for multiple results",
)
@click.option("-n", "--no-log", is_flag=True, help="Don't log to database")
@click.option("--log", is_flag=True, help="Log prompt and response to the database")
@click.option("--key", help="API key to use")
def prompt_multi(
prompt_file,
prompt_column,
result_file,
system,
model_id,
database,
queries,
options,
schema_input,
schema_multi,
no_log,
log,
key,
):
"""
Execute multiple prompts from prompt_file.

prompt_file supports csv and parquet"
"""
try:
import bodo.pandas as pd
# ignore bodo warnings
warnings.filterwarnings("ignore", module="bodo.pandas")
except ImportError:
try:
import pandas as pd
warnings.warn("For better performance with large files, install bodo: pip install bodo")
except ImportError:
raise click.ClickException("This command requires a Pandas compatible dataframe library such as pandas or bodo to be installed: pip install pandas/bodo")



if log and no_log:
raise click.ClickException("--log and --no-log are mutually exclusive")

log_path = pathlib.Path(database) if database else logs_db_path()
(log_path.parent).mkdir(parents=True, exist_ok=True)
db = sqlite_utils.Database(log_path)
migrate(db)

if queries and not model_id:
# Use -q options to find model with shortest model_id
matches = []
for model_with_aliases in get_models_with_aliases():
if all(model_with_aliases.matches(q) for q in queries):
matches.append(model_with_aliases.model.model_id)
if not matches:
raise click.ClickException(
"No model found matching queries {}".format(", ".join(queries))
)
model_id = min(matches, key=len)

if schema_multi:
schema_input = schema_multi

schema = resolve_schema_input(db, schema_input, load_template)

if schema_multi:
# Convert that schema into multiple "items" of the same schema
schema = multi_schema(schema)

model_id = model_id or get_default_model()
async_model = False
# Get the model, preferring async if available
try:
model = get_async_model(model_id)
async_model = True
except UnknownModelError as ex:
try:
model = get_model(model_id)
except UnknownModelError as ex:
raise click.ClickException(str(ex))

# Validate options
validated_options = {}
if options:
# Validate with pydantic
try:
validated_options = dict(
(key, value)
for key, value in model.Options(**dict(options))
if value is not None
)
except pydantic.ValidationError as ex:
raise click.ClickException(render_errors(ex.errors()))

# Add on any default model options
default_options = get_model_options(model.model_id)
for key_, value in default_options.items():
if key_ not in validated_options:
validated_options[key_] = value

kwargs = {}
# We don't care about streaming for multi prompts
kwargs["stream"] = False
if isinstance(model, (KeyModel, AsyncKeyModel)):
kwargs["key"] = key

# Validate file extensions
if not prompt_file.endswith(('.csv', '.parquet', '.pq')):
raise click.ClickException("Prompt file must be a CSV or Parquet file")
if result_file and not result_file.endswith(('.csv', '.parquet', '.pq')):
raise click.ClickException("Output file must be a CSV or Parquet file")
prompts_df = pd.read_csv(prompt_file) if prompt_file.endswith('.csv') else pd.read_parquet(prompt_file)
if prompt_column not in prompts_df.columns and len(prompts_df.columns) != 1:
raise click.ClickException(f"Prompt column '{prompt_column}' not found in prompt file")
prompts = prompts_df[prompts_df.columns[0]] if len(prompts_df.columns) == 1 else prompts_df[prompt_column]

def process_row_prompt(p):
if async_model:
async def inner():
response = model.prompt(system=system, prompt=p, schema=schema, **kwargs)
return response
response = asyncio.run(inner())
else:
response = model.prompt(system=system, prompt=p, schema=schema, **kwargs)
# Log responses to the database
if (logs_on() or log) and not no_log:
db = sqlite_utils.Database(log_path)
# Could be Response, AsyncResponse, ChainResponse, AsyncChainResponse
if isinstance(response, AsyncResponse):
response = asyncio.run(response.to_sync_response())
# At this point ALL forms should have a log_to_db() method that works:
response.log_to_db(db)
return response.text()

responses = prompts.map(process_row_prompt)

results = pd.DataFrame({
"prompt": prompts,
"response": responses
})

if result_file:
if result_file.endswith('.csv'):
results.to_csv(result_file, index=False)
else:
results.to_parquet(result_file, index=False)
else:
print(results.to_string(index=False))








@cli.command()
@click.option("-s", "--system", help="System prompt to use")
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ test = [
"types-PyYAML",
"types-setuptools",
"llm-echo==0.3a3",
"bodo",
]

[build-system]
Expand Down
41 changes: 41 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pytest
import sqlite_utils
from unittest import mock
import bodo
import bodo.pandas as pd


def test_version():
Expand Down Expand Up @@ -851,3 +853,42 @@ def test_llm_prompt_continue_with_database(
assert (user_path / "logs.db").exists()
db_path = str(user_path / "logs.db")
assert sqlite_utils.Database(db_path)["responses"].count == 2

@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
def test_llm_prompt_multi(tmpdir, async_mock_model):
# Disable parallel processing in bodo for this test
# so we can use mock_model
bodo.dataframe_library_run_parallel = False

test_csv_path = str(pathlib.Path(tmpdir, "test.csv"))
result_path = str(pathlib.Path(tmpdir, "results.pq"))
prompts = pd.DataFrame({
"promptx": ["test1", "test2", "test3"],
})
prompts.to_csv(test_csv_path, index=False)
# Needs an extra one for typing
async_mock_model.enqueue(["test1_resp"])

async_mock_model.enqueue(["test1_resp"])
async_mock_model.enqueue(["test2_resp"])
async_mock_model.enqueue(["test3_resp"])

runner = CliRunner()
result = runner.invoke(
cli,
[
"prompt_multi",
test_csv_path,
"-m",
"mock",
"-s",
"You are a helpful assistant.",
"-c", "promptx", "-r", result_path
],
catch_exceptions=True,
)
assert result.exit_code == 0

results = pd.read_parquet(result_path)
assert list(results["prompt"]) == ["test1", "test2", "test3"]
assert list(results["response"]) == ["test1_resp", "test2_resp", "test3_resp"]