Skip to content

Commit 8e7e5ba

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Add sparse tensors constructed via legacy constructor to _sparse_tensors_to_validate (pytorch#147759)
This is a redo of pytorch#147408 which added validation at the end of the legacy constructor calls. The reason why I didn't land that was because in `legacy_load`, constructor would be called before storages of indices/values are set. So the tensor would not actually be validated. Technically, torch.sparse.{Foo}Tensor should not even be called by our rebuild process since afaict this was the first PR that added support for sparse tensor serialization pytorch#27062 and it already uses `_rebuild_sparse_tensor` (which would add the rebuilt tensor to the list to validate), but torch.sparse.FooTensor is allowlisted This PR adds tensors constructed as such to the list to validate at the end of torch.load. Pull Request resolved: pytorch#147759 Approved by: https://github.com/albanD
1 parent c82c141 commit 8e7e5ba

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

test/test_serialization.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,40 @@ def __reduce_ex__(self, proto):
443443
"size is inconsistent with indices"):
444444
y = torch.load(f, weights_only=weights_only)
445445

446+
def test_serialization_sparse_invalid_legacy_ctor(self):
447+
# This is set in test class setup but would not be check when running user code
448+
prev_invariant_check_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled()
449+
try:
450+
torch.sparse.check_sparse_tensor_invariants.disable()
451+
x = torch.zeros(3, 3)
452+
x[1][1] = 1
453+
x = x.to_sparse()
454+
x_legacy_ctor = torch.sparse.FloatTensor(x.indices(), x.values())
455+
456+
# technically legacy ctor will still always be rebuilt with _rebuild_sparse_tensor
457+
# this is to test that legacy ctor in data.pkl will be validated by weights_only unpickler
458+
class LegacyCtorSerializationSpoofer:
459+
def __init__(self, tensor):
460+
self.tensor = tensor
461+
462+
def __reduce_ex__(self, proto):
463+
indices = self.tensor._indices()
464+
indices[0][0] = 3
465+
return (torch.sparse.FloatTensor, (indices, self.tensor._values(), self.tensor.size()))
466+
467+
with tempfile.NamedTemporaryFile() as f:
468+
sd = {"spoofed_legacy_ctor": LegacyCtorSerializationSpoofer(x_legacy_ctor)}
469+
torch.save(sd, f)
470+
for weights_only in (True,):
471+
f.seek(0)
472+
with self.assertRaisesRegex(
473+
RuntimeError,
474+
"size is inconsistent with indices"):
475+
y = torch.load(f, weights_only=weights_only)
476+
finally:
477+
if prev_invariant_check_enabled:
478+
torch.sparse.check_sparse_tensor_invariants.enable()
479+
446480
def _test_serialization_sparse_compressed_invalid(self,
447481
conversion,
448482
get_compressed_indices,

torch/_weights_only_unpickler.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
from typing import Any, Callable, Union
7272

7373
import torch
74-
from torch._utils import IMPORT_MAPPING, NAME_MAPPING
74+
from torch._utils import _sparse_tensors_to_validate, IMPORT_MAPPING, NAME_MAPPING
7575

7676

7777
# modules in this list are never allowed, even if the user attempts to allowlist
@@ -387,7 +387,10 @@ def load(self):
387387
cls in _get_user_allowed_globals().values()
388388
or cls in _get_allowed_globals().values()
389389
):
390-
self.append(cls.__new__(cls, *args))
390+
result = cls.__new__(cls, *args)
391+
if cls in torch._tensor_classes and "sparse" in cls.__module__:
392+
_sparse_tensors_to_validate.append(result)
393+
self.append(result)
391394
else:
392395
raise UnpicklingError(
393396
"Can only create new object for nn.Parameter or classes allowlisted "
@@ -403,7 +406,10 @@ def load(self):
403406
raise UnpicklingError(
404407
f"Trying to call reduce for unrecognized function {func}"
405408
)
406-
self.stack[-1] = func(*args)
409+
result = func(*args)
410+
if func in torch._tensor_classes and "sparse" in func.__module__:
411+
_sparse_tensors_to_validate.append(result)
412+
self.stack[-1] = result
407413
elif key[0] == BUILD[0]:
408414
state = self.stack.pop()
409415
inst = self.stack[-1]

0 commit comments

Comments
 (0)