@@ -22,10 +22,36 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
2222
2323 // Load a tensor that precisely matches the graph input tensor (see
2424 // `fixture/frozen_inference_graph.xml`).
25- print ! ( "Set input tensor ..." ) ;
25+ print ! ( "Load input tensor ..." ) ;
2626 let input_dims = vec ! [ 1 , 3 , 224 , 224 ] ;
2727 let tensor_data = fs:: read ( tensor_name) . unwrap ( ) ;
28- context. set_input ( 0 , TensorType :: F32 , & input_dims, tensor_data) ?;
28+ println ! ( "done" ) ;
29+
30+ print ! ( "Transpose input tensor ..." ) ;
31+ // Transpose from [height, width, 3] to [3, height, width]
32+ // For the historical reasons, the input tensor is in the format of [height, width, channels],
33+ // but the graph expects it in the format of [channels, height, width].
34+ // The input tensor is 224x224x3, so we need to transpose it
35+ // to 3x224x224.
36+ let height = 224 ;
37+ let width = 224 ;
38+ println ! ( "tensor_data.len() = {}" , tensor_data. len( ) ) ;
39+ let mut transposed: Vec < u8 > = vec ! [ 0 ; tensor_data. len( ) ] ;
40+ for ch in 0 ..3 {
41+ for y in 0 ..height {
42+ for x in 0 ..width {
43+ let loc = y * height + x;
44+ for b in 0 ..4 {
45+ transposed[ ( ch * width * height + loc) * 4 + b as usize ] =
46+ tensor_data[ ( loc * 3 + ch) * 4 + b as usize ] ;
47+ }
48+ }
49+ }
50+ }
51+ println ! ( "done" ) ;
52+
53+ print ! ( "Set input tensor ..." ) ;
54+ context. set_input ( 0 , TensorType :: F32 , & input_dims, transposed) ?;
2955 println ! ( "done" ) ;
3056
3157 print ! ( "Perform graph inference ..." ) ;
0 commit comments