diff --git a/model/backbones/vit_pytorch.py b/model/backbones/vit_pytorch.py index 66cf849..16e928a 100644 --- a/model/backbones/vit_pytorch.py +++ b/model/backbones/vit_pytorch.py @@ -27,13 +27,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch._six import container_abcs - +import collections.abc # From PyTorch internals def _ntuple(n): def parse(x): - if isinstance(x, container_abcs.Iterable): + if isinstance(x, collections.abc.Iterable): return x return tuple(repeat(x, n)) return parse