Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 67f7011

Browse files
committed
Fix AOTI runner
1 parent 15cf464 commit 67f7011

File tree

2 files changed

+9
-15
lines changed

2 files changed

+9
-15
lines changed

.github/workflows/pull.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,11 @@ jobs:
927927
run: |
928928
echo "Intalling pip3 packages"
929929
./install/install_requirements.sh
930+
931+
# Install ET
932+
source ./torchchat/utils/scripts/install_utils.sh
933+
install_executorch_python_libs
934+
930935
pip3 list
931936
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
932937
- name: Set ET git sha

runner/run.cpp

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@ LICENSE file in the root directory of this source tree.
3131
#endif
3232

3333
#ifdef __AOTI_MODEL__
34-
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
35-
#ifdef USE_CUDA
36-
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
37-
#endif
34+
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
3835
torch::Device aoti_device(torch::kCPU);
3936

4037
#else // __ET_MODEL__
@@ -94,7 +91,7 @@ typedef struct {
9491
RunState state; // buffers for the "wave" of activations in the forward pass
9592

9693
#ifdef __AOTI_MODEL__
97-
torch::inductor::AOTIModelContainerRunner* runner;
94+
torch::inductor::AOTIModelPackageLoader* runner;
9895
#else // __ET_MODEL__
9996
Module* runner;
10097
#endif
@@ -144,16 +141,8 @@ void build_transformer(
144141
malloc_run_state(&t->state, &t->config);
145142

146143
#ifdef __AOTI_MODEL__
147-
#ifdef USE_CUDA
148-
if (aoti_device.type() == torch::kCUDA) {
149-
t->runner = new torch::inductor::AOTIModelContainerRunnerCuda(model_path);
150-
aoti_device = torch::Device(torch::kCUDA);
151-
} else {
152-
#else
153-
{
154-
#endif
155-
t->runner = new torch::inductor::AOTIModelContainerRunnerCpu(model_path);
156-
}
144+
t->runner = new torch::inductor::AOTIModelPackageLoader(model_path);
145+
aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu" ? torch::Device(torch::kCPU) : torch::Device(torch::kCUDA);
157146
#else //__ET_MODEL__
158147
t->runner = new Module(
159148
/* path to PTE model */ model_path,

0 commit comments

Comments
 (0)