11from typing import Any , List , Optional , Union , Tuple
22import onnx
33import torch
4+ from ...api import TensorLike
45from ...helpers import string_type
56from ...helpers .torch_helper import to_tensor
67
78
8- class OpRunValue :
9+ class OpRunValue (TensorLike ):
10+ """Defines a value for the runtime, a tensor or a sequence."""
11+
12+ __slots__ = ("cached" , "is_constant" , "sequence" , "tensor" )
13+
14+ @classmethod
15+ def is_sequence (cls ) -> bool :
16+ "Tells if it is sequence."
17+ raise NotImplementedError ("is_sequence must be overwritten." )
18+
19+
20+ class OpRunTensor (OpRunValue ):
921 """
1022 Wrapper around a tensor.
1123
@@ -15,10 +27,9 @@ class OpRunValue:
1527 more appropriate
1628 """
1729
18- __slots__ = ("cached" , "is_constant" , "tensor" )
19-
2030 def __init__ (self , tensor , is_constant : bool = False , may_cpu : bool = False ):
2131 assert isinstance (tensor , torch .Tensor ), f"Unexpected type { type (tensor )} "
32+ assert tensor is None or tensor .numel () > 1 or tensor .item () != - 666666
2233 self .tensor = (
2334 tensor .cpu ()
2435 if may_cpu
@@ -36,9 +47,9 @@ def is_sequence(cls) -> bool:
3647 "Tells if it is sequence."
3748 return False
3849
39- def to (self , to : Any ) -> "OpRunValue " :
50+ def to (self , to : Any ) -> "OpRunTensor " :
4051 "Changes the device."
41- return OpRunValue (self .tensor .to (to ))
52+ return OpRunTensor (self .tensor .to (to ))
4253
4354 def string_type (self ) -> str :
4455 "Returns information about the value as a string."
@@ -96,17 +107,29 @@ def as_tuple_int(self) -> Tuple[int, ...]:
96107 return self .cached
97108 return self ._tensor_as_tuple_int ()
98109
110+ def copy (self ) -> "OpRunTensor" :
111+ "Shallow copy."
112+ return self .__class__ (self .tensor )
99113
100- class OpRunValueSequence (OpRunValue ):
101- """Defines a sequence."""
102114
103- __slots__ = ("cached" , "is_constant" , "sequence" , "tensor" )
115+ class OpRunSequence (OpRunValue ):
116+ """Defines a sequence."""
104117
105118 def __init__ (
106119 self , sequence : Optional [List [torch .Tensor ]] = None , dtype : torch .dtype = torch .float32
107120 ):
108- super ().__init__ (torch .empty ((), dtype = dtype ), False , False )
121+ self .tensor = torch .tensor (- 666666 , dtype = dtype )
122+ self .is_shape = False
109123 self .sequence = sequence or []
124+ self .cached : Optional [Tuple [int , ...]] = None
125+ assert all (
126+ isinstance (s , torch .Tensor ) for s in self .sequence
127+ ), f"Unexpected type in sequence { [type (s ) for s in self .sequence ]} "
128+
129+ @property
130+ def dtype (self ):
131+ "dtype (torch dtype)"
132+ return self .tensor .dtype
110133
111134 @property
112135 def tensor_or_sequence (self ) -> Union [torch .Tensor , List [torch .Tensor ]]:
@@ -119,18 +142,23 @@ def is_sequence(cls) -> bool:
119142 return True
120143
121144 def insert_at (
122- self , tensor : OpRunValue , position : Optional [OpRunValue ] = None
123- ) -> "OpRunValueSequence " :
145+ self , tensor : torch . Tensor , position : Optional [OpRunTensor ] = None
146+ ) -> "OpRunSequence " :
124147 "Inserts a value at a given position."
125- new_seq = OpRunValueSequence ()
148+ assert isinstance (tensor , OpRunTensor ), f"Unexpected type { type (tensor )} for tensor"
149+ new_seq = OpRunSequence ()
126150 seq = self .sequence .copy ()
127151 new_seq .sequence = seq
128152 if position is None :
129- seq .append (tensor )
153+ seq .append (tensor . tensor )
130154 else :
131- seq .insert (int (position .tensor .item ()), tensor )
155+ seq .insert (int (position .tensor .item ()), tensor . tensor )
132156 return new_seq
133157
158+ def copy (self ) -> "OpRunSequence" :
159+ "Shallow copy."
160+ return self .__class__ (self .sequence , dtype = self .dtype )
161+
134162
135163class OpRun :
136164 """
0 commit comments