diff --git a/install/install_requirements.sh b/install/install_requirements.sh index 6344509d8..7c0d74f7a 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -47,10 +47,10 @@ fi # NOTE: If a newly-fetched version of the executorch repo changes the value of # PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -PYTORCH_NIGHTLY_VERSION=dev20241002 +PYTORCH_NIGHTLY_VERSION=dev20241028 # Nightly version for torchvision -VISION_NIGHTLY_VERSION=dev20241002 +VISION_NIGHTLY_VERSION=dev20241028 # Nightly version for torchtune TUNE_NIGHTLY_VERSION=dev20241010 diff --git a/torchchat/export.py b/torchchat/export.py index 7c5243b68..9399c2add 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -122,7 +122,7 @@ def export_for_server( from executorch.exir.tracer import Value from torch._export import capture_pre_autograd_graph - from torch.export import export, ExportedProgram + from torch.export import export_for_training, ExportedProgram from torchchat.model import apply_rotary_emb, Attention from torchchat.utils.build_utils import get_precision @@ -238,7 +238,7 @@ def _to_core_aten( raise ValueError( f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}" ) - core_aten_ep = export(model, example_inputs, dynamic_shapes=dynamic_shapes) + core_aten_ep = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shapes) if verbose: logging.info(f"Core ATen graph:\n{core_aten_ep.graph}") return core_aten_ep