1
1
{-# LANGUAGE BlockArguments #-}
2
+ {-# LANGUAGE DeriveAnyClass #-}
2
3
{-# LANGUAGE DerivingVia #-}
3
4
{-# LANGUAGE TypeOperators #-}
4
5
{-# LANGUAGE NoStarIsType #-}
5
6
6
7
module ZkFold.ArithmeticCircuit.Context where
7
8
8
- import Control.Applicative (liftA2 , pure )
9
+ import Control.Applicative (liftA2 , pure , (<*>) )
9
10
import Control.DeepSeq (NFData , NFData1 , liftRnf , rnf , rwhnf )
10
11
import Control.Monad.State (State , modify , runState , state )
11
12
import Data.Aeson ((.:) , (.=) )
12
13
import qualified Data.Aeson.Types as Aeson
13
14
import Data.Binary (Binary )
14
- import Data.Bool (Bool (.. ), otherwise , (&&) )
15
+ import Data.Bool (Bool (.. ), (&&) )
15
16
import Data.ByteString (ByteString )
16
17
import Data.Either (Either (.. ))
17
- import Data.Eq ((==) )
18
+ import Data.Eq (Eq , (==) )
18
19
import Data.Foldable (Foldable , fold , foldl' , for_ , toList )
19
20
import Data.Function (flip , ($) , (.) )
20
21
import Data.Functor (Functor , fmap , (<$>) , (<&>) )
21
22
import Data.Functor.Classes (Show1 , liftShowList , liftShowsPrec )
22
23
import Data.Functor.Rep
23
- import Data.Kind (Type )
24
24
import Data.List.Infinite (Infinite )
25
25
import qualified Data.List.Infinite as I
26
26
import Data.Map (Map )
@@ -38,23 +38,21 @@ import qualified Data.Set as S
38
38
import Data.Traversable (Traversable , traverse )
39
39
import Data.Tuple (fst , snd , uncurry )
40
40
import Data.Type.Equality (type (~ ))
41
- import Data.Typeable (Typeable )
42
41
import GHC.Generics (Generic , Par1 (.. ), U1 (.. ), (:*:) (.. ))
43
42
import Optics (over , set , zoom )
44
43
import Text.Show
45
- import qualified Type.Reflection as R
46
44
import Prelude (error , seq )
47
45
48
46
import ZkFold.Algebra.Class
49
47
import ZkFold.Algebra.Number
50
48
import ZkFold.Algebra.Polynomial.Multivariate (Poly , var )
51
- import ZkFold.ArithmeticCircuit.Lookup (FunctionId (.. ), LookupType (.. ))
52
49
import ZkFold.ArithmeticCircuit.MerkleHash (MerkleHash (.. ), merkleHash , runHash )
53
50
import ZkFold.ArithmeticCircuit.Var
54
51
import ZkFold.ArithmeticCircuit.Witness (WitnessF (.. ))
55
52
import ZkFold.ArithmeticCircuit.WitnessEstimation (Partial (.. ), UVar (.. ))
56
53
import ZkFold.Control.HApplicative (HApplicative , hliftA2 , hpure )
57
54
import ZkFold.Data.Binary (fromByteString , toByteString )
55
+ import ZkFold.Data.FromList (FromList , fromList )
58
56
import ZkFold.Data.HFunctor (HFunctor , hmap )
59
57
import ZkFold.Data.HFunctor.Classes
60
58
import ZkFold.Data.Package (Package , packWith , unpackWith )
@@ -90,10 +88,8 @@ data CircuitFold a
90
88
instance NFData a => NFData (CircuitFold a ) where
91
89
rnf CircuitFold {.. } = rnf foldCount `seq` liftRnf rnf foldSeed
92
90
93
- data LookupFunction a
94
- = forall f g .
95
- (Representable f , Traversable g , Typeable f , Typeable g , Binary (Rep f )) =>
96
- LookupFunction (forall x . PrimeField x => f x -> g x )
91
+ newtype LookupFunction a = LookupFunction
92
+ { runLookupFunction :: forall x . (PrimeField x , Algebra a x ) => [x ] -> [x ]}
97
93
98
94
instance NFData (LookupFunction a ) where
99
95
rnf = rwhnf
@@ -102,31 +98,34 @@ type FunctionRegistry a = Map ByteString (LookupFunction a)
102
98
103
99
appendFunction
104
100
:: forall f g a
105
- . (Representable f , Typeable f , Binary (Rep f ))
106
- => (Traversable g , Typeable g , Arithmetic a )
107
- => (forall x . PrimeField x => f x -> g x )
101
+ . (Representable f , FromList f , Binary (Rep f ))
102
+ => (Foldable g , Arithmetic a , Binary a )
103
+ => (forall x . ( PrimeField x , Algebra a x ) => f x -> g x )
108
104
-> FunctionRegistry a
109
- -> (FunctionId ( f a -> g a ) , FunctionRegistry a )
105
+ -> (ByteString , FunctionRegistry a )
110
106
appendFunction f r =
111
107
let functionId = runHash @ (Just (Order a )) $ sum (f $ tabulate merkleHash)
112
- in (FunctionId functionId, M. insert functionId (LookupFunction f) r)
113
-
114
- lookupFunction
115
- :: forall f g (a :: Type )
116
- . (Typeable f , Typeable g )
117
- => FunctionRegistry a
118
- -> FunctionId (f a -> g a )
119
- -> (forall x . PrimeField x => f x -> g x )
120
- lookupFunction m (FunctionId i) = case m M. ! i of
121
- LookupFunction f -> cast1 . f . cast1
122
- where
123
- cast1 :: forall h k b . (Typeable h , Typeable k ) => h b -> k b
124
- cast1 x
125
- | Just R. HRefl <- th `R.eqTypeRep` tk = x
126
- | otherwise = error " types are not equal"
127
- where
128
- th = R. typeRep :: R. TypeRep h
129
- tk = R. typeRep :: R. TypeRep k
108
+ in (functionId, M. insert functionId (LookupFunction (toList . \ x -> f $ fromList x)) r)
109
+
110
+ data LookupType a
111
+ = LTRanges (Set (a , a ))
112
+ | LTProduct (LookupType a ) (LookupType a )
113
+ | LTPlot ByteString (LookupType a )
114
+ deriving
115
+ ( Aeson. FromJSON
116
+ , Aeson. FromJSONKey
117
+ , Aeson. ToJSON
118
+ , Aeson. ToJSONKey
119
+ , Eq
120
+ , Generic
121
+ , NFData
122
+ , Ord
123
+ , Show
124
+ )
125
+
126
+ asRange :: LookupType a -> Maybe (Set (a , a ))
127
+ asRange (LTRanges rs) = Just rs
128
+ asRange _ = Nothing
130
129
131
130
-- | Circuit context in the form of a system of polynomial constraints.
132
131
data CircuitContext a o = CircuitContext
@@ -352,10 +351,10 @@ instance
352
351
zoom # acSystem . modify $
353
352
M. insert (witToVar (p at)) (p $ evalVar var)
354
353
355
- lookupConstraint vars lt = do
354
+ lookupConstraint vars ltable = do
356
355
vs <- traverse prepare (toList vars)
357
- zoom # acLookup . modify $
358
- MM. insertWith S. union ( LookupType lt) (S. singleton vs)
356
+ lt <- lookupType ltable
357
+ zoom # acLookup . modify $ MM. insertWith S. union lt (S. singleton vs)
359
358
pure ()
360
359
where
361
360
prepare (LinVar k x b) | k == one && b == zero = pure x
@@ -367,7 +366,16 @@ instance
367
366
constraint (($ toVar v) - ($ src))
368
367
pure v
369
368
370
- registerFunction f = zoom # acLookupFunction $ state (appendFunction f)
369
+ -- | Translates a lookup table into a lookup type,
370
+ -- storing all lookup functions in a circuit.
371
+ lookupType
372
+ :: (Arithmetic a , Binary a )
373
+ => LookupTable a f -> State (CircuitContext a o ) (LookupType a )
374
+ lookupType (Ranges rs) = pure (LTRanges rs)
375
+ lookupType (Product t u) = LTProduct <$> lookupType t <*> lookupType u
376
+ lookupType (Plot f t) = do
377
+ funcId <- zoom # acLookupFunction $ state (appendFunction f)
378
+ LTPlot funcId <$> lookupType t
371
379
372
380
-- | Generates new variable index given a witness for it.
373
381
--
0 commit comments