2121import uuid
2222import re
2323import copy
24- from typing import Union , Optional , Tuple , List
24+ from typing import overload , Union , Optional , Tuple , List , FrozenSet
2525from functools import lru_cache
2626
2727
@@ -118,12 +118,12 @@ class Expression(object, metaclass=XSlotted, abstract=True):
118118
119119 __xslots__ : Tuple [str , ...] = ("_taint" ,)
120120
121- def __init__ (self , * , taint : Union [ tuple , frozenset ] = ()):
121+ def __init__ (self , * , taint : FrozenSet [ str ] = frozenset ()):
122122 """
123123 An abstract Unmutable Taintable Expression
124- :param taint: A frozenzset
124+ :param taint: A frozenzset of taints. Normally strings
125125 """
126- self ._taint = frozenset ( taint )
126+ self ._taint = taint
127127 super ().__init__ ()
128128
129129 def __repr__ (self ):
@@ -132,15 +132,15 @@ def __repr__(self):
132132 )
133133
134134 @property
135- def is_tainted (self ):
136- return len (self ._taint ) != 0
135+ def is_tainted (self ) -> bool :
136+ return bool (self ._taint )
137137
138138 @property
139- def taint (self ):
139+ def taint (self ) -> FrozenSet [ str ] :
140140 return self ._taint
141141
142142 @property
143- def operands (self ):
143+ def operands (self ) -> Tuple [ "Expression" ] :
144144 """ Hack so we can use any Expression as a node """
145145 return ()
146146
@@ -165,7 +165,7 @@ def __init__(self, *, name: str, **kwargs):
165165 self ._name = name
166166
167167 @property
168- def name (self ):
168+ def name (self ) -> str :
169169 return self ._name
170170
171171 def __repr__ (self ):
@@ -229,7 +229,7 @@ def cast(self, value: Union["Bool", int, bool], **kwargs) -> Union["BoolConstant
229229 """ Cast any type into a Bool or fail """
230230 if isinstance (value , Bool ):
231231 return value
232- return BoolConstant (bool (value ), ** kwargs )
232+ return BoolConstant (value = bool (value ), ** kwargs )
233233
234234 def __cmp__ (self , * args ):
235235 raise NotImplementedError ("CMP for Bool" )
@@ -263,7 +263,7 @@ def __ror__(self, other):
263263 def __rxor__ (self , other ):
264264 return BoolXor (self .cast (other ), self )
265265
266- def __dbool__ (self ):
266+ def __bool__ (self ):
267267 raise ExpressionError (
268268 "You tried to use a Bool Expression as a boolean constant. Expressions could represent a set of concrete values."
269269 )
@@ -274,26 +274,26 @@ class BoolVariable(Bool, Variable):
274274
275275
276276class BoolConstant (Bool , Constant ):
277- def __init__ (self , value : bool , ** kwargs ):
277+ def __init__ (self , * , value : bool , ** kwargs ):
278278 super ().__init__ (value = bool (value ), ** kwargs )
279279
280- def __bool__ (self ):
280+ def __bool__ (self ) -> bool :
281281 return self ._value
282282
283283
284284class BoolOperation (Bool , Operation , abstract = True ):
285285 """ It's an operation that results in a Bool """
286+ pass
287+ #def __init__(self, *args, **kwargs):
288+ # super().__init__(*args, **kwargs)
286289
287- def __init__ (self , * args , ** kwargs ):
288- super ().__init__ (* args , ** kwargs )
289-
290- def __xbool__ (self ):
291- # FIXME: TODO: re-think is we want to be this forgiving every use of
292- # local_simplify looks hacky
293- simplified = self # local_simplify(self)
294- if isinstance (simplified , Constant ):
295- return simplified .value
296- raise ExpressionError ("BoolOperation can not be reduced to a constant" )
290+ #def __xbool__(self):
291+ # # FIXME: TODO: re-think is we want to be this forgiving every use of
292+ # # local_simplify looks hacky
293+ # simplified = self # local_simplify(self)
294+ # if isinstance(simplified, Constant):
295+ # return simplified.value
296+ # raise ExpressionError("BoolOperation can not be reduced to a constant")
297297
298298
299299class BoolNot (BoolOperation ):
@@ -335,15 +335,15 @@ def __init__(self, size: int, **kwargs):
335335 self ._size = size
336336
337337 @property
338- def size (self ):
338+ def size (self ) -> int :
339339 return self ._size
340340
341341 @property
342- def mask (self ):
342+ def mask (self ) -> int :
343343 return (1 << self .size ) - 1
344344
345345 @property
346- def signmask (self ):
346+ def signmask (self ) -> int :
347347 return 1 << (self .size - 1 )
348348
349349 def cast (self , value : Union ["BitVec" , str , int , bytes ], ** kwargs ) -> "BitVec" :
@@ -545,21 +545,24 @@ def __init__(self, size: int, value: int, **kwargs):
545545 value &= (1 << size ) - 1 # Can not use self.mask yet
546546 super ().__init__ (size = size , value = value , ** kwargs )
547547
548- def __bool__ (self ):
549- return self .value != 0
548+ def __bool__ (self ) -> bool :
549+ return bool ( self .value )
550550
551- def __int__ (self ):
551+ def __int__ (self ) -> int :
552552 return self .value
553553
554554 @property
555- def signed_value (self ):
555+ def signed_value (self ) -> int :
556556 """ Gives signed python int representation """
557557 if self ._value & self .signmask :
558558 return self ._value - (1 << self .size )
559559 else :
560560 return self ._value
561561
562562 def __eq__ (self , other ):
563+ """ If tainted keep a tainted symbolic value"""
564+ if self .taint :
565+ return BoolEqual (self , self .cast (other ))
563566 # Ignore the taint for eq comparison
564567 return self ._value == other
565568
@@ -569,12 +572,13 @@ def __repr__(self):
569572
570573class BitVecOperation (BitVec , Operation , abstract = True ):
571574 """ Operations that result in a BitVec """
572-
573- pass
575+ def __init__ ( self , * , operands : Tuple [ BitVec , ...], ** kwargs ):
576+ super (). __init__ ( operands = operands , ** kwargs )
574577
575578
576579class BitVecAdd (BitVecOperation ):
577- def __init__ (self , operanda , operandb , ** kwargs ):
580+ def __init__ (self , operanda :BitVec , operandb :BitVec , ** kwargs ):
581+ assert operanda .size == operandb .size
578582 super ().__init__ (size = operanda .size , operands = (operanda , operandb ), ** kwargs )
579583
580584
@@ -670,9 +674,17 @@ def __init__(self, operanda: BitVec, operandb: BitVec, **kwargs):
670674
671675
672676class BoolEqual (BoolOperation ):
677+ @overload
673678 def __init__ (self , operanda : BitVec , operandb : BitVec , ** kwargs ):
674- assert isinstance (operanda , Expression )
675- assert isinstance (operandb , Expression )
679+ ...
680+ @overload
681+ def __init__ (self , operanda : Bool , operandb : Bool , ** kwargs ):
682+ ...
683+ @overload
684+ def __init__ (self , operanda : "Array" , operandb : "Array" , ** kwargs ):
685+ ...
686+
687+ def __init__ (self , operanda , operandb , ** kwargs ):
676688 super ().__init__ (operands = (operanda , operandb ), ** kwargs )
677689
678690
@@ -707,7 +719,7 @@ def __init__(self, operanda, operandb, **kwargs):
707719 operands = (operanda , operandb ), ** kwargs
708720 )
709721
710-
722+ Array = "Array"
711723class Array (Expression , abstract = True ):
712724 """An Array expression is an unmutable mapping from bitvector to bitvector
713725
@@ -719,30 +731,30 @@ class Array(Expression, abstract=True):
719731
720732 @property
721733 @abstractmethod
722- def index_size (self ):
734+ def index_size (self ) -> int :
723735 """ The bit size of the index part. Must be overloaded by a more specific class"""
724736 ...
725737
726738 @property
727- def value_size (self ):
739+ def value_size (self ) -> int :
728740 """ The bit size of the value part. Must be overloaded by a more specific class"""
729741 raise NotImplementedError
730742
731743 @property
732- def length (self ):
744+ def length (self ) -> int :
733745 """ Number of defined items. Must be overloaded by a more specific class"""
734746 raise NotImplementedError
735747
736- def select (self , index ):
748+ def select (self , index ) -> Union [ BitVec , int ] :
737749 """ Gets a bitvector element from the Array que la"""
738750 raise NotImplementedError
739751
740- def store (self , index , value ):
752+ def store (self , index , value ) -> "Array" :
741753 """ Create a new array that contains the updated value"""
742754 raise NotImplementedError
743755
744756 @property
745- def default (self ):
757+ def default (self ) -> Optional [ Union [ BitVec , int ]] :
746758 """If defined, reading from an uninitialized index return the default value.
747759 Otherwise, reading from an uninitialized index gives a symbol (normal Array behavior)
748760 """
@@ -817,18 +829,24 @@ def cast_value(self, value: Union[BitVec, bytes, int]) -> BitVec:
817829 value = int (value )
818830 return BitVecConstant (self .value_size , value )
819831
820- def write (self , offset , buf ) :
832+ def write (self , offset : Union [ BitVec , int ], buf : Union [ "Array" , bytes ]) -> "Array" :
821833 """Builds a new unmutable Array instance on top of current array by
822834 writing buf at offset"""
823835 array = self
824836 for i , value in enumerate (buf ):
825837 array = array .store (offset + i , value )
826838 return array
827839
828- def read (self , offset , size ) :
840+ def read (self , offset : Union [ BitVec , int ], size : int ) -> "Array" :
829841 """ A projection of the current array. """
830842 return ArraySlice (self , offset = offset , size = size )
831843
844+ @overload
845+ def __getitem__ (self , index : Union [BitVec , int ]) -> Union [BitVec , int ]:
846+ ...
847+ @overload
848+ def __getitem__ (self , index : slice ) -> "Array" :
849+ ...
832850 def __getitem__ (self , index ):
833851 """__getitem__ allows for pythonic access
834852 A = ArrayVariable(index_size=32, value_size=8)
@@ -846,15 +864,15 @@ def __iter__(self):
846864 yield self [i ]
847865
848866 @staticmethod
849- def _compare_buffers (a , b ) :
867+ def _compare_buffers (a : "Array" , b : "Array" ) -> Union [ Bool , bool ] :
850868 """ Builds an expression that represents equality between the two arrays."""
851869 if a .length != b .length :
852- return BoolConstant (False )
853- cond = BoolConstant (True )
870+ return BoolConstant (value = False )
871+ cond = BoolConstant (value = True )
854872 for i in range (a .length ):
855873 cond = BoolAnd (cond .cast (a [i ] == b [i ]), cond )
856- if cond is BoolConstant (False ):
857- return BoolConstant (False )
874+ if cond is BoolConstant (value = False ):
875+ return BoolConstant (value = False )
858876 return cond
859877
860878 def __eq__ (self , other ):
@@ -881,14 +899,14 @@ def _fix_slice(self, index: slice):
881899 raise ExpressionError ("Size could not be simplified to a constant in a slice operation" )
882900 return start , stop , size .value
883901
884- def _concatenate (self , array_a , array_b ):
902+ def _concatenate (self , array_a :"Array" , array_b :"Array" ) -> "Array" :
903+ """Build a new array from the concatenation of the operands"""
885904 new_arr = ArrayVariable (
886905 index_size = self .index_size ,
887906 length = len (array_a ) + len (array_b ),
888907 value_size = self .value_size ,
889908 name = "concatenation" ,
890909 )
891-
892910 for index in range (len (array_a )):
893911 new_arr = new_arr .store (index , local_simplify (array_a [index ]))
894912 for index in range (len (array_b )):
@@ -902,22 +920,22 @@ def __radd__(self, other):
902920 return self ._concatenate (other , self )
903921
904922 @lru_cache (maxsize = 128 , typed = True )
905- def read_BE (self , address , size ) :
923+ def read_BE (self , address : Union [ int , BitVec ], size : int ) -> Union [ BitVec , int ] :
906924 address = self .cast_index (address )
907925 bytes = []
908926 for offset in range (size ):
909927 bytes .append (self .cast_value (self [address + offset ]))
910928 return BitVecConcat (operands = tuple (bytes ))
911929
912930 @lru_cache (maxsize = 128 , typed = True )
913- def read_LE (self , address , size ) :
931+ def read_LE (self , address : Union [ int , BitVec ], size : int ) -> Union [ BitVec , int ] :
914932 address = self .cast_index (address )
915933 bytes = []
916934 for offset in range (size ):
917935 bytes .append (self .get (address + offset , self ._default ))
918936 return BitVecConcat (operands = reversed (bytes ))
919937
920- def write_BE (self , address , value , size ) :
938+ def write_BE (self , address : Union [ int , BitVec ], value : Union [ int , BitVec ], size : int ) -> Array :
921939 address = self .cast_index (address )
922940 value = BitVecConstant (size = size * self .value_size , value = 0 ).cast (value )
923941 array = self
@@ -928,7 +946,7 @@ def write_BE(self, address, value, size):
928946 )
929947 return array
930948
931- def write_LE (self , address , value , size ) :
949+ def write_LE (self , address : Union [ int , BitVec ], value : Union [ int , BitVec ], size : int ) -> Array :
932950 address = self .cast_index (address )
933951 value = BitVec (size * self .value_size ).cast (value )
934952 array = self
@@ -951,15 +969,15 @@ def __init__(
951969 super ().__init__ (** kwargs )
952970
953971 @property
954- def index_size (self ):
972+ def index_size (self ) -> int :
955973 return self ._index_size
956974
957975 @property
958- def value_size (self ):
976+ def value_size (self ) -> int :
959977 return self ._value_size
960978
961979 @property
962- def length (self ):
980+ def length (self ) -> int :
963981 return len (self .value )
964982
965983 def select (self , index ):
@@ -1018,7 +1036,7 @@ def length(self):
10181036 return self ._length
10191037
10201038 def __init__ (
1021- self ,
1039+ self , * ,
10221040 index_size : int ,
10231041 value_size : int ,
10241042 length : Optional [int ] = None ,
@@ -1157,14 +1175,14 @@ def written(self):
11571175
11581176 def is_known (self , index ):
11591177 if isinstance (index , Constant ) and index .value in self .concrete_cache :
1160- return BoolConstant (True )
1178+ return BoolConstant (value = True )
11611179
1162- is_known_index = BoolConstant (False )
1180+ is_known_index = BoolConstant (value = False )
11631181 written = self .written
11641182 for known_index in written :
11651183 if isinstance (index , Constant ) and isinstance (known_index , Constant ):
11661184 if known_index .value == index .value :
1167- return BoolConstant (True )
1185+ return BoolConstant (value = True )
11681186 is_known_index = BoolOr (is_known_index .cast (index == known_index ), is_known_index )
11691187 return is_known_index
11701188
0 commit comments