This repository was archived by the owner on Sep 10, 2025. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +9
-15
lines changed Expand file tree Collapse file tree 2 files changed +9
-15
lines changed Original file line number Diff line number Diff line change @@ -927,6 +927,11 @@ jobs:
927927 run : |
928928 echo "Intalling pip3 packages"
929929 ./install/install_requirements.sh
930+
931+ # Install ET
932+ source ./torchchat/utils/scripts/install_utils.sh
933+ install_executorch_python_libs
934+
930935 pip3 list
931936 python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
932937 - name : Set ET git sha
Original file line number Diff line number Diff line change @@ -31,10 +31,7 @@ LICENSE file in the root directory of this source tree.
3131#endif
3232
3333#ifdef __AOTI_MODEL__
34- #include < torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
35- #ifdef USE_CUDA
36- #include < torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
37- #endif
34+ #include < torch/csrc/inductor/aoti_package/model_package_loader.h>
3835torch::Device aoti_device (torch::kCPU );
3936
4037#else // __ET_MODEL__
@@ -94,7 +91,7 @@ typedef struct {
9491 RunState state; // buffers for the "wave" of activations in the forward pass
9592
9693#ifdef __AOTI_MODEL__
97- torch::inductor::AOTIModelContainerRunner * runner;
94+ torch::inductor::AOTIModelPackageLoader * runner;
9895#else // __ET_MODEL__
9996 Module* runner;
10097#endif
@@ -144,16 +141,8 @@ void build_transformer(
144141 malloc_run_state (&t->state , &t->config );
145142
146143#ifdef __AOTI_MODEL__
147- #ifdef USE_CUDA
148- if (aoti_device.type () == torch::kCUDA ) {
149- t->runner = new torch::inductor::AOTIModelContainerRunnerCuda (model_path);
150- aoti_device = torch::Device (torch::kCUDA );
151- } else {
152- #else
153- {
154- #endif
155- t->runner = new torch::inductor::AOTIModelContainerRunnerCpu (model_path);
156- }
144+ t->runner = new torch::inductor::AOTIModelPackageLoader (model_path);
145+ aoti_device = t->runner ->get_metadata ()[" AOTI_DEVICE_KEY" ] == " cpu" ? torch::Device (torch::kCPU ) : torch::Device (torch::kCUDA );
157146#else // __ET_MODEL__
158147 t->runner = new Module (
159148 /* path to PTE model */ model_path,
You can’t perform that action at this time.
0 commit comments