Skip to content

Commit ddb1e97

Browse files
Revert "Support torchbind in OSS proxy executor (pytorch#149747)"
This reverts commit aa70d62. Reverted pytorch#149747 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#149747 (comment)))
1 parent 2f785ab commit ddb1e97

File tree

5 files changed

+83
-359
lines changed

5 files changed

+83
-359
lines changed

test/inductor/test_torchbind.py

Lines changed: 13 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def test_torchbind_hop_schema(self):
128128
schema = CallTorchBind.schema(foo_ir, "add")
129129
self.assertEqual(
130130
str(schema),
131-
"call_torchbind(__torch__.torch.classes._TorchScriptTesting._Foo _0, str method, int _1) -> int _0",
131+
"call_torchbind(__torch__.torch.classes._TorchScriptTesting._Foo obj, str method, int _1) -> int _0",
132132
)
133133

134134
def test_torchbind_config_not_generated(self):
@@ -146,7 +146,7 @@ def test_torchbind_hop_schema_no_input(self):
146146
schema = CallTorchBind.schema(q_ir, "pop")
147147
self.assertEqual(
148148
str(schema),
149-
"call_torchbind(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0, str method) -> Tensor _0",
149+
"call_torchbind(__torch__.torch.classes._TorchScriptTesting._TensorQueue obj, str method) -> Tensor _0",
150150
)
151151

152152
def test_torchbind_hop_schema_no_output(self):
@@ -155,7 +155,7 @@ def test_torchbind_hop_schema_no_output(self):
155155
schema = CallTorchBind.schema(q_ir, "push")
156156
self.assertEqual(
157157
str(schema),
158-
"call_torchbind(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0, str method, Tensor _1) -> NoneType _0",
158+
"call_torchbind(__torch__.torch.classes._TorchScriptTesting._TensorQueue obj, str method, Tensor _1) -> NoneType _0",
159159
)
160160

161161
def test_torchbind_aot_compile(self):
@@ -250,7 +250,7 @@ def test_torchbind_aot_compile(self):
250250
"target": "call_torchbind",
251251
"inputs": [
252252
{
253-
"name": "_0",
253+
"name": "obj",
254254
"arg": {
255255
"as_custom_obj": {
256256
"name": "_torchbind_obj0",
@@ -293,20 +293,15 @@ def test_torchbind_aot_compile(self):
293293
self.assertTrue((tmp_path_model / "custom_objs_config.json").exists())
294294
self.assertTrue((tmp_path_constants / "custom_obj_0").exists())
295295

296-
def test_torchbind_aoti(self):
297-
ep, inputs, orig_res, _ = self.get_exported_model()
298-
pt2_path = torch._inductor.aoti_compile_and_package(ep)
299-
optimized = torch._inductor.aoti_load_package(pt2_path)
300-
result = optimized(*inputs)
301-
self.assertEqual(result, orig_res)
296+
# TODO: add accuracy test after we support loading and running compiled models with
297+
# torchbind objects.
302298

303299
@torch._inductor.config.patch("aot_inductor.use_runtime_constant_folding", True)
304300
def test_torchbind_aot_compile_constant_folding(self):
305-
ep, inputs, orig_res, _ = self.get_exported_model()
306-
pt2_path = torch._inductor.aoti_compile_and_package(ep)
307-
optimized = torch._inductor.aoti_load_package(pt2_path)
308-
result = optimized(*inputs)
309-
self.assertEqual(result, orig_res)
301+
ep, inputs, _, _ = self.get_exported_model()
302+
aot_compile(ep.module(), inputs, options={"aot_inductor.package": True})
303+
# TODO: add accuracy test after we support loading and running compiled models with
304+
# torchbind objects.
310305

311306
def test_torchbind_list_return_aot_compile(self):
312307
class M(torch.nn.Module):
@@ -322,48 +317,15 @@ def forward(self, x):
322317

323318
m = M()
324319
inputs = (torch.ones(2, 3),)
325-
orig_res = m(*inputs)
326320

327321
# We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
328322
with enable_torchbind_tracing():
329323
ep = torch.export.export(m, inputs, strict=False)
330324

331-
pt2_path = torch._inductor.aoti_compile_and_package(ep)
332-
optimized = torch._inductor.aoti_load_package(pt2_path)
333-
result = optimized(*inputs)
334-
self.assertEqual(result, orig_res)
335-
336-
def test_torchbind_queue(self):
337-
class Foo(torch.nn.Module):
338-
def __init__(self, tq) -> None:
339-
super().__init__()
340-
self.tq = tq
341-
342-
def forward(self, x):
343-
self.tq.push(x.cos())
344-
self.tq.push(x.sin())
345-
# TODO: int return type in fallback kernel not support yet
346-
x_cos = self.tq.pop() # + self.tq.size()
347-
x_sin = self.tq.pop() # - self.tq.size()
348-
return x_sin, x_cos
349-
350-
inputs = (torch.randn(3, 2),)
351-
352-
q = _empty_tensor_queue()
353-
m = Foo(q)
354-
orig_res = m(*inputs)
355-
356-
q2 = _empty_tensor_queue()
357-
m2 = Foo(q2)
358-
359-
# We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
360-
with enable_torchbind_tracing():
361-
ep = torch.export.export(m2, inputs, strict=False)
325+
aot_compile(ep.module(), inputs, options={"aot_inductor.package": True})
362326

363-
pt2_path = torch._inductor.aoti_compile_and_package(ep)
364-
optimized = torch._inductor.aoti_load_package(pt2_path)
365-
result = optimized(*inputs)
366-
self.assertEqual(result, orig_res)
327+
# TODO: add accuracy test after we support loading and running compiled models with
328+
# torchbind objects.
367329

368330
@requires_gpu()
369331
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)

torch/_higher_order_ops/torchbind.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def schema(obj, method) -> torch.FunctionSchema:
4242
val = obj.get_real_obj()
4343
schema = val._get_method(method).schema
4444
schema_str = str(schema)
45-
new_schema_str = f"call_torchbind({str(schema.arguments[0].real_type)} {schema.arguments[0].name},"
45+
new_schema_str = (
46+
"call_torchbind(" + str(schema.arguments[0].real_type) + " obj,"
47+
)
4648
first_comma_index = schema_str.find(",")
4749
if first_comma_index == -1:
4850
# If no comma is found, find the last closing parenthesis

torch/csrc/inductor/aoti_package/model_package_loader.cpp

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#if !defined(C10_MOBILE) && !defined(ANDROID)
22

33
#include <c10/util/error.h>
4-
#include <c10/util/string_view.h>
54
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
65
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
76
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
@@ -63,6 +62,7 @@ const std::string k_separator = "\\";
6362
#else
6463
const std::string k_separator = "/";
6564
#endif
65+
6666
} // namespace
6767

6868
namespace torch::inductor {
@@ -187,7 +187,7 @@ bool recursive_mkdir(const std::string& dir) {
187187
}
188188

189189
// Find folder separator and check if we are at the top
190-
auto pos = dir.find_last_of(k_separator);
190+
auto pos = dir.find_last_of("/\\");
191191
if (pos == std::string::npos) {
192192
return false;
193193
}
@@ -372,7 +372,6 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
372372
std::string found_filenames; // Saving for bookkeeping
373373
std::string model_directory =
374374
"data" + k_separator + "aotinductor" + k_separator + model_name;
375-
std::string const_directory = "data" + k_separator + "constants";
376375

377376
for (uint32_t i = 0; i < zip_archive.m_total_files; i++) {
378377
uint32_t filename_len =
@@ -390,30 +389,14 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
390389
found_filenames += " ";
391390

392391
// Only compile files in the specified model directory
393-
if (c10::starts_with(filename_str, model_directory) ||
394-
c10::starts_with(filename_str, const_directory)) {
392+
if (filename_str.length() >= model_directory.length() &&
393+
filename_str.substr(0, model_directory.length()) == model_directory) {
395394
std::string output_path_str = temp_dir_;
396-
397-
if (c10::starts_with(filename_str, model_directory)) {
398-
output_path_str += k_separator;
399-
output_path_str += filename_str;
400-
} else { // startsWith(filename_str, const_directory)
401-
// Extract constants to the same directory as the rest of the files
402-
// to be consistent with internal implementation
403-
size_t lastSlash = filename_str.find_last_of(k_separator);
404-
std::string filename = filename_str;
405-
if (lastSlash != std::string::npos) {
406-
filename = filename_str.substr(lastSlash + 1);
407-
}
408-
output_path_str +=
409-
k_separator + model_directory + k_separator + filename;
410-
}
411-
412-
LOG(INFO) << "Extract file: " << filename_str << " to "
413-
<< output_path_str;
395+
output_path_str += k_separator;
396+
output_path_str += filename_str;
414397

415398
// Create the parent directory if it doesn't exist
416-
size_t parent_path_idx = output_path_str.find_last_of(k_separator);
399+
size_t parent_path_idx = output_path_str.find_last_of("/\\");
417400
if (parent_path_idx == std::string::npos) {
418401
throw std::runtime_error(
419402
"Failed to find parent path in " + output_path_str);

0 commit comments

Comments
 (0)