Skip to content

Commit edbdbfb

Browse files
Save XOR weights in .ptd
Differential Revision: D68785514 Pull Request resolved: pytorch#8014
1 parent c5fea7e commit edbdbfb

File tree

3 files changed

+11
-1
lines changed

3 files changed

+11
-1
lines changed

extension/training/examples/XOR/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def define_common_targets():
1717
"//executorch/runtime/executor:program",
1818
"//executorch/extension/data_loader:file_data_loader",
1919
"//executorch/kernels/portable:generated_lib",
20+
"//executorch/extension/flat_tensor/serialize:serialize_cpp"
2021
],
2122
external_deps = ["gflags"],
2223
define_static_target = True,

extension/training/examples/XOR/test/test_export.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
class TestXORExport(unittest.TestCase):
1515
def test(self):
16-
_ = _export_model()
16+
ep = _export_model()
17+
self.assertTrue(ep is not None)
1718
# Expect that we reach this far without an exception being thrown.
1819
self.assertTrue(True)

extension/training/examples/XOR/train.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/extension/data_loader/file_data_loader.h>
10+
#include <executorch/extension/flat_tensor/serialize/serialize.h>
1011
#include <executorch/extension/tensor/tensor.h>
1112
#include <executorch/extension/training/module/training_module.h>
1213
#include <executorch/extension/training/optimizer/sgd.h>
@@ -105,4 +106,11 @@ int main(int argc, char** argv) {
105106
}
106107
optimizer.step(mod.named_gradients("forward").get());
107108
}
109+
std::map<std::string, exec_aten::Tensor> param_map;
110+
for (auto& param : param_res.get()) {
111+
param_map.insert(std::pair<std::string, exec_aten::Tensor>{
112+
std::string(param.first.data()), param.second});
113+
}
114+
115+
executorch::extension::flat_tensor::save_ptd("xor.ptd", param_map, 16);
108116
}

0 commit comments

Comments
 (0)