Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 45 additions & 37 deletions symbolic-base/src/ZkFold/ArithmeticCircuit/Context.hs
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE NoStarIsType #-}

module ZkFold.ArithmeticCircuit.Context where

import Control.Applicative (liftA2, pure)
import Control.Applicative (liftA2, pure, (<*>))
import Control.DeepSeq (NFData, NFData1, liftRnf, rnf, rwhnf)
import Control.Monad.State (State, modify, runState, state)
import Data.Aeson ((.:), (.=))
import qualified Data.Aeson.Types as Aeson
import Data.Binary (Binary)
import Data.Bool (Bool (..), otherwise, (&&))
import Data.Bool (Bool (..), (&&))
import Data.ByteString (ByteString)
import Data.Either (Either (..))
import Data.Eq ((==))
import Data.Eq (Eq, (==))
import Data.Foldable (Foldable, fold, foldl', for_, toList)
import Data.Function (flip, ($), (.))
import Data.Functor (Functor, fmap, (<$>), (<&>))
import Data.Functor.Classes (Show1, liftShowList, liftShowsPrec)
import Data.Functor.Rep
import Data.Kind (Type)
import Data.List.Infinite (Infinite)
import qualified Data.List.Infinite as I
import Data.Map (Map)
Expand All @@ -38,23 +38,21 @@ import qualified Data.Set as S
import Data.Traversable (Traversable, traverse)
import Data.Tuple (fst, snd, uncurry)
import Data.Type.Equality (type (~))
import Data.Typeable (Typeable)
import GHC.Generics (Generic, Par1 (..), U1 (..), (:*:) (..))
import Optics (over, set, zoom)
import Text.Show
import qualified Type.Reflection as R
import Prelude (error, seq)

import ZkFold.Algebra.Class
import ZkFold.Algebra.Number
import ZkFold.Algebra.Polynomial.Multivariate (Poly, var)
import ZkFold.ArithmeticCircuit.Lookup (FunctionId (..), LookupType (..))
import ZkFold.ArithmeticCircuit.MerkleHash (MerkleHash (..), merkleHash, runHash)
import ZkFold.ArithmeticCircuit.Var
import ZkFold.ArithmeticCircuit.Witness (WitnessF (..))
import ZkFold.ArithmeticCircuit.WitnessEstimation (Partial (..), UVar (..))
import ZkFold.Control.HApplicative (HApplicative, hliftA2, hpure)
import ZkFold.Data.Binary (fromByteString, toByteString)
import ZkFold.Data.FromList (FromList, fromList)
import ZkFold.Data.HFunctor (HFunctor, hmap)
import ZkFold.Data.HFunctor.Classes
import ZkFold.Data.Package (Package, packWith, unpackWith)
Expand Down Expand Up @@ -90,10 +88,8 @@ data CircuitFold a
instance NFData a => NFData (CircuitFold a) where
rnf CircuitFold {..} = rnf foldCount `seq` liftRnf rnf foldSeed

data LookupFunction a
= forall f g.
(Representable f, Traversable g, Typeable f, Typeable g, Binary (Rep f)) =>
LookupFunction (forall x. PrimeField x => f x -> g x)
newtype LookupFunction a = LookupFunction
{runLookupFunction :: forall x. (PrimeField x, Algebra a x) => [x] -> [x]}

instance NFData (LookupFunction a) where
rnf = rwhnf
Expand All @@ -102,31 +98,34 @@ type FunctionRegistry a = Map ByteString (LookupFunction a)

appendFunction
:: forall f g a
. (Representable f, Typeable f, Binary (Rep f))
=> (Traversable g, Typeable g, Arithmetic a)
=> (forall x. PrimeField x => f x -> g x)
. (Representable f, FromList f, Binary (Rep f))
=> (Foldable g, Arithmetic a, Binary a)
=> (forall x. (PrimeField x, Algebra a x) => f x -> g x)
-> FunctionRegistry a
-> (FunctionId (f a -> g a), FunctionRegistry a)
-> (ByteString, FunctionRegistry a)
appendFunction f r =
let functionId = runHash @(Just (Order a)) $ sum (f $ tabulate merkleHash)
in (FunctionId functionId, M.insert functionId (LookupFunction f) r)

lookupFunction
:: forall f g (a :: Type)
. (Typeable f, Typeable g)
=> FunctionRegistry a
-> FunctionId (f a -> g a)
-> (forall x. PrimeField x => f x -> g x)
lookupFunction m (FunctionId i) = case m M.! i of
LookupFunction f -> cast1 . f . cast1
where
cast1 :: forall h k b. (Typeable h, Typeable k) => h b -> k b
cast1 x
| Just R.HRefl <- th `R.eqTypeRep` tk = x
| otherwise = error "types are not equal"
where
th = R.typeRep :: R.TypeRep h
tk = R.typeRep :: R.TypeRep k
in (functionId, M.insert functionId (LookupFunction (toList . \x -> f $ fromList x)) r)

data LookupType a
= LTRanges (Set (a, a))
| LTProduct (LookupType a) (LookupType a)
| LTPlot ByteString (LookupType a)
deriving
( Aeson.FromJSON
, Aeson.FromJSONKey
, Aeson.ToJSON
, Aeson.ToJSONKey
, Eq
, Generic
, NFData
, Ord
, Show
)

asRange :: LookupType a -> Maybe (Set (a, a))
asRange (LTRanges rs) = Just rs
asRange _ = Nothing

-- | Circuit context in the form of a system of polynomial constraints.
data CircuitContext a o = CircuitContext
Expand Down Expand Up @@ -352,10 +351,10 @@ instance
zoom #acSystem . modify $
M.insert (witToVar (p at)) (p $ evalVar var)

lookupConstraint vars lt = do
lookupConstraint vars ltable = do
vs <- traverse prepare (toList vars)
zoom #acLookup . modify $
MM.insertWith S.union (LookupType lt) (S.singleton vs)
lt <- lookupType ltable
zoom #acLookup . modify $ MM.insertWith S.union lt (S.singleton vs)
pure ()
where
prepare (LinVar k x b) | k == one && b == zero = pure x
Expand All @@ -367,7 +366,16 @@ instance
constraint (($ toVar v) - ($ src))
pure v

registerFunction f = zoom #acLookupFunction $ state (appendFunction f)
-- | Translates a lookup table into a lookup type,
-- storing all lookup functions in a circuit.
lookupType
:: (Arithmetic a, Binary a)
=> LookupTable a f -> State (CircuitContext a o) (LookupType a)
lookupType (Ranges rs) = pure (LTRanges rs)
lookupType (Product t u) = LTProduct <$> lookupType t <*> lookupType u
lookupType (Plot f t) = do
funcId <- zoom #acLookupFunction $ state (appendFunction f)
LTPlot funcId <$> lookupType t

-- | Generates new variable index given a witness for it.
--
Expand Down
3 changes: 1 addition & 2 deletions symbolic-base/src/ZkFold/ArithmeticCircuit/Desugaring.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ import Data.Tuple (fst, uncurry)
import Prelude (error)

import ZkFold.Algebra.Class
import ZkFold.ArithmeticCircuit.Context (CircuitContext, acLookup)
import ZkFold.ArithmeticCircuit.Lookup (asRange)
import ZkFold.ArithmeticCircuit.Context (CircuitContext, acLookup, asRange)
import ZkFold.ArithmeticCircuit.Var (toVar)
import ZkFold.Prelude (assert, length)
import ZkFold.Symbolic.Class (Arithmetic)
Expand Down
23 changes: 5 additions & 18 deletions symbolic-base/src/ZkFold/ArithmeticCircuit/Experimental.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@ module ZkFold.ArithmeticCircuit.Experimental where
import Control.Applicative (pure)
import Control.DeepSeq (NFData (..), NFData1, liftRnf, rwhnf)
import Control.Monad (unless)
import Control.Monad.State (State, gets, modify', runState, state)
import Control.Monad.State (State, gets, modify', runState)
import Data.Binary (Binary)
import Data.ByteString (ByteString)
import Data.Eq (Eq (..))
import Data.Foldable (Foldable (..), any, for_)
import Data.Function (flip, on, ($), (.))
import Data.Functor (Functor, fmap)
import Data.Functor.Rep (Rep, Representable)
import Data.Map (Map)
import qualified Data.Map as M
import qualified Data.Map.Monoidal as MM
import Data.Maybe (Maybe (..))
Expand All @@ -40,16 +38,14 @@ import ZkFold.ArithmeticCircuit (ArithmeticCircuit, optimize, solder)
import ZkFold.ArithmeticCircuit.Children (children)
import ZkFold.ArithmeticCircuit.Context (
CircuitContext,
LookupFunction (LookupFunction),
acLookup,
acSystem,
acWitness,
appendFunction,
crown,
emptyContext,
lookupType,
witToVar,
)
import ZkFold.ArithmeticCircuit.Lookup (LookupTable, LookupType (..))
import ZkFold.ArithmeticCircuit.Var (NewVar (..), Var)
import ZkFold.ArithmeticCircuit.Witness (WitnessF (..))
import ZkFold.Control.HApplicative (HApplicative (..))
Expand All @@ -70,7 +66,7 @@ import ZkFold.Symbolic.Data.Class (
restore,
)
import ZkFold.Symbolic.Data.Input (isValid)
import ZkFold.Symbolic.MonadCircuit (MonadCircuit (..), Witness (..))
import ZkFold.Symbolic.MonadCircuit (LookupTable, MonadCircuit (..), Witness (..))

---------------------- Efficient "list" concatenation --------------------------

Expand Down Expand Up @@ -105,12 +101,8 @@ data LookupEntry a v

------------- Box of constraints supporting efficient concatenation ------------

-- After #573, can be made even more declarative
-- by getting rid of 'cbLkpFuns' field.
-- Can then be used for new public Symbolic API (see 'constrain' below)!
data ConstraintBox a v = MkCBox
{ cbPolyCon :: AppList (Polynomial a v)
, cbLkpFuns :: Map ByteString (LookupFunction a)
, cbLookups :: AppList (LookupEntry a v)
}
deriving (Generic, NFData)
Expand Down Expand Up @@ -192,9 +184,6 @@ instance
where
unconstrained = pure . fromConstant
constraint c = modify' \cb -> cb {cbPolyCon = MkPolynomial c `app` cbPolyCon cb}
registerFunction f = state \(!cb) ->
let (i, r') = appendFunction f (cbLkpFuns cb)
in (i, cb {cbLkpFuns = r'})
lookupConstraint c t = modify' \cb -> cb {cbLookups = LEntry c t `app` cbLookups cb}

------------------------- Optimized compilation function -----------------------
Expand Down Expand Up @@ -240,14 +229,12 @@ compile =
unless isDone' do
constraint (\x -> runPolynomial c (x . pure . elHash))
for_ (children asWitness) work
for_ cbLkpFuns \(LookupFunction f) -> do
_ <- registerFunction f
pure ()
for_ cbLookups \(LEntry l t) -> do
lt <- lookupType t
isDone' <-
gets
( any (S.member $ toList $ fmap elHash l)
. (MM.!? LookupType t)
. (MM.!? lt)
. acLookup
)
unless isDone' do
Expand Down
97 changes: 0 additions & 97 deletions symbolic-base/src/ZkFold/ArithmeticCircuit/Lookup.hs

This file was deleted.

3 changes: 2 additions & 1 deletion symbolic-base/src/ZkFold/ArithmeticCircuit/Optimization.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ import ZkFold.ArithmeticCircuit.Context (
CircuitContext (..),
CircuitFold (..),
Constraint,
LookupType,
asRange,
witToVar,
)
import ZkFold.ArithmeticCircuit.Lookup (LookupType, asRange)
import ZkFold.ArithmeticCircuit.Var (NewVar (..))
import ZkFold.Data.Binary (fromByteString)
import ZkFold.Symbolic.Class (Arithmetic)
Expand Down
1 change: 1 addition & 0 deletions symbolic-base/src/ZkFold/ArithmeticCircuit/Var.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

module ZkFold.ArithmeticCircuit.Var where

import ByteString.Aeson.Orphans ()
import Control.Applicative (Applicative, pure, (<*>))
import Control.DeepSeq (NFData)
import Control.Monad (Monad, ap, (>>=))
Expand Down
25 changes: 25 additions & 0 deletions symbolic-base/src/ZkFold/Data/FromList.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE TypeOperators #-}

module ZkFold.Data.FromList where

import Control.Applicative (liftA2)
import Control.Monad.State (State, runState, state)
import GHC.Err (error)
import GHC.Generics (Par1 (..), (:*:) (..))

class FromList f where
parseList :: State [a] (f a)

instance FromList Par1 where
parseList = state \case
[] -> error "parseList @Par1: empty list"
(x : xs) -> (Par1 x, xs)

instance (FromList f, FromList g) => FromList (f :*: g) where
parseList = liftA2 (:*:) parseList parseList

fromList :: FromList f => [a] -> f a
fromList input = case runState parseList input of
(result, []) -> result
_ -> error "fromList: unconsumed elements"
Loading