Skip to content

Commit 321f67d

Browse files
authored
Merge pull request #9 from theam/run-evaluations-from-ui
Run evaluations from UI and add Auth
2 parents 176532e + 1f40d7a commit 321f67d

File tree

36 files changed

+1126
-129
lines changed

36 files changed

+1126
-129
lines changed

apps/aifindr-evaluations-runner/evaluator.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from opik import Opik
1+
import logging
2+
from opik.api_objects.opik_client import get_client_cached
3+
from opik.config import update_session_config
24
from opik.evaluation import evaluate
35
from opik.evaluation.metrics import (Hallucination, ContextRecall, ContextPrecision)
46
from workflows import run_workflow
57
from metrics.follows_criteria import FollowsCriteria
68
from pydantic import BaseModel
79
from enum import Enum
810

9-
client = Opik()
11+
logger = logging.getLogger(__name__)
1012

1113
class ExperimentStatus(Enum):
1214
RUNNING = "running"
@@ -16,19 +18,30 @@ class ExperimentStatus(Enum):
1618

1719
class EvaluationParams(BaseModel):
1820
task_id: str
21+
workspace_name: str
1922
dataset_name: str
2023
experiment_name: str
2124
project_name: str
2225
base_prompt_name: str
2326
workflow: str
27+
api_key: str
2428

2529
def evaluation_task(dataset_item, workflow: str):
30+
# TODO: validate properly dataset_item so that no field is empty
31+
if not dataset_item['query']:
32+
logger.error("Trying to run workflow with an empty query")
33+
return {
34+
"input": "invalid-query",
35+
"output": "invalid-query",
36+
"context": [],
37+
}
38+
2639
response_content = run_workflow(workflow, dataset_item['query'])
2740

2841
# parsed_response = json.loads(response_content.response)
29-
# print(parsed_response)
30-
# print(parsed_response.keys())
31-
# print(parsed_response['text_response'])
42+
# logger.info("------> Response: ", parsed_response)
43+
# logger.info("------> Response keys: ", parsed_response.keys())
44+
# logger.info("------> Response text_response: ", parsed_response['text_response'])
3245

3346
result = {
3447
"input": dataset_item['query'],
@@ -42,8 +55,10 @@ def build_evaluation_task(params: EvaluationParams):
4255

4356

4457
def execute_evaluation(params: EvaluationParams):
58+
client = build_opik_client(params.workspace_name, params.api_key)
4559
dataset = client.get_dataset(name=params.dataset_name)
4660
base_prompt = client.get_prompt(name=params.base_prompt_name)
61+
4762
if not base_prompt:
4863
raise ValueError(f"No base prompt found with name '{params.base_prompt_name}'")
4964

@@ -63,3 +78,16 @@ def execute_evaluation(params: EvaluationParams):
6378
prompt=base_prompt,
6479
task_threads=20
6580
)
81+
82+
def build_opik_client(workspace_name: str, api_key: str):
83+
# Normally you would create the Opik cliekt with Opik(workspace=workspace_name, api_key=api_key)
84+
# However, the internal evaluate method gets the client from get_client_cached(), what creates
85+
# a client with the default workspace and without the api key. What is more, it is reused in every request
86+
87+
# So we use the functions update_session_config to set the parameters, as the Opik constructor can take them from there
88+
# Then, we clear the cache to remove any previous client with possible different workspace or api key.
89+
# Finally, we return the client created with the new parameters. Now "evaluate" will use this client we created.
90+
update_session_config('workspace', workspace_name)
91+
update_session_config('api_key', api_key)
92+
get_client_cached.cache_clear()
93+
return get_client_cached()

apps/aifindr-evaluations-runner/main.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from fastapi import FastAPI, HTTPException, Request
2+
from pydantic import BaseModel
13
import logging
24
import asyncio
35
import uuid
@@ -6,9 +8,6 @@
68
from urllib.parse import urljoin, urlparse
79
from typing import Dict
810

9-
from fastapi import FastAPI, HTTPException
10-
from pydantic import BaseModel
11-
1211
from settings import settings
1312
from evaluator import EvaluationParams, ExperimentStatus, execute_evaluation
1413

@@ -25,9 +24,10 @@
2524

2625

2726
class RunEvaluationsRequest(BaseModel):
27+
workspace_name: str
2828
dataset_name: str
2929
experiment_name: str
30-
project_name: str
30+
project_name: str | None = None
3131
base_prompt_name: str
3232
workflow: str
3333

@@ -113,23 +113,25 @@ async def health(timeout: int = 5):
113113

114114

115115
@app.post("/evaluations/run", response_model=RunEvaluationsResponse)
116-
async def run_evaluation(request: RunEvaluationsRequest):
116+
async def run_evaluation(input: RunEvaluationsRequest, req: Request):
117117
try:
118118
# Generate task ID
119119
task_id = str(uuid.uuid4())
120120
# Create EvaluationParams with all fields from request plus task_id
121121
evaluation_params = EvaluationParams(
122122
task_id=task_id,
123-
dataset_name=request.dataset_name,
124-
experiment_name=request.experiment_name,
125-
project_name=request.project_name,
126-
base_prompt_name=request.base_prompt_name,
127-
workflow=request.workflow,
123+
workspace_name=input.workspace_name,
124+
dataset_name=input.dataset_name,
125+
experiment_name=input.experiment_name,
126+
project_name=input.project_name,
127+
base_prompt_name=input.base_prompt_name,
128+
workflow=input.workflow,
129+
api_key=req.headers.get("Authorization")
128130
)
129131

130132
try:
131133
TASK_QUEUE.put_nowait(evaluation_params)
132-
logger.info(f"Evaluation task added to queue: {evaluation_params}")
134+
logger.info("Evaluation task added to queue")
133135
except asyncio.QueueFull:
134136
logger.error(
135137
f"Queue is full. Evaluation task not added to the queue: {evaluation_params}"

apps/aifindr-evaluations-runner/metrics/follows_criteria.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ def __init__(self, prompt_template: str, name: str = "Follows criteria", model_n
3535
self.name = name
3636
self.llm_client = models.LiteLLMChatModel(model_name=model_name)
3737
self.prompt_template = f"""
38+
# Instructions
3839
{prompt_template}
39-
-----
40+
41+
# Answer format
4042
Answer with a json with the following format:
4143
4244
{{{{
@@ -58,8 +60,6 @@ def score(self, output: str, criteria: str, **ignored_kwargs: Any):
5860
output=output,
5961
criteria=criteria
6062
)
61-
62-
print("Prompt total: ", prompt)
6363
# Generate and parse the response from the LLM
6464
response = self.llm_client.generate_string(input=prompt, response_format=FollowsCriteriaResult)
6565

apps/aifindr-evaluations-runner/workflows.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pydantic import BaseModel
88
from typing import Optional, List, Any
99

10-
MAX_RETRIES = 3
10+
MAX_RETRIES = 5
1111
RETRIEVAL_EVENT_ID_PREFIX = "similarity_search_by_text"
1212
LLM_EVENT_ID_PREFIX = "llm"
1313

@@ -35,7 +35,7 @@ def run_workflow(workflow: str, query: str) -> WorkflowResponse:
3535
return _make_workflow_request(workflow, query)
3636
except Exception as e:
3737
wait_time = 0.5 * (retry_count + 1) # Increasing delay between retries
38-
print(f"Request failed with error: {e}. Waiting {wait_time}s before retrying... ({retry_count + 1}/{MAX_RETRIES})")
38+
logger.warn(f"Request failed with error: {e}. Waiting {wait_time}s before retrying... ({retry_count + 1}/{MAX_RETRIES})")
3939
retry_count += 1
4040
time.sleep(wait_time)
4141

@@ -85,6 +85,7 @@ def _make_workflow_request(workflow: str, query: str) -> WorkflowResponse:
8585
try:
8686
data = json.loads(event.data)
8787
if event.id.startswith(RETRIEVAL_EVENT_ID_PREFIX):
88+
logger.debug(f"Retrieval response: {data}")
8889
retrieval_response = data['response']['hits']
8990
elif event.id.startswith(LLM_EVENT_ID_PREFIX) and 'content' in data['delta']['message']:
9091
llm_response += data['delta']['message']['content']

apps/opik-backend/config.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,14 @@ authentication:
129129
# Default:
130130
# Description: Configures how to authenticate requests which originates from the sdk
131131
sdk:
132-
url: ${AUTH_SDK_URL:-''}
132+
url: ${AUTH_DOMAIN:-''}/userinfo
133+
# AIFindr: Whether the URL is from auth0 (ended in /userinfo)
134+
isAuth0: ${AUTH_IS_AUTH0:-false}
133135
# Default:
134136
# Description: Configures how to authenticate requests which originates from the ui
135137
ui:
136-
url: ${AUTH_UI_URL:-''}
138+
url: ${AUTH_DOMAIN:-''}/userinfo
139+
isAuth0: false
137140

138141
# https://www.dropwizard.io/en/stable/manual/configuration.html#servers
139142
server:
@@ -285,3 +288,6 @@ clickHouseLogAppender:
285288
# Default: PT0.500S or 500ms
286289
# Description: Time interval after which the log messages are sent to ClickHouse if the batch size is not reached
287290
flushIntervalDuration: ${CLICKHOUSE_LOG_APPENDER_FLUSH_INTERVAL_DURATION:-PT0.500S}
291+
292+
experimentRunner:
293+
url: ${EXPERIMENT_RUNNER_URL:-''}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package com.comet.opik.api;
2+
3+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
4+
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
5+
import com.fasterxml.jackson.databind.annotation.JsonNaming;
6+
7+
import jakarta.validation.constraints.NotBlank;
8+
import lombok.Builder;
9+
10+
@Builder(toBuilder = true)
11+
@JsonIgnoreProperties(ignoreUnknown = true)
12+
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
13+
public record ExperimentRunRequest(
14+
@NotBlank String workspaceName,
15+
@NotBlank String datasetName,
16+
@NotBlank String experimentName,
17+
String projectName,
18+
@NotBlank String basePromptName,
19+
@NotBlank String workflow
20+
) {}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package com.comet.opik.api;
2+
3+
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
4+
import com.fasterxml.jackson.databind.annotation.JsonNaming;
5+
6+
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
7+
public record ExperimentRunResponse(
8+
String status,
9+
String taskId
10+
) {}

apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
package com.comet.opik.api.resources.v1.priv;
22

3+
import java.util.Collections;
4+
import java.util.Optional;
5+
import java.util.Set;
6+
import java.util.UUID;
7+
import java.util.stream.Collectors;
8+
9+
import org.glassfish.jersey.server.ChunkedOutput;
10+
311
import com.codahale.metrics.annotation.Timed;
412
import com.comet.opik.api.Experiment;
513
import com.comet.opik.api.ExperimentItem;
614
import com.comet.opik.api.ExperimentItemSearchCriteria;
715
import com.comet.opik.api.ExperimentItemStreamRequest;
816
import com.comet.opik.api.ExperimentItemsBatch;
917
import com.comet.opik.api.ExperimentItemsDelete;
18+
import com.comet.opik.api.ExperimentRunRequest;
19+
import com.comet.opik.api.ExperimentRunResponse;
1020
import com.comet.opik.api.ExperimentSearchCriteria;
1121
import com.comet.opik.api.ExperimentsDelete;
1222
import com.comet.opik.api.FeedbackDefinition;
@@ -15,14 +25,17 @@
1525
import com.comet.opik.api.resources.v1.priv.validate.IdParamsValidator;
1626
import com.comet.opik.domain.ExperimentItemService;
1727
import com.comet.opik.domain.ExperimentService;
28+
import com.comet.opik.domain.FeedbackScoreDAO.EntityType;
1829
import com.comet.opik.domain.FeedbackScoreService;
1930
import com.comet.opik.domain.IdGenerator;
2031
import com.comet.opik.domain.Streamer;
2132
import com.comet.opik.infrastructure.auth.RequestContext;
2233
import com.comet.opik.infrastructure.ratelimit.RateLimited;
2334
import com.comet.opik.utils.AsyncUtils;
35+
import static com.comet.opik.utils.AsyncUtils.setRequestContext;
2436
import com.fasterxml.jackson.annotation.JsonView;
2537
import com.fasterxml.jackson.databind.JsonNode;
38+
2639
import io.dropwizard.jersey.errors.ErrorMessage;
2740
import io.swagger.v3.oas.annotations.Operation;
2841
import io.swagger.v3.oas.annotations.headers.Header;
@@ -40,6 +53,7 @@
4053
import jakarta.ws.rs.Consumes;
4154
import jakarta.ws.rs.DefaultValue;
4255
import jakarta.ws.rs.GET;
56+
import jakarta.ws.rs.HeaderParam;
4357
import jakarta.ws.rs.POST;
4458
import jakarta.ws.rs.Path;
4559
import jakarta.ws.rs.PathParam;
@@ -52,16 +66,6 @@
5266
import lombok.NonNull;
5367
import lombok.RequiredArgsConstructor;
5468
import lombok.extern.slf4j.Slf4j;
55-
import org.glassfish.jersey.server.ChunkedOutput;
56-
57-
import java.util.Collections;
58-
import java.util.Optional;
59-
import java.util.Set;
60-
import java.util.UUID;
61-
import java.util.stream.Collectors;
62-
63-
import static com.comet.opik.domain.FeedbackScoreDAO.EntityType;
64-
import static com.comet.opik.utils.AsyncUtils.setRequestContext;
6569

6670
@Path("/v1/private/experiments")
6771
@Produces(MediaType.APPLICATION_JSON)
@@ -304,4 +308,29 @@ public Response findFeedbackScoreNames(@QueryParam("experiment_ids") String expe
304308

305309
return Response.ok(feedbackScoreNames).build();
306310
}
311+
312+
@POST
313+
@Path("/run")
314+
@Operation(operationId = "runExperiment",
315+
summary = "Run experiment",
316+
description = "Run experiment with specified dataset, project and workflow configuration",
317+
responses = {
318+
@ApiResponse(responseCode = "200", description = "Experiment run started",
319+
content = @Content(schema = @Schema(implementation = ExperimentRunResponse.class)))})
320+
@RateLimited
321+
public Response runExperiment(
322+
@RequestBody(content = @Content(schema = @Schema(implementation = ExperimentRunRequest.class)))
323+
@NotNull @Valid ExperimentRunRequest request,
324+
@HeaderParam("Authorization") String authorization) {
325+
326+
log.info("Running experiment {}", request);
327+
328+
ExperimentRunResponse response = experimentService.runExperiment(request, authorization)
329+
.contextWrite(ctx -> setRequestContext(ctx, requestContext))
330+
.block();
331+
332+
log.info("Experiment started. Status: '{}', Task ID: '{}'",
333+
response.status(), response.taskId());
334+
return Response.ok(response).build();
335+
}
307336
}

0 commit comments

Comments
 (0)