@@ -3,16 +3,20 @@ module Simplex.Messaging.Agent.Store.SQLite.Util where
33import Control.Exception (SomeException , catch , mask_ )
44import Data.ByteString (ByteString )
55import qualified Data.ByteString as B
6+ import Data.IORef
67import Database.SQLite3.Direct (Database (.. ), FuncArgs (.. ), FuncContext (.. ))
78import Database.SQLite3.Bindings
89import Foreign.C.String
910import Foreign.Ptr
1011import Foreign.StablePtr
12+ import Foreign.Storable
1113
1214data CFuncPtrs = CFuncPtrs (FunPtr CFunc ) (FunPtr CFunc ) (FunPtr CFuncFinal )
1315
1416type SQLiteFunc = Ptr CContext -> CArgCount -> Ptr (Ptr CValue ) -> IO ()
1517
18+ type SQLiteFuncFinal = Ptr CContext -> IO ()
19+
1620mkSQLiteFunc :: (FuncContext -> FuncArgs -> IO () ) -> SQLiteFunc
1721mkSQLiteFunc f cxt nArgs cvals = catchAsResultError cxt $ f (FuncContext cxt) (FuncArgs nArgs cvals)
1822{-# INLINE mkSQLiteFunc #-}
@@ -25,6 +29,50 @@ createStaticFunction (Database db) name nArgs isDet funPtr = mask_ $ do
2529 B. useAsCString name $ \ namePtr ->
2630 toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs flags (castStablePtrToPtr u) funPtr nullFunPtr nullFunPtr nullFunPtr
2731
32+ mkSQLiteAggStep :: a -> (FuncContext -> FuncArgs -> a -> IO a ) -> SQLiteFunc
33+ mkSQLiteAggStep initSt xStep cxt nArgs cvals = catchAsResultError cxt $ do
34+ -- we store the aggregate state in the buffer returned by
35+ -- c_sqlite3_aggregate_context as a StablePtr pointing to an IORef that
36+ -- contains the actual aggregate state
37+ aggCtx <- getAggregateContext cxt
38+ aggStPtr <- peek aggCtx
39+ aggStRef <-
40+ if castStablePtrToPtr aggStPtr /= nullPtr
41+ then deRefStablePtr aggStPtr
42+ else do
43+ aggStRef <- newIORef initSt
44+ aggStPtr' <- newStablePtr aggStRef
45+ poke aggCtx aggStPtr'
46+ return aggStRef
47+ aggSt <- readIORef aggStRef
48+ aggSt' <- xStep (FuncContext cxt) (FuncArgs nArgs cvals) aggSt
49+ writeIORef aggStRef aggSt'
50+
51+ mkSQLiteAggFinal :: a -> (FuncContext -> a -> IO () ) -> SQLiteFuncFinal
52+ mkSQLiteAggFinal initSt xFinal cxt = do
53+ aggCtx <- getAggregateContext cxt
54+ aggStPtr <- peek aggCtx
55+ if castStablePtrToPtr aggStPtr == nullPtr
56+ then catchAsResultError cxt $ xFinal (FuncContext cxt) initSt
57+ else do
58+ catchAsResultError cxt $ do
59+ aggStRef <- deRefStablePtr aggStPtr
60+ aggSt <- readIORef aggStRef
61+ xFinal (FuncContext cxt) aggSt
62+ freeStablePtr aggStPtr
63+
64+ getAggregateContext :: Ptr CContext -> IO (Ptr a )
65+ getAggregateContext cxt = c_sqlite3_aggregate_context cxt stPtrSize
66+ where
67+ stPtrSize = fromIntegral $ sizeOf (undefined :: StablePtr () )
68+
69+ -- Based on createAggregate from Database.SQLite3.Direct, but uses static function pointers to avoid dynamic wrappers that trigger DCL.
70+ createStaticAggregate :: Database -> ByteString -> CArgCount -> FunPtr SQLiteFunc -> FunPtr SQLiteFuncFinal -> IO (Either Error () )
71+ createStaticAggregate (Database db) name nArgs stepPtr finalPtr = mask_ $ do
72+ u <- newStablePtr $ CFuncPtrs nullFunPtr stepPtr finalPtr
73+ B. useAsCString name $ \ namePtr ->
74+ toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs 0 (castStablePtrToPtr u) nullFunPtr stepPtr finalPtr nullFunPtr
75+
2876-- Convert a 'CError' to a 'Either Error', in the common case where
2977-- SQLITE_OK signals success and anything else signals an error.
3078--
0 commit comments