Skip to content

Commit f61b46d

Browse files
authored
Merge pull request #705 from zkFold/TurtlePU/lookup-tables
Split internal and external representation of `LookupTable`s
2 parents 35c75bd + 92063b8 commit f61b46d

File tree

12 files changed

+132
-198
lines changed

12 files changed

+132
-198
lines changed

symbolic-base/src/ZkFold/ArithmeticCircuit/Context.hs

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
{-# LANGUAGE BlockArguments #-}
2+
{-# LANGUAGE DeriveAnyClass #-}
23
{-# LANGUAGE DerivingVia #-}
34
{-# LANGUAGE TypeOperators #-}
45
{-# LANGUAGE NoStarIsType #-}
56

67
module ZkFold.ArithmeticCircuit.Context where
78

8-
import Control.Applicative (liftA2, pure)
9+
import Control.Applicative (liftA2, pure, (<*>))
910
import Control.DeepSeq (NFData, NFData1, liftRnf, rnf, rwhnf)
1011
import Control.Monad.State (State, modify, runState, state)
1112
import Data.Aeson ((.:), (.=))
1213
import qualified Data.Aeson.Types as Aeson
1314
import Data.Binary (Binary)
14-
import Data.Bool (Bool (..), otherwise, (&&))
15+
import Data.Bool (Bool (..), (&&))
1516
import Data.ByteString (ByteString)
1617
import Data.Either (Either (..))
17-
import Data.Eq ((==))
18+
import Data.Eq (Eq, (==))
1819
import Data.Foldable (Foldable, fold, foldl', for_, toList)
1920
import Data.Function (flip, ($), (.))
2021
import Data.Functor (Functor, fmap, (<$>), (<&>))
2122
import Data.Functor.Classes (Show1, liftShowList, liftShowsPrec)
2223
import Data.Functor.Rep
23-
import Data.Kind (Type)
2424
import Data.List.Infinite (Infinite)
2525
import qualified Data.List.Infinite as I
2626
import Data.Map (Map)
@@ -38,23 +38,21 @@ import qualified Data.Set as S
3838
import Data.Traversable (Traversable, traverse)
3939
import Data.Tuple (fst, snd, uncurry)
4040
import Data.Type.Equality (type (~))
41-
import Data.Typeable (Typeable)
4241
import GHC.Generics (Generic, Par1 (..), U1 (..), (:*:) (..))
4342
import Optics (over, set, zoom)
4443
import Text.Show
45-
import qualified Type.Reflection as R
4644
import Prelude (error, seq)
4745

4846
import ZkFold.Algebra.Class
4947
import ZkFold.Algebra.Number
5048
import ZkFold.Algebra.Polynomial.Multivariate (Poly, var)
51-
import ZkFold.ArithmeticCircuit.Lookup (FunctionId (..), LookupType (..))
5249
import ZkFold.ArithmeticCircuit.MerkleHash (MerkleHash (..), merkleHash, runHash)
5350
import ZkFold.ArithmeticCircuit.Var
5451
import ZkFold.ArithmeticCircuit.Witness (WitnessF (..))
5552
import ZkFold.ArithmeticCircuit.WitnessEstimation (Partial (..), UVar (..))
5653
import ZkFold.Control.HApplicative (HApplicative, hliftA2, hpure)
5754
import ZkFold.Data.Binary (fromByteString, toByteString)
55+
import ZkFold.Data.FromList (FromList, fromList)
5856
import ZkFold.Data.HFunctor (HFunctor, hmap)
5957
import ZkFold.Data.HFunctor.Classes
6058
import ZkFold.Data.Package (Package, packWith, unpackWith)
@@ -90,10 +88,8 @@ data CircuitFold a
9088
instance NFData a => NFData (CircuitFold a) where
9189
rnf CircuitFold {..} = rnf foldCount `seq` liftRnf rnf foldSeed
9290

93-
data LookupFunction a
94-
= forall f g.
95-
(Representable f, Traversable g, Typeable f, Typeable g, Binary (Rep f)) =>
96-
LookupFunction (forall x. PrimeField x => f x -> g x)
91+
newtype LookupFunction a = LookupFunction
92+
{runLookupFunction :: forall x. (PrimeField x, Algebra a x) => [x] -> [x]}
9793

9894
instance NFData (LookupFunction a) where
9995
rnf = rwhnf
@@ -102,31 +98,34 @@ type FunctionRegistry a = Map ByteString (LookupFunction a)
10298

10399
appendFunction
104100
:: forall f g a
105-
. (Representable f, Typeable f, Binary (Rep f))
106-
=> (Traversable g, Typeable g, Arithmetic a)
107-
=> (forall x. PrimeField x => f x -> g x)
101+
. (Representable f, FromList f, Binary (Rep f))
102+
=> (Foldable g, Arithmetic a, Binary a)
103+
=> (forall x. (PrimeField x, Algebra a x) => f x -> g x)
108104
-> FunctionRegistry a
109-
-> (FunctionId (f a -> g a), FunctionRegistry a)
105+
-> (ByteString, FunctionRegistry a)
110106
appendFunction f r =
111107
let functionId = runHash @(Just (Order a)) $ sum (f $ tabulate merkleHash)
112-
in (FunctionId functionId, M.insert functionId (LookupFunction f) r)
113-
114-
lookupFunction
115-
:: forall f g (a :: Type)
116-
. (Typeable f, Typeable g)
117-
=> FunctionRegistry a
118-
-> FunctionId (f a -> g a)
119-
-> (forall x. PrimeField x => f x -> g x)
120-
lookupFunction m (FunctionId i) = case m M.! i of
121-
LookupFunction f -> cast1 . f . cast1
122-
where
123-
cast1 :: forall h k b. (Typeable h, Typeable k) => h b -> k b
124-
cast1 x
125-
| Just R.HRefl <- th `R.eqTypeRep` tk = x
126-
| otherwise = error "types are not equal"
127-
where
128-
th = R.typeRep :: R.TypeRep h
129-
tk = R.typeRep :: R.TypeRep k
108+
in (functionId, M.insert functionId (LookupFunction (toList . \x -> f $ fromList x)) r)
109+
110+
data LookupType a
111+
= LTRanges (Set (a, a))
112+
| LTProduct (LookupType a) (LookupType a)
113+
| LTPlot ByteString (LookupType a)
114+
deriving
115+
( Aeson.FromJSON
116+
, Aeson.FromJSONKey
117+
, Aeson.ToJSON
118+
, Aeson.ToJSONKey
119+
, Eq
120+
, Generic
121+
, NFData
122+
, Ord
123+
, Show
124+
)
125+
126+
asRange :: LookupType a -> Maybe (Set (a, a))
127+
asRange (LTRanges rs) = Just rs
128+
asRange _ = Nothing
130129

131130
-- | Circuit context in the form of a system of polynomial constraints.
132131
data CircuitContext a o = CircuitContext
@@ -352,10 +351,10 @@ instance
352351
zoom #acSystem . modify $
353352
M.insert (witToVar (p at)) (p $ evalVar var)
354353

355-
lookupConstraint vars lt = do
354+
lookupConstraint vars ltable = do
356355
vs <- traverse prepare (toList vars)
357-
zoom #acLookup . modify $
358-
MM.insertWith S.union (LookupType lt) (S.singleton vs)
356+
lt <- lookupType ltable
357+
zoom #acLookup . modify $ MM.insertWith S.union lt (S.singleton vs)
359358
pure ()
360359
where
361360
prepare (LinVar k x b) | k == one && b == zero = pure x
@@ -367,7 +366,16 @@ instance
367366
constraint (($ toVar v) - ($ src))
368367
pure v
369368

370-
registerFunction f = zoom #acLookupFunction $ state (appendFunction f)
369+
-- | Translates a lookup table into a lookup type,
370+
-- storing all lookup functions in a circuit.
371+
lookupType
372+
:: (Arithmetic a, Binary a)
373+
=> LookupTable a f -> State (CircuitContext a o) (LookupType a)
374+
lookupType (Ranges rs) = pure (LTRanges rs)
375+
lookupType (Product t u) = LTProduct <$> lookupType t <*> lookupType u
376+
lookupType (Plot f t) = do
377+
funcId <- zoom #acLookupFunction $ state (appendFunction f)
378+
LTPlot funcId <$> lookupType t
371379

372380
-- | Generates new variable index given a witness for it.
373381
--

symbolic-base/src/ZkFold/ArithmeticCircuit/Desugaring.hs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ import Data.Tuple (fst, uncurry)
1616
import Prelude (error)
1717

1818
import ZkFold.Algebra.Class
19-
import ZkFold.ArithmeticCircuit.Context (CircuitContext, acLookup)
20-
import ZkFold.ArithmeticCircuit.Lookup (asRange)
19+
import ZkFold.ArithmeticCircuit.Context (CircuitContext, acLookup, asRange)
2120
import ZkFold.ArithmeticCircuit.Var (toVar)
2221
import ZkFold.Prelude (assert, length)
2322
import ZkFold.Symbolic.Class (Arithmetic)

symbolic-base/src/ZkFold/ArithmeticCircuit/Experimental.hs

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,13 @@ module ZkFold.ArithmeticCircuit.Experimental where
1010
import Control.Applicative (pure)
1111
import Control.DeepSeq (NFData (..), NFData1, liftRnf, rwhnf)
1212
import Control.Monad (unless)
13-
import Control.Monad.State (State, gets, modify', runState, state)
13+
import Control.Monad.State (State, gets, modify', runState)
1414
import Data.Binary (Binary)
15-
import Data.ByteString (ByteString)
1615
import Data.Eq (Eq (..))
1716
import Data.Foldable (Foldable (..), any, for_)
1817
import Data.Function (flip, on, ($), (.))
1918
import Data.Functor (Functor, fmap)
2019
import Data.Functor.Rep (Rep, Representable)
21-
import Data.Map (Map)
2220
import qualified Data.Map as M
2321
import qualified Data.Map.Monoidal as MM
2422
import Data.Maybe (Maybe (..))
@@ -40,16 +38,14 @@ import ZkFold.ArithmeticCircuit (ArithmeticCircuit, optimize, solder)
4038
import ZkFold.ArithmeticCircuit.Children (children)
4139
import ZkFold.ArithmeticCircuit.Context (
4240
CircuitContext,
43-
LookupFunction (LookupFunction),
4441
acLookup,
4542
acSystem,
4643
acWitness,
47-
appendFunction,
4844
crown,
4945
emptyContext,
46+
lookupType,
5047
witToVar,
5148
)
52-
import ZkFold.ArithmeticCircuit.Lookup (LookupTable, LookupType (..))
5349
import ZkFold.ArithmeticCircuit.Var (NewVar (..), Var)
5450
import ZkFold.ArithmeticCircuit.Witness (WitnessF (..))
5551
import ZkFold.Control.HApplicative (HApplicative (..))
@@ -70,7 +66,7 @@ import ZkFold.Symbolic.Data.Class (
7066
restore,
7167
)
7268
import ZkFold.Symbolic.Data.Input (isValid)
73-
import ZkFold.Symbolic.MonadCircuit (MonadCircuit (..), Witness (..))
69+
import ZkFold.Symbolic.MonadCircuit (LookupTable, MonadCircuit (..), Witness (..))
7470

7571
---------------------- Efficient "list" concatenation --------------------------
7672

@@ -105,12 +101,8 @@ data LookupEntry a v
105101

106102
------------- Box of constraints supporting efficient concatenation ------------
107103

108-
-- After #573, can be made even more declarative
109-
-- by getting rid of 'cbLkpFuns' field.
110-
-- Can then be used for new public Symbolic API (see 'constrain' below)!
111104
data ConstraintBox a v = MkCBox
112105
{ cbPolyCon :: AppList (Polynomial a v)
113-
, cbLkpFuns :: Map ByteString (LookupFunction a)
114106
, cbLookups :: AppList (LookupEntry a v)
115107
}
116108
deriving (Generic, NFData)
@@ -192,9 +184,6 @@ instance
192184
where
193185
unconstrained = pure . fromConstant
194186
constraint c = modify' \cb -> cb {cbPolyCon = MkPolynomial c `app` cbPolyCon cb}
195-
registerFunction f = state \(!cb) ->
196-
let (i, r') = appendFunction f (cbLkpFuns cb)
197-
in (i, cb {cbLkpFuns = r'})
198187
lookupConstraint c t = modify' \cb -> cb {cbLookups = LEntry c t `app` cbLookups cb}
199188

200189
------------------------- Optimized compilation function -----------------------
@@ -240,14 +229,12 @@ compile =
240229
unless isDone' do
241230
constraint (\x -> runPolynomial c (x . pure . elHash))
242231
for_ (children asWitness) work
243-
for_ cbLkpFuns \(LookupFunction f) -> do
244-
_ <- registerFunction f
245-
pure ()
246232
for_ cbLookups \(LEntry l t) -> do
233+
lt <- lookupType t
247234
isDone' <-
248235
gets
249236
( any (S.member $ toList $ fmap elHash l)
250-
. (MM.!? LookupType t)
237+
. (MM.!? lt)
251238
. acLookup
252239
)
253240
unless isDone' do

symbolic-base/src/ZkFold/ArithmeticCircuit/Lookup.hs

Lines changed: 0 additions & 97 deletions
This file was deleted.

symbolic-base/src/ZkFold/ArithmeticCircuit/Optimization.hs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ import ZkFold.ArithmeticCircuit.Context (
3232
CircuitContext (..),
3333
CircuitFold (..),
3434
Constraint,
35+
LookupType,
36+
asRange,
3537
witToVar,
3638
)
37-
import ZkFold.ArithmeticCircuit.Lookup (LookupType, asRange)
3839
import ZkFold.ArithmeticCircuit.Var (NewVar (..))
3940
import ZkFold.Data.Binary (fromByteString)
4041
import ZkFold.Symbolic.Class (Arithmetic)

symbolic-base/src/ZkFold/ArithmeticCircuit/Var.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
module ZkFold.ArithmeticCircuit.Var where
66

7+
import ByteString.Aeson.Orphans ()
78
import Control.Applicative (Applicative, pure, (<*>))
89
import Control.DeepSeq (NFData)
910
import Control.Monad (Monad, ap, (>>=))
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{-# LANGUAGE BlockArguments #-}
2+
{-# LANGUAGE TypeOperators #-}
3+
4+
module ZkFold.Data.FromList where
5+
6+
import Control.Applicative (liftA2)
7+
import Control.Monad.State (State, runState, state)
8+
import GHC.Err (error)
9+
import GHC.Generics (Par1 (..), (:*:) (..))
10+
11+
class FromList f where
12+
parseList :: State [a] (f a)
13+
14+
instance FromList Par1 where
15+
parseList = state \case
16+
[] -> error "parseList @Par1: empty list"
17+
(x : xs) -> (Par1 x, xs)
18+
19+
instance (FromList f, FromList g) => FromList (f :*: g) where
20+
parseList = liftA2 (:*:) parseList parseList
21+
22+
fromList :: FromList f => [a] -> f a
23+
fromList input = case runState parseList input of
24+
(result, []) -> result
25+
_ -> error "fromList: unconsumed elements"

0 commit comments

Comments
 (0)