diff --git a/simplexmq.cabal b/simplexmq.cabal index 81f5ee808..84a4cc927 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -154,6 +154,9 @@ library Simplex.Messaging.Agent.Store.Postgres.DB Simplex.Messaging.Agent.Store.Postgres.Migrations Simplex.Messaging.Agent.Store.Postgres.Migrations.M20241210_initial + if !flag(client_library) + exposed-modules: + Simplex.Messaging.Agent.Store.Postgres.Util else exposed-modules: Simplex.Messaging.Agent.Store.SQLite @@ -260,7 +263,6 @@ library , crypton-x509-validation ==1.6.* , cryptostore ==0.3.* , data-default ==0.7.* - , direct-sqlcipher ==2.3.* , directory ==1.3.* , filepath ==1.4.* , hourglass ==0.2.* @@ -280,7 +282,6 @@ library , random >=1.1 && <1.3 , simple-logger ==0.1.* , socks ==0.6.* - , sqlcipher-simple ==0.4.* , stm ==2.5.* , temporary ==1.3.* , time ==1.12.* @@ -301,9 +302,14 @@ library , hashable ==1.4.* if flag(client_postgres) build-depends: - postgresql-simple ==0.6.* + postgresql-libpq >=0.10.0.0 + , postgresql-simple ==0.7.* , raw-strings-qq ==1.1.* cpp-options: -DdbPostgres + else + build-depends: + direct-sqlcipher ==2.3.* + , sqlcipher-simple ==0.4.* if impl(ghc >= 9.6.2) build-depends: bytestring ==0.11.* @@ -406,10 +412,7 @@ test-suite simplexmq-test AgentTests.EqInstances AgentTests.FunctionalAPITests AgentTests.MigrationTests - AgentTests.NotificationTests - AgentTests.SchemaDump AgentTests.ServerChoice - AgentTests.SQLiteTests CLITests CoreTests.BatchingTests CoreTests.CryptoFileTests @@ -423,6 +426,7 @@ test-suite simplexmq-test CoreTests.UtilTests CoreTests.VersionRangeTests FileDescriptionTests + Fixtures NtfClient NtfServerTests RemoteControl @@ -438,6 +442,11 @@ test-suite simplexmq-test Static Static.Embedded Paths_simplexmq + if !flag(client_postgres) + other-modules: + AgentTests.NotificationTests + AgentTests.SchemaDump + AgentTests.SQLiteTests hs-source-dirs: tests apps/smp-server/web @@ -478,7 +487,6 @@ test-suite simplexmq-test , silently ==1.2.* , simple-logger , simplexmq - , sqlcipher-simple , stm , text , time @@ -495,6 +503,10 @@ test-suite simplexmq-test default-language: Haskell2010 if flag(client_postgres) build-depends: - postgresql-simple ==0.6.* + postgresql-libpq >=0.10.0.0 + , postgresql-simple ==0.7.* , raw-strings-qq ==1.1.* cpp-options: -DdbPostgres + else + build-depends: + sqlcipher-simple diff --git a/src/Simplex/FileTransfer/Description.hs b/src/Simplex/FileTransfer/Description.hs index 8cb98fd32..11ca98edc 100644 --- a/src/Simplex/FileTransfer/Description.hs +++ b/src/Simplex/FileTransfer/Description.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DerivingStrategies #-} @@ -9,6 +10,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TemplateHaskell #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} @@ -62,17 +64,23 @@ import Data.Text (Text) import Data.Text.Encoding (encodeUtf8) import Data.Word (Word32) import qualified Data.Yaml as Y -import Database.SQLite.Simple.FromField (FromField (..)) -import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.FileTransfer.Chunks import Simplex.FileTransfer.Protocol import Simplex.Messaging.Agent.QueryString +import Simplex.Messaging.Agent.Store.DB (Binary (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (defaultJSON, parseAll) import Simplex.Messaging.Protocol (XFTPServer) import Simplex.Messaging.ServiceScheme (ServiceScheme (..)) import Simplex.Messaging.Util (bshow, safeDecodeUtf8, (<$?>)) +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif data FileDescription (p :: FileParty) = FileDescription { party :: SFileParty p, @@ -109,6 +117,9 @@ fdSeparator = "################################\n" newtype FileDigest = FileDigest {unFileDigest :: ByteString} deriving (Eq, Show) + deriving newtype (FromField) + +instance ToField FileDigest where toField (FileDigest s) = toField $ Binary s instance StrEncoding FileDigest where strEncode (FileDigest fd) = strEncode fd @@ -122,10 +133,6 @@ instance ToJSON FileDigest where toJSON = strToJSON toEncoding = strToJEncoding -instance FromField FileDigest where fromField f = FileDigest <$> fromField f - -instance ToField FileDigest where toField (FileDigest s) = toField s - data FileChunk = FileChunk { chunkNo :: Int, chunkSize :: FileSize Word32, @@ -288,9 +295,9 @@ instance (Integral a, Show a) => StrEncoding (FileSize a) where instance (Integral a, Show a) => IsString (FileSize a) where fromString = either error id . strDecode . B.pack -instance FromField a => FromField (FileSize a) where fromField f = FileSize <$> fromField f +deriving newtype instance FromField a => FromField (FileSize a) -instance ToField a => ToField (FileSize a) where toField (FileSize s) = toField s +deriving newtype instance ToField a => ToField (FileSize a) groupReplicasByServer :: FileSize Word32 -> [FileChunk] -> [NonEmpty FileServerReplica] groupReplicasByServer defChunkSize = diff --git a/src/Simplex/FileTransfer/Types.hs b/src/Simplex/FileTransfer/Types.hs index 8569bdd12..f45ac75f8 100644 --- a/src/Simplex/FileTransfer/Types.hs +++ b/src/Simplex/FileTransfer/Types.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -13,8 +14,6 @@ import Data.Int (Int64) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import Data.Word (Word32) -import Database.SQLite.Simple.FromField (FromField (..)) -import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.FileTransfer.Client (XFTPChunkSpec (..)) import Simplex.FileTransfer.Description import qualified Simplex.Messaging.Crypto as C @@ -24,6 +23,13 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers import Simplex.Messaging.Protocol (XFTPServer) import System.FilePath (()) +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif type RcvFileId = ByteString -- Agent entity ID diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 8068f171f..5d5e6fcaf 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -1,5 +1,6 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} @@ -205,7 +206,6 @@ import Data.Text.Encoding import Data.Time (UTCTime, addUTCTime, defaultTimeLocale, formatTime, getCurrentTime) import Data.Time.Clock.System (getSystemTime) import Data.Word (Word16) -import qualified Database.SQLite.Simple as SQL import Network.Socket (HostName) import Simplex.FileTransfer.Client (XFTPChunkSpec (..), XFTPClient, XFTPClientConfig (..), XFTPClientError) import qualified Simplex.FileTransfer.Client as X @@ -282,6 +282,9 @@ import UnliftIO.Concurrent (forkIO, mkWeakThreadId) import UnliftIO.Directory (doesFileExist, getTemporaryDirectory, removeFile) import qualified UnliftIO.Exception as E import UnliftIO.STM +#if !defined(dbPostgres) +import qualified Database.SQLite.Simple as SQL +#endif type ClientVar msg = SessionVar (Either (AgentErrorType, Maybe UTCTime) (Client msg)) @@ -1989,6 +1992,13 @@ withStore c action = do withExceptT storeError . ExceptT . liftIO . agentOperationBracket c AODatabase (\_ -> pure ()) $ withTransaction st action `E.catches` handleDBErrors where +#if defined(dbPostgres) + -- TODO [postgres] postgres specific error handling + handleDBErrors :: [E.Handler IO (Either StoreError a)] + handleDBErrors = + [ E.Handler $ \(E.SomeException e) -> pure . Left $ SEInternal $ bshow e + ] +#else handleDBErrors :: [E.Handler IO (Either StoreError a)] handleDBErrors = [ E.Handler $ \(e :: SQL.SQLError) -> @@ -1997,6 +2007,7 @@ withStore c action = do in pure . Left . (if busy then SEDatabaseBusy else SEInternal) $ bshow se, E.Handler $ \(E.SomeException e) -> pure . Left $ SEInternal $ bshow e ] +#endif withStoreBatch :: Traversable t => AgentClient -> (DB.Connection -> t (IO (Either AgentErrorType a))) -> AM' (t (Either AgentErrorType a)) withStoreBatch c actions = do diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index b6a3830c7..80a307efa 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -88,6 +88,7 @@ import System.Mem.Weak (Weak) import System.Random (StdGen, newStdGen) import UnliftIO.STM #if defined(dbPostgres) +import Database.PostgreSQL.Simple (ConnectInfo (..)) #else import Data.ByteArray (ScrubbedBytes) #endif @@ -277,8 +278,7 @@ newSMPAgentEnv config store = do pure Env {config, store, random, randomServer, ntfSupervisor, xftpAgent, multicastSubscribers} #if defined(dbPostgres) --- TODO [postgres] pass db name / ConnectInfo? -createAgentStore :: MigrationConfirmation -> IO (Either MigrationError DBStore) +createAgentStore :: ConnectInfo -> String -> MigrationConfirmation -> IO (Either MigrationError DBStore) createAgentStore = createStore #else createAgentStore :: FilePath -> ScrubbedBytes -> Bool -> MigrationConfirmation -> IO (Either MigrationError DBStore) diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 08e8add24..b87f87f18 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DuplicateRecordFields #-} @@ -167,13 +168,12 @@ import Data.Time.Clock.System (SystemTime) import Data.Type.Equality import Data.Typeable () import Data.Word (Word16, Word32) -import Database.SQLite.Simple.FromField -import Database.SQLite.Simple.ToField import Simplex.FileTransfer.Description import Simplex.FileTransfer.Protocol (FileParty (..)) import Simplex.FileTransfer.Transport (XFTPErrorType) import Simplex.FileTransfer.Types (FileErrorType) import Simplex.Messaging.Agent.QueryString +import Simplex.Messaging.Agent.Store.DB (Binary (..)) import Simplex.Messaging.Client (ProxyClientError) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet @@ -224,6 +224,13 @@ import Simplex.Messaging.Version import Simplex.Messaging.Version.Internal import Simplex.RemoteControl.Types import UnliftIO.Exception (Exception) +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif -- SMP agent protocol version history: -- 1 - binary protocol encoding (1/1/2022) @@ -644,7 +651,7 @@ instance ToJSON NotificationsMode where instance FromJSON NotificationsMode where parseJSON = strParseJSON "NotificationsMode" -instance ToField NotificationsMode where toField = toField . strEncode +instance ToField NotificationsMode where toField = toField . Binary . strEncode instance FromField NotificationsMode where fromField = blobFieldDecoder $ parseAll strP diff --git a/src/Simplex/Messaging/Agent/Stats.hs b/src/Simplex/Messaging/Agent/Stats.hs index d4663bfb1..1d174622e 100644 --- a/src/Simplex/Messaging/Agent/Stats.hs +++ b/src/Simplex/Messaging/Agent/Stats.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE NamedFieldPuns #-} @@ -10,13 +11,18 @@ import qualified Data.Aeson.TH as J import Data.Int (Int64) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -import Database.SQLite.Simple.FromField (FromField (..)) -import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Agent.Protocol (UserId) import Simplex.Messaging.Parsers (defaultJSON, fromTextField_) -import Simplex.Messaging.Protocol (SMPServer, XFTPServer, NtfServer) +import Simplex.Messaging.Protocol (NtfServer, SMPServer, XFTPServer) import Simplex.Messaging.Util (decodeJSON, encodeJSON) import UnliftIO.STM +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif data AgentSMPServerStats = AgentSMPServerStats { sentDirect :: TVar Int, -- successfully sent messages diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 1f21c7e71..c199e480b 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -55,15 +55,16 @@ import Simplex.Messaging.Protocol ) import qualified Simplex.Messaging.Protocol as SMP #if defined(dbPostgres) +import Database.PostgreSQL.Simple (ConnectInfo (..)) import qualified Simplex.Messaging.Agent.Store.Postgres as StoreFunctions #else -import qualified Simplex.Messaging.Agent.Store.SQLite as StoreFunctions import Data.ByteArray (ScrubbedBytes) +import qualified Simplex.Messaging.Agent.Store.SQLite as StoreFunctions #endif #if defined(dbPostgres) -createStore :: MigrationConfirmation -> IO (Either MigrationError DBStore) -createStore = StoreFunctions.createDBStore Migrations.app +createStore :: ConnectInfo -> String -> MigrationConfirmation -> IO (Either MigrationError DBStore) +createStore connectInfo schema = StoreFunctions.createDBStore connectInfo schema Migrations.app #else createStore :: FilePath -> ScrubbedBytes -> Bool -> MigrationConfirmation -> IO (Either MigrationError DBStore) createStore dbFilePath dbKey keepKey = StoreFunctions.createDBStore dbFilePath dbKey keepKey Migrations.app diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index a2ecab6ea..c339b7a01 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} @@ -18,7 +19,6 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} @@ -249,11 +249,6 @@ import Data.Ord (Down (..)) import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Time.Clock (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime) import Data.Word (Word32) -import Database.SQLite.Simple (FromRow (..), NamedParam (..), Only (..), Query (..), SQLError, ToRow (..), field, (:.) (..)) -import qualified Database.SQLite.Simple as SQL -import Database.SQLite.Simple.FromField -import Database.SQLite.Simple.QQ (sql) -import Database.SQLite.Simple.ToField (ToField (..)) import Network.Socket (ServiceName) import Simplex.FileTransfer.Client (XFTPChunkSpec (..)) import Simplex.FileTransfer.Description @@ -265,6 +260,7 @@ import Simplex.Messaging.Agent.Stats import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.Common import qualified Simplex.Messaging.Agent.Store.DB as DB +import Simplex.Messaging.Agent.Store.DB (Binary (..), BoolInt (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs (..)) import Simplex.Messaging.Crypto.Ratchet (PQEncryption (..), PQSupport (..), RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys) @@ -281,14 +277,34 @@ import Simplex.Messaging.Util (bshow, catchAllErrors, eitherToMaybe, ifM, tshow, import Simplex.Messaging.Version.Internal import qualified UnliftIO.Exception as E import UnliftIO.STM +#if defined(dbPostgres) +import Database.PostgreSQL.Simple (Only (..), Query, SqlError, (:.) (..)) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.Errors (constraintViolation) +import Database.PostgreSQL.Simple.SqlQQ (sql) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple (FromRow (..), Only (..), Query (..), SQLError, ToRow (..), field, (:.) (..)) +import qualified Database.SQLite.Simple as SQL +import Database.SQLite.Simple.FromField +import Database.SQLite.Simple.QQ (sql) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif checkConstraint :: StoreError -> IO (Either StoreError a) -> IO (Either StoreError a) checkConstraint err action = action `E.catch` (pure . Left . handleSQLError err) +#if defined(dbPostgres) +handleSQLError :: StoreError -> SqlError -> StoreError +handleSQLError err e = case constraintViolation e of + Just _ -> err + Nothing -> SEInternal $ bshow e +#else handleSQLError :: StoreError -> SQLError -> StoreError handleSQLError err e | SQL.sqlError e == SQL.ErrorConstraint = err | otherwise = SEInternal $ bshow e +#endif createUserRecord :: DB.Connection -> IO UserId createUserRecord db = do @@ -298,7 +314,7 @@ createUserRecord db = do checkUser :: DB.Connection -> UserId -> IO (Either StoreError ()) checkUser db userId = firstRow (\(_ :: Only Int64) -> ()) SEUserNotFound $ - DB.query db "SELECT user_id FROM users WHERE user_id = ? AND deleted = ?" (userId, False) + DB.query db "SELECT user_id FROM users WHERE user_id = ? AND deleted = ?" (userId, BI False) deleteUserRecord :: DB.Connection -> UserId -> IO (Either StoreError ()) deleteUserRecord db userId = runExceptT $ do @@ -309,7 +325,7 @@ setUserDeleted :: DB.Connection -> UserId -> IO (Either StoreError [ConnId]) setUserDeleted db userId = runExceptT $ do ExceptT $ checkUser db userId liftIO $ do - DB.execute db "UPDATE users SET deleted = ? WHERE user_id = ?" (True, userId) + DB.execute db "UPDATE users SET deleted = ? WHERE user_id = ?" (BI True, userId) map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE user_id = ?" (Only userId) deleteUserWithoutConns :: DB.Connection -> UserId -> IO Bool @@ -324,7 +340,7 @@ deleteUserWithoutConns db userId = do AND u.deleted = ? AND NOT EXISTS (SELECT c.conn_id FROM connections c WHERE c.user_id = u.user_id) |] - (userId, True) + (userId, BI True) case userId_ of Just _ -> DB.execute db "DELETE FROM users WHERE user_id = ?" (Only userId) $> True _ -> pure False @@ -340,7 +356,7 @@ deleteUsersWithoutConns db = do WHERE u.deleted = ? AND NOT EXISTS (SELECT c.conn_id FROM connections c WHERE c.user_id = u.user_id) |] - (Only True) + (Only (BI True)) forM_ userIds $ DB.execute db "DELETE FROM users WHERE user_id = ?" . Only pure userIds @@ -394,7 +410,7 @@ createConnRecord db connId ConnData {userId, connAgentVersion, enableNtfs, pqSup INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, pq_support, duplex_handshake) VALUES (?,?,?,?,?,?,?) |] - (userId, connId, cMode, connAgentVersion, enableNtfs, pqSupport, True) + (userId, connId, cMode, connAgentVersion, BI enableNtfs, pqSupport, BI True) checkConfirmedSndQueueExists_ :: DB.Connection -> NewSndQueue -> IO Bool checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do @@ -481,14 +497,14 @@ addConnSndQueue_ db connId sq@SndQueue {server} = do setRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO () setRcvQueueStatus db RcvQueue {rcvId, server = ProtocolServer {host, port}} status = -- ? return error if queue does not exist? - DB.executeNamed + DB.execute db [sql| UPDATE rcv_queues - SET status = :status - WHERE host = :host AND port = :port AND rcv_id = :rcv_id; + SET status = ? + WHERE host = ? AND port = ? AND rcv_id = ? |] - [":status" := status, ":host" := host, ":port" := port, ":rcv_id" := rcvId] + (status, host, port, rcvId) setRcvSwitchStatus :: DB.Connection -> RcvQueue -> Maybe RcvSwitchStatus -> IO RcvQueue setRcvSwitchStatus db rq@RcvQueue {rcvId, server = ProtocolServer {host, port}} rcvSwchStatus = do @@ -515,34 +531,28 @@ setRcvQueueDeleted db RcvQueue {rcvId, server = ProtocolServer {host, port}} = d setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> VersionSMPC -> IO () setRcvQueueConfirmedE2E db RcvQueue {rcvId, server = ProtocolServer {host, port}} e2eDhSecret smpClientVersion = - DB.executeNamed + DB.execute db [sql| UPDATE rcv_queues - SET e2e_dh_secret = :e2e_dh_secret, - status = :status, - smp_client_version = :smp_client_version - WHERE host = :host AND port = :port AND rcv_id = :rcv_id + SET e2e_dh_secret = ?, + status = ?, + smp_client_version = ? + WHERE host = ? AND port = ? AND rcv_id = ? |] - [ ":status" := Confirmed, - ":e2e_dh_secret" := e2eDhSecret, - ":smp_client_version" := smpClientVersion, - ":host" := host, - ":port" := port, - ":rcv_id" := rcvId - ] + (e2eDhSecret, Confirmed, smpClientVersion, host, port, rcvId) setSndQueueStatus :: DB.Connection -> SndQueue -> QueueStatus -> IO () setSndQueueStatus db SndQueue {sndId, server = ProtocolServer {host, port}} status = -- ? return error if queue does not exist? - DB.executeNamed + DB.execute db [sql| UPDATE snd_queues - SET status = :status - WHERE host = :host AND port = :port AND snd_id = :snd_id; + SET status = ? + WHERE host = ? AND port = ? AND snd_id = ? |] - [":status" := status, ":host" := host, ":port" := port, ":snd_id" := sndId] + (status, host, port, sndId) setSndSwitchStatus :: DB.Connection -> SndQueue -> Maybe SndSwitchStatus -> IO SndQueue setSndSwitchStatus db sq@SndQueue {sndId, server = ProtocolServer {host, port}} sndSwchStatus = do @@ -558,19 +568,19 @@ setSndSwitchStatus db sq@SndQueue {sndId, server = ProtocolServer {host, port}} setRcvQueuePrimary :: DB.Connection -> ConnId -> RcvQueue -> IO () setRcvQueuePrimary db connId RcvQueue {dbQueueId} = do - DB.execute db "UPDATE rcv_queues SET rcv_primary = ? WHERE conn_id = ?" (False, connId) + DB.execute db "UPDATE rcv_queues SET rcv_primary = ? WHERE conn_id = ?" (BI False, connId) DB.execute db "UPDATE rcv_queues SET rcv_primary = ?, replace_rcv_queue_id = ? WHERE conn_id = ? AND rcv_queue_id = ?" - (True, Nothing :: Maybe Int64, connId, dbQueueId) + (BI True, Nothing :: Maybe Int64, connId, dbQueueId) setSndQueuePrimary :: DB.Connection -> ConnId -> SndQueue -> IO () setSndQueuePrimary db connId SndQueue {dbQueueId} = do - DB.execute db "UPDATE snd_queues SET snd_primary = ? WHERE conn_id = ?" (False, connId) + DB.execute db "UPDATE snd_queues SET snd_primary = ? WHERE conn_id = ?" (BI False, connId) DB.execute db "UPDATE snd_queues SET snd_primary = ?, replace_snd_queue_id = ? WHERE conn_id = ? AND snd_queue_id = ?" - (True, Nothing :: Maybe Int64, connId, dbQueueId) + (BI True, Nothing :: Maybe Int64, connId, dbQueueId) incRcvDeleteErrors :: DB.Connection -> RcvQueue -> IO () incRcvDeleteErrors db RcvQueue {connId, dbQueueId} = @@ -592,12 +602,12 @@ getPrimaryRcvQueue db connId = getRcvQueue :: DB.Connection -> ConnId -> SMPServer -> SMP.RecipientId -> IO (Either StoreError RcvQueue) getRcvQueue db connId (SMPServer host port _) rcvId = firstRow toRcvQueue SEConnNotFound $ - DB.query db (rcvQueueQuery <> "WHERE q.conn_id = ? AND q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 0") (connId, host, port, rcvId) + DB.query db (rcvQueueQuery <> " WHERE q.conn_id = ? AND q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 0") (connId, host, port, rcvId) getDeletedRcvQueue :: DB.Connection -> ConnId -> SMPServer -> SMP.RecipientId -> IO (Either StoreError RcvQueue) getDeletedRcvQueue db connId (SMPServer host port _) rcvId = firstRow toRcvQueue SEConnNotFound $ - DB.query db (rcvQueueQuery <> "WHERE q.conn_id = ? AND q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 1") (connId, host, port, rcvId) + DB.query db (rcvQueueQuery <> " WHERE q.conn_id = ? AND q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 1") (connId, host, port, rcvId) setRcvQueueNtfCreds :: DB.Connection -> ConnId -> Maybe ClientNtfCreds -> IO () setRcvQueueNtfCreds db connId clientNtfCreds = @@ -635,21 +645,19 @@ createConfirmation db gVar NewConfirmation {connId, senderConf = SMPConfirmation INSERT INTO conn_confirmations (confirmation_id, conn_id, sender_key, e2e_snd_pub_key, ratchet_state, sender_conn_info, smp_reply_queues, smp_client_version, accepted) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0); |] - (confirmationId, connId, senderKey, e2ePubKey, ratchetState, connInfo, smpReplyQueues, smpClientVersion) + (Binary confirmationId, connId, senderKey, e2ePubKey, ratchetState, Binary connInfo, smpReplyQueues, smpClientVersion) acceptConfirmation :: DB.Connection -> ConfirmationId -> ConnInfo -> IO (Either StoreError AcceptedConfirmation) acceptConfirmation db confirmationId ownConnInfo = do - DB.executeNamed + DB.execute db [sql| UPDATE conn_confirmations SET accepted = 1, - own_conn_info = :own_conn_info - WHERE confirmation_id = :confirmation_id; + own_conn_info = ? + WHERE confirmation_id = ? |] - [ ":own_conn_info" := ownConnInfo, - ":confirmation_id" := confirmationId - ] + (Binary ownConnInfo, Binary confirmationId) firstRow confirmation SEConfirmationNotFound $ DB.query db @@ -658,7 +666,7 @@ acceptConfirmation db confirmationId ownConnInfo = do FROM conn_confirmations WHERE confirmation_id = ?; |] - (Only confirmationId) + (Only (Binary confirmationId)) where confirmation ((connId, ratchetState) :. confRow) = AcceptedConfirmation @@ -692,13 +700,13 @@ getAcceptedConfirmation db connId = removeConfirmations :: DB.Connection -> ConnId -> IO () removeConfirmations db connId = - DB.executeNamed + DB.execute db [sql| DELETE FROM conn_confirmations - WHERE conn_id = :conn_id; + WHERE conn_id = ? |] - [":conn_id" := connId] + (Only connId) createInvitation :: DB.Connection -> TVar ChaChaDRG -> NewInvitation -> IO (Either StoreError InvitationId) createInvitation db gVar NewInvitation {contactConnId, connReq, recipientConnInfo} = @@ -709,7 +717,7 @@ createInvitation db gVar NewInvitation {contactConnId, connReq, recipientConnInf INSERT INTO conn_invitations (invitation_id, contact_conn_id, cr_invitation, recipient_conn_info, accepted) VALUES (?, ?, ?, ?, 0); |] - (invitationId, contactConnId, connReq, recipientConnInfo) + (Binary invitationId, contactConnId, connReq, Binary recipientConnInfo) getInvitation :: DB.Connection -> String -> InvitationId -> IO (Either StoreError Invitation) getInvitation db cxt invitationId = @@ -722,34 +730,32 @@ getInvitation db cxt invitationId = WHERE invitation_id = ? AND accepted = 0 |] - (Only invitationId) + (Only (Binary invitationId)) where - invitation (contactConnId, connReq, recipientConnInfo, ownConnInfo, accepted) = + invitation (contactConnId, connReq, recipientConnInfo, ownConnInfo, BI accepted) = Invitation {invitationId, contactConnId, connReq, recipientConnInfo, ownConnInfo, accepted} acceptInvitation :: DB.Connection -> InvitationId -> ConnInfo -> IO () acceptInvitation db invitationId ownConnInfo = - DB.executeNamed + DB.execute db [sql| UPDATE conn_invitations SET accepted = 1, - own_conn_info = :own_conn_info - WHERE invitation_id = :invitation_id + own_conn_info = ? + WHERE invitation_id = ? |] - [ ":own_conn_info" := ownConnInfo, - ":invitation_id" := invitationId - ] + (Binary ownConnInfo, Binary invitationId) unacceptInvitation :: DB.Connection -> InvitationId -> IO () unacceptInvitation db invitationId = - DB.execute db "UPDATE conn_invitations SET accepted = 0, own_conn_info = NULL WHERE invitation_id = ?" (Only invitationId) + DB.execute db "UPDATE conn_invitations SET accepted = 0, own_conn_info = NULL WHERE invitation_id = ?" (Only (Binary invitationId)) deleteInvitation :: DB.Connection -> ConnId -> InvitationId -> IO (Either StoreError ()) deleteInvitation db contactConnId invId = getConn db contactConnId $>>= \case SomeConn SCContact _ -> - Right <$> DB.execute db "DELETE FROM conn_invitations WHERE contact_conn_id = ? AND invitation_id = ?" (contactConnId, invId) + Right <$> DB.execute db "DELETE FROM conn_invitations WHERE contact_conn_id = ? AND invitation_id = ?" (contactConnId, Binary invId) _ -> pure $ Left SEConnNotFound updateRcvIds :: DB.Connection -> ConnId -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) @@ -919,7 +925,7 @@ setMsgUserAck db connId agentMsgId = runExceptT $ do ExceptT . firstRow id SEMsgNotFound $ DB.query db "SELECT rcv_queue_id, broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?" (connId, agentMsgId) rq <- ExceptT $ getRcvQueueById db connId dbRcvId - liftIO $ DB.execute db "UPDATE rcv_messages SET user_ack = ? WHERE conn_id = ? AND internal_id = ?" (True, connId, agentMsgId) + liftIO $ DB.execute db "UPDATE rcv_messages SET user_ack = ? WHERE conn_id = ? AND internal_id = ?" (BI True, connId, agentMsgId) pure (rq, srvMsgId) getRcvMsg :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError RcvMsg) @@ -953,10 +959,10 @@ getLastMsg db connId msgId = LEFT JOIN snd_messages s ON s.conn_id = r.conn_id AND s.rcpt_internal_id = r.internal_id WHERE r.conn_id = ? AND r.broker_id = ? |] - (connId, msgId) + (connId, Binary msgId) -toRcvMsg :: (Int64, InternalTs, BrokerId, BrokerTs) :. (AgentMsgId, MsgIntegrity, MsgHash, AgentMessageType, MsgBody, PQEncryption, Maybe AgentMsgId, Maybe MsgReceiptStatus, Bool) -> RcvMsg -toRcvMsg ((agentMsgId, internalTs, brokerId, brokerTs) :. (sndMsgId, integrity, internalHash, msgType, msgBody, pqEncryption, rcptInternalId_, rcptStatus_, userAck)) = +toRcvMsg :: (Int64, InternalTs, BrokerId, BrokerTs) :. (AgentMsgId, MsgIntegrity, MsgHash, AgentMessageType, MsgBody, PQEncryption, Maybe AgentMsgId, Maybe MsgReceiptStatus, BoolInt) -> RcvMsg +toRcvMsg ((agentMsgId, internalTs, brokerId, brokerTs) :. (sndMsgId, integrity, internalHash, msgType, msgBody, pqEncryption, rcptInternalId_, rcptStatus_, BI userAck)) = let msgMeta = MsgMeta {recipient = (agentMsgId, internalTs), broker = (brokerId, brokerTs), sndMsgId, integrity, pqEncryption} msgReceipt = MsgReceipt <$> rcptInternalId_ <*> rcptStatus_ in RcvMsg {internalId = InternalId agentMsgId, msgMeta, msgType, msgBody, internalHash, msgReceipt, userAck} @@ -969,13 +975,13 @@ checkRcvMsgHashExists db connId hash = do ( DB.query db "SELECT 1 FROM encrypted_rcv_message_hashes WHERE conn_id = ? AND hash = ? LIMIT 1" - (connId, hash) + (connId, Binary hash) ) getRcvMsgBrokerTs :: DB.Connection -> ConnId -> SMP.MsgId -> IO (Either StoreError BrokerTs) getRcvMsgBrokerTs db connId msgId = firstRow fromOnly SEMsgNotFound $ - DB.query db "SELECT broker_ts FROM rcv_messages WHERE conn_id = ? AND broker_id = ?" (connId, msgId) + DB.query db "SELECT broker_ts FROM rcv_messages WHERE conn_id = ? AND broker_id = ?" (connId, Binary msgId) deleteMsg :: DB.Connection -> ConnId -> InternalId -> IO () deleteMsg db connId msgId = @@ -983,7 +989,11 @@ deleteMsg db connId msgId = deleteMsgContent :: DB.Connection -> ConnId -> InternalId -> IO () deleteMsgContent db connId msgId = - DB.execute db "UPDATE messages SET msg_body = x'' WHERE conn_id = ? AND internal_id = ?;" (connId, msgId) +#if defined(dbPostgres) + DB.execute db "UPDATE messages SET msg_body = ''::BYTEA WHERE conn_id = ? AND internal_id = ?" (connId, msgId) +#else + DB.execute db "UPDATE messages SET msg_body = x'' WHERE conn_id = ? AND internal_id = ?" (connId, msgId) +#endif deleteDeliveredSndMsg :: DB.Connection -> ConnId -> InternalId -> IO () deleteDeliveredSndMsg db connId msgId = do @@ -1052,20 +1062,20 @@ setRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem = -- TODO remove the columns for public keys in v5.7. createRatchet :: DB.Connection -> ConnId -> RatchetX448 -> IO () createRatchet db connId rc = - DB.executeNamed + DB.execute db [sql| INSERT INTO ratchets (conn_id, ratchet_state) - VALUES (:conn_id, :ratchet_state) + VALUES (?, ?) ON CONFLICT (conn_id) DO UPDATE SET - ratchet_state = :ratchet_state, + ratchet_state = ?, x3dh_priv_key_1 = NULL, x3dh_priv_key_2 = NULL, x3dh_pub_key_1 = NULL, x3dh_pub_key_2 = NULL, pq_priv_kem = NULL |] - [":conn_id" := connId, ":ratchet_state" := rc] + (connId, rc, rc) deleteRatchet :: DB.Connection -> ConnId -> IO () deleteRatchet db connId = @@ -1106,12 +1116,18 @@ createCommand db corrId connId srv_ cmd = runExceptT $ do DB.execute db "INSERT INTO commands (host, port, corr_id, conn_id, command_tag, command, server_key_hash, created_at) VALUES (?,?,?,?,?,?,?,?)" - (host_, port_, corrId, connId, cmdTag, cmd, serverKeyHash_, createdAt) + (host_, port_, Binary corrId, connId, cmdTag, cmd, serverKeyHash_, createdAt) where cmdTag = agentCommandTag cmd +#if defined(dbPostgres) + handleErr e = case constraintViolation e of + Just _ -> logError $ "tried to create command " <> tshow cmdTag <> " for deleted connection" + Nothing -> E.throwIO e +#else handleErr e | SQL.sqlError e == SQL.ErrorConstraint = logError $ "tried to create command " <> tshow cmdTag <> " for deleted connection" | otherwise = E.throwIO e +#endif serverFields :: ExceptT StoreError IO (Maybe (NonEmpty TransportHost), Maybe ServiceName, Maybe C.KeyHash) serverFields = case srv_ of Just srv@(SMPServer host port _) -> @@ -1119,7 +1135,13 @@ createCommand db corrId connId srv_ cmd = runExceptT $ do Nothing -> pure (Nothing, Nothing, Nothing) insertedRowId :: DB.Connection -> IO Int64 -insertedRowId db = fromOnly . head <$> DB.query_ db "SELECT last_insert_rowid()" +insertedRowId db = fromOnly . head <$> DB.query_ db q + where +#if defined(dbPostgres) + q = "SELECT lastval()" +#else + q = "SELECT last_insert_rowid()" +#endif getPendingCommandServers :: DB.Connection -> ConnId -> IO [Maybe SMPServer] getPendingCommandServers db connId = do @@ -1408,7 +1430,7 @@ supervisorUpdateNtfSub db NtfSubscription {connId, smpServer = (SMPServer smpHos WHERE conn_id = ? |] ( (smpHost, smpPort, ntfQueueId, ntfHost, ntfPort, ntfSubId) - :. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, ts, True, ts, connId) + :. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, ts, BI True, ts, connId) ) where (ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action @@ -1423,13 +1445,13 @@ supervisorUpdateNtfAction db connId action = do SET ntf_sub_action = ?, ntf_sub_smp_action = ?, ntf_sub_action_ts = ?, updated_by_supervisor = ?, updated_at = ? WHERE conn_id = ? |] - (ntfSubAction, ntfSubSMPAction, ts, True, ts, connId) + (ntfSubAction, ntfSubSMPAction, ts, BI True, ts, connId) where (ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action updateNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> NtfActionTs -> IO () updateNtfSubscription db NtfSubscription {connId, ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} action actionTs = do - r <- maybeFirstRow fromOnly $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) + r <- maybeFirstRow fromOnlyBI $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) forM_ r $ \updatedBySupervisor -> do updatedAt <- getCurrentTime if updatedBySupervisor @@ -1441,7 +1463,7 @@ updateNtfSubscription db NtfSubscription {connId, ntfQueueId, ntfServer = (NtfSe SET smp_ntf_id = ?, ntf_sub_id = ?, ntf_sub_status = ?, updated_by_supervisor = ?, updated_at = ? WHERE conn_id = ? |] - (ntfQueueId, ntfSubId, ntfSubStatus, False, updatedAt, connId) + (ntfQueueId, ntfSubId, ntfSubStatus, BI False, updatedAt, connId) else DB.execute db @@ -1450,13 +1472,13 @@ updateNtfSubscription db NtfSubscription {connId, ntfQueueId, ntfServer = (NtfSe SET smp_ntf_id = ?, ntf_host = ?, ntf_port = ?, ntf_sub_id = ?, ntf_sub_status = ?, ntf_sub_action = ?, ntf_sub_smp_action = ?, ntf_sub_action_ts = ?, updated_by_supervisor = ?, updated_at = ? WHERE conn_id = ? |] - ((ntfQueueId, ntfHost, ntfPort, ntfSubId) :. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, actionTs, False, updatedAt, connId)) + ((ntfQueueId, ntfHost, ntfPort, ntfSubId) :. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, actionTs, BI False, updatedAt, connId)) where (ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action setNullNtfSubscriptionAction :: DB.Connection -> ConnId -> IO () setNullNtfSubscriptionAction db connId = do - r <- maybeFirstRow fromOnly $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) + r <- maybeFirstRow fromOnlyBI $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) forM_ r $ \updatedBySupervisor -> unless updatedBySupervisor $ do updatedAt <- getCurrentTime @@ -1467,11 +1489,11 @@ setNullNtfSubscriptionAction db connId = do SET ntf_sub_action = ?, ntf_sub_smp_action = ?, ntf_sub_action_ts = ?, updated_by_supervisor = ?, updated_at = ? WHERE conn_id = ? |] - (Nothing :: Maybe NtfSubNTFAction, Nothing :: Maybe NtfSubSMPAction, Nothing :: Maybe UTCTime, False, updatedAt, connId) + (Nothing :: Maybe NtfSubNTFAction, Nothing :: Maybe NtfSubSMPAction, Nothing :: Maybe UTCTime, BI False, updatedAt, connId) deleteNtfSubscription :: DB.Connection -> ConnId -> IO () deleteNtfSubscription db connId = do - r <- maybeFirstRow fromOnly $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) + r <- maybeFirstRow fromOnlyBI $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) forM_ r $ \updatedBySupervisor -> do updatedAt <- getCurrentTime if updatedBySupervisor @@ -1483,7 +1505,7 @@ deleteNtfSubscription db connId = do SET smp_ntf_id = ?, ntf_sub_id = ?, ntf_sub_status = ?, updated_by_supervisor = ?, updated_at = ? WHERE conn_id = ? |] - (Nothing :: Maybe SMP.NotifierId, Nothing :: Maybe NtfSubscriptionId, NASDeleted, False, updatedAt, connId) + (Nothing :: Maybe SMP.NotifierId, Nothing :: Maybe NtfSubscriptionId, NASDeleted, BI False, updatedAt, connId) else deleteNtfSubscription' db connId deleteNtfSubscription' :: DB.Connection -> ConnId -> IO () @@ -1620,7 +1642,7 @@ getNtfRcvQueue db SMPQueueNtf {smpServer = (SMPServer host port _), notifierId} setConnectionNtfs :: DB.Connection -> ConnId -> Bool -> IO () setConnectionNtfs db connId enableNtfs = - DB.execute db "UPDATE connections SET enable_ntfs = ? WHERE conn_id = ?" (enableNtfs, connId) + DB.execute db "UPDATE connections SET enable_ntfs = ? WHERE conn_id = ?" (BI enableNtfs, connId) -- * Auxiliary helpers @@ -1630,37 +1652,42 @@ instance FromField QueueStatus where fromField = fromTextField_ queueStatusT instance ToField (DBQueueId 'QSStored) where toField (DBQueueId qId) = toField qId -instance FromField (DBQueueId 'QSStored) where fromField x = DBQueueId <$> fromField x +instance FromField (DBQueueId 'QSStored) where +#if defined(dbPostgres) + fromField x dat = DBQueueId <$> fromField x dat +#else + fromField x = DBQueueId <$> fromField x +#endif instance ToField InternalRcvId where toField (InternalRcvId x) = toField x -instance FromField InternalRcvId where fromField x = InternalRcvId <$> fromField x +deriving newtype instance FromField InternalRcvId instance ToField InternalSndId where toField (InternalSndId x) = toField x -instance FromField InternalSndId where fromField x = InternalSndId <$> fromField x +deriving newtype instance FromField InternalSndId instance ToField InternalId where toField (InternalId x) = toField x -instance FromField InternalId where fromField x = InternalId <$> fromField x +deriving newtype instance FromField InternalId -instance ToField AgentMessageType where toField = toField . smpEncode +instance ToField AgentMessageType where toField = toField . Binary . smpEncode instance FromField AgentMessageType where fromField = blobFieldParser smpP -instance ToField MsgIntegrity where toField = toField . strEncode +instance ToField MsgIntegrity where toField = toField . Binary . strEncode instance FromField MsgIntegrity where fromField = blobFieldParser strP -instance ToField SMPQueueUri where toField = toField . strEncode +instance ToField SMPQueueUri where toField = toField . Binary . strEncode instance FromField SMPQueueUri where fromField = blobFieldParser strP -instance ToField AConnectionRequestUri where toField = toField . strEncode +instance ToField AConnectionRequestUri where toField = toField . Binary . strEncode instance FromField AConnectionRequestUri where fromField = blobFieldParser strP -instance ConnectionModeI c => ToField (ConnectionRequestUri c) where toField = toField . strEncode +instance ConnectionModeI c => ToField (ConnectionRequestUri c) where toField = toField . Binary . strEncode instance (E.Typeable c, ConnectionModeI c) => FromField (ConnectionRequestUri c) where fromField = blobFieldParser strP @@ -1676,7 +1703,7 @@ instance ToField MsgFlags where toField = toField . decodeLatin1 . smpEncode instance FromField MsgFlags where fromField = fromTextField_ $ eitherToMaybe . smpDecode . encodeUtf8 -instance ToField [SMPQueueInfo] where toField = toField . smpEncodeList +instance ToField [SMPQueueInfo] where toField = toField . Binary . smpEncodeList instance FromField [SMPQueueInfo] where fromField = blobFieldParser smpListP @@ -1684,11 +1711,11 @@ instance ToField (NonEmpty TransportHost) where toField = toField . decodeLatin1 instance FromField (NonEmpty TransportHost) where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8 -instance ToField AgentCommand where toField = toField . strEncode +instance ToField AgentCommand where toField = toField . Binary . strEncode instance FromField AgentCommand where fromField = blobFieldParser strP -instance ToField AgentCommandTag where toField = toField . strEncode +instance ToField AgentCommandTag where toField = toField . Binary . strEncode instance FromField AgentCommandTag where fromField = blobFieldParser strP @@ -1698,9 +1725,9 @@ instance FromField MsgReceiptStatus where fromField = fromTextField_ $ eitherToM instance ToField (Version v) where toField (Version v) = toField v -instance FromField (Version v) where fromField f = Version <$> fromField f +deriving newtype instance FromField (Version v) -deriving newtype instance ToField EntityId +instance ToField EntityId where toField (EntityId s) = toField $ Binary s deriving newtype instance FromField EntityId @@ -1718,9 +1745,14 @@ firstRow f e a = second f . listToEither e <$> a maybeFirstRow :: Functor f => (a -> b) -> f [a] -> f (Maybe b) maybeFirstRow f q = fmap f . listToMaybe <$> q +fromOnlyBI :: Only BoolInt -> Bool +fromOnlyBI (Only (BI b)) = b +{-# INLINE fromOnlyBI #-} + firstRow' :: (a -> Either e b) -> e -> IO [a] -> IO (Either e b) firstRow' f e a = (f <=< listToEither e) <$> a +#if !defined(dbPostgres) {- ORMOLU_DISABLE -} -- SQLite.Simple only has these up to 10 fields, which is insufficient for some of our queries instance (FromField a, FromField b, FromField c, FromField d, FromField e, @@ -1748,6 +1780,7 @@ instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f, ] {- ORMOLU_ENABLE -} +#endif -- * Server helper @@ -1771,16 +1804,16 @@ getServerKeyHash_ db ProtocolServer {host, port, keyHash} = do upsertNtfServer_ :: DB.Connection -> NtfServer -> IO () upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do - DB.executeNamed + DB.execute db [sql| - INSERT INTO ntf_servers (ntf_host, ntf_port, ntf_key_hash) VALUES (:host,:port,:key_hash) + INSERT INTO ntf_servers (ntf_host, ntf_port, ntf_key_hash) VALUES (?,?,?) ON CONFLICT (ntf_host, ntf_port) DO UPDATE SET ntf_host=excluded.ntf_host, ntf_port=excluded.ntf_port, ntf_key_hash=excluded.ntf_key_hash; |] - [":host" := host, ":port" := port, ":key_hash" := keyHash] + (host, port, keyHash) -- * createRcvConn helpers @@ -1796,7 +1829,7 @@ insertRcvQueue_ db connId' rq@RcvQueue {..} serverKeyHash_ = do INSERT INTO rcv_queues (host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, snd_secure, status, rcv_queue_id, rcv_primary, replace_rcv_queue_id, smp_client_version, server_key_hash) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?); |] - ((host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, sndSecure, status, qId, primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_)) + ((host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, BI sndSecure, status, qId, BI primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_)) pure (rq :: NewRcvQueue) {connId = connId', dbQueueId = qId} -- * createSndConn helpers @@ -1810,10 +1843,29 @@ insertSndQueue_ db connId' sq@SndQueue {..} serverKeyHash_ = do DB.execute db [sql| - INSERT OR REPLACE INTO snd_queues - (host, port, snd_id, snd_secure, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, snd_queue_id, snd_primary, replace_snd_queue_id, smp_client_version, server_key_hash) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?); + INSERT INTO snd_queues + (host, port, snd_id, snd_secure, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, + status, snd_queue_id, snd_primary, replace_snd_queue_id, smp_client_version, server_key_hash) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) + ON CONFLICT (host, port, snd_id) DO UPDATE SET + host=EXCLUDED.host, + port=EXCLUDED.port, + snd_id=EXCLUDED.snd_id, + snd_secure=EXCLUDED.snd_secure, + conn_id=EXCLUDED.conn_id, + snd_public_key=EXCLUDED.snd_public_key, + snd_private_key=EXCLUDED.snd_private_key, + e2e_pub_key=EXCLUDED.e2e_pub_key, + e2e_dh_secret=EXCLUDED.e2e_dh_secret, + status=EXCLUDED.status, + snd_queue_id=EXCLUDED.snd_queue_id, + snd_primary=EXCLUDED.snd_primary, + replace_snd_queue_id=EXCLUDED.replace_snd_queue_id, + smp_client_version=EXCLUDED.smp_client_version, + server_key_hash=EXCLUDED.server_key_hash |] - ((host server, port server, sndId, sndSecure, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret) :. (status, qId, primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_)) + ((host server, port server, sndId, BI sndSecure, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret) + :. (status, qId, BI primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_)) pure (sq :: NewSndQueue) {connId = connId', dbQueueId = qId} newQueueId_ :: [Only Int64] -> DBQueueId 'QSStored @@ -1875,8 +1927,8 @@ getConnData db connId' = |] (Only connId') where - cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport) = - (ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode) + cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, BI deleted, ratchetSyncState, pqSupport) = + (ConnData {userId, connId, connAgentVersion, enableNtfs = maybe True unBI enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode) setConnDeleted :: DB.Connection -> Bool -> ConnId -> IO () setConnDeleted db waitDelivery connId @@ -1884,7 +1936,7 @@ setConnDeleted db waitDelivery connId currentTs <- getCurrentTime DB.execute db "UPDATE connections SET deleted_at_wait_delivery = ? WHERE conn_id = ?" (currentTs, connId) | otherwise = - DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId) + DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (BI True, connId) setConnUserId :: DB.Connection -> UserId -> ConnId -> UserId -> IO () setConnUserId db oldUserId connId newUserId = @@ -1899,7 +1951,7 @@ setConnPQSupport db connId pqSupport = DB.execute db "UPDATE connections SET pq_support = ? WHERE conn_id = ?" (pqSupport, connId) getDeletedConnIds :: DB.Connection -> IO [ConnId] -getDeletedConnIds db = map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE deleted = ?" (Only True) +getDeletedConnIds db = map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE deleted = ?" (Only (BI True)) getDeletedWaitingDeliveryConnIds :: DB.Connection -> IO [ConnId] getDeletedWaitingDeliveryConnIds db = @@ -1911,7 +1963,7 @@ setConnRatchetSync db connId ratchetSyncState = addProcessedRatchetKeyHash :: DB.Connection -> ConnId -> ByteString -> IO () addProcessedRatchetKeyHash db connId hash = - DB.execute db "INSERT INTO processed_ratchet_key_hashes (conn_id, hash) VALUES (?,?)" (connId, hash) + DB.execute db "INSERT INTO processed_ratchet_key_hashes (conn_id, hash) VALUES (?,?)" (connId, Binary hash) checkRatchetKeyHashExists :: DB.Connection -> ConnId -> ByteString -> IO Bool checkRatchetKeyHashExists db connId hash = do @@ -1921,7 +1973,7 @@ checkRatchetKeyHashExists db connId hash = do ( DB.query db "SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1" - (connId, hash) + (connId, Binary hash) ) deleteRatchetKeyHashesExpired :: DB.Connection -> NominalDiffTime -> IO () @@ -1933,7 +1985,7 @@ deleteRatchetKeyHashesExpired db ttl = do getRcvQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty RcvQueue)) getRcvQueuesByConnId_ db connId = L.nonEmpty . sortBy primaryFirst . map toRcvQueue - <$> DB.query db (rcvQueueQuery <> "WHERE q.conn_id = ? AND q.deleted = 0") (Only connId) + <$> DB.query db (rcvQueueQuery <> " WHERE q.conn_id = ? AND q.deleted = 0") (Only connId) where primaryFirst RcvQueue {primary = p, dbReplaceQueueId = i} RcvQueue {primary = p', dbReplaceQueueId = i'} = -- the current primary queue is ordered first, the next primary - second @@ -1952,11 +2004,11 @@ rcvQueueQuery = |] toRcvQueue :: - (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, SenderCanSecure) - :. (QueueStatus, DBQueueId 'QSStored, Bool, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int) + (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, BoolInt) + :. (QueueStatus, DBQueueId 'QSStored, BoolInt, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int) :. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) -> RcvQueue -toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndSecure) :. (status, dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) = +toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, BI sndSecure) :. (status, dbQueueId, BI primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) = let server = SMPServer host port keyHash smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_ clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of @@ -1973,7 +2025,7 @@ getRcvQueueById db connId dbRcvId = getSndQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty SndQueue)) getSndQueuesByConnId_ dbConn connId = L.nonEmpty . sortBy primaryFirst . map toSndQueue - <$> DB.query dbConn (sndQueueQuery <> "WHERE q.conn_id = ?") (Only connId) + <$> DB.query dbConn (sndQueueQuery <> " WHERE q.conn_id = ?") (Only connId) where primaryFirst SndQueue {primary = p, dbReplaceQueueId = i} SndQueue {primary = p', dbReplaceQueueId = i'} = -- the current primary queue is ordered first, the next primary - second @@ -1992,14 +2044,14 @@ sndQueueQuery = |] toSndQueue :: - (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SenderId, SenderCanSecure) + (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SenderId, BoolInt) :. (Maybe SndPublicAuthKey, SndPrivateAuthKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus) - :. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe SndSwitchStatus, VersionSMPC) -> + :. (DBQueueId 'QSStored, BoolInt, Maybe Int64, Maybe SndSwitchStatus, VersionSMPC) -> SndQueue toSndQueue - ( (userId, keyHash, connId, host, port, sndId, sndSecure) + ( (userId, keyHash, connId, host, port, sndId, BI sndSecure) :. (sndPubKey, sndPrivateKey@(C.APrivateAuthKey a pk), e2ePubKey, e2eDhSecret, status) - :. (dbQueueId, primary, dbReplaceQueueId, sndSwchStatus, smpClientVersion) + :. (dbQueueId, BI primary, dbReplaceQueueId, sndSwchStatus, smpClientVersion) ) = let server = SMPServer host port keyHash sndPublicKey = fromMaybe (C.APublicAuthKey a (C.publicKey pk)) sndPubKey @@ -2015,30 +2067,27 @@ getSndQueueById db connId dbSndId = retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnId -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) retrieveLastIdsAndHashRcv_ dbConn connId = do [(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <- - DB.queryNamed + DB.query dbConn [sql| SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash FROM connections - WHERE conn_id = :conn_id; + WHERE conn_id = ? |] - [":conn_id" := connId] + (Only connId) return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) updateLastIdsRcv_ :: DB.Connection -> ConnId -> InternalId -> InternalRcvId -> IO () updateLastIdsRcv_ dbConn connId newInternalId newInternalRcvId = - DB.executeNamed + DB.execute dbConn [sql| UPDATE connections - SET last_internal_msg_id = :last_internal_msg_id, - last_internal_rcv_msg_id = :last_internal_rcv_msg_id - WHERE conn_id = :conn_id; + SET last_internal_msg_id = ?, + last_internal_rcv_msg_id = ? + WHERE conn_id = ? |] - [ ":last_internal_msg_id" := newInternalId, - ":last_internal_rcv_msg_id" := newInternalRcvId, - ":conn_id" := connId - ] + (newInternalId, newInternalRcvId, connId) -- * createRcvMsg helpers @@ -2052,12 +2101,12 @@ insertRcvMsgBase_ dbConn connId RcvMsgData {msgMeta, msgType, msgFlags, msgBody, (conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body, pq_encryption) VALUES (?,?,?,?,?,?,?,?,?); |] - (connId, internalId, internalTs, internalRcvId, Nothing :: Maybe Int64, msgType, msgFlags, msgBody, pqEncryption) + (connId, internalId, internalTs, internalRcvId, Nothing :: Maybe Int64, msgType, msgFlags, Binary msgBody, pqEncryption) insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO () insertRcvMsgDetails_ db connId RcvQueue {dbQueueId} RcvMsgData {msgMeta, internalRcvId, internalHash, externalPrevSndHash, encryptedMsgHash} = do let MsgMeta {integrity, recipient, broker, sndMsgId} = msgMeta - DB.executeNamed + DB.execute db [sql| INSERT INTO rcv_messages @@ -2065,69 +2114,50 @@ insertRcvMsgDetails_ db connId RcvQueue {dbQueueId} RcvMsgData {msgMeta, interna broker_id, broker_ts, internal_hash, external_prev_snd_hash, integrity) VALUES - (:conn_id,:rcv_queue_id,:internal_rcv_id,:internal_id,:external_snd_id, - :broker_id,:broker_ts, - :internal_hash,:external_prev_snd_hash,:integrity); + (?,?,?,?,?,?,?,?,?,?) |] - [ ":conn_id" := connId, - ":rcv_queue_id" := dbQueueId, - ":internal_rcv_id" := internalRcvId, - ":internal_id" := fst recipient, - ":external_snd_id" := sndMsgId, - ":broker_id" := fst broker, - ":broker_ts" := snd broker, - ":internal_hash" := internalHash, - ":external_prev_snd_hash" := externalPrevSndHash, - ":integrity" := integrity - ] - DB.execute db "INSERT INTO encrypted_rcv_message_hashes (conn_id, hash) VALUES (?,?)" (connId, encryptedMsgHash) + (connId, dbQueueId, internalRcvId, fst recipient, sndMsgId, Binary (fst broker), snd broker, Binary internalHash, Binary externalPrevSndHash, integrity) + DB.execute db "INSERT INTO encrypted_rcv_message_hashes (conn_id, hash) VALUES (?,?)" (connId, Binary encryptedMsgHash) updateRcvMsgHash :: DB.Connection -> ConnId -> AgentMsgId -> InternalRcvId -> MsgHash -> IO () updateRcvMsgHash db connId sndMsgId internalRcvId internalHash = - DB.executeNamed + DB.execute db -- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved [sql| UPDATE connections - SET last_external_snd_msg_id = :last_external_snd_msg_id, - last_rcv_msg_hash = :last_rcv_msg_hash - WHERE conn_id = :conn_id - AND last_internal_rcv_msg_id = :last_internal_rcv_msg_id; + SET last_external_snd_msg_id = ?, + last_rcv_msg_hash = ? + WHERE conn_id = ? + AND last_internal_rcv_msg_id = ? |] - [ ":last_external_snd_msg_id" := sndMsgId, - ":last_rcv_msg_hash" := internalHash, - ":conn_id" := connId, - ":last_internal_rcv_msg_id" := internalRcvId - ] + (sndMsgId, Binary internalHash, connId, internalRcvId) -- * updateSndIds helpers retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnId -> IO (Either StoreError (InternalId, InternalSndId, PrevSndMsgHash)) retrieveLastIdsAndHashSnd_ dbConn connId = do firstRow id SEConnNotFound $ - DB.queryNamed + DB.query dbConn [sql| SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash FROM connections - WHERE conn_id = :conn_id; + WHERE conn_id = ? |] - [":conn_id" := connId] + (Only connId) updateLastIdsSnd_ :: DB.Connection -> ConnId -> InternalId -> InternalSndId -> IO () updateLastIdsSnd_ dbConn connId newInternalId newInternalSndId = - DB.executeNamed + DB.execute dbConn [sql| UPDATE connections - SET last_internal_msg_id = :last_internal_msg_id, - last_internal_snd_msg_id = :last_internal_snd_msg_id - WHERE conn_id = :conn_id; + SET last_internal_msg_id = ?, + last_internal_snd_msg_id = ? + WHERE conn_id = ? |] - [ ":last_internal_msg_id" := newInternalId, - ":last_internal_snd_msg_id" := newInternalSndId, - ":conn_id" := connId - ] + (newInternalId, newInternalSndId, connId) -- * createSndMsg helpers @@ -2141,40 +2171,32 @@ insertSndMsgBase_ db connId SndMsgData {internalId, internalTs, internalSndId, m VALUES (?,?,?,?,?,?,?,?,?); |] - (connId, internalId, internalTs, Nothing :: Maybe Int64, internalSndId, msgType, msgFlags, msgBody, pqEncryption) + (connId, internalId, internalTs, Nothing :: Maybe Int64, internalSndId, msgType, msgFlags, Binary msgBody, pqEncryption) insertSndMsgDetails_ :: DB.Connection -> ConnId -> SndMsgData -> IO () insertSndMsgDetails_ dbConn connId SndMsgData {..} = - DB.executeNamed + DB.execute dbConn [sql| INSERT INTO snd_messages ( conn_id, internal_snd_id, internal_id, internal_hash, previous_msg_hash) VALUES - (:conn_id,:internal_snd_id,:internal_id,:internal_hash,:previous_msg_hash); + (?,?,?,?,?) |] - [ ":conn_id" := connId, - ":internal_snd_id" := internalSndId, - ":internal_id" := internalId, - ":internal_hash" := internalHash, - ":previous_msg_hash" := prevMsgHash - ] + (connId, internalSndId, internalId, Binary internalHash, Binary prevMsgHash) updateSndMsgHash :: DB.Connection -> ConnId -> InternalSndId -> MsgHash -> IO () updateSndMsgHash db connId internalSndId internalHash = - DB.executeNamed + DB.execute db -- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved [sql| UPDATE connections - SET last_snd_msg_hash = :last_snd_msg_hash - WHERE conn_id = :conn_id - AND last_internal_snd_msg_id = :last_internal_snd_msg_id; + SET last_snd_msg_hash = ? + WHERE conn_id = ? + AND last_internal_snd_msg_id = ?; |] - [ ":last_snd_msg_hash" := internalHash, - ":conn_id" := connId, - ":last_internal_snd_msg_id" := internalSndId - ] + (Binary internalHash, connId, internalSndId) -- create record with a random ID createWithRandomId :: TVar ChaChaDRG -> (ByteString -> IO ()) -> IO (Either StoreError ByteString) @@ -2189,9 +2211,16 @@ createWithRandomId' gVar create = tryCreate 3 id' <- randomId gVar 12 E.try (create id') >>= \case Right r -> pure $ Right (id', r) - Left e - | SQL.sqlError e == SQL.ErrorConstraint -> tryCreate (n - 1) - | otherwise -> pure . Left . SEInternal $ bshow e + Left e -> handleErr n e +#if defined(dbPostgres) + handleErr n e = case constraintViolation e of + Just _ -> tryCreate (n - 1) + Nothing -> pure . Left . SEInternal $ bshow e +#else + handleErr n e + | SQL.sqlError e == SQL.ErrorConstraint = tryCreate (n - 1) + | otherwise = pure . Left . SEInternal $ bshow e +#endif randomId :: TVar ChaChaDRG -> Int -> IO ByteString randomId gVar n = atomically $ U.encode <$> C.randomBytes n gVar @@ -2258,7 +2287,7 @@ insertRcvFile db gVar userId FileDescription {size, digest, key, nonce, chunkSiz DB.execute db "INSERT INTO rcv_files (rcv_file_entity_id, user_id, size, digest, key, nonce, chunk_size, prefix_path, tmp_path, save_path, save_file_key, save_file_nonce, status, redirect_id, redirect_entity_id, redirect_digest, redirect_size, approved_relays) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" - ((rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, prefixPath, tmpPath) :. (savePath, fileKey <$> cfArgs, fileNonce <$> cfArgs, RFSReceiving, redirectId_, redirectEntityId_, redirectDigest_, redirectSize_, approvedRelays)) + ((Binary rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, prefixPath, tmpPath) :. (savePath, fileKey <$> cfArgs, fileNonce <$> cfArgs, RFSReceiving, redirectId_, Binary <$> redirectEntityId_, redirectDigest_, redirectSize_, BI approvedRelays)) rcvFileId <- liftIO $ insertedRowId db pure (rcvFileEntityId, rcvFileId) @@ -2286,7 +2315,7 @@ getRcvFileByEntityId db rcvFileEntityId = runExceptT $ do getRcvFileIdByEntityId_ :: DB.Connection -> RcvFileId -> IO (Either StoreError DBRcvFileId) getRcvFileIdByEntityId_ db rcvFileEntityId = firstRow fromOnly SEFileNotFound $ - DB.query db "SELECT rcv_file_id FROM rcv_files WHERE rcv_file_entity_id = ?" (Only rcvFileEntityId) + DB.query db "SELECT rcv_file_id FROM rcv_files WHERE rcv_file_entity_id = ?" (Only (Binary rcvFileEntityId)) getRcvFileRedirects :: DB.Connection -> DBRcvFileId -> IO [RcvFile] getRcvFileRedirects db rcvFileId = do @@ -2311,8 +2340,8 @@ getRcvFile db rcvFileId = runExceptT $ do |] (Only rcvFileId) where - toFile :: (RcvFileId, UserId, FileSize Int64, FileDigest, C.SbKey, C.CbNonce, FileSize Word32, FilePath, Maybe FilePath) :. (FilePath, Maybe C.SbKey, Maybe C.CbNonce, RcvFileStatus, Bool, Maybe DBRcvFileId, Maybe RcvFileId, Maybe (FileSize Int64), Maybe FileDigest) -> RcvFile - toFile ((rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, prefixPath, tmpPath) :. (savePath, saveKey_, saveNonce_, status, deleted, redirectDbId, redirectEntityId, redirectSize_, redirectDigest_)) = + toFile :: (RcvFileId, UserId, FileSize Int64, FileDigest, C.SbKey, C.CbNonce, FileSize Word32, FilePath, Maybe FilePath) :. (FilePath, Maybe C.SbKey, Maybe C.CbNonce, RcvFileStatus, BoolInt, Maybe DBRcvFileId, Maybe RcvFileId, Maybe (FileSize Int64), Maybe FileDigest) -> RcvFile + toFile ((rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, prefixPath, tmpPath) :. (savePath, saveKey_, saveNonce_, status, BI deleted, redirectDbId, redirectEntityId, redirectSize_, redirectDigest_)) = let cfArgs = CFArgs <$> saveKey_ <*> saveNonce_ saveFile = CryptoFile savePath cfArgs redirect = @@ -2355,8 +2384,8 @@ getRcvFile db rcvFileId = runExceptT $ do |] (Only chunkId) where - toReplica :: (Int64, ChunkReplicaId, C.APrivateAuthKey, Bool, Maybe Int64, Int, NonEmpty TransportHost, ServiceName, C.KeyHash) -> RcvFileChunkReplica - toReplica (rcvChunkReplicaId, replicaId, replicaKey, received, delay, retries, host, port, keyHash) = + toReplica :: (Int64, ChunkReplicaId, C.APrivateAuthKey, BoolInt, Maybe Int64, Int, NonEmpty TransportHost, ServiceName, C.KeyHash) -> RcvFileChunkReplica + toReplica (rcvChunkReplicaId, replicaId, replicaKey, BI received, delay, retries, host, port, keyHash) = let server = XFTPServer host port keyHash in RcvFileChunkReplica {rcvChunkReplicaId, server, replicaId, replicaKey, received, delay, retries} @@ -2450,8 +2479,8 @@ getNextRcvChunkToDownload db server@ProtocolServer {host, port, keyHash} ttl = d |] (Only rcvFileChunkReplicaId) where - toChunk :: ((DBRcvFileId, RcvFileId, UserId, Int64, Int, FileSize Word32, FileDigest, FilePath, Maybe FilePath) :. (Int64, ChunkReplicaId, C.APrivateAuthKey, Bool, Maybe Int64, Int) :. (Bool, Maybe RcvFileId)) -> (RcvFileChunk, Bool, Maybe RcvFileId) - toChunk ((rcvFileId, rcvFileEntityId, userId, rcvChunkId, chunkNo, chunkSize, digest, fileTmpPath, chunkTmpPath) :. (rcvChunkReplicaId, replicaId, replicaKey, received, delay, retries) :. (approvedRelays, redirectEntityId_)) = + toChunk :: ((DBRcvFileId, RcvFileId, UserId, Int64, Int, FileSize Word32, FileDigest, FilePath, Maybe FilePath) :. (Int64, ChunkReplicaId, C.APrivateAuthKey, BoolInt, Maybe Int64, Int) :. (BoolInt, Maybe RcvFileId)) -> (RcvFileChunk, Bool, Maybe RcvFileId) + toChunk ((rcvFileId, rcvFileEntityId, userId, rcvChunkId, chunkNo, chunkSize, digest, fileTmpPath, chunkTmpPath) :. (rcvChunkReplicaId, replicaId, replicaKey, BI received, delay, retries) :. (BI approvedRelays, redirectEntityId_)) = ( RcvFileChunk { rcvFileId, rcvFileEntityId, @@ -2551,7 +2580,7 @@ createSndFile db gVar userId (CryptoFile path cfArgs) numRecipients prefixPath k DB.execute db "INSERT INTO snd_files (snd_file_entity_id, user_id, path, src_file_key, src_file_nonce, num_recipients, prefix_path, key, nonce, status, redirect_size, redirect_digest) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)" - ((sndFileEntityId, userId, path, fileKey <$> cfArgs, fileNonce <$> cfArgs, numRecipients) :. (prefixPath, key, nonce, SFSNew, redirectSize_, redirectDigest_)) + ((Binary sndFileEntityId, userId, path, fileKey <$> cfArgs, fileNonce <$> cfArgs, numRecipients) :. (prefixPath, key, nonce, SFSNew, redirectSize_, redirectDigest_)) where (redirectSize_, redirectDigest_) = case redirect_ of @@ -2566,7 +2595,7 @@ getSndFileByEntityId db sndFileEntityId = runExceptT $ do getSndFileIdByEntityId_ :: DB.Connection -> SndFileId -> IO (Either StoreError DBSndFileId) getSndFileIdByEntityId_ db sndFileEntityId = firstRow fromOnly SEFileNotFound $ - DB.query db "SELECT snd_file_id FROM snd_files WHERE snd_file_entity_id = ?" (Only sndFileEntityId) + DB.query db "SELECT snd_file_id FROM snd_files WHERE snd_file_entity_id = ?" (Only (Binary sndFileEntityId)) getSndFile :: DB.Connection -> DBSndFileId -> IO (Either StoreError SndFile) getSndFile db sndFileId = runExceptT $ do @@ -2586,8 +2615,8 @@ getSndFile db sndFileId = runExceptT $ do |] (Only sndFileId) where - toFile :: (SndFileId, UserId, FilePath, Maybe C.SbKey, Maybe C.CbNonce, Int, Maybe FileDigest, Maybe FilePath, C.SbKey, C.CbNonce) :. (SndFileStatus, Bool, Maybe (FileSize Int64), Maybe FileDigest) -> SndFile - toFile ((sndFileEntityId, userId, srcPath, srcKey_, srcNonce_, numRecipients, digest, prefixPath, key, nonce) :. (status, deleted, redirectSize_, redirectDigest_)) = + toFile :: (SndFileId, UserId, FilePath, Maybe C.SbKey, Maybe C.CbNonce, Int, Maybe FileDigest, Maybe FilePath, C.SbKey, C.CbNonce) :. (SndFileStatus, BoolInt, Maybe (FileSize Int64), Maybe FileDigest) -> SndFile + toFile ((sndFileEntityId, userId, srcPath, srcKey_, srcNonce_, numRecipients, digest, prefixPath, key, nonce) :. (status, BI deleted, redirectSize_, redirectDigest_)) = let cfArgs = CFArgs <$> srcKey_ <*> srcNonce_ srcFile = CryptoFile srcPath cfArgs redirect = RedirectFileInfo <$> redirectSize_ <*> redirectDigest_ @@ -2709,7 +2738,7 @@ deleteSndFile' db sndFileId = getSndFileDeleted :: DB.Connection -> DBSndFileId -> IO Bool getSndFileDeleted db sndFileId = fromMaybe True - <$> maybeFirstRow fromOnly (DB.query db "SELECT deleted FROM snd_files WHERE snd_file_id = ?" (Only sndFileId)) + <$> maybeFirstRow fromOnlyBI (DB.query db "SELECT deleted FROM snd_files WHERE snd_file_id = ?" (Only sndFileId)) createSndFileReplica :: DB.Connection -> SndFileChunk -> NewSndChunkReplica -> IO () createSndFileReplica db SndFileChunk {sndChunkId} = createSndFileReplica_ db sndChunkId diff --git a/src/Simplex/Messaging/Agent/Store/Postgres.hs b/src/Simplex/Messaging/Agent/Store/Postgres.hs index 037ac6bb9..a4c8a52bb 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres.hs @@ -1,42 +1,94 @@ +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE ScopedTypeVariables #-} + module Simplex.Messaging.Agent.Store.Postgres ( createDBStore, + defaultSimplexConnectInfo, closeDBStore, - execSQL, + execSQL ) where +import Control.Exception (throwIO) +import Control.Monad (unless, void) +import Data.Functor (($>)) +import Data.String (fromString) import Data.Text (Text) +import Database.PostgreSQL.Simple (ConnectInfo (..), Only (..), defaultConnectInfo) import qualified Database.PostgreSQL.Simple as PSQL +import Database.PostgreSQL.Simple.SqlQQ (sql) +import Simplex.Messaging.Agent.Store.Migrations (migrateSchema) import Simplex.Messaging.Agent.Store.Postgres.Common +import qualified Simplex.Messaging.Agent.Store.Postgres.DB as DB +import Simplex.Messaging.Agent.Store.Postgres.Util (createDBAndUserIfNotExists) import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfirmation (..), MigrationError (..)) +import Simplex.Messaging.Util (ifM) +import UnliftIO.Exception (onException) +import UnliftIO.MVar +import UnliftIO.STM + +defaultSimplexConnectInfo :: ConnectInfo +defaultSimplexConnectInfo = + defaultConnectInfo + { connectUser = "simplex", + connectDatabase = "simplex_v6_3_client_db" + } + +-- | Create a new Postgres DBStore with the given connection info, schema name and migrations. +-- This function creates the user and/or database passed in connectInfo if they do not exist +-- (expects the default 'postgres' user and 'postgres' db to exist). +-- If passed schema does not exist in connectInfo database, it will be created. +-- Applies necessary migrations to schema. +-- TODO [postgres] authentication / user password, db encryption (?) +createDBStore :: ConnectInfo -> String -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError DBStore) +createDBStore connectInfo schema migrations confirmMigrations = do + createDBAndUserIfNotExists connectInfo + st <- connectPostgresStore connectInfo schema + r <- migrateSchema st migrations confirmMigrations `onException` closeDBStore st + case r of + Right () -> pure $ Right st + Left e -> closeDBStore st $> Left e + +connectPostgresStore :: ConnectInfo -> String -> IO DBStore +connectPostgresStore dbConnectInfo schema = do + (dbConn, dbNew) <- connectDB dbConnectInfo schema -- TODO [postgres] analogue for dbBusyLoop? + dbConnection <- newMVar dbConn + dbClosed <- newTVarIO False + pure DBStore {dbConnectInfo, dbConnection, dbNew, dbClosed} --- TODO [postgres] pass db name / ConnectInfo? -createDBStore :: [Migration] -> MigrationConfirmation -> IO (Either MigrationError DBStore) -createDBStore = undefined +connectDB :: ConnectInfo -> String -> IO (DB.Connection, Bool) +connectDB dbConnectInfo schema = do + db <- PSQL.connect dbConnectInfo + schemaExists <- prepare db `onException` PSQL.close db + let dbNew = not schemaExists + pure (db, dbNew) + where + prepare db = do + void $ PSQL.execute_ db "SET client_min_messages TO WARNING" + [Only schemaExists] <- + PSQL.query + db + [sql| + SELECT EXISTS ( + SELECT 1 FROM pg_catalog.pg_namespace + WHERE nspname = ? + ) + |] + (Only schema) + unless schemaExists $ void $ PSQL.execute_ db (fromString $ "CREATE SCHEMA " <> schema) + void $ PSQL.execute_ db (fromString $ "SET search_path TO " <> schema) + pure schemaExists +-- can share with SQLite closeDBStore :: DBStore -> IO () -closeDBStore = undefined +closeDBStore st@DBStore {dbClosed} = + ifM (readTVarIO dbClosed) (putStrLn "closeDBStore: already closed") $ + withConnection st $ \conn -> do + DB.close conn + atomically $ writeTVar dbClosed True +-- TODO [postgres] not necessary for postgres (used for ExecAgentStoreSQL, ExecChatStoreSQL) execSQL :: PSQL.Connection -> Text -> IO [Text] -execSQL = undefined - --- createDatabaseIfNotExists :: ConnectInfo -> String -> IO () --- createDatabaseIfNotExists defaultConnectInfo targetDbName = do --- -- Connect to the default maintenance database (e.g., postgres) --- bracket (connect defaultConnectInfo) close $ \conn -> do --- -- Check if the database already exists --- [Only dbExists] <- query conn --- [sql| --- SELECT EXISTS ( --- SELECT 1 FROM pg_catalog.pg_database --- WHERE datname = ? --- ) --- |] (Only targetDbName) - --- -- If it doesn't exist, create the database --- if not dbExists --- then do --- putStrLn $ "Creating database: " ++ targetDbName --- execute_ conn (Query $ "CREATE DATABASE " <> targetDbName) --- putStrLn "Database created." --- else putStrLn "Database already exists." +execSQL _db _query = throwIO (userError "not implemented") diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Common.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Common.hs index bdcb63bbd..b23dcf9c8 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres/Common.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Common.hs @@ -22,7 +22,7 @@ data DBStore = DBStore dbNew :: Bool } --- TODO [postgres] semaphore / connection pool? +-- TODO [postgres] connection pool withConnectionPriority :: DBStore -> Bool -> (PSQL.Connection -> IO a) -> IO a withConnectionPriority DBStore {dbConnection} _priority action = withMVar dbConnection action diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/DB.hs b/src/Simplex/Messaging/Agent/Store/Postgres/DB.hs index 3fb8f593e..9e597aef7 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres/DB.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres/DB.hs @@ -1,5 +1,9 @@ +{-# LANGUAGE ScopedTypeVariables #-} + module Simplex.Messaging.Agent.Store.Postgres.DB - ( PSQL.Connection, + ( BoolInt (..), + PSQL.Binary (..), + PSQL.Connection, PSQL.connect, PSQL.close, execute, @@ -11,7 +15,22 @@ module Simplex.Messaging.Agent.Store.Postgres.DB where import Control.Monad (void) +import Data.Int (Int32, Int64) +import Data.Word (Word16, Word32) +import Database.PostgreSQL.Simple (ResultError (..)) import qualified Database.PostgreSQL.Simple as PSQL +import Database.PostgreSQL.Simple.FromField (FromField (..), returnError) +import Database.PostgreSQL.Simple.ToField (ToField (..)) + +newtype BoolInt = BI {unBI :: Bool} + +instance FromField BoolInt where + fromField field dat = BI . (/= (0 :: Int)) <$> fromField field dat + {-# INLINE fromField #-} + +instance ToField BoolInt where + toField (BI b) = toField ((if b then 1 else 0) :: Int) + {-# INLINE toField #-} execute :: PSQL.ToRow q => PSQL.Connection -> PSQL.Query -> q -> IO () execute db q qs = void $ PSQL.execute db q qs @@ -24,3 +43,21 @@ execute_ db q = void $ PSQL.execute_ db q executeMany :: PSQL.ToRow q => PSQL.Connection -> PSQL.Query -> [q] -> IO () executeMany db q qs = void $ PSQL.executeMany db q qs {-# INLINE executeMany #-} + +-- orphan instances + +-- used in FileSize +instance FromField Word32 where + fromField field dat = do + i <- fromField field dat + if i >= (0 :: Int64) + then pure (fromIntegral i :: Word32) + else returnError ConversionFailed field "Negative value can't be converted to Word32" + +-- used in Version +instance FromField Word16 where + fromField field dat = do + i <- fromField field dat + if i >= (0 :: Int32) + then pure (fromIntegral i :: Word16) + else returnError ConversionFailed field "Negative value can't be converted to Word16" diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations.hs index ed44c09b6..bf8d56caa 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations.hs @@ -1,5 +1,8 @@ +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TupleSections #-} module Simplex.Messaging.Agent.Store.Postgres.Migrations ( app, @@ -9,13 +12,21 @@ module Simplex.Messaging.Agent.Store.Postgres.Migrations ) where +import Control.Monad (void) import Data.List (sortOn) import Data.Text (Text) import qualified Data.Text as T +import qualified Data.Text.Encoding as TE +import Data.Time.Clock (getCurrentTime) +import qualified Database.PostgreSQL.LibPQ as LibPQ +import Database.PostgreSQL.Simple (Only (..)) import qualified Database.PostgreSQL.Simple as PSQL +import Database.PostgreSQL.Simple.Internal (Connection (..)) +import Database.PostgreSQL.Simple.SqlQQ (sql) import Simplex.Messaging.Agent.Store.Postgres.Common import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20241210_initial import Simplex.Messaging.Agent.Store.Shared +import UnliftIO.MVar schemaMigrations :: [(String, Text, Maybe Text)] schemaMigrations = @@ -28,13 +39,38 @@ app = sortOn name $ map migration schemaMigrations where migration (name, up, down) = Migration {name, up, down = down} --- TODO [postgres] initialize initialize :: DBStore -> IO () -initialize st = undefined +initialize st = withTransaction' st $ \db -> + void $ + PSQL.execute_ + db + [sql| + CREATE TABLE IF NOT EXISTS migrations ( + name TEXT NOT NULL, + ts TIMESTAMP NOT NULL, + down TEXT, + PRIMARY KEY (name) + ) + |] --- TODO [postgres] run run :: DBStore -> MigrationsToRun -> IO () -run st = undefined +run st = \case + MTRUp [] -> pure () + MTRUp ms -> mapM_ runUp ms + MTRDown ms -> mapM_ runDown $ reverse ms + MTRNone -> pure () + where + runUp Migration {name, up, down} = withTransaction' st $ \db -> do + insert db + execSQL db up + where + insert db = void $ PSQL.execute db "INSERT INTO migrations (name, down, ts) VALUES (?,?,?)" . (name,down,) =<< getCurrentTime + runDown DownMigration {downName, downQuery} = withTransaction' st $ \db -> do + execSQL db downQuery + void $ PSQL.execute db "DELETE FROM migrations WHERE name = ?" (Only downName) + execSQL db query = + withMVar (connectionHandle db) $ \pqConn -> + void $ LibPQ.exec pqConn (TE.encodeUtf8 query) getCurrent :: PSQL.Connection -> IO [Migration] getCurrent db = map toMigration <$> PSQL.query_ db "SELECT name, down FROM migrations ORDER BY name ASC;" diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20241210_initial.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20241210_initial.hs index 3fab63768..a68144f1f 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20241210_initial.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20241210_initial.hs @@ -10,19 +10,10 @@ m20241210_initial :: Text m20241210_initial = T.pack [r| -REVOKE CREATE ON SCHEMA public FROM PUBLIC; - --- TODO [postgres] remove -DROP SCHEMA IF EXISTS agent_schema CASCADE; -CREATE SCHEMA agent_schema; - -SET search_path TO agent_schema; - CREATE TABLE users( - user_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, - deleted INTEGER NOT NULL DEFAULT 0 + user_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + deleted SMALLINT NOT NULL DEFAULT 0 ); -INSERT INTO users (user_id) OVERRIDING SYSTEM VALUE VALUES (1); CREATE TABLE servers( host TEXT NOT NULL, port TEXT NOT NULL, @@ -32,21 +23,20 @@ CREATE TABLE servers( CREATE TABLE connections( conn_id BYTEA NOT NULL PRIMARY KEY, conn_mode TEXT NOT NULL, - last_internal_msg_id INTEGER NOT NULL DEFAULT 0, - last_internal_rcv_msg_id INTEGER NOT NULL DEFAULT 0, - last_internal_snd_msg_id INTEGER NOT NULL DEFAULT 0, - last_external_snd_msg_id INTEGER NOT NULL DEFAULT 0, - last_rcv_msg_hash BYTEA NOT NULL DEFAULT E'\\x', - last_snd_msg_hash BYTEA NOT NULL DEFAULT E'\\x', + last_internal_msg_id BIGINT NOT NULL DEFAULT 0, + last_internal_rcv_msg_id BIGINT NOT NULL DEFAULT 0, + last_internal_snd_msg_id BIGINT NOT NULL DEFAULT 0, + last_external_snd_msg_id BIGINT NOT NULL DEFAULT 0, + last_rcv_msg_hash BYTEA NOT NULL DEFAULT ''::BYTEA, + last_snd_msg_hash BYTEA NOT NULL DEFAULT ''::BYTEA, smp_agent_version INTEGER NOT NULL DEFAULT 1, - duplex_handshake INTEGER NULL DEFAULT 0, - enable_ntfs INTEGER, - deleted INTEGER NOT NULL DEFAULT 0, - user_id INTEGER NOT NULL DEFAULT 1 - REFERENCES users ON DELETE CASCADE, + duplex_handshake SMALLINT NULL DEFAULT 0, + enable_ntfs SMALLINT, + deleted SMALLINT NOT NULL DEFAULT 0, + user_id BIGINT NOT NULL REFERENCES users ON DELETE CASCADE, ratchet_sync_state TEXT NOT NULL DEFAULT 'ok', - deleted_at_wait_delivery TIMESTAMP, - pq_support INTEGER NOT NULL DEFAULT 0 + deleted_at_wait_delivery TIMESTAMPTZ, + pq_support SMALLINT NOT NULL DEFAULT 0 ); CREATE TABLE rcv_queues( host TEXT NOT NULL, @@ -66,15 +56,15 @@ CREATE TABLE rcv_queues( ntf_private_key BYTEA, ntf_id BYTEA, rcv_ntf_dh_secret BYTEA, - rcv_queue_id INTEGER NOT NULL, - rcv_primary INTEGER NOT NULL, - replace_rcv_queue_id INTEGER NULL, - delete_errors INTEGER NOT NULL DEFAULT 0, + rcv_queue_id BIGINT NOT NULL, + rcv_primary SMALLINT NOT NULL, + replace_rcv_queue_id BIGINT NULL, + delete_errors BIGINT NOT NULL DEFAULT 0, server_key_hash BYTEA, switch_status TEXT, - deleted INTEGER NOT NULL DEFAULT 0, - snd_secure INTEGER NOT NULL DEFAULT 0, - last_broker_ts TIMESTAMP, + deleted SMALLINT NOT NULL DEFAULT 0, + snd_secure SMALLINT NOT NULL DEFAULT 0, + last_broker_ts TIMESTAMPTZ, PRIMARY KEY(host, port, rcv_id), FOREIGN KEY(host, port) REFERENCES servers ON DELETE RESTRICT ON UPDATE CASCADE, @@ -92,12 +82,12 @@ CREATE TABLE snd_queues( smp_client_version INTEGER NOT NULL DEFAULT 1, snd_public_key BYTEA, e2e_pub_key BYTEA, - snd_queue_id INTEGER NOT NULL, - snd_primary INTEGER NOT NULL, - replace_snd_queue_id INTEGER NULL, + snd_queue_id BIGINT NOT NULL, + snd_primary SMALLINT NOT NULL, + replace_snd_queue_id BIGINT NULL, server_key_hash BYTEA, switch_status TEXT, - snd_secure INTEGER NOT NULL DEFAULT 0, + snd_secure SMALLINT NOT NULL DEFAULT 0, PRIMARY KEY(host, port, snd_id), FOREIGN KEY(host, port) REFERENCES servers ON DELETE RESTRICT ON UPDATE CASCADE @@ -105,28 +95,28 @@ CREATE TABLE snd_queues( CREATE TABLE messages( conn_id BYTEA NOT NULL REFERENCES connections(conn_id) ON DELETE CASCADE, - internal_id INTEGER NOT NULL, - internal_ts TIMESTAMP NOT NULL, - internal_rcv_id INTEGER, - internal_snd_id INTEGER, + internal_id BIGINT NOT NULL, + internal_ts TIMESTAMPTZ NOT NULL, + internal_rcv_id BIGINT, + internal_snd_id BIGINT, msg_type BYTEA NOT NULL, - msg_body BYTEA NOT NULL DEFAULT E'\\x', + msg_body BYTEA NOT NULL DEFAULT ''::BYTEA, msg_flags TEXT NULL, - pq_encryption INTEGER NOT NULL DEFAULT 0, + pq_encryption SMALLINT NOT NULL DEFAULT 0, PRIMARY KEY(conn_id, internal_id) ); CREATE TABLE rcv_messages( conn_id BYTEA NOT NULL, - internal_rcv_id INTEGER NOT NULL, - internal_id INTEGER NOT NULL, - external_snd_id INTEGER NOT NULL, + internal_rcv_id BIGINT NOT NULL, + internal_id BIGINT NOT NULL, + external_snd_id BIGINT NOT NULL, broker_id BYTEA NOT NULL, - broker_ts TIMESTAMP NOT NULL, + broker_ts TIMESTAMPTZ NOT NULL, internal_hash BYTEA NOT NULL, external_prev_snd_hash BYTEA NOT NULL, integrity BYTEA NOT NULL, - user_ack INTEGER NULL DEFAULT 0, - rcv_queue_id INTEGER NOT NULL, + user_ack SMALLINT NULL DEFAULT 0, + rcv_queue_id BIGINT NOT NULL, PRIMARY KEY(conn_id, internal_rcv_id), FOREIGN KEY(conn_id, internal_id) REFERENCES messages ON DELETE CASCADE @@ -137,13 +127,13 @@ ADD CONSTRAINT fk_messages_rcv_messages ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED; CREATE TABLE snd_messages( conn_id BYTEA NOT NULL, - internal_snd_id INTEGER NOT NULL, - internal_id INTEGER NOT NULL, + internal_snd_id BIGINT NOT NULL, + internal_id BIGINT NOT NULL, internal_hash BYTEA NOT NULL, - previous_msg_hash BYTEA NOT NULL DEFAULT E'\\x', - retry_int_slow INTEGER, - retry_int_fast INTEGER, - rcpt_internal_id INTEGER, + previous_msg_hash BYTEA NOT NULL DEFAULT ''::BYTEA, + retry_int_slow BIGINT, + retry_int_fast BIGINT, + rcpt_internal_id BIGINT, rcpt_status TEXT, PRIMARY KEY(conn_id, internal_snd_id), FOREIGN KEY(conn_id, internal_id) REFERENCES messages @@ -160,9 +150,9 @@ CREATE TABLE conn_confirmations( sender_key BYTEA, ratchet_state BYTEA NOT NULL, sender_conn_info BYTEA NOT NULL, - accepted INTEGER NOT NULL, + accepted SMALLINT NOT NULL, own_conn_info BYTEA, - created_at TIMESTAMP NOT NULL DEFAULT (now()), + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), smp_reply_queues BYTEA NULL, smp_client_version INTEGER ); @@ -171,9 +161,9 @@ CREATE TABLE conn_invitations( contact_conn_id BYTEA NOT NULL REFERENCES connections ON DELETE CASCADE, cr_invitation BYTEA NOT NULL, recipient_conn_info BYTEA NOT NULL, - accepted INTEGER NOT NULL DEFAULT 0, + accepted SMALLINT NOT NULL DEFAULT 0, own_conn_info BYTEA, - created_at TIMESTAMP NOT NULL DEFAULT (now()) + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()) ); CREATE TABLE ratchets( conn_id BYTEA NOT NULL PRIMARY KEY REFERENCES connections @@ -187,19 +177,19 @@ CREATE TABLE ratchets( pq_priv_kem BYTEA ); CREATE TABLE skipped_messages( - skipped_message_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + skipped_message_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, conn_id BYTEA NOT NULL REFERENCES ratchets ON DELETE CASCADE, header_key BYTEA NOT NULL, - msg_n INTEGER NOT NULL, + msg_n BIGINT NOT NULL, msg_key BYTEA NOT NULL ); CREATE TABLE ntf_servers( ntf_host TEXT NOT NULL, ntf_port TEXT NOT NULL, ntf_key_hash BYTEA NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT (now()), - updated_at TIMESTAMP NOT NULL DEFAULT (now()), + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()), PRIMARY KEY(ntf_host, ntf_port) ); CREATE TABLE ntf_tokens( @@ -209,14 +199,14 @@ CREATE TABLE ntf_tokens( ntf_port TEXT NOT NULL, tkn_id BYTEA, tkn_pub_key BYTEA NOT NULL, -tkn_priv_key BYTEA NOT NULL, -tkn_pub_dh_key BYTEA NOT NULL, -tkn_priv_dh_key BYTEA NOT NULL, -tkn_dh_secret BYTEA, + tkn_priv_key BYTEA NOT NULL, + tkn_pub_dh_key BYTEA NOT NULL, + tkn_priv_dh_key BYTEA NOT NULL, + tkn_dh_secret BYTEA, tkn_status TEXT NOT NULL, tkn_action BYTEA, - created_at TIMESTAMP NOT NULL DEFAULT (now()), - updated_at TIMESTAMP NOT NULL DEFAULT (now()), + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()), ntf_mode TEXT NULL, PRIMARY KEY(provider, device_token, ntf_host, ntf_port), FOREIGN KEY(ntf_host, ntf_port) REFERENCES ntf_servers @@ -233,13 +223,13 @@ CREATE TABLE ntf_subscriptions( ntf_sub_status TEXT NOT NULL, ntf_sub_action TEXT, ntf_sub_smp_action TEXT, - ntf_sub_action_ts TIMESTAMP, - updated_by_supervisor INTEGER NOT NULL DEFAULT 0, - created_at TIMESTAMP NOT NULL DEFAULT (now()), - updated_at TIMESTAMP NOT NULL DEFAULT (now()), + ntf_sub_action_ts TIMESTAMPTZ, + updated_by_supervisor SMALLINT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()), smp_server_key_hash BYTEA, - ntf_failed INTEGER DEFAULT 0, - smp_failed INTEGER DEFAULT 0, + ntf_failed SMALLINT DEFAULT 0, + smp_failed SMALLINT DEFAULT 0, PRIMARY KEY(conn_id), FOREIGN KEY(smp_host, smp_port) REFERENCES servers(host, port) ON DELETE SET NULL ON UPDATE CASCADE, @@ -247,7 +237,7 @@ CREATE TABLE ntf_subscriptions( ON DELETE RESTRICT ON UPDATE CASCADE ); CREATE TABLE commands( - command_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + command_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, conn_id BYTEA NOT NULL REFERENCES connections ON DELETE CASCADE, host TEXT, port TEXT, @@ -256,173 +246,174 @@ CREATE TABLE commands( command BYTEA NOT NULL, agent_version INTEGER NOT NULL DEFAULT 1, server_key_hash BYTEA, - created_at TIMESTAMP NOT NULL DEFAULT '1970-01-01 00:00:00', - failed INTEGER DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT '1970-01-01 00:00:00', + failed SMALLINT DEFAULT 0, FOREIGN KEY(host, port) REFERENCES servers ON DELETE RESTRICT ON UPDATE CASCADE ); CREATE TABLE snd_message_deliveries( - snd_message_delivery_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + snd_message_delivery_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, conn_id BYTEA NOT NULL REFERENCES connections ON DELETE CASCADE, - snd_queue_id INTEGER NOT NULL, - internal_id INTEGER NOT NULL, - failed INTEGER DEFAULT 0, + snd_queue_id BIGINT NOT NULL, + internal_id BIGINT NOT NULL, + failed SMALLINT DEFAULT 0, FOREIGN KEY(conn_id, internal_id) REFERENCES messages ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED ); CREATE TABLE xftp_servers( - xftp_server_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + xftp_server_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, xftp_host TEXT NOT NULL, xftp_port TEXT NOT NULL, xftp_key_hash BYTEA NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT (now()), - updated_at TIMESTAMP NOT NULL DEFAULT (now()), + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()), UNIQUE(xftp_host, xftp_port, xftp_key_hash) ); CREATE TABLE rcv_files( - rcv_file_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + rcv_file_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, rcv_file_entity_id BYTEA NOT NULL, - user_id INTEGER NOT NULL REFERENCES users ON DELETE CASCADE, - size INTEGER NOT NULL, + user_id BIGINT NOT NULL REFERENCES users ON DELETE CASCADE, + size BIGINT NOT NULL, digest BYTEA NOT NULL, key BYTEA NOT NULL, nonce BYTEA NOT NULL, - chunk_size INTEGER NOT NULL, + chunk_size BIGINT NOT NULL, prefix_path TEXT NOT NULL, tmp_path TEXT, save_path TEXT NOT NULL, status TEXT NOT NULL, - deleted INTEGER NOT NULL DEFAULT 0, + deleted SMALLINT NOT NULL DEFAULT 0, error TEXT, - created_at TIMESTAMP NOT NULL DEFAULT (now()), - updated_at TIMESTAMP NOT NULL DEFAULT (now()), + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()), save_file_key BYTEA, save_file_nonce BYTEA, - failed INTEGER DEFAULT 0, - redirect_id INTEGER REFERENCES rcv_files ON DELETE SET NULL, + failed SMALLINT DEFAULT 0, + redirect_id BIGINT REFERENCES rcv_files ON DELETE SET NULL, redirect_entity_id BYTEA, - redirect_size INTEGER, + redirect_size BIGINT, redirect_digest BYTEA, - approved_relays INTEGER NOT NULL DEFAULT 0, + approved_relays SMALLINT NOT NULL DEFAULT 0, UNIQUE(rcv_file_entity_id) ); CREATE TABLE rcv_file_chunks( - rcv_file_chunk_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, - rcv_file_id INTEGER NOT NULL REFERENCES rcv_files ON DELETE CASCADE, - chunk_no INTEGER NOT NULL, - chunk_size INTEGER NOT NULL, + rcv_file_chunk_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + rcv_file_id BIGINT NOT NULL REFERENCES rcv_files ON DELETE CASCADE, + chunk_no BIGINT NOT NULL, + chunk_size BIGINT NOT NULL, digest BYTEA NOT NULL, tmp_path TEXT, - created_at TIMESTAMP NOT NULL DEFAULT (now()), - updated_at TIMESTAMP NOT NULL DEFAULT (now()) + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) ); CREATE TABLE rcv_file_chunk_replicas( - rcv_file_chunk_replica_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, - rcv_file_chunk_id INTEGER NOT NULL REFERENCES rcv_file_chunks ON DELETE CASCADE, - replica_number INTEGER NOT NULL, - xftp_server_id INTEGER NOT NULL REFERENCES xftp_servers ON DELETE CASCADE, + rcv_file_chunk_replica_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + rcv_file_chunk_id BIGINT NOT NULL REFERENCES rcv_file_chunks ON DELETE CASCADE, + replica_number BIGINT NOT NULL, + xftp_server_id BIGINT NOT NULL REFERENCES xftp_servers ON DELETE CASCADE, replica_id BYTEA NOT NULL, replica_key BYTEA NOT NULL, - received INTEGER NOT NULL DEFAULT 0, - delay INTEGER, - retries INTEGER NOT NULL DEFAULT 0, - created_at TIMESTAMP NOT NULL DEFAULT (now()), - updated_at TIMESTAMP NOT NULL DEFAULT (now()) + received SMALLINT NOT NULL DEFAULT 0, + delay BIGINT, + retries BIGINT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) ); CREATE TABLE snd_files( - snd_file_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + snd_file_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, snd_file_entity_id BYTEA NOT NULL, - user_id INTEGER NOT NULL REFERENCES users ON DELETE CASCADE, - num_recipients INTEGER NOT NULL, + user_id BIGINT NOT NULL REFERENCES users ON DELETE CASCADE, + num_recipients BIGINT NOT NULL, digest BYTEA, key BYTEA NOT NUll, nonce BYTEA NOT NUll, path TEXT NOT NULL, prefix_path TEXT, status TEXT NOT NULL, - deleted INTEGER NOT NULL DEFAULT 0, + deleted SMALLINT NOT NULL DEFAULT 0, error TEXT, - created_at TIMESTAMP NOT NULL DEFAULT (now()), - updated_at TIMESTAMP NOT NULL DEFAULT (now()), + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()), src_file_key BYTEA, src_file_nonce BYTEA, - failed INTEGER DEFAULT 0, - redirect_size INTEGER, + failed SMALLINT DEFAULT 0, + redirect_size BIGINT, redirect_digest BYTEA ); CREATE TABLE snd_file_chunks( - snd_file_chunk_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, - snd_file_id INTEGER NOT NULL REFERENCES snd_files ON DELETE CASCADE, - chunk_no INTEGER NOT NULL, - chunk_offset INTEGER NOT NULL, - chunk_size INTEGER NOT NULL, + snd_file_chunk_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + snd_file_id BIGINT NOT NULL REFERENCES snd_files ON DELETE CASCADE, + chunk_no BIGINT NOT NULL, + chunk_offset BIGINT NOT NULL, + chunk_size BIGINT NOT NULL, digest BYTEA NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT (now()), - updated_at TIMESTAMP NOT NULL DEFAULT (now()) + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) ); CREATE TABLE snd_file_chunk_replicas( - snd_file_chunk_replica_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, - snd_file_chunk_id INTEGER NOT NULL REFERENCES snd_file_chunks ON DELETE CASCADE, - replica_number INTEGER NOT NULL, - xftp_server_id INTEGER NOT NULL REFERENCES xftp_servers ON DELETE CASCADE, + snd_file_chunk_replica_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + snd_file_chunk_id BIGINT NOT NULL REFERENCES snd_file_chunks ON DELETE CASCADE, + replica_number BIGINT NOT NULL, + xftp_server_id BIGINT NOT NULL REFERENCES xftp_servers ON DELETE CASCADE, replica_id BYTEA NOT NULL, replica_key BYTEA NOT NULL, replica_status TEXT NOT NULL, - delay INTEGER, - retries INTEGER NOT NULL DEFAULT 0, - created_at TIMESTAMP NOT NULL DEFAULT (now()), - updated_at TIMESTAMP NOT NULL DEFAULT (now()) + delay BIGINT, + retries BIGINT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) ); CREATE TABLE snd_file_chunk_replica_recipients( - snd_file_chunk_replica_recipient_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, - snd_file_chunk_replica_id INTEGER NOT NULL REFERENCES snd_file_chunk_replicas ON DELETE CASCADE, + snd_file_chunk_replica_recipient_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + snd_file_chunk_replica_id BIGINT NOT NULL REFERENCES snd_file_chunk_replicas ON DELETE CASCADE, rcv_replica_id BYTEA NOT NULL, rcv_replica_key BYTEA NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT (now()), - updated_at TIMESTAMP NOT NULL DEFAULT (now()) + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) ); CREATE TABLE deleted_snd_chunk_replicas( - deleted_snd_chunk_replica_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, - user_id INTEGER NOT NULL REFERENCES users ON DELETE CASCADE, - xftp_server_id INTEGER NOT NULL REFERENCES xftp_servers ON DELETE CASCADE, + deleted_snd_chunk_replica_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + user_id BIGINT NOT NULL REFERENCES users ON DELETE CASCADE, + xftp_server_id BIGINT NOT NULL REFERENCES xftp_servers ON DELETE CASCADE, replica_id BYTEA NOT NULL, replica_key BYTEA NOT NULL, chunk_digest BYTEA NOT NULL, - delay INTEGER, - retries INTEGER NOT NULL DEFAULT 0, - created_at TIMESTAMP NOT NULL DEFAULT (now()), - updated_at TIMESTAMP NOT NULL DEFAULT (now()), - failed INTEGER DEFAULT 0 + delay BIGINT, + retries BIGINT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + failed SMALLINT DEFAULT 0 ); CREATE TABLE encrypted_rcv_message_hashes( - encrypted_rcv_message_hash_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + encrypted_rcv_message_hash_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, conn_id BYTEA NOT NULL REFERENCES connections ON DELETE CASCADE, hash BYTEA NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT (now()), - updated_at TIMESTAMP NOT NULL DEFAULT (now()) + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) ); CREATE TABLE processed_ratchet_key_hashes( - processed_ratchet_key_hash_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + processed_ratchet_key_hash_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, conn_id BYTEA NOT NULL REFERENCES connections ON DELETE CASCADE, hash BYTEA NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT (now()), - updated_at TIMESTAMP NOT NULL DEFAULT (now()) + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) ); CREATE TABLE servers_stats( - servers_stats_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + servers_stats_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, servers_stats TEXT, - started_at TIMESTAMP NOT NULL DEFAULT (now()), - created_at TIMESTAMP NOT NULL DEFAULT (now()), - updated_at TIMESTAMP NOT NULL DEFAULT (now()) + started_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) ); +INSERT INTO servers_stats DEFAULT VALUES; CREATE TABLE ntf_tokens_to_delete( - ntf_token_to_delete_id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + ntf_token_to_delete_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, ntf_host TEXT NOT NULL, ntf_port TEXT NOT NULL, ntf_key_hash BYTEA NOT NULL, tkn_id BYTEA NOT NULL, tkn_priv_key BYTEA NOT NULL, -del_failed INTEGER DEFAULT 0, -created_at TIMESTAMP NOT NULL DEFAULT (now()) + del_failed SMALLINT DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()) ); CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues(host, port, ntf_id); CREATE UNIQUE INDEX idx_rcv_queue_id ON rcv_queues(conn_id, rcv_queue_id); diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Util.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Util.hs new file mode 100644 index 000000000..35aa84d86 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Util.hs @@ -0,0 +1,111 @@ +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Simplex.Messaging.Agent.Store.Postgres.Util + ( createDBAndUserIfNotExists, + -- for tests + dropSchema, + dropAllSchemasExceptSystem, + dropDatabaseAndUser, + ) +where + +import Control.Exception (bracket, throwIO) +import Control.Monad (forM_, unless, void, when) +import Data.Functor (($>)) +import Data.String (fromString) +import Data.Text (Text) +import Database.PostgreSQL.Simple (ConnectInfo (..), Only (..), defaultConnectInfo) +import qualified Database.PostgreSQL.Simple as PSQL +import Database.PostgreSQL.Simple.SqlQQ (sql) +import Simplex.Messaging.Agent.Store.Migrations (migrateSchema) +import Simplex.Messaging.Agent.Store.Postgres.Common +import qualified Simplex.Messaging.Agent.Store.Postgres.DB as DB +import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfirmation (..), MigrationError (..)) +import Simplex.Messaging.Util (ifM) +import UnliftIO.Exception (onException) +import UnliftIO.MVar +import UnliftIO.STM + +createDBAndUserIfNotExists :: ConnectInfo -> IO () +createDBAndUserIfNotExists ConnectInfo {connectUser = user, connectDatabase = dbName} = do + -- connect to the default "postgres" maintenance database + bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $ + \postgresDB -> do + void $ PSQL.execute_ postgresDB "SET client_min_messages TO WARNING" + -- check if the user exists, create if not + [Only userExists] <- + PSQL.query + postgresDB + [sql| + SELECT EXISTS ( + SELECT 1 FROM pg_catalog.pg_roles + WHERE rolname = ? + ) + |] + (Only user) + unless userExists $ void $ PSQL.execute_ postgresDB (fromString $ "CREATE USER " <> user) + -- check if the database exists, create if not + dbExists <- checkDBExists postgresDB dbName + unless dbExists $ void $ PSQL.execute_ postgresDB (fromString $ "CREATE DATABASE " <> dbName <> " OWNER " <> user) + +checkDBExists :: PSQL.Connection -> String -> IO Bool +checkDBExists postgresDB dbName = do + [Only dbExists] <- + PSQL.query + postgresDB + [sql| + SELECT EXISTS ( + SELECT 1 FROM pg_catalog.pg_database + WHERE datname = ? + ) + |] + (Only dbName) + pure dbExists + +dropSchema :: ConnectInfo -> String -> IO () +dropSchema connectInfo schema = + bracket (PSQL.connect connectInfo) PSQL.close $ + \db -> do + void $ PSQL.execute_ db "SET client_min_messages TO WARNING" + void $ PSQL.execute_ db (fromString $ "DROP SCHEMA IF EXISTS " <> schema <> " CASCADE") + +dropAllSchemasExceptSystem :: ConnectInfo -> IO () +dropAllSchemasExceptSystem connectInfo = + bracket (PSQL.connect connectInfo) PSQL.close $ + \db -> do + void $ PSQL.execute_ db "SET client_min_messages TO WARNING" + schemaNames :: [Only String] <- + PSQL.query_ + db + [sql| + SELECT schema_name + FROM information_schema.schemata + WHERE schema_name NOT IN ('public', 'pg_catalog', 'information_schema') + |] + forM_ schemaNames $ \(Only schema) -> + PSQL.execute_ db (fromString $ "DROP SCHEMA " <> schema <> " CASCADE") + +dropDatabaseAndUser :: ConnectInfo -> IO () +dropDatabaseAndUser ConnectInfo {connectUser = user, connectDatabase = dbName} = + bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $ + \postgresDB -> do + void $ PSQL.execute_ postgresDB "SET client_min_messages TO WARNING" + dbExists <- checkDBExists postgresDB dbName + when dbExists $ do + void $ PSQL.execute_ postgresDB (fromString $ "ALTER DATABASE " <> dbName <> " WITH ALLOW_CONNECTIONS false") + -- terminate all connections to the database + _r :: [Only Bool] <- + PSQL.query + postgresDB + [sql| + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE datname = ? + AND pid <> pg_backend_pid() + |] + (Only dbName) + void $ PSQL.execute_ postgresDB (fromString $ "DROP DATABASE " <> dbName) + void $ PSQL.execute_ postgresDB (fromString $ "DROP USER IF EXISTS " <> user) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index b0c53dee2..816968208 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -26,14 +26,17 @@ module Simplex.Messaging.Agent.Store.SQLite ( createDBStore, - connectSQLiteStore, closeDBStore, - openSQLiteStore, - reopenSQLiteStore, + execSQL, + -- used in Simplex.Chat.Archive sqlString, keyString, storeKey, - execSQL, + -- used in Simplex.Chat.Mobile and tests + reopenSQLiteStore, + -- used in tests + connectSQLiteStore, + openSQLiteStore, ) where diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs b/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs index a5d5189d0..7e8406d5c 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs @@ -1,21 +1,23 @@ {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TemplateHaskell #-} module Simplex.Messaging.Agent.Store.SQLite.DB - ( Connection (..), + ( BoolInt (..), + Binary (..), + Connection (..), SlowQueryStats (..), open, close, execute, execute_, - executeNamed, executeMany, query, query_, - queryNamed, ) where @@ -23,18 +25,27 @@ import Control.Concurrent.STM import Control.Exception import Control.Monad (when) import qualified Data.Aeson.TH as J +import Data.ByteString (ByteString) import Data.Int (Int64) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Text (Text) import Data.Time (diffUTCTime, getCurrentTime) -import Database.SQLite.Simple (FromRow, NamedParam, Query, ToRow) +import Database.SQLite.Simple (FromRow, Query, ToRow) import qualified Database.SQLite.Simple as SQL +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Parsers (defaultJSON) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util (diffToMilliseconds, tshow) +newtype BoolInt = BI {unBI :: Bool} + deriving newtype (FromField, ToField) + +newtype Binary = Binary {fromBinary :: ByteString} + deriving newtype (FromField, ToField) + data Connection = Connection { conn :: SQL.Connection, slow :: TMap Query SlowQueryStats @@ -92,11 +103,6 @@ execute_ :: Connection -> Query -> IO () execute_ Connection {conn, slow} sql = timeIt slow sql $ SQL.execute_ conn sql {-# INLINE execute_ #-} --- TODO [postgres] remove -executeNamed :: Connection -> Query -> [NamedParam] -> IO () -executeNamed Connection {conn, slow} sql = timeIt slow sql . SQL.executeNamed conn sql -{-# INLINE executeNamed #-} - executeMany :: ToRow q => Connection -> Query -> [q] -> IO () executeMany Connection {conn, slow} sql = timeIt slow sql . SQL.executeMany conn sql {-# INLINE executeMany #-} @@ -109,9 +115,4 @@ query_ :: FromRow r => Connection -> Query -> IO [r] query_ Connection {conn, slow} sql = timeIt slow sql $ SQL.query_ conn sql {-# INLINE query_ #-} --- TODO [postgres] remove -queryNamed :: FromRow r => Connection -> Query -> [NamedParam] -> IO [r] -queryNamed Connection {conn, slow} sql = timeIt slow sql . SQL.queryNamed conn sql -{-# INLINE queryNamed #-} - $(J.deriveJSON defaultJSON ''SlowQueryStats) diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 05ba861bc..a955d0d8a 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -1,11 +1,14 @@ {-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} @@ -233,14 +236,20 @@ import Data.Typeable (Proxy (Proxy), Typeable) import Data.Word (Word32) import Data.X509 import Data.X509.Validation (Fingerprint (..), getFingerprint) -import Database.SQLite.Simple.FromField (FromField (..)) -import Database.SQLite.Simple.ToField (ToField (..)) import GHC.TypeLits (ErrorMessage (..), KnownNat, Nat, TypeError, natVal, type (+)) import Network.Transport.Internal (decodeWord16, encodeWord16) +import Simplex.Messaging.Agent.Store.DB (Binary (..)) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (blobFieldDecoder, parseAll, parseString) import Simplex.Messaging.Util ((<$?>)) +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif -- | Cryptographic algorithms. data Algorithm = Ed25519 | Ed448 | X25519 | X448 @@ -721,23 +730,23 @@ generateKeyPair_ = case sAlgorithm @a of let k = X448.toPublic pk in pure (PublicKeyX448 k, PrivateKeyX448 pk k) -instance ToField APrivateSignKey where toField = toField . encodePrivKey +instance ToField APrivateSignKey where toField = toField . Binary . encodePrivKey -instance ToField APublicVerifyKey where toField = toField . encodePubKey +instance ToField APublicVerifyKey where toField = toField . Binary . encodePubKey -instance ToField APrivateAuthKey where toField = toField . encodePrivKey +instance ToField APrivateAuthKey where toField = toField . Binary . encodePrivKey -instance ToField APublicAuthKey where toField = toField . encodePubKey +instance ToField APublicAuthKey where toField = toField . Binary . encodePubKey -instance ToField APrivateDhKey where toField = toField . encodePrivKey +instance ToField APrivateDhKey where toField = toField . Binary . encodePrivKey -instance ToField APublicDhKey where toField = toField . encodePubKey +instance ToField APublicDhKey where toField = toField . Binary . encodePubKey -instance AlgorithmI a => ToField (PrivateKey a) where toField = toField . encodePrivKey +instance AlgorithmI a => ToField (PrivateKey a) where toField = toField . Binary . encodePrivKey -instance AlgorithmI a => ToField (PublicKey a) where toField = toField . encodePubKey +instance AlgorithmI a => ToField (PublicKey a) where toField = toField . Binary . encodePubKey -instance ToField (DhSecret a) where toField = toField . dhBytes' +instance ToField (DhSecret a) where toField = toField . Binary . dhBytes' instance FromField APrivateSignKey where fromField = blobFieldDecoder decodePrivKey @@ -888,10 +897,9 @@ validSignatureSize n = -- | AES key newtype. newtype Key = Key {unKey :: ByteString} deriving (Eq, Ord, Show) + deriving newtype (FromField) -instance ToField Key where toField = toField . unKey - -instance FromField Key where fromField f = Key <$> fromField f +instance ToField Key where toField (Key s) = toField $ Binary s instance ToJSON Key where toJSON = strToJSON . unKey @@ -952,7 +960,7 @@ instance FromJSON KeyHash where instance IsString KeyHash where fromString = parseString $ parseAll strP -instance ToField KeyHash where toField = toField . strEncode +instance ToField KeyHash where toField = toField . Binary . strEncode instance FromField KeyHash where fromField = blobFieldDecoder $ parseAll strP @@ -1162,10 +1170,14 @@ instance SignatureAlgorithmX509 pk => SignatureAlgorithmX509 (a, pk) where newtype SignedObject a = SignedObject {getSignedExact :: SignedExact a} instance (Typeable a, Eq a, Show a, ASN1Object a) => FromField (SignedObject a) where +#if defined(dbPostgres) + fromField f dat = SignedObject <$> blobFieldDecoder decodeSignedObject f dat +#else fromField = fmap SignedObject . blobFieldDecoder decodeSignedObject +#endif instance (Eq a, Show a, ASN1Object a) => ToField (SignedObject a) where - toField (SignedObject s) = toField $ encodeSignedObject s + toField (SignedObject s) = toField . Binary $ encodeSignedObject s instance (Eq a, Show a, ASN1Object a) => Encoding (SignedObject a) where smpEncode (SignedObject exact) = smpEncode . Large $ encodeSignedObject exact @@ -1265,6 +1277,9 @@ cbVerify k pk nonce (CbAuthenticator s) authorized = cbDecryptNoPad (dh' k pk) n newtype CbNonce = CryptoBoxNonce {unCbNonce :: ByteString} deriving (Eq, Show) + deriving newtype (FromField) + +instance ToField CbNonce where toField (CryptoBoxNonce s) = toField $ Binary s pattern CbNonce :: ByteString -> CbNonce pattern CbNonce s <- CryptoBoxNonce s @@ -1282,10 +1297,6 @@ instance ToJSON CbNonce where instance FromJSON CbNonce where parseJSON = strParseJSON "CbNonce" -instance FromField CbNonce where fromField f = CryptoBoxNonce <$> fromField f - -instance ToField CbNonce where toField (CryptoBoxNonce s) = toField s - cbNonce :: ByteString -> CbNonce cbNonce s | len == 24 = CryptoBoxNonce s @@ -1309,6 +1320,9 @@ instance Encoding CbNonce where newtype SbKey = SecretBoxKey {unSbKey :: ByteString} deriving (Eq, Show) + deriving newtype (FromField) + +instance ToField SbKey where toField (SecretBoxKey s) = toField $ Binary s pattern SbKey :: ByteString -> SbKey pattern SbKey s <- SecretBoxKey s @@ -1326,10 +1340,6 @@ instance ToJSON SbKey where instance FromJSON SbKey where parseJSON = strParseJSON "SbKey" -instance FromField SbKey where fromField f = SecretBoxKey <$> fromField f - -instance ToField SbKey where toField (SecretBoxKey s) = toField s - sbKey :: ByteString -> Either String SbKey sbKey s | B.length s == 32 = Right $ SecretBoxKey s diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index 148d931a9..c7ead660a 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DuplicateRecordFields #-} @@ -109,9 +110,8 @@ import Data.Maybe (fromMaybe, isJust) import Data.Type.Equality import Data.Typeable (Typeable) import Data.Word (Word16, Word32) -import Database.SQLite.Simple.FromField (FromField (..)) -import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Agent.QueryString +import Simplex.Messaging.Agent.Store.DB (Binary (..), BoolInt (..)) import Simplex.Messaging.Crypto import Simplex.Messaging.Crypto.SNTRUP761.Bindings import Simplex.Messaging.Encoding @@ -121,6 +121,13 @@ import Simplex.Messaging.Util (($>>=), (<$?>)) import Simplex.Messaging.Version import Simplex.Messaging.Version.Internal import UnliftIO.STM +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif -- e2e encryption headers version history: -- 1 - binary protocol encoding (1/1/2022) @@ -359,7 +366,7 @@ instance Encoding APrivRKEMParams where 'A' -> APRKP SRKSAccepted .:. PrivateRKParamsAccepted <$> smpP <*> smpP <*> smpP _ -> fail "bad APrivRKEMParams" -instance RatchetKEMStateI s => ToField (PrivRKEMParams s) where toField = toField . smpEncode +instance RatchetKEMStateI s => ToField (PrivRKEMParams s) where toField = toField . Binary . smpEncode instance (Typeable s, RatchetKEMStateI s) => FromField (PrivRKEMParams s) where fromField = blobFieldDecoder smpDecode @@ -576,7 +583,7 @@ instance ToJSON RatchetKey where instance FromJSON RatchetKey where parseJSON = fmap RatchetKey . strParseJSON "Key" -instance ToField MessageKey where toField = toField . smpEncode +instance ToField MessageKey where toField = toField . Binary . smpEncode instance FromField MessageKey where fromField = blobFieldDecoder smpDecode @@ -1120,14 +1127,24 @@ instance AlgorithmI a => ToJSON (Ratchet a) where instance AlgorithmI a => FromJSON (Ratchet a) where parseJSON = $(JQ.mkParseJSON defaultJSON ''Ratchet) -instance AlgorithmI a => ToField (Ratchet a) where toField = toField . LB.toStrict . J.encode +instance AlgorithmI a => ToField (Ratchet a) where toField = toField . Binary . LB.toStrict . J.encode instance (AlgorithmI a, Typeable a) => FromField (Ratchet a) where fromField = blobFieldDecoder J.eitherDecodeStrict' -instance ToField PQEncryption where toField (PQEncryption pqEnc) = toField pqEnc +instance ToField PQEncryption where toField (PQEncryption pqEnc) = toField (BI pqEnc) -instance FromField PQEncryption where fromField f = PQEncryption <$> fromField f +instance FromField PQEncryption where +#if defined(dbPostgres) + fromField f dat = PQEncryption . unBI <$> fromField f dat +#else + fromField f = PQEncryption . unBI <$> fromField f +#endif -instance ToField PQSupport where toField (PQSupport pqEnc) = toField pqEnc +instance ToField PQSupport where toField (PQSupport pqEnc) = toField (BI pqEnc) -instance FromField PQSupport where fromField f = PQSupport <$> fromField f +instance FromField PQSupport where +#if defined(dbPostgres) + fromField f dat = PQSupport . unBI <$> fromField f dat +#else + fromField f = PQSupport . unBI <$> fromField f +#endif diff --git a/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs b/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs index 3b2238086..35e46e3de 100644 --- a/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs +++ b/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE TypeApplications #-} module Simplex.Messaging.Crypto.SNTRUP761.Bindings where @@ -9,14 +10,19 @@ import Data.Bifunctor (bimap) import Data.ByteArray (ScrubbedBytes) import qualified Data.ByteArray as BA import Data.ByteString (ByteString) -import Database.SQLite.Simple.FromField -import Database.SQLite.Simple.ToField import Foreign (nullPtr) import Simplex.Messaging.Crypto.SNTRUP761.Bindings.Defines import Simplex.Messaging.Crypto.SNTRUP761.Bindings.FFI import Simplex.Messaging.Crypto.SNTRUP761.Bindings.RNG (withDRG) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField +import Database.PostgreSQL.Simple.ToField +#else +import Database.SQLite.Simple.FromField +import Database.SQLite.Simple.ToField +#endif newtype KEMPublicKey = KEMPublicKey ByteString deriving (Eq, Show) @@ -121,7 +127,11 @@ instance ToField KEMSharedKey where toField (KEMSharedKey k) = toField (BA.convert k :: ByteString) instance FromField KEMSharedKey where +#if defined(dbPostgres) + fromField f dat = KEMSharedKey . BA.convert @ByteString <$> fromField f dat +#else fromField f = KEMSharedKey . BA.convert @ByteString <$> fromField f +#endif instance ToJSON KEMSharedKey where toJSON = strToJSON diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index af8987dcc..96f8b337e 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} @@ -27,8 +28,6 @@ import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Time.Clock.System import Data.Type.Equality import Data.Word (Word16) -import Database.SQLite.Simple.FromField (FromField (..)) -import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Agent.Protocol (updateSMPServerHosts) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding @@ -37,6 +36,13 @@ import Simplex.Messaging.Notifications.Transport (NTFVersion, ntfClientHandshake import Simplex.Messaging.Parsers (fromTextField_) import Simplex.Messaging.Protocol hiding (Command (..), CommandTag (..)) import Simplex.Messaging.Util (eitherToMaybe, (<$?>)) +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif data NtfEntity = Token | Subscription deriving (Show) diff --git a/src/Simplex/Messaging/Notifications/Types.hs b/src/Simplex/Messaging/Notifications/Types.hs index 774f354bb..dd6e99733 100644 --- a/src/Simplex/Messaging/Notifications/Types.hs +++ b/src/Simplex/Messaging/Notifications/Types.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE LambdaCase #-} @@ -9,14 +10,20 @@ module Simplex.Messaging.Notifications.Types where import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Time (UTCTime) -import Database.SQLite.Simple.FromField (FromField (..)) -import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Agent.Protocol (ConnId, NotificationsMode (..), UserId) +import Simplex.Messaging.Agent.Store.DB (Binary (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Parsers (blobFieldDecoder, fromTextField_) import Simplex.Messaging.Protocol (NotifierId, NtfServer, SMPServer) +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif data NtfTknAction = NTARegister @@ -41,7 +48,7 @@ instance Encoding NtfTknAction where instance FromField NtfTknAction where fromField = blobFieldDecoder smpDecode -instance ToField NtfTknAction where toField = toField . smpEncode +instance ToField NtfTknAction where toField = toField . Binary . smpEncode data NtfToken = NtfToken { deviceToken :: DeviceToken, @@ -119,7 +126,7 @@ instance Encoding NtfSubNTFAction where instance FromField NtfSubNTFAction where fromField = blobFieldDecoder smpDecode -instance ToField NtfSubNTFAction where toField = toField . smpEncode +instance ToField NtfSubNTFAction where toField = toField . Binary . smpEncode data NtfSubSMPAction = NSASmpKey @@ -138,7 +145,7 @@ instance Encoding NtfSubSMPAction where instance FromField NtfSubSMPAction where fromField = blobFieldDecoder smpDecode -instance ToField NtfSubSMPAction where toField = toField . smpEncode +instance ToField NtfSubSMPAction where toField = toField . Binary . smpEncode data NtfAgentSubStatus = -- | subscription started diff --git a/src/Simplex/Messaging/Parsers.hs b/src/Simplex/Messaging/Parsers.hs index 6ad9f867d..a75efe0ee 100644 --- a/src/Simplex/Messaging/Parsers.hs +++ b/src/Simplex/Messaging/Parsers.hs @@ -1,5 +1,6 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} @@ -20,12 +21,19 @@ import qualified Data.Text as T import Data.Time.Clock (UTCTime) import Data.Time.ISO8601 (parseISO8601) import Data.Typeable (Typeable) +import Simplex.Messaging.Util (safeDecodeUtf8, (<$?>)) +import Text.Read (readMaybe) +#if defined(dbPostgres) +import Database.PostgreSQL.Simple (ResultError (..)) +import Database.PostgreSQL.Simple.FromField (FromField(..), FieldParser, returnError, Field (..)) +import Database.PostgreSQL.Simple.TypeInfo.Static (textOid, varcharOid) +import qualified Data.Text.Encoding as TE +#else import Database.SQLite.Simple (ResultError (..), SQLData (..)) import Database.SQLite.Simple.FromField (FieldParser, returnError) import Database.SQLite.Simple.Internal (Field (..)) import Database.SQLite.Simple.Ok (Ok (Ok)) -import Simplex.Messaging.Util (safeDecodeUtf8, (<$?>)) -import Text.Read (readMaybe) +#endif base64P :: Parser ByteString base64P = decode <$?> paddedBase64 rawBase64P @@ -77,6 +85,14 @@ parseString p = either error id . p . B.pack blobFieldParser :: Typeable k => Parser k -> FieldParser k blobFieldParser = blobFieldDecoder . parseAll +#if defined(dbPostgres) +blobFieldDecoder :: Typeable k => (ByteString -> Either String k) -> FieldParser k +blobFieldDecoder dec f val = do + x <- fromField f val + case dec x of + Right k -> pure k + Left e -> returnError ConversionFailed f ("couldn't parse field: " ++ e) +#else blobFieldDecoder :: Typeable k => (ByteString -> Either String k) -> FieldParser k blobFieldDecoder dec = \case f@(Field (SQLBlob b) _) -> @@ -84,7 +100,20 @@ blobFieldDecoder dec = \case Right k -> Ok k Left e -> returnError ConversionFailed f ("couldn't parse field: " ++ e) f -> returnError ConversionFailed f "expecting SQLBlob column type" +#endif +-- TODO [postgres] review +#if defined(dbPostgres) +fromTextField_ :: Typeable a => (Text -> Maybe a) -> FieldParser a +fromTextField_ fromText f val = + if typeOid f `elem` [textOid, varcharOid] + then case val of + Just t -> case fromText (TE.decodeUtf8 t) of + Just x -> pure x + _ -> returnError ConversionFailed f "invalid text value" + Nothing -> returnError UnexpectedNull f "NULL value found for non-NULL field" + else returnError Incompatible f "expecting TEXT or VARCHAR column type" +#else fromTextField_ :: Typeable a => (Text -> Maybe a) -> Field -> Ok a fromTextField_ fromText = \case f@(Field (SQLText t) _) -> @@ -92,6 +121,7 @@ fromTextField_ fromText = \case Just x -> Ok x _ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t) f -> returnError ConversionFailed f "expecting SQLText column type" +#endif fstToLower :: String -> String fstToLower "" = "" diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index 56a7fef1f..dff6cd4b0 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -1,12 +1,10 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PostfixOperators #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} module AgentTests (agentTests) where @@ -14,18 +12,30 @@ import AgentTests.ConnectionRequestTests import AgentTests.DoubleRatchetTests (doubleRatchetTests) import AgentTests.FunctionalAPITests (functionalAPITests) import AgentTests.MigrationTests (migrationTests) -import AgentTests.NotificationTests (notificationTests) -import AgentTests.SQLiteTests (storeTests) import AgentTests.ServerChoice (serverChoiceTests) import Simplex.Messaging.Transport (ATransport (..)) import Test.Hspec +#if defined(dbPostgres) +import Fixtures +import Simplex.Messaging.Agent.Store.Postgres.Util (dropAllSchemasExceptSystem) +#else +import AgentTests.NotificationTests (notificationTests) +import AgentTests.SQLiteTests (storeTests) +#endif agentTests :: ATransport -> Spec agentTests (ATransport t) = do + describe "Migration tests" migrationTests describe "Connection request" connectionRequestTests describe "Double ratchet tests" doubleRatchetTests +#if defined(dbPostgres) + after_ (dropAllSchemasExceptSystem testDBConnectInfo) $ do + describe "Functional API" $ functionalAPITests (ATransport t) + describe "Chosen servers" serverChoiceTests +#else describe "Functional API" $ functionalAPITests (ATransport t) + describe "Chosen servers" serverChoiceTests + -- notifications aren't tested with postgres, as we don't plan to use iOS client with it describe "Notification tests" $ notificationTests (ATransport t) describe "SQLite store" storeTests - describe "Chosen servers" serverChoiceTests - describe "Migration tests" migrationTests +#endif diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 71000b60a..9c3c5a972 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -75,7 +75,6 @@ import Data.Time.Clock (diffUTCTime, getCurrentTime) import Data.Time.Clock.System (SystemTime (..), getSystemTime) import Data.Type.Equality (testEquality, (:~:) (Refl)) import Data.Word (Word16) -import qualified Database.SQLite.Simple as SQL import GHC.Stack (withFrozenCallStack) import SMPAgentClient import SMPClient (cfg, prevRange, prevVersion, testPort, testPort2, testStoreLogFile2, testStoreMsgsDir2, withSmpServer, withSmpServerConfigOn, withSmpServerProxy, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn) @@ -85,8 +84,9 @@ import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestSte import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers (..), createAgentStore) import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO, REQ, SENT) import qualified Simplex.Messaging.Agent.Protocol as A -import Simplex.Messaging.Agent.Store.Common (DBStore (..), withTransaction') -import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..)) +import Simplex.Messaging.Agent.Store.Common (DBStore (..), withTransaction) +import qualified Simplex.Messaging.Agent.Store.DB as DB +import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..), MigrationError (..)) import Simplex.Messaging.Client (NetworkConfig (..), ProtocolClientConfig (..), SMPProxyFallback (..), SMPProxyMode (..), TransportSessionMode (..), defaultClientConfig) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern IKPQOff, pattern IKPQOn, pattern PQEncOff, pattern PQEncOn, pattern PQSupportOff, pattern PQSupportOn) @@ -109,6 +109,9 @@ import Test.Hspec import UnliftIO import Util import XFTPClient (testXFTPServer) +#if defined(dbPostgres) +import Fixtures +#endif type AEntityTransmission e = (ACorrId, ConnId, AEvent e) @@ -259,7 +262,6 @@ sendMessage c connId msgFlags msgBody = do liftIO $ pqEnc `shouldBe` PQEncOn pure msgId --- TODO [postgres] run with postgres functionalAPITests :: ATransport -> Spec functionalAPITests t = do describe "Establishing duplex connection" $ do @@ -327,6 +329,8 @@ functionalAPITests t = do it "should expire multiple messages" $ testExpireManyMessages t it "should expire one message if quota is exceeded" $ testExpireMessageQuota t it "should expire multiple messages if quota is exceeded" $ testExpireManyMessagesQuota t +#if !defined(dbPostgres) + -- TODO [postgres] restore from outdated db backup (we use copyFile/renameFile for sqlite) describe "Ratchet synchronization" $ do it "should report ratchet de-synchronization, synchronize ratchets" $ testRatchetSync t @@ -338,6 +342,7 @@ functionalAPITests t = do testRatchetSyncSuspendForeground t it "should synchronize ratchets when clients start synchronization simultaneously" $ testRatchetSyncSimultaneous t +#endif describe "Subscription mode OnlyCreate" $ do it "messages delivered only when polled (v8 - slow handshake)" $ withSmpServer t testOnlyCreatePullSlowHandshake @@ -2563,7 +2568,7 @@ testSwitchAsync servers = do withB :: (AgentClient -> IO a) -> IO a withB = withAgent 2 agentCfg servers testDB2 -withAgent :: HasCallStack => Int -> AgentConfig -> InitialAgentServers -> FilePath -> (HasCallStack => AgentClient -> IO a) -> IO a +withAgent :: HasCallStack => Int -> AgentConfig -> InitialAgentServers -> String -> (HasCallStack => AgentClient -> IO a) -> IO a withAgent clientId cfg' servers dbPath = bracket (getSMPAgentClient' clientId cfg' servers dbPath) (\a -> disposeAgentClient a >> threadDelay 100000) sessionSubscribe :: (forall a. (AgentClient -> IO a) -> IO a) -> [ConnId] -> (AgentClient -> ExceptT AgentErrorType IO ()) -> IO () @@ -3093,13 +3098,27 @@ testTwoUsers = withAgentClients2 $ \a b -> do hasClients :: HasCallStack => AgentClient -> Int -> ExceptT AgentErrorType IO () hasClients c n = liftIO $ M.size <$> readTVarIO (smpClients c) `shouldReturn` n -getSMPAgentClient' :: Int -> AgentConfig -> InitialAgentServers -> FilePath -> IO AgentClient +getSMPAgentClient' :: Int -> AgentConfig -> InitialAgentServers -> String -> IO AgentClient getSMPAgentClient' clientId cfg' initServers dbPath = do - Right st <- liftIO $ createAgentStore dbPath "" False MCError + Right st <- liftIO $ createStore dbPath c <- getSMPAgentClient_ clientId cfg' initServers st False - when (dbNew st) $ withTransaction' st (`SQL.execute_` "INSERT INTO users (user_id) VALUES (1)") + when (dbNew st) $ insertUser st pure c +#if defined(dbPostgres) +createStore :: String -> IO (Either MigrationError DBStore) +createStore schema = createAgentStore testDBConnectInfo schema MCError + +insertUser :: DBStore -> IO () +insertUser st = withTransaction st (`DB.execute_` "INSERT INTO users DEFAULT VALUES") +#else +createStore :: String -> IO (Either MigrationError DBStore) +createStore dbPath = createAgentStore dbPath "" False MCError + +insertUser :: DBStore -> IO () +insertUser st = withTransaction st (`DB.execute_` "INSERT INTO users (user_id) VALUES (1)") +#endif + testServerMultipleIdentities :: HasCallStack => IO () testServerMultipleIdentities = withAgentClients2 $ \alice bob -> runRight_ $ do diff --git a/tests/AgentTests/MigrationTests.hs b/tests/AgentTests/MigrationTests.hs index 31ec79bf9..fb8550a7d 100644 --- a/tests/AgentTests/MigrationTests.hs +++ b/tests/AgentTests/MigrationTests.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} module AgentTests.MigrationTests (migrationTests) where @@ -5,17 +6,24 @@ module AgentTests.MigrationTests (migrationTests) where import Control.Monad import Data.Maybe (fromJust) import Data.Word (Word32) -import Database.SQLite.Simple (fromOnly) import Simplex.Messaging.Agent.Store.Common (DBStore, withTransaction) import Simplex.Messaging.Agent.Store.Migrations (migrationsToRun) -import Simplex.Messaging.Agent.Store.SQLite (closeDBStore, createDBStore) -import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB import Simplex.Messaging.Agent.Store.Shared -import System.Directory (removeFile) import System.Random (randomIO) import Test.Hspec +#if defined(dbPostgres) +import Database.PostgreSQL.Simple (fromOnly) +import Fixtures +import Simplex.Messaging.Agent.Store.Postgres (closeDBStore, createDBStore) +import Simplex.Messaging.Agent.Store.Postgres.Util (dropSchema) +import qualified Simplex.Messaging.Agent.Store.Postgres.DB as DB +#else +import Database.SQLite.Simple (fromOnly) +import Simplex.Messaging.Agent.Store.SQLite (closeDBStore, createDBStore) +import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB +import System.Directory (removeFile) +#endif --- TODO [postgres] run with postgres migrationTests :: Spec migrationTests = do it "should determine migrations to run" testMigrationsToRun @@ -98,9 +106,6 @@ migrationTests = do ([m1, m2, m3, m4], [t1, t2, t3, t4]) ([m1, m2, m4], [MCYesUp, MCYesUpDown, MCError], Left . MigrationError $ MTREDifferent (name m4) (name m3)) -testDB :: FilePath -testDB = "tests/tmp/test_migrations.db" - m1 :: Migration m1 = Migration "20230301-migration1" "create table test1 (id1 integer primary key);" Nothing @@ -180,21 +185,46 @@ testMigration :: IO () testMigration (initMs, initTables) (finalMs, confirmModes, tablesOrError) = forM_ confirmModes $ \confirmMode -> do r <- randomIO :: IO Word32 - let dpPath = testDB <> show r - Right st <- createDBStore dpPath "" False initMs MCError + Right st <- createStore r initMs MCError st `shouldHaveTables` initTables closeDBStore st case tablesOrError of Right tables -> do - Right st' <- createDBStore dpPath "" False finalMs confirmMode + Right st' <- createStore r finalMs confirmMode st' `shouldHaveTables` tables closeDBStore st' Left e -> do - Left e' <- createDBStore dpPath "" False finalMs confirmMode + Left e' <- createStore r finalMs confirmMode e `shouldBe` e' - removeFile dpPath - where - shouldHaveTables :: DBStore -> [String] -> IO () - st `shouldHaveTables` expected = do - tables <- map fromOnly <$> withTransaction st (`DB.query_` "SELECT name FROM sqlite_schema WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY 1;") - tables `shouldBe` "migrations" : expected + cleanup r + +#if defined(dbPostgres) +testSchema :: Word32 -> String +testSchema randSuffix = "test_migrations_schema" <> show randSuffix + +createStore :: Word32 -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError DBStore) +createStore randSuffix migrations confirmMigrations = + createDBStore testDBConnectInfo (testSchema randSuffix) migrations confirmMigrations + +cleanup :: Word32 -> IO () +cleanup randSuffix = dropSchema testDBConnectInfo (testSchema randSuffix) + +shouldHaveTables :: DBStore -> [String] -> IO () +st `shouldHaveTables` expected = do + tables <- map fromOnly <$> withTransaction st (`DB.query_` "SELECT table_name FROM information_schema.tables WHERE table_schema = current_schema() AND table_type = 'BASE TABLE' ORDER BY 1") + tables `shouldBe` "migrations" : expected +#else +testDB :: Word32 -> FilePath +testDB randSuffix = "tests/tmp/test_migrations.db" <> show randSuffix + +createStore :: Word32 -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError DBStore) +createStore randSuffix = createDBStore (testDB randSuffix) "" False + +cleanup :: Word32 -> IO () +cleanup randSuffix = removeFile (testDB randSuffix) + +shouldHaveTables :: DBStore -> [String] -> IO () +st `shouldHaveTables` expected = do + tables <- map fromOnly <$> withTransaction st (`DB.query_` "SELECT name FROM sqlite_schema WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY 1") + tables `shouldBe` "migrations" : expected +#endif diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index ce07d580a..33e15792e 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -76,15 +76,9 @@ import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgFlags (MsgFlags), NtfSer import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server.Env.STM (ServerConfig (..)) import Simplex.Messaging.Transport (ATransport) -import System.Directory (doesFileExist, removeFile) import Test.Hspec import UnliftIO -removeFileIfExists :: FilePath -> IO () -removeFileIfExists filePath = do - fileExists <- doesFileExist filePath - when fileExists $ removeFile filePath - notificationTests :: ATransport -> Spec notificationTests t = do describe "Managing notification tokens" $ do diff --git a/tests/AgentTests/SchemaDump.hs b/tests/AgentTests/SchemaDump.hs index ba344a4b1..b7fcce8ee 100644 --- a/tests/AgentTests/SchemaDump.hs +++ b/tests/AgentTests/SchemaDump.hs @@ -37,7 +37,6 @@ appLint = "src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_lint.sql" testSchema :: FilePath testSchema = "tests/tmp/test_agent_schema.sql" --- TODO [postgres] run with postgres schemaDumpTest :: Spec schemaDumpTest = do it "verify and overwrite schema dump" testVerifySchemaDump diff --git a/tests/CoreTests/StoreLogTests.hs b/tests/CoreTests/StoreLogTests.hs index e24f9f1ea..90bea0192 100644 --- a/tests/CoreTests/StoreLogTests.hs +++ b/tests/CoreTests/StoreLogTests.hs @@ -10,13 +10,12 @@ module CoreTests.StoreLogTests where import Control.Concurrent.STM import Control.Monad +import CoreTests.MsgStoreTests import Crypto.Random (ChaChaDRG) import qualified Data.ByteString.Char8 as B import Data.Either (partitionEithers) import qualified Data.Map.Strict as M import SMPClient -import AgentTests.SQLiteTests -import CoreTests.MsgStoreTests import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol @@ -27,6 +26,9 @@ import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.Server.StoreLog import Test.Hspec +testPublicAuthKey :: C.APublicAuthKey +testPublicAuthKey = C.APublicAuthKey C.SEd25519 (C.publicKey "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe") + testNtfCreds :: TVar ChaChaDRG -> IO NtfCreds testNtfCreds g = do (notifierKey, _) <- atomically $ C.generateAuthKeyPair C.SX25519 g @@ -54,7 +56,8 @@ storeLogTests = ((rId, qr), ntfCreds, date) <- runIO $ do g <- C.newRandom (,,) <$> testNewQueueRec g sndSecure <*> testNtfCreds g <*> getSystemDate - testSMPStoreLog ("SMP server store log, sndSecure = " <> show sndSecure) + testSMPStoreLog + ("SMP server store log, sndSecure = " <> show sndSecure) [ SLTC { name = "create new queue", saved = [CreateQueue qr], @@ -66,7 +69,7 @@ storeLogTests = saved = [CreateQueue qr, SecureQueue rId testPublicAuthKey], compacted = [CreateQueue qr {senderKey = Just testPublicAuthKey}], state = M.fromList [(rId, qr {senderKey = Just testPublicAuthKey})] - }, + }, SLTC { name = "create and delete queue", saved = [CreateQueue qr, DeleteQueue rId], @@ -90,7 +93,7 @@ storeLogTests = saved = [CreateQueue qr, UpdateTime rId date], compacted = [CreateQueue qr {updatedAt = Just date}], state = M.fromList [(rId, qr {updatedAt = Just date})] - } + } ] testSMPStoreLog :: String -> [SMPStoreLogTestCase] -> Spec diff --git a/tests/Fixtures.hs b/tests/Fixtures.hs new file mode 100644 index 000000000..a8e2542ec --- /dev/null +++ b/tests/Fixtures.hs @@ -0,0 +1,16 @@ +{-# LANGUAGE CPP #-} + +module Fixtures where + +#if defined(dbPostgres) +import Database.PostgreSQL.Simple (ConnectInfo (..), defaultConnectInfo) +#endif + +#if defined(dbPostgres) +testDBConnectInfo :: ConnectInfo +testDBConnectInfo = + defaultConnectInfo { + connectUser = "test_user", + connectDatabase = "test_agent_db" + } +#endif diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 5e5f91b09..fc66f2ab1 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} @@ -25,6 +26,17 @@ import Simplex.Messaging.Protocol (NtfServer, ProtoServerWithAuth (..), Protocol import Simplex.Messaging.Transport import XFTPClient (testXFTPServer) +-- name fixtures are reused, but they are used as schema name instead of database file path +#if defined(dbPostgres) +testDB :: String +testDB = "smp_agent_test_protocol_schema" + +testDB2 :: String +testDB2 = "smp_agent2_test_protocol_schema" + +testDB3 :: String +testDB3 = "smp_agent3_test_protocol_schema" +#else testDB :: FilePath testDB = "tests/tmp/smp-agent.test.protocol.db" @@ -33,6 +45,7 @@ testDB2 = "tests/tmp/smp-agent2.test.protocol.db" testDB3 :: FilePath testDB3 = "tests/tmp/smp-agent3.test.protocol.db" +#endif testSMPServer :: SMPServer testSMPServer = "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001" diff --git a/tests/SMPProxyTests.hs b/tests/SMPProxyTests.hs index b827edda2..cbdc7a3f5 100644 --- a/tests/SMPProxyTests.hs +++ b/tests/SMPProxyTests.hs @@ -1,4 +1,5 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} @@ -46,6 +47,10 @@ import System.Random (randomRIO) import Test.Hspec import UnliftIO import Util +#if defined(dbPostgres) +import Fixtures +import Simplex.Messaging.Agent.Store.Postgres.Util (dropAllSchemasExceptSystem) +#endif smpProxyTests :: Spec smpProxyTests = do @@ -101,7 +106,11 @@ smpProxyTests = do it "100x100 N4 C16" . twoServersMoreConc $ withNumCapabilities 4 $ 100 `inParrallel` deliver 100 it "100x100 N" . twoServersFirstProxy $ withNCPUCapabilities $ 100 `inParrallel` deliver 100 it "500x20" . twoServersFirstProxy $ 500 `inParrallel` deliver 20 +#if defined(dbPostgres) + after_ (dropAllSchemasExceptSystem testDBConnectInfo) . describe "agent API" $ do +#else describe "agent API" $ do +#endif describe "one server" $ do it "always via proxy" . oneServer $ agentDeliverMessageViaProxy ([srv1], SPMAlways, True) ([srv1], SPMAlways, True) C.SEd448 "hello 1" "hello 2" 1 diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 744ceb437..f9c9dcce1 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -15,13 +15,12 @@ module ServerTests where -import AgentTests.NotificationTests (removeFileIfExists) -import CoreTests.MsgStoreTests (testJournalStoreCfg) import Control.Concurrent (ThreadId, killThread, threadDelay) import Control.Concurrent.STM import Control.Exception (SomeException, try) import Control.Monad import Control.Monad.IO.Class +import CoreTests.MsgStoreTests (testJournalStoreCfg) import Data.Bifunctor (first) import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) @@ -51,9 +50,10 @@ import System.TimeIt (timeItT) import System.Timeout import Test.HUnit import Test.Hspec +import Util (removeFileIfExists) serverTests :: SpecWith (ATransport, AMSType) -serverTests = do +serverTests = do describe "SMP queues" $ do describe "NEW and KEY commands, SEND messages" testCreateSecure describe "NEW and SKEY commands" $ do diff --git a/tests/Test.hs b/tests/Test.hs index f8505b133..09fb856fd 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -1,8 +1,8 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE TypeApplications #-} import AgentTests (agentTests) -import AgentTests.SchemaDump (schemaDumpTest) import CLITests import Control.Concurrent (threadDelay) import qualified Control.Exception as E @@ -34,6 +34,12 @@ import Test.Hspec import XFTPAgent import XFTPCLI import XFTPServerTests (xftpServerTests) +#if defined(dbPostgres) +import Fixtures +import Simplex.Messaging.Agent.Store.Postgres.Util (createDBAndUserIfNotExists, dropDatabaseAndUser) +#else +import AgentTests.SchemaDump (schemaDumpTest) +#endif logCfg :: LogConfig logCfg = LogConfig {lc_file = Nothing, lc_stderr = True} @@ -45,10 +51,17 @@ main = do setEnv "APNS_KEY_ID" "H82WD9K9AQ" setEnv "APNS_KEY_FILE" "./tests/fixtures/AuthKey_H82WD9K9AQ.p8" hspec +#if defined(dbPostgres) + . beforeAll_ (dropDatabaseAndUser testDBConnectInfo >> createDBAndUserIfNotExists testDBConnectInfo) + . afterAll_ (dropDatabaseAndUser testDBConnectInfo) +#endif . before_ (createDirectoryIfMissing False "tests/tmp") . after_ (eventuallyRemove "tests/tmp" 3) $ do +-- TODO [postgres] schema dump for postgres +#if !defined(dbPostgres) describe "Agent SQLite schema dump" schemaDumpTest +#endif describe "Core tests" $ do describe "Batching tests" batchingTests describe "Encoding tests" encodingTests diff --git a/tests/Util.hs b/tests/Util.hs index 6ad6d054f..0ad371b69 100644 --- a/tests/Util.hs +++ b/tests/Util.hs @@ -1,9 +1,10 @@ module Util where -import Control.Monad (replicateM) +import Control.Monad (replicateM, when) import Data.Either (partitionEithers) import Data.List (tails) import GHC.Conc (getNumCapabilities, getNumProcessors, setNumCapabilities) +import System.Directory (doesFileExist, removeFile) import Test.Hspec import UnliftIO @@ -26,3 +27,8 @@ inParrallel n action = do combinations :: Int -> [a] -> [[a]] combinations 0 _ = [[]] combinations k xs = [y : ys | y : xs' <- tails xs, ys <- combinations (k - 1) xs'] + +removeFileIfExists :: FilePath -> IO () +removeFileIfExists filePath = do + fileExists <- doesFileExist filePath + when fileExists $ removeFile filePath diff --git a/tests/XFTPAgent.hs b/tests/XFTPAgent.hs index 6d6446959..f7e880083 100644 --- a/tests/XFTPAgent.hs +++ b/tests/XFTPAgent.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} @@ -46,40 +47,49 @@ import UnliftIO import UnliftIO.Concurrent import XFTPCLI import XFTPClient +#if defined(dbPostgres) +import Fixtures +import Simplex.Messaging.Agent.Store.Postgres.Util (dropAllSchemasExceptSystem) +#endif xftpAgentTests :: Spec -xftpAgentTests = around_ testBracket . describe "agent XFTP API" $ do - it "should send and receive file" $ withXFTPServer testXFTPAgentSendReceive - -- uncomment CPP option slow_servers and run hpack to run this test - xit "should send and receive file with slow server responses" $ - withXFTPServerCfg testXFTPServerConfig {responseDelay = 500000} $ - \_ -> testXFTPAgentSendReceive - it "should send and receive with encrypted local files" testXFTPAgentSendReceiveEncrypted - it "should send and receive large file with a redirect" testXFTPAgentSendReceiveRedirect - it "should send and receive small file without a redirect" testXFTPAgentSendReceiveNoRedirect - describe "sending and receiving with version negotiation" testXFTPAgentSendReceiveMatrix - it "should resume receiving file after restart" testXFTPAgentReceiveRestore - it "should cleanup rcv tmp path after permanent error" testXFTPAgentReceiveCleanup - it "should resume sending file after restart" testXFTPAgentSendRestore - xit'' "should cleanup snd prefix path after permanent error" testXFTPAgentSendCleanup - it "should delete sent file on server" testXFTPAgentDelete - it "should resume deleting file after restart" testXFTPAgentDeleteRestore - -- TODO when server is fixed to correctly send AUTH error, this test has to be modified to expect AUTH error - it "if file is deleted on server, should limit retries and continue receiving next file" testXFTPAgentDeleteOnServer - it "if file is expired on server, should report error and continue receiving next file" testXFTPAgentExpiredOnServer - it "should request additional recipient IDs when number of recipients exceeds maximum per request" testXFTPAgentRequestAdditionalRecipientIDs - describe "XFTP server test via agent API" $ do - it "should pass without basic auth" $ testXFTPServerTest Nothing (noAuthSrv testXFTPServer2) `shouldReturn` Nothing - let srv1 = testXFTPServer2 {keyHash = "1234"} - it "should fail with incorrect fingerprint" $ do - testXFTPServerTest Nothing (noAuthSrv srv1) `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) NETWORK) - describe "server with password" $ do - let auth = Just "abcd" - srv = ProtoServerWithAuth testXFTPServer2 - authErr = Just (ProtocolTestFailure TSCreateFile $ XFTP (B.unpack $ strEncode testXFTPServer2) AUTH) - it "should pass with correct password" $ testXFTPServerTest auth (srv auth) `shouldReturn` Nothing - it "should fail without password" $ testXFTPServerTest auth (srv Nothing) `shouldReturn` authErr - it "should fail with incorrect password" $ testXFTPServerTest auth (srv $ Just "wrong") `shouldReturn` authErr +xftpAgentTests = + around_ testBracket +#if defined(dbPostgres) + . after_ (dropAllSchemasExceptSystem testDBConnectInfo) +#endif + . describe "agent XFTP API" $ do + it "should send and receive file" $ withXFTPServer testXFTPAgentSendReceive + -- uncomment CPP option slow_servers and run hpack to run this test + xit "should send and receive file with slow server responses" $ + withXFTPServerCfg testXFTPServerConfig {responseDelay = 500000} $ + \_ -> testXFTPAgentSendReceive + it "should send and receive with encrypted local files" testXFTPAgentSendReceiveEncrypted + it "should send and receive large file with a redirect" testXFTPAgentSendReceiveRedirect + it "should send and receive small file without a redirect" testXFTPAgentSendReceiveNoRedirect + describe "sending and receiving with version negotiation" testXFTPAgentSendReceiveMatrix + it "should resume receiving file after restart" testXFTPAgentReceiveRestore + it "should cleanup rcv tmp path after permanent error" testXFTPAgentReceiveCleanup + it "should resume sending file after restart" testXFTPAgentSendRestore + xit'' "should cleanup snd prefix path after permanent error" testXFTPAgentSendCleanup + it "should delete sent file on server" testXFTPAgentDelete + it "should resume deleting file after restart" testXFTPAgentDeleteRestore + -- TODO when server is fixed to correctly send AUTH error, this test has to be modified to expect AUTH error + it "if file is deleted on server, should limit retries and continue receiving next file" testXFTPAgentDeleteOnServer + it "if file is expired on server, should report error and continue receiving next file" testXFTPAgentExpiredOnServer + it "should request additional recipient IDs when number of recipients exceeds maximum per request" testXFTPAgentRequestAdditionalRecipientIDs + describe "XFTP server test via agent API" $ do + it "should pass without basic auth" $ testXFTPServerTest Nothing (noAuthSrv testXFTPServer2) `shouldReturn` Nothing + let srv1 = testXFTPServer2 {keyHash = "1234"} + it "should fail with incorrect fingerprint" $ do + testXFTPServerTest Nothing (noAuthSrv srv1) `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) NETWORK) + describe "server with password" $ do + let auth = Just "abcd" + srv = ProtoServerWithAuth testXFTPServer2 + authErr = Just (ProtocolTestFailure TSCreateFile $ XFTP (B.unpack $ strEncode testXFTPServer2) AUTH) + it "should pass with correct password" $ testXFTPServerTest auth (srv auth) `shouldReturn` Nothing + it "should fail without password" $ testXFTPServerTest auth (srv Nothing) `shouldReturn` authErr + it "should fail with incorrect password" $ testXFTPServerTest auth (srv $ Just "wrong") `shouldReturn` authErr rfProgress :: forall m. (HasCallStack, MonadIO m, MonadFail m) => AgentClient -> Int64 -> m () rfProgress c expected = loop 0