Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/Simplex/Messaging/Agent/Store/SQLite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ connectDB path functions key track = do
pure db
where
prepare db = do
let db' = SQL.connectionHandle $ DB.conn db
unless (BA.null key) . SQLite3.exec db' $ "PRAGMA key = " <> keyString key <> ";"
SQLite3.exec db' . fromQuery $
[sql|
Expand All @@ -120,9 +119,13 @@ connectDB path functions key track = do
PRAGMA secure_delete = ON;
PRAGMA auto_vacuum = FULL;
|]
forM_ functions $ \SQLiteFuncDef {funcName, argCount, deterministic, funcPtr} ->
createStaticFunction db' funcName argCount deterministic funcPtr
>>= either (throwIO . userError . show) pure
mapM_ addFunction functions
where
db' = SQL.connectionHandle $ DB.conn db
addFunction SQLiteFuncDef {funcName, argCount, funcPtrs} =
either (throwIO . userError . show) pure =<< case funcPtrs of
SQLiteFuncPtr isDet funcPtr -> createStaticFunction db' funcName argCount isDet funcPtr
SQLiteAggrPtrs stepPtr finalPtr -> createStaticAggregate db' funcName argCount stepPtr finalPtr

closeDBStore :: DBStore -> IO ()
closeDBStore st@DBStore {dbClosed} =
Expand Down
11 changes: 8 additions & 3 deletions src/Simplex/Messaging/Agent/Store/SQLite/Common.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module Simplex.Messaging.Agent.Store.SQLite.Common
( DBStore (..),
DBOpts (..),
SQLiteFuncDef (..),
SQLiteFuncPtrs (..),
withConnection,
withConnection',
withTransaction,
Expand Down Expand Up @@ -55,14 +56,18 @@ data DBOpts = DBOpts
track :: DB.TrackQueries
}

-- e.g. `SQLiteFuncDef "name" 2 True f`
-- e.g. `SQLiteFuncDef "func_name" 2 (SQLiteFuncPtr True func)`
-- or `SQLiteFuncDef "aggr_name" 3 (SQLiteAggrPtrs step final)`
data SQLiteFuncDef = SQLiteFuncDef
{ funcName :: ByteString,
argCount :: CArgCount,
deterministic :: Bool,
funcPtr :: FunPtr SQLiteFunc
funcPtrs :: SQLiteFuncPtrs
}

data SQLiteFuncPtrs
= SQLiteFuncPtr {deterministic :: Bool, funcPtr :: FunPtr SQLiteFunc}
| SQLiteAggrPtrs {stepPtr :: FunPtr SQLiteFunc, finalPtr :: FunPtr SQLiteFuncFinal}

withConnectionPriority :: DBStore -> Bool -> (DB.Connection -> IO a) -> IO a
withConnectionPriority DBStore {dbSem, dbConnection} priority action
| priority = E.bracket_ signal release $ withMVar dbConnection action
Expand Down
48 changes: 48 additions & 0 deletions src/Simplex/Messaging/Agent/Store/SQLite/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@ module Simplex.Messaging.Agent.Store.SQLite.Util where
import Control.Exception (SomeException, catch, mask_)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.IORef
import Database.SQLite3.Direct (Database (..), FuncArgs (..), FuncContext (..))
import Database.SQLite3.Bindings
import Foreign.C.String
import Foreign.Ptr
import Foreign.StablePtr
import Foreign.Storable

data CFuncPtrs = CFuncPtrs (FunPtr CFunc) (FunPtr CFunc) (FunPtr CFuncFinal)

type SQLiteFunc = Ptr CContext -> CArgCount -> Ptr (Ptr CValue) -> IO ()

type SQLiteFuncFinal = Ptr CContext -> IO ()

mkSQLiteFunc :: (FuncContext -> FuncArgs -> IO ()) -> SQLiteFunc
mkSQLiteFunc f cxt nArgs cvals = catchAsResultError cxt $ f (FuncContext cxt) (FuncArgs nArgs cvals)
{-# INLINE mkSQLiteFunc #-}
Expand All @@ -25,6 +29,50 @@ createStaticFunction (Database db) name nArgs isDet funPtr = mask_ $ do
B.useAsCString name $ \namePtr ->
toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs flags (castStablePtrToPtr u) funPtr nullFunPtr nullFunPtr nullFunPtr

mkSQLiteAggStep :: a -> (FuncContext -> FuncArgs -> a -> IO a) -> SQLiteFunc
mkSQLiteAggStep initSt xStep cxt nArgs cvals = catchAsResultError cxt $ do
-- we store the aggregate state in the buffer returned by
-- c_sqlite3_aggregate_context as a StablePtr pointing to an IORef that
-- contains the actual aggregate state
aggCtx <- getAggregateContext cxt
aggStPtr <- peek aggCtx
aggStRef <-
if castStablePtrToPtr aggStPtr /= nullPtr
then deRefStablePtr aggStPtr
else do
aggStRef <- newIORef initSt
aggStPtr' <- newStablePtr aggStRef
poke aggCtx aggStPtr'
return aggStRef
aggSt <- readIORef aggStRef
aggSt' <- xStep (FuncContext cxt) (FuncArgs nArgs cvals) aggSt
writeIORef aggStRef aggSt'

mkSQLiteAggFinal :: a -> (FuncContext -> a -> IO ()) -> SQLiteFuncFinal
mkSQLiteAggFinal initSt xFinal cxt = do
aggCtx <- getAggregateContext cxt
aggStPtr <- peek aggCtx
if castStablePtrToPtr aggStPtr == nullPtr
then catchAsResultError cxt $ xFinal (FuncContext cxt) initSt
else do
catchAsResultError cxt $ do
aggStRef <- deRefStablePtr aggStPtr
aggSt <- readIORef aggStRef
xFinal (FuncContext cxt) aggSt
freeStablePtr aggStPtr

getAggregateContext :: Ptr CContext -> IO (Ptr a)
getAggregateContext cxt = c_sqlite3_aggregate_context cxt stPtrSize
where
stPtrSize = fromIntegral $ sizeOf (undefined :: StablePtr ())

-- Based on createAggregate from Database.SQLite3.Direct, but uses static function pointers to avoid dynamic wrappers that trigger DCL.
createStaticAggregate :: Database -> ByteString -> CArgCount -> FunPtr SQLiteFunc -> FunPtr SQLiteFuncFinal -> IO (Either Error ())
createStaticAggregate (Database db) name nArgs stepPtr finalPtr = mask_ $ do
u <- newStablePtr $ CFuncPtrs nullFunPtr stepPtr finalPtr
B.useAsCString name $ \namePtr ->
toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs 0 (castStablePtrToPtr u) nullFunPtr stepPtr finalPtr nullFunPtr

-- Convert a 'CError' to a 'Either Error', in the common case where
-- SQLITE_OK signals success and anything else signals an error.
--
Expand Down
Loading