Skip to content

Commit ac88a58

Browse files
authored
Merge pull request #112 from seanwevans/codex/update-non-tch-branch-of-perform_computation
Handle JSON output for non-tch Ki gradient updates
2 parents d3b4926 + 9519295 commit ac88a58

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

src/ki_node.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,13 @@ async fn perform_computation(task: Task) -> Result<Task, Box<dyn Error>> {
187187
}
188188
#[cfg(not(feature = "tch"))]
189189
{
190+
let input: Vec<f32> = serde_json::from_str(&task.data)?;
191+
let grad: Vec<f32> = input.into_iter().map(|v| v * 2.0).collect();
192+
let data = serde_json::to_string(&grad)?;
190193
Ok(Task {
191194
task_id: task.task_id,
192195
task_type: TaskType::GradientUpdate,
193-
data: format!("Processed data: {}", task.data),
196+
data,
194197
})
195198
}
196199
}
@@ -223,6 +226,8 @@ async fn send_result(result: Task, channel: &Channel) -> Result<(), Box<dyn Erro
223226
mod tests {
224227
use super::*;
225228
use crate::common::{Task, TaskType};
229+
#[cfg(not(feature = "tch"))]
230+
use crate::an_node::AnNodeState;
226231

227232
#[test]
228233
fn test_update_model_parameters() {
@@ -249,6 +254,34 @@ mod tests {
249254
assert_eq!(result.task_id, task.task_id);
250255
}
251256

257+
#[cfg(not(feature = "tch"))]
258+
#[tokio::test]
259+
async fn test_perform_computation_gradient_update_non_tch() {
260+
let task = Task {
261+
task_id: uuid::Uuid::new_v4(),
262+
task_type: TaskType::GradientUpdate,
263+
data: serde_json::to_string(&vec![1.0_f32, -0.5_f32]).unwrap(),
264+
};
265+
let result = perform_computation(task.clone()).await.unwrap();
266+
assert_eq!(result.task_id, task.task_id);
267+
assert_eq!(result.task_type, TaskType::GradientUpdate);
268+
let gradient: Vec<f32> = serde_json::from_str(&result.data).unwrap();
269+
assert_eq!(gradient, vec![2.0_f32, -1.0_f32]);
270+
}
271+
272+
#[cfg(not(feature = "tch"))]
273+
#[tokio::test]
274+
async fn test_an_node_processes_ki_payload() {
275+
let task = Task {
276+
task_id: uuid::Uuid::new_v4(),
277+
task_type: TaskType::GradientUpdate,
278+
data: serde_json::to_string(&vec![0.25_f32, 0.5_f32]).unwrap(),
279+
};
280+
let result = perform_computation(task).await.unwrap();
281+
let mut state = AnNodeState::new();
282+
assert!(state.process_task(result, None, 1).await.is_ok());
283+
}
284+
252285
#[cfg(feature = "integration-tests")]
253286
#[tokio::test]
254287
async fn test_send_result_publishes_message() {

0 commit comments

Comments
 (0)