Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/adaptive/src/acg_component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ pub(crate) fn create_acg_llm_request_intercept(
provider: String,
plugin: Arc<dyn ProviderPlugin>,
) -> LlmRequestInterceptFn {
Box::new(move |_name: &str, request: LlmRequest, annotated| {
Arc::new(move |_name: &str, request: LlmRequest, annotated| {
let translated =
translate_request(&request, &agent_id, &provider, plugin.as_ref(), &hot_cache)
.unwrap_or(request);
Expand Down
2 changes: 1 addition & 1 deletion crates/adaptive/src/adaptive_hints_intercept.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ impl AdaptiveHintsIntercept {
/// transformed request.
pub fn into_request_fn(self) -> LlmRequestInterceptFn {
let this = Arc::new(self);
Box::new(
Arc::new(
move |_name: &str, mut request: LlmRequest, annotated: Option<AnnotatedLlmRequest>| {
let scope_path = extract_scope_path();
let manual_ls = read_manual_latency_sensitivity();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ impl Plugin for HeaderPlugin {
"header_plugin",
priority,
false,
Box::new(|_name, mut request, annotated| {
Arc::new(|_name, mut request, annotated| {
request.headers.insert("x-plugin".into(), json!("set"));
Ok((request, annotated))
}),
Expand All @@ -727,7 +727,7 @@ impl Plugin for HeaderPlugin {
"tool_request_plugin",
priority,
false,
Box::new(|_name, mut args| {
Arc::new(|_name, mut args| {
if let Json::Object(ref mut map) = args {
map.insert("x-tool-plugin".into(), json!(true));
}
Expand Down
210 changes: 120 additions & 90 deletions crates/adaptive/tests/unit/runtime_features_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@ use std::sync::Arc;

use nemo_flow::api::llm::LlmRequest;
use nemo_flow::api::llm::llm_request_intercepts;
use nemo_flow::api::registry::{
deregister_llm_execution_intercept, deregister_llm_request_intercept,
deregister_llm_stream_execution_intercept, deregister_tool_execution_intercept,
register_llm_execution_intercept, register_llm_request_intercept,
register_llm_stream_execution_intercept, register_tool_execution_intercept,
};
use nemo_flow::api::runtime::NemoFlowContextState;
use nemo_flow::api::runtime::ToolExecutionNextFn;
use nemo_flow::api::runtime::global_context;
use nemo_flow::api::subscriber::{deregister_subscriber, register_subscriber};
use nemo_flow::api::tool::tool_call_execute;
use nemo_flow::error::FlowError;
use nemo_flow::plugin::{ConfigPolicy, UnsupportedBehavior};
use nemo_flow::plugin::{clear_plugin_configuration, rollback_registrations};
use serde_json::json;
Expand Down Expand Up @@ -55,6 +63,100 @@ fn sample_plan(agent_id: &str) -> ExecutionPlan {
}
}

fn assert_already_registered(result: nemo_flow::error::Result<()>, name: &str) {
match result {
Err(FlowError::AlreadyExists(message)) => assert!(message.contains(name)),
other => panic!("expected {name} to be registered, got {other:?}"),
}
}

fn assert_subscriber_registered(name: &str) {
assert_already_registered(register_subscriber(name, Arc::new(|_event| {})), name);
}

fn assert_subscriber_absent(name: &str) {
register_subscriber(name, Arc::new(|_event| {})).unwrap();
deregister_subscriber(name).unwrap();
}

fn assert_llm_request_intercept_registered(name: &str) {
assert_already_registered(
register_llm_request_intercept(
name,
i32::MAX,
false,
Arc::new(|_name, request, annotated| Ok((request, annotated))),
),
name,
);
}

fn assert_llm_request_intercept_absent(name: &str) {
register_llm_request_intercept(
name,
i32::MAX,
false,
Arc::new(|_name, request, annotated| Ok((request, annotated))),
)
.unwrap();
deregister_llm_request_intercept(name).unwrap();
}

fn assert_llm_execution_intercept_registered(name: &str) {
assert_already_registered(
register_llm_execution_intercept(
name,
i32::MAX,
Arc::new(|_name, request, next| next(request)),
),
name,
);
}

fn assert_llm_execution_intercept_absent(name: &str) {
register_llm_execution_intercept(
name,
i32::MAX,
Arc::new(|_name, request, next| next(request)),
)
.unwrap();
deregister_llm_execution_intercept(name).unwrap();
}

fn assert_llm_stream_execution_intercept_registered(name: &str) {
assert_already_registered(
register_llm_stream_execution_intercept(
name,
i32::MAX,
Arc::new(|_name, request, next| next(request)),
),
name,
);
}

fn assert_llm_stream_execution_intercept_absent(name: &str) {
register_llm_stream_execution_intercept(
name,
i32::MAX,
Arc::new(|_name, request, next| next(request)),
)
.unwrap();
deregister_llm_stream_execution_intercept(name).unwrap();
}

fn assert_tool_execution_intercept_registered(name: &str) {
assert_already_registered(
register_tool_execution_intercept(name, i32::MAX, Arc::new(|_name, args, next| next(args))),
name,
);
}

fn assert_tool_execution_intercept_absent(name: &str) {
register_tool_execution_intercept(name, i32::MAX, Arc::new(|_name, args, next| next(args)))
.unwrap();
deregister_tool_execution_intercept(name).unwrap();
}

struct SeedFailBackend;

impl StorageBackendDyn for SeedFailBackend {
Expand Down Expand Up @@ -201,22 +303,10 @@ async fn telemetry_feature_registers_subscriber_and_starts_drain_task() {
ctx.finish()
};
assert!(runtime.drain_handle.is_some());
assert!(
global_context()
.read()
.unwrap()
.event_subscribers
.contains_key(&name)
);
assert_subscriber_registered(&name);

rollback_registrations(&mut registrations);
assert!(
!global_context()
.read()
.unwrap()
.event_subscribers
.contains_key(&name)
);
assert_subscriber_absent(&name);

if let Some(handle) = runtime.drain_handle.take() {
handle.abort();
Expand Down Expand Up @@ -284,13 +374,7 @@ async fn adaptive_hints_feature_registers_request_intercept() {

let mut ctx = RegistrationContext::new(&mut runtime);
feature.register(&mut ctx).await.unwrap();
assert!(
global_context()
.read()
.unwrap()
.llm_request_intercepts
.contains(&name)
);
assert_llm_request_intercept_registered(&name);

let request = llm_request_intercepts(
"model",
Expand All @@ -304,13 +388,7 @@ async fn adaptive_hints_feature_registers_request_intercept() {

let mut registrations = ctx.finish();
rollback_registrations(&mut registrations);
assert!(
!global_context()
.read()
.unwrap()
.llm_request_intercepts
.contains(&name)
);
assert_llm_request_intercept_absent(&name);
}

#[tokio::test(flavor = "current_thread")]
Expand Down Expand Up @@ -343,13 +421,7 @@ async fn tool_parallelism_feature_registers_execution_intercept() {

let mut ctx = RegistrationContext::new(&mut runtime);
feature.register(&mut ctx).await.unwrap();
assert!(
global_context()
.read()
.unwrap()
.tool_execution_intercepts
.contains(&name)
);
assert_tool_execution_intercept_registered(&name);

let next: ToolExecutionNextFn = Arc::new(|args| Box::pin(async move { Ok(args) }));
let result = tool_call_execute(
Expand All @@ -365,13 +437,7 @@ async fn tool_parallelism_feature_registers_execution_intercept() {

let mut registrations = ctx.finish();
rollback_registrations(&mut registrations);
assert!(
!global_context()
.read()
.unwrap()
.tool_execution_intercepts
.contains(&name)
);
assert_tool_execution_intercept_absent(&name);
}

#[tokio::test(flavor = "current_thread")]
Expand Down Expand Up @@ -481,7 +547,7 @@ async fn registration_context_registers_all_supported_callback_types() {
"adaptive_test_request",
5,
false,
Box::new(|_name, request, annotated| Ok((request, annotated))),
Arc::new(|_name, request, annotated| Ok((request, annotated))),
)
.unwrap();
ctx.register_llm_execution_intercept(
Expand Down Expand Up @@ -515,34 +581,11 @@ async fn registration_context_registers_all_supported_callback_types() {
.unwrap();

let mut registrations = ctx.finish();
let global = global_context();
let state = global.read().unwrap();
assert!(
state
.event_subscribers
.contains_key("adaptive_test_subscriber")
);
assert!(
state
.llm_request_intercepts
.contains("adaptive_test_request")
);
assert!(
state
.llm_execution_intercepts
.contains("adaptive_test_execution")
);
assert!(
state
.llm_stream_execution_intercepts
.contains("adaptive_test_stream")
);
assert!(
state
.tool_execution_intercepts
.contains("adaptive_test_tool")
);
drop(state);
assert_subscriber_registered("adaptive_test_subscriber");
assert_llm_request_intercept_registered("adaptive_test_request");
assert_llm_execution_intercept_registered("adaptive_test_execution");
assert_llm_stream_execution_intercept_registered("adaptive_test_stream");
assert_tool_execution_intercept_registered("adaptive_test_tool");

rollback_registrations(&mut registrations);
}
Comment thread
willkill07 marked this conversation as resolved.
Expand Down Expand Up @@ -622,14 +665,13 @@ async fn acg_feature_registers_execution_and_stream_intercepts() {
let mut ctx = RegistrationContext::new(&mut runtime);
feature.register(&mut ctx).await.unwrap();

let global = global_context();
let state = global.read().unwrap();
assert!(state.llm_execution_intercepts.contains(&execution_name));
assert!(state.llm_stream_execution_intercepts.contains(&stream_name));
drop(state);
assert_llm_execution_intercept_registered(&execution_name);
assert_llm_stream_execution_intercept_registered(&stream_name);

let mut registrations = ctx.finish();
rollback_registrations(&mut registrations);
assert_llm_execution_intercept_absent(&execution_name);
assert_llm_stream_execution_intercept_absent(&stream_name);
}

#[tokio::test(flavor = "current_thread")]
Expand Down Expand Up @@ -658,20 +700,8 @@ async fn adaptive_runtime_register_feature_rolls_back_partial_registrations_and_
assert!(!runtime.registered);
assert!(runtime.drain_handle.is_none());
assert!(runtime.registrations.is_empty());
assert!(
!global_context()
.read()
.unwrap()
.event_subscribers
.contains_key("existing_feature")
);
assert!(
!global_context()
.read()
.unwrap()
.event_subscribers
.contains_key("partial_feature")
);
assert_subscriber_absent("existing_feature");
assert_subscriber_absent("partial_feature");
}

#[tokio::test(flavor = "current_thread")]
Expand Down
2 changes: 1 addition & 1 deletion crates/cli/tests/coverage/server_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ async fn pre_tool_hook_rejects_when_conditional_guardrail_blocks() {
register_tool_conditional_execution_guardrail(
"cli-pre-tool-blocker",
1,
Box::new(|name, _args| {
Arc::new(|name, _args| {
Ok((name == BLOCKED_TEST_TOOL).then(|| "blocked by policy".to_string()))
}),
)
Expand Down
6 changes: 3 additions & 3 deletions crates/core/src/api/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result<Json> {
};
if let Some(error) = NemoFlowContextState::llm_conditional_execution_snapshot_chain(
&request,
entries,
&entries,
&subscribers,
parent_uuid,
guardrail_metadata,
Expand Down Expand Up @@ -680,7 +680,7 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu
};
if let Some(error) = NemoFlowContextState::llm_conditional_execution_snapshot_chain(
&request,
entries,
&entries,
&subscribers,
parent_uuid,
guardrail_metadata,
Expand Down Expand Up @@ -822,7 +822,7 @@ pub fn llm_conditional_execution(request: &LlmRequest) -> Result<()> {
};
if let Some(error) = NemoFlowContextState::llm_conditional_execution_snapshot_chain(
request,
entries,
&entries,
&subscribers,
parent_uuid,
None,
Expand Down
Loading
Loading