diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 0471a5cd7..6cc63c066 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -110,7 +110,6 @@ connectDB path functions key track = do pure db where prepare db = do - let db' = SQL.connectionHandle $ DB.conn db unless (BA.null key) . SQLite3.exec db' $ "PRAGMA key = " <> keyString key <> ";" SQLite3.exec db' . fromQuery $ [sql| @@ -120,9 +119,13 @@ connectDB path functions key track = do PRAGMA secure_delete = ON; PRAGMA auto_vacuum = FULL; |] - forM_ functions $ \SQLiteFuncDef {funcName, argCount, deterministic, funcPtr} -> - createStaticFunction db' funcName argCount deterministic funcPtr - >>= either (throwIO . userError . show) pure + mapM_ addFunction functions + where + db' = SQL.connectionHandle $ DB.conn db + addFunction SQLiteFuncDef {funcName, argCount, funcPtrs} = + either (throwIO . userError . show) pure =<< case funcPtrs of + SQLiteFuncPtr isDet funcPtr -> createStaticFunction db' funcName argCount isDet funcPtr + SQLiteAggrPtrs stepPtr finalPtr -> createStaticAggregate db' funcName argCount stepPtr finalPtr closeDBStore :: DBStore -> IO () closeDBStore st@DBStore {dbClosed} = diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs index 04e724749..aac5ee37e 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs @@ -7,6 +7,7 @@ module Simplex.Messaging.Agent.Store.SQLite.Common ( DBStore (..), DBOpts (..), SQLiteFuncDef (..), + SQLiteFuncPtrs (..), withConnection, withConnection', withTransaction, @@ -55,14 +56,18 @@ data DBOpts = DBOpts track :: DB.TrackQueries } --- e.g. `SQLiteFuncDef "name" 2 True f` +-- e.g. `SQLiteFuncDef "func_name" 2 (SQLiteFuncPtr True func)` +-- or `SQLiteFuncDef "aggr_name" 3 (SQLiteAggrPtrs step final)` data SQLiteFuncDef = SQLiteFuncDef { funcName :: ByteString, argCount :: CArgCount, - deterministic :: Bool, - funcPtr :: FunPtr SQLiteFunc + funcPtrs :: SQLiteFuncPtrs } +data SQLiteFuncPtrs + = SQLiteFuncPtr {deterministic :: Bool, funcPtr :: FunPtr SQLiteFunc} + | SQLiteAggrPtrs {stepPtr :: FunPtr SQLiteFunc, finalPtr :: FunPtr SQLiteFuncFinal} + withConnectionPriority :: DBStore -> Bool -> (DB.Connection -> IO a) -> IO a withConnectionPriority DBStore {dbSem, dbConnection} priority action | priority = E.bracket_ signal release $ withMVar dbConnection action diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs index a3c3b94ac..2cbd7ecff 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs @@ -3,16 +3,20 @@ module Simplex.Messaging.Agent.Store.SQLite.Util where import Control.Exception (SomeException, catch, mask_) import Data.ByteString (ByteString) import qualified Data.ByteString as B +import Data.IORef import Database.SQLite3.Direct (Database (..), FuncArgs (..), FuncContext (..)) import Database.SQLite3.Bindings import Foreign.C.String import Foreign.Ptr import Foreign.StablePtr +import Foreign.Storable data CFuncPtrs = CFuncPtrs (FunPtr CFunc) (FunPtr CFunc) (FunPtr CFuncFinal) type SQLiteFunc = Ptr CContext -> CArgCount -> Ptr (Ptr CValue) -> IO () +type SQLiteFuncFinal = Ptr CContext -> IO () + mkSQLiteFunc :: (FuncContext -> FuncArgs -> IO ()) -> SQLiteFunc mkSQLiteFunc f cxt nArgs cvals = catchAsResultError cxt $ f (FuncContext cxt) (FuncArgs nArgs cvals) {-# INLINE mkSQLiteFunc #-} @@ -25,6 +29,50 @@ createStaticFunction (Database db) name nArgs isDet funPtr = mask_ $ do B.useAsCString name $ \namePtr -> toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs flags (castStablePtrToPtr u) funPtr nullFunPtr nullFunPtr nullFunPtr +mkSQLiteAggStep :: a -> (FuncContext -> FuncArgs -> a -> IO a) -> SQLiteFunc +mkSQLiteAggStep initSt xStep cxt nArgs cvals = catchAsResultError cxt $ do + -- we store the aggregate state in the buffer returned by + -- c_sqlite3_aggregate_context as a StablePtr pointing to an IORef that + -- contains the actual aggregate state + aggCtx <- getAggregateContext cxt + aggStPtr <- peek aggCtx + aggStRef <- + if castStablePtrToPtr aggStPtr /= nullPtr + then deRefStablePtr aggStPtr + else do + aggStRef <- newIORef initSt + aggStPtr' <- newStablePtr aggStRef + poke aggCtx aggStPtr' + return aggStRef + aggSt <- readIORef aggStRef + aggSt' <- xStep (FuncContext cxt) (FuncArgs nArgs cvals) aggSt + writeIORef aggStRef aggSt' + +mkSQLiteAggFinal :: a -> (FuncContext -> a -> IO ()) -> SQLiteFuncFinal +mkSQLiteAggFinal initSt xFinal cxt = do + aggCtx <- getAggregateContext cxt + aggStPtr <- peek aggCtx + if castStablePtrToPtr aggStPtr == nullPtr + then catchAsResultError cxt $ xFinal (FuncContext cxt) initSt + else do + catchAsResultError cxt $ do + aggStRef <- deRefStablePtr aggStPtr + aggSt <- readIORef aggStRef + xFinal (FuncContext cxt) aggSt + freeStablePtr aggStPtr + +getAggregateContext :: Ptr CContext -> IO (Ptr a) +getAggregateContext cxt = c_sqlite3_aggregate_context cxt stPtrSize + where + stPtrSize = fromIntegral $ sizeOf (undefined :: StablePtr ()) + +-- Based on createAggregate from Database.SQLite3.Direct, but uses static function pointers to avoid dynamic wrappers that trigger DCL. +createStaticAggregate :: Database -> ByteString -> CArgCount -> FunPtr SQLiteFunc -> FunPtr SQLiteFuncFinal -> IO (Either Error ()) +createStaticAggregate (Database db) name nArgs stepPtr finalPtr = mask_ $ do + u <- newStablePtr $ CFuncPtrs nullFunPtr stepPtr finalPtr + B.useAsCString name $ \namePtr -> + toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs 0 (castStablePtrToPtr u) nullFunPtr stepPtr finalPtr nullFunPtr + -- Convert a 'CError' to a 'Either Error', in the common case where -- SQLITE_OK signals success and anything else signals an error. --