diff --git a/lib/instructor.ex b/lib/instructor.ex index 48d6d3d..6dd3151 100644 --- a/lib/instructor.ex +++ b/lib/instructor.ex @@ -260,6 +260,61 @@ defmodule Instructor do changeset end + @spec prepare_prompt(Keyword.t()) :: map() + def prepare_prompt(params, config \\ nil) do + response_model = Keyword.fetch!(params, :response_model) + mode = Keyword.get(params, :mode, :tools) + params = params_for_mode(mode, response_model, params) + + adapter(config).prompt(params) + end + + @spec consume_response(any(), Keyword.t()) :: + {:ok, map()} | {:error, String.t()} | {:error, Ecto.Changeset.t(), Keyword.t()} + def consume_response(response, params) do + validation_context = Keyword.get(params, :validation_context, %{}) + response_model = Keyword.fetch!(params, :response_model) + mode = Keyword.get(params, :mode, :tools) + + model = + if is_ecto_schema(response_model) do + response_model.__struct__() + else + {%{}, response_model} + end + + with {:valid_json, {:ok, params}} <- {:valid_json, parse_response_for_mode(mode, response)}, + changeset <- cast_all(model, params), + {:validation, %Ecto.Changeset{valid?: true} = changeset, _response} <- + {:validation, call_validate(response_model, changeset, validation_context), response} do + {:ok, changeset |> Ecto.Changeset.apply_changes()} + else + {:valid_json, {:error, error}} -> + {:error, "Invalid JSON returned from LLM: #{inspect(error)}"} + + {:validation, changeset, response} -> + errors = Instructor.ErrorFormatter.format_errors(changeset) + + params = + Keyword.update(params, :messages, [], fn messages -> + messages ++ + echo_response(response) ++ + [ + %{ + role: "system", + content: """ + The response did not pass validation. Please try again and fix the following validation errors:\n + + #{errors} + """ + } + ] + end) + + {:error, changeset, params} + end + end + defp do_streaming_partial_array_chat_completion(response_model, params, config) do wrapped_model = %{ value: @@ -270,7 +325,7 @@ defmodule Instructor do params = Keyword.put(params, :response_model, wrapped_model) validation_context = Keyword.get(params, :validation_context, %{}) mode = Keyword.get(params, :mode, :tools) - params = params_for_mode(mode, wrapped_model, params) + prompt = prepare_prompt(params, config) model = if is_ecto_schema(response_model) do @@ -279,7 +334,7 @@ defmodule Instructor do {%{}, response_model} end - adapter(config).chat_completion(params, config) + adapter(config).chat_completion(prompt, params, config) |> Stream.map(&parse_stream_chunk_for_mode(mode, &1)) |> Instructor.JSONStreamParser.parse() |> Stream.transform( @@ -341,9 +396,9 @@ defmodule Instructor do params = Keyword.put(params, :response_model, wrapped_model) validation_context = Keyword.get(params, :validation_context, %{}) mode = Keyword.get(params, :mode, :tools) - params = params_for_mode(mode, wrapped_model, params) + prompt = prepare_prompt(params, config) - adapter(config).chat_completion(params, config) + adapter(config).chat_completion(prompt, params, config) |> Stream.map(&parse_stream_chunk_for_mode(mode, &1)) |> Instructor.JSONStreamParser.parse() |> Stream.transform( @@ -389,9 +444,10 @@ defmodule Instructor do params = Keyword.put(params, :response_model, wrapped_model) validation_context = Keyword.get(params, :validation_context, %{}) mode = Keyword.get(params, :mode, :tools) - params = params_for_mode(mode, wrapped_model, params) - adapter(config).chat_completion(params, config) + prompt = prepare_prompt(params, config) + + adapter(config).chat_completion(prompt, params, config) |> Stream.map(&parse_stream_chunk_for_mode(mode, &1)) |> Jaxon.Stream.from_enumerable() |> Jaxon.Stream.query([:root, "value", :all]) @@ -416,32 +472,18 @@ defmodule Instructor do end defp do_chat_completion(response_model, params, config) do - validation_context = Keyword.get(params, :validation_context, %{}) max_retries = Keyword.get(params, :max_retries) - mode = Keyword.get(params, :mode, :tools) - params = params_for_mode(mode, response_model, params) + prompt = prepare_prompt(params, config) - model = - if is_ecto_schema(response_model) do - response_model.__struct__() - else - {%{}, response_model} - end - - with {:llm, {:ok, response}} <- {:llm, adapter(config).chat_completion(params, config)}, - {:valid_json, {:ok, params}} <- {:valid_json, parse_response_for_mode(mode, response)}, - changeset <- cast_all(model, params), - {:validation, %Ecto.Changeset{valid?: true} = changeset, _response} <- - {:validation, call_validate(response_model, changeset, validation_context), response} do - {:ok, changeset |> Ecto.Changeset.apply_changes()} + with {:llm, {:ok, response}} <- + {:llm, adapter(config).chat_completion(prompt, params, config)}, + {:ok, result} <- consume_response(response, params) do + {:ok, result} else {:llm, {:error, error}} -> {:error, "LLM Adapter Error: #{inspect(error)}"} - {:valid_json, {:error, error}} -> - {:error, "Invalid JSON returned from LLM: #{inspect(error)}"} - - {:validation, changeset, response} -> + {:error, changeset, new_params} -> if max_retries > 0 do errors = Instructor.ErrorFormatter.format_errors(changeset) @@ -449,23 +491,7 @@ defmodule Instructor do errors: errors ) - params = - params - |> Keyword.put(:max_retries, max_retries - 1) - |> Keyword.update(:messages, [], fn messages -> - messages ++ - echo_response(response) ++ - [ - %{ - role: "system", - content: """ - The response did not pass validation. Please try again and fix the following validation errors:\n - - #{errors} - """ - } - ] - end) + params = Keyword.put(new_params, :max_retries, max_retries - 1) do_chat_completion(response_model, params, config) else diff --git a/lib/instructor/adapter.ex b/lib/instructor/adapter.ex index 5a00249..b4ec338 100644 --- a/lib/instructor/adapter.ex +++ b/lib/instructor/adapter.ex @@ -2,5 +2,6 @@ defmodule Instructor.Adapter do @moduledoc """ Behavior for `Instructor.Adapter`. """ - @callback chat_completion([Keyword.t()], any()) :: any() + @callback chat_completion(map(), [Keyword.t()], any()) :: any() + @callback prompt(Keyword.t()) :: map() end diff --git a/lib/instructor/adapters/llamacpp.ex b/lib/instructor/adapters/llamacpp.ex index 8c30214..c7e9434 100644 --- a/lib/instructor/adapters/llamacpp.ex +++ b/lib/instructor/adapters/llamacpp.ex @@ -31,34 +31,39 @@ defmodule Instructor.Adapters.Llamacpp do ...> ) """ @impl true - def chat_completion(params, _config \\ nil) do + def chat_completion(prompt, params, _config \\ nil) do + stream = Keyword.get(params, :stream, false) + + if stream do + do_streaming_chat_completion(prompt) + else + do_chat_completion(prompt) + end + end + + @impl true + def prompt(params) do {response_model, _} = Keyword.pop!(params, :response_model) {messages, _} = Keyword.pop!(params, :messages) json_schema = JSONSchema.from_ecto_schema(response_model) grammar = GBNF.from_json_schema(json_schema) prompt = apply_chat_template(chat_template(), messages) - stream = Keyword.get(params, :stream, false) - if stream do - do_streaming_chat_completion(prompt, grammar) - else - do_chat_completion(prompt, grammar) - end + %{ + grammar: grammar, + prompt: prompt + } end - defp do_streaming_chat_completion(prompt, grammar) do + defp do_streaming_chat_completion(prompt) do pid = self() Stream.resource( fn -> Task.async(fn -> Req.post(url(), - json: %{ - grammar: grammar, - prompt: prompt, - stream: true - }, + json: Map.put(prompt, :stream, true), receive_timeout: 60_000, into: fn {:data, data}, {req, resp} -> send(pid, data) @@ -94,13 +99,10 @@ defmodule Instructor.Adapters.Llamacpp do } end - defp do_chat_completion(prompt, grammar) do + defp do_chat_completion(prompt) do response = Req.post(url(), - json: %{ - grammar: grammar, - prompt: prompt - }, + json: prompt, receive_timeout: 60_000 ) diff --git a/lib/instructor/adapters/openai.ex b/lib/instructor/adapters/openai.ex index c752e6d..3c77a75 100644 --- a/lib/instructor/adapters/openai.ex +++ b/lib/instructor/adapters/openai.ex @@ -5,25 +5,30 @@ defmodule Instructor.Adapters.OpenAI do @behaviour Instructor.Adapter @impl true - def chat_completion(params, config) do + def chat_completion(prompt, params, config) do config = if config, do: config, else: config() + stream = Keyword.get(params, :stream, false) + + if stream do + do_streaming_chat_completion(prompt, config) + else + do_chat_completion(prompt, config) + end + end + + @impl true + def prompt(params) do # Peel off instructor only parameters {_, params} = Keyword.pop(params, :response_model) {_, params} = Keyword.pop(params, :validation_context) {_, params} = Keyword.pop(params, :max_retries) {_, params} = Keyword.pop(params, :mode) - stream = Keyword.get(params, :stream, false) - params = Enum.into(params, %{}) - if stream do - do_streaming_chat_completion(params, config) - else - do_chat_completion(params, config) - end + Enum.into(params, %{}) end - defp do_streaming_chat_completion(params, config) do + defp do_streaming_chat_completion(prompt, config) do pid = self() options = http_options(config) @@ -32,7 +37,7 @@ defmodule Instructor.Adapters.OpenAI do Task.async(fn -> options = Keyword.merge(options, - json: params, + json: prompt, auth: {:bearer, api_key(config)}, into: fn {:data, data}, {req, resp} -> chunks = @@ -75,8 +80,8 @@ defmodule Instructor.Adapters.OpenAI do ) end - defp do_chat_completion(params, config) do - options = Keyword.merge(http_options(config), json: params, auth: {:bearer, api_key(config)}) + defp do_chat_completion(prompt, config) do + options = Keyword.merge(http_options(config), json: prompt, auth: {:bearer, api_key(config)}) case Req.post(url(config), options) do {:ok, %{status: 200, body: body}} -> {:ok, body} diff --git a/test/instructor_test.exs b/test/instructor_test.exs index 8b44060..cec4b53 100644 --- a/test/instructor_test.exs +++ b/test/instructor_test.exs @@ -21,6 +21,9 @@ defmodule InstructorTest do :openai_mock -> Application.put_env(:instructor, :adapter, InstructorTest.MockOpenAI) + + _ -> + :ok end end @@ -419,4 +422,136 @@ defmodule InstructorTest do result |> Enum.to_list() end end + + describe "prepare_prompt" do + @tag adapter: :openai_mock + test "calls adapter's prompt/1 callback" do + expect(InstructorTest.MockOpenAI, :prompt, fn params -> + assert params[:tool_choice] == %{function: %{name: "Schema"}, type: "function"} + + assert params[:tools] == [ + %{ + function: %{ + "description" => + "Correctly extracted `Schema` with all the required parameters with correct types", + "name" => "Schema", + "parameters" => %{ + "properties" => %{ + "birth_date" => %{"format" => "date", "type" => "string"}, + "name" => %{"type" => "string"} + }, + "required" => ["birth_date", "name"], + "title" => "root", + "type" => "object" + } + }, + type: "function" + } + ] + + assert params[:model] == "gpt-3.5-turbo" + assert params[:response_model] == %{name: :string, birth_date: :date} + + assert params[:messages] == [ + %{role: "user", content: "Who was the first president of the USA?"} + ] + + %{foo: "bar"} + end) + + assert %{foo: "bar"} == + Instructor.prepare_prompt( + model: "gpt-3.5-turbo", + response_model: %{name: :string, birth_date: :date}, + messages: [ + %{role: "user", content: "Who was the first president of the USA?"} + ] + ) + end + end + + describe "consume_response" do + @tag adapter: :openai_mock + test "returns data if response is valid" do + response = + TestHelpers.example_openai_response(:tools, %{ + name: "George Washington", + birth_date: ~D[1732-02-22] + }) + + assert {:ok, %{name: "George Washington", birth_date: ~D[1732-02-22]}} = + Instructor.consume_response(response, + response_model: %{name: :string, birth_date: :date} + ) + end + + test "errors on invalid json" do + assert {:error, + "Invalid JSON returned from LLM: %Jason.DecodeError{position: 0, token: nil, data: \"I'm sorry Dave, I'm afraid I can't do this\"}"} = + Instructor.consume_response( + %{ + "choices" => [ + %{ + "message" => %{ + "tool_calls" => [ + %{ + "function" => %{ + "arguments" => "I'm sorry Dave, I'm afraid I can't do this" + } + } + ] + } + } + ] + }, + response_model: %{name: :string, birth_date: :date} + ) + end + + test "returns new params on failed cast" do + response = + TestHelpers.example_openai_response(:tools, %{ + name: 123, + birth_date: false + }) + + assert {:error, + %Ecto.Changeset{errors: [name: {"is invalid", _}, birth_date: {"is invalid", _}]}, + [ + response_model: %{name: :string, birth_date: :date}, + messages: [ + %{ + role: "assistant", + content: + "{\"function\":{\"arguments\":\"{\\\"name\\\":123,\\\"birth_date\\\":false}\",\"name\":\"schema\"},\"id\":\"call_DT9fBvVCHWGSf9IeFZnlarIY\",\"type\":\"function\"}", + tool_calls: [ + %{ + "function" => %{ + "arguments" => "{\"name\":123,\"birth_date\":false}", + "name" => "schema" + }, + "id" => "call_DT9fBvVCHWGSf9IeFZnlarIY", + "type" => "function" + } + ] + }, + %{ + name: "schema", + role: "tool", + content: "{\"name\":123,\"birth_date\":false}", + tool_call_id: "call_DT9fBvVCHWGSf9IeFZnlarIY" + }, + %{ + role: "system", + content: + "The response did not pass validation. Please try again and fix the following validation errors:\n\n\nname - is invalid\nbirth_date - is invalid\n" + } + ] + ]} = + Instructor.consume_response(response, + response_model: %{name: :string, birth_date: :date}, + messages: [] + ) + end + end end diff --git a/test/support/test_helpers.ex b/test/support/test_helpers.ex index 44848c8..29927f9 100644 --- a/test/support/test_helpers.ex +++ b/test/support/test_helpers.ex @@ -1,75 +1,73 @@ defmodule Instructor.TestHelpers do import Mox - def mock_openai_response(:tools, result) do - InstructorTest.MockOpenAI - |> expect(:chat_completion, fn _params, _config -> - {:ok, - %{ - "id" => "chatcmpl-8e9AVo9NHfvBG5cdtAEiJMm7q4Htz", - "usage" => %{ - "completion_tokens" => 23, - "prompt_tokens" => 136, - "total_tokens" => 159 - }, - "choices" => [ - %{ - "finish_reason" => "stop", - "index" => 0, - "logprobs" => nil, - "message" => %{ - "content" => nil, - "role" => "assistant", - "tool_calls" => [ - %{ - "function" => %{ - "arguments" => Jason.encode!(result), - "name" => "schema" - }, - "id" => "call_DT9fBvVCHWGSf9IeFZnlarIY", - "type" => "function" - } - ] - } - } - ], - "model" => "gpt-3.5-turbo-0613", - "object" => "chat.completion", - "created" => 1_704_579_055, - "system_fingerprint" => nil - }} - end) + def example_openai_response(:tools, result) do + %{ + "id" => "chatcmpl-8e9AVo9NHfvBG5cdtAEiJMm7q4Htz", + "usage" => %{ + "completion_tokens" => 23, + "prompt_tokens" => 136, + "total_tokens" => 159 + }, + "choices" => [ + %{ + "finish_reason" => "stop", + "index" => 0, + "logprobs" => nil, + "message" => %{ + "content" => nil, + "role" => "assistant", + "tool_calls" => [ + %{ + "function" => %{ + "arguments" => Jason.encode!(result), + "name" => "schema" + }, + "id" => "call_DT9fBvVCHWGSf9IeFZnlarIY", + "type" => "function" + } + ] + } + } + ], + "model" => "gpt-3.5-turbo-0613", + "object" => "chat.completion", + "created" => 1_704_579_055, + "system_fingerprint" => nil + } end - def mock_openai_response(mode, result) when mode in [:json, :md_json] do - InstructorTest.MockOpenAI - |> expect(:chat_completion, fn _params, _config -> - { - :ok, + def example_openai_response(mode, result) when mode in [:json, :md_json] do + %{ + "id" => "chatcmpl-8e9AVo9NHfvBG5cdtAEiJMm7q4Htz", + "usage" => %{ + "completion_tokens" => 23, + "prompt_tokens" => 136, + "total_tokens" => 159 + }, + "choices" => [ %{ - "id" => "chatcmpl-8e9AVo9NHfvBG5cdtAEiJMm7q4Htz", - "usage" => %{ - "completion_tokens" => 23, - "prompt_tokens" => 136, - "total_tokens" => 159 - }, - "choices" => [ - %{ - "finish_reason" => "stop", - "index" => 0, - "logprobs" => nil, - "message" => %{ - "content" => Jason.encode!(result), - "role" => "assistant" - } - } - ], - "model" => "gpt-3.5-turbo-0613", - "object" => "chat.completion", - "created" => 1_704_579_055, - "system_fingerprint" => nil + "finish_reason" => "stop", + "index" => 0, + "logprobs" => nil, + "message" => %{ + "content" => Jason.encode!(result), + "role" => "assistant" + } } - } + ], + "model" => "gpt-3.5-turbo-0613", + "object" => "chat.completion", + "created" => 1_704_579_055, + "system_fingerprint" => nil + } + end + + def mock_openai_response(mode, result) do + InstructorTest.MockOpenAI + |> expect(:prompt, &Instructor.Adapters.OpenAI.prompt/1) + |> expect(:chat_completion, fn _prompt, _params, _config -> + {:ok, example_openai_response(mode, result)} end) end @@ -114,7 +112,8 @@ defmodule Instructor.TestHelpers do ] InstructorTest.MockOpenAI - |> expect(:chat_completion, fn _params, _config -> + |> expect(:prompt, &Instructor.Adapters.OpenAI.prompt/1) + |> expect(:chat_completion, fn _prompt, _params, _config -> chunks end) end @@ -160,7 +159,8 @@ defmodule Instructor.TestHelpers do ] InstructorTest.MockOpenAI - |> expect(:chat_completion, fn _params, _config -> + |> expect(:prompt, &Instructor.Adapters.OpenAI.prompt/1) + |> expect(:chat_completion, fn _prompt, _params, _config -> chunks end) end