File tree Expand file tree Collapse file tree 3 files changed +11
-1
lines changed
extension/training/examples/XOR Expand file tree Collapse file tree 3 files changed +11
-1
lines changed Original file line number Diff line number Diff line change @@ -17,6 +17,7 @@ def define_common_targets():
1717 "//executorch/runtime/executor:program" ,
1818 "//executorch/extension/data_loader:file_data_loader" ,
1919 "//executorch/kernels/portable:generated_lib" ,
20+ "//executorch/extension/flat_tensor/serialize:serialize_cpp"
2021 ],
2122 external_deps = ["gflags" ],
2223 define_static_target = True ,
Original file line number Diff line number Diff line change 1313
1414class TestXORExport (unittest .TestCase ):
1515 def test (self ):
16- _ = _export_model ()
16+ ep = _export_model ()
17+ self .assertTrue (ep is not None )
1718 # Expect that we reach this far without an exception being thrown.
1819 self .assertTrue (True )
Original file line number Diff line number Diff line change 77 */
88
99#include < executorch/extension/data_loader/file_data_loader.h>
10+ #include < executorch/extension/flat_tensor/serialize/serialize.h>
1011#include < executorch/extension/tensor/tensor.h>
1112#include < executorch/extension/training/module/training_module.h>
1213#include < executorch/extension/training/optimizer/sgd.h>
@@ -105,4 +106,11 @@ int main(int argc, char** argv) {
105106 }
106107 optimizer.step (mod.named_gradients (" forward" ).get ());
107108 }
109+ std::map<std::string, exec_aten::Tensor> param_map;
110+ for (auto & param : param_res.get ()) {
111+ param_map.insert (std::pair<std::string, exec_aten::Tensor>{
112+ std::string (param.first .data ()), param.second });
113+ }
114+
115+ executorch::extension::flat_tensor::save_ptd (" xor.ptd" , param_map, 16 );
108116}
You can’t perform that action at this time.
0 commit comments