@@ -187,10 +187,13 @@ async fn perform_computation(task: Task) -> Result<Task, Box<dyn Error>> {
187187 }
188188 #[ cfg( not( feature = "tch" ) ) ]
189189 {
190+ let input: Vec < f32 > = serde_json:: from_str ( & task. data ) ?;
191+ let grad: Vec < f32 > = input. into_iter ( ) . map ( |v| v * 2.0 ) . collect ( ) ;
192+ let data = serde_json:: to_string ( & grad) ?;
190193 Ok ( Task {
191194 task_id : task. task_id ,
192195 task_type : TaskType :: GradientUpdate ,
193- data : format ! ( "Processed data: {}" , task . data ) ,
196+ data,
194197 } )
195198 }
196199 }
@@ -223,6 +226,8 @@ async fn send_result(result: Task, channel: &Channel) -> Result<(), Box<dyn Erro
223226mod tests {
224227 use super :: * ;
225228 use crate :: common:: { Task , TaskType } ;
229+ #[ cfg( not( feature = "tch" ) ) ]
230+ use crate :: an_node:: AnNodeState ;
226231
227232 #[ test]
228233 fn test_update_model_parameters ( ) {
@@ -249,6 +254,34 @@ mod tests {
249254 assert_eq ! ( result. task_id, task. task_id) ;
250255 }
251256
257+ #[ cfg( not( feature = "tch" ) ) ]
258+ #[ tokio:: test]
259+ async fn test_perform_computation_gradient_update_non_tch ( ) {
260+ let task = Task {
261+ task_id : uuid:: Uuid :: new_v4 ( ) ,
262+ task_type : TaskType :: GradientUpdate ,
263+ data : serde_json:: to_string ( & vec ! [ 1.0_f32 , -0.5_f32 ] ) . unwrap ( ) ,
264+ } ;
265+ let result = perform_computation ( task. clone ( ) ) . await . unwrap ( ) ;
266+ assert_eq ! ( result. task_id, task. task_id) ;
267+ assert_eq ! ( result. task_type, TaskType :: GradientUpdate ) ;
268+ let gradient: Vec < f32 > = serde_json:: from_str ( & result. data ) . unwrap ( ) ;
269+ assert_eq ! ( gradient, vec![ 2.0_f32 , -1.0_f32 ] ) ;
270+ }
271+
272+ #[ cfg( not( feature = "tch" ) ) ]
273+ #[ tokio:: test]
274+ async fn test_an_node_processes_ki_payload ( ) {
275+ let task = Task {
276+ task_id : uuid:: Uuid :: new_v4 ( ) ,
277+ task_type : TaskType :: GradientUpdate ,
278+ data : serde_json:: to_string ( & vec ! [ 0.25_f32 , 0.5_f32 ] ) . unwrap ( ) ,
279+ } ;
280+ let result = perform_computation ( task) . await . unwrap ( ) ;
281+ let mut state = AnNodeState :: new ( ) ;
282+ assert ! ( state. process_task( result, None , 1 ) . await . is_ok( ) ) ;
283+ }
284+
252285 #[ cfg( feature = "integration-tests" ) ]
253286 #[ tokio:: test]
254287 async fn test_send_result_publishes_message ( ) {
0 commit comments