diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 18e9d0465..f155ce77b 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -153,7 +153,7 @@ import Data.Bifunctor (bimap, first) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Composition -import Data.Either (isRight, partitionEithers, rights) +import Data.Either (fromRight, isRight, partitionEithers, rights) import Data.Foldable (foldl', toList) import Data.Functor (($>)) import Data.Functor.Identity @@ -221,7 +221,6 @@ import Simplex.Messaging.Protocol SMPMsgMeta, SParty (..), SProtocolType (..), - ServiceSub (..), ServiceSubResult, SndPublicAuthKey, SubscriptionMode (..), @@ -1451,7 +1450,23 @@ subscribeAllConnections' c onlyNeeded activeUserId_ = handleErr $ do let userSrvs' = case activeUserId_ of Just activeUserId -> sortOn (\(uId, _) -> if uId == activeUserId then 0 else 1 :: Int) userSrvs Nothing -> userSrvs - rs <- lift $ mapConcurrently (subscribeUserServer maxPending currPending) userSrvs' + useServices <- readTVarIO $ useClientServices c + -- These options are possible below: + -- 1) services fully disabled: + -- No service subscriptions will be attempted, and existing services and association will remain in in the database, + -- but they will be ignored because of hasService parameter set to False. + -- This approach preserves performance for all clients that do not use services. + -- 2) at least one user ID has services enabled: + -- Service will be loaded for all user/server combinations: + -- a) service is enabled for user ID and service record exists: subscription will be attempted, + -- b) service is disabled and record exists: service record and all associations will be removed, + -- c) service is disabled or no record: no subscription attempt. + -- On successful service subscription, only unassociated queues will be subscribed. + userSrvs'' <- + if any id useServices + then lift $ mapConcurrently (subscribeService useServices) userSrvs' + else pure $ map (,False) userSrvs' + rs <- lift $ mapConcurrently (subscribeUserServer maxPending currPending) userSrvs'' let (errs, oks) = partitionEithers rs logInfo $ "subscribed " <> tshow (sum oks) <> " queues" forM_ (L.nonEmpty errs) $ notifySub c . ERRS . L.map ("",) @@ -1460,21 +1475,31 @@ subscribeAllConnections' c onlyNeeded activeUserId_ = handleErr $ do resumeAllCommands c where handleErr = (`catchAllErrors` \e -> notifySub' c "" (ERR e) >> throwE e) - subscribeUserServer :: Int -> TVar Int -> (UserId, SMPServer) -> AM' (Either AgentErrorType Int) - subscribeUserServer maxPending currPending (userId, srv) = do + subscribeService :: Map UserId Bool -> (UserId, SMPServer) -> AM' ((UserId, SMPServer), ServiceAssoc) + subscribeService useServices us@(userId, srv) = fmap ((us,) . fromRight False) $ tryAllErrors' $ do + withStore' c (\db -> getSubscriptionService db userId srv) >>= \case + Just serviceSub -> case M.lookup userId useServices of + Just True -> tryAllErrors (subscribeClientService c True userId srv serviceSub) >>= \case + Left e | clientServiceError e -> unassocQueues $> False + _ -> pure True + _ -> unassocQueues $> False + where + unassocQueues = withStore' c $ \db -> unassocUserServerRcvQueueSubs db userId srv + _ -> pure False + subscribeUserServer :: Int -> TVar Int -> ((UserId, SMPServer), ServiceAssoc) -> AM' (Either AgentErrorType Int) + subscribeUserServer maxPending currPending ((userId, srv), hasService) = do atomically $ whenM ((maxPending <=) <$> readTVar currPending) retry tryAllErrors' $ do qs <- withStore' c $ \db -> do - qs <- getUserServerRcvQueueSubs db userId srv onlyNeeded - atomically $ modifyTVar' currPending (+ length qs) -- update before leaving transaction + qs <- getUserServerRcvQueueSubs db userId srv onlyNeeded hasService + unless (null qs) $ atomically $ modifyTVar' currPending (+ length qs) -- update before leaving transaction pure qs let n = length qs - lift $ subscribe qs `E.finally` atomically (modifyTVar' currPending $ subtract n) + unless (null qs) $ lift $ subscribe qs `E.finally` atomically (modifyTVar' currPending $ subtract n) pure n where subscribe qs = do rs <- subscribeUserServerQueues c userId srv qs - -- TODO [certs rcv] storeClientServiceAssocs store associations of queues with client service ID ns <- asks ntfSupervisor whenM (liftIO $ hasInstantNotifications ns) $ sendNtfCreate ns rs sendNtfCreate :: NtfSupervisor -> [(RcvQueueSub, Either AgentErrorType (Maybe SMP.ServiceId))] -> AM' () @@ -1522,7 +1547,7 @@ subscribeClientServices' c userId = useService = liftIO $ (Just True ==) <$> TM.lookupIO userId (useClientServices c) subscribe = do srvs <- withStore' c (`getClientServiceServers` userId) - lift $ M.fromList <$> mapConcurrently (\(srv, ServiceSub _ n idsHash) -> fmap (srv,) $ tryAllErrors' $ subscribeClientService c False userId srv n idsHash) srvs + lift $ M.fromList <$> mapConcurrently (\(srv, serviceSub) -> fmap (srv,) $ tryAllErrors' $ subscribeClientService c False userId srv serviceSub) srvs -- requesting messages sequentially, to reduce memory usage getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta))) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 77d73027d..7acfb0b49 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -120,6 +120,7 @@ module Simplex.Messaging.Agent.Client getAgentSubscriptions, slowNetworkConfig, protocolClientError, + clientServiceError, Worker (..), SessionVar (..), SubscriptionsInfo (..), @@ -303,7 +304,7 @@ import Simplex.Messaging.Session import Simplex.Messaging.SystemTime import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (SMPServiceRole (..), SMPVersion, ServiceCredentials (..), SessionId, THClientService' (..), THandleParams (sessionId, thVersion), TransportError (..), TransportPeer (..), sndAuthKeySMPVersion, shortLinksSMPVersion, newNtfCredsSMPVersion) +import Simplex.Messaging.Transport (HandshakeError (..), SMPServiceRole (..), SMPVersion, ServiceCredentials (..), SessionId, THClientService' (..), THandleAuth (..), THandleParams (sessionId, thAuth, thVersion), TransportError (..), TransportPeer (..), sndAuthKeySMPVersion, shortLinksSMPVersion, newNtfCredsSMPVersion) import Simplex.Messaging.Transport.Client (TransportHost (..)) import Simplex.Messaging.Transport.Credentials import Simplex.Messaging.Util @@ -619,7 +620,7 @@ getServiceCredentials c userId srv = let g = agentDRG c ((C.KeyHash kh, serviceCreds), serviceId_) <- withStore' c $ \db -> - getClientService db userId srv >>= \case + getClientServiceCredentials db userId srv >>= \case Just service -> pure service Nothing -> do cred <- genCredentials g Nothing (25, 24 * 999999) "simplex" @@ -747,15 +748,13 @@ smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm smp <- liftError (protocolClientError SMP $ B.unpack $ strEncode srv) $ do ts <- readTVarIO proxySessTs ExceptT $ getProtocolClient g nm tSess cfg' presetDomains (Just msgQ) ts $ smpClientDisconnected c tSess env v' prs - -- TODO [certs rcv] add service to SS, possibly combine with SS.setSessionId atomically $ SS.setSessionId tSess (sessionId $ thParams smp) $ currentSubs c updateClientService service smp pure SMPConnectedClient {connectedClient = smp, proxiedRelays = prs} - -- TODO [certs rcv] this should differentiate between service ID just set and service ID changed, and in the latter case disassociate the queues updateClientService service smp = case (service, smpClientService smp) of - (Just (_, serviceId_), Just THClientService {serviceId}) - | serviceId_ /= Just serviceId -> withStore' c $ \db -> setClientServiceId db userId srv serviceId - | otherwise -> pure () + (Just (_, serviceId_), Just THClientService {serviceId}) -> withStore' c $ \db -> do + setClientServiceId db userId srv serviceId + forM_ serviceId_ $ \sId -> when (sId /= serviceId) $ removeRcvServiceAssocs db userId srv (Just _, Nothing) -> withStore' c $ \db -> deleteClientService db userId srv -- e.g., server version downgrade (Nothing, Just _) -> logError "server returned serviceId without service credentials in request" (Nothing, Nothing) -> pure () @@ -1258,6 +1257,14 @@ protocolClientError protocolError_ host = \case PCEServiceUnavailable {} -> BROKER host NO_SERVICE PCEIOError e -> BROKER host $ NETWORK $ NEConnectError $ E.displayException e +-- it is consistent with smpClientServiceError +clientServiceError :: AgentErrorType -> Bool +clientServiceError = \case + BROKER _ NO_SERVICE -> True + SMP _ SMP.SERVICE -> True + SMP _ (SMP.PROXY (SMP.BROKER NO_SERVICE)) -> True -- for completeness, it cannot happen. + _ -> False + data ProtocolTestStep = TSConnect | TSDisconnect @@ -1446,8 +1453,8 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl withClient c nm tSess $ \(SMPConnectedClient smp _) -> do (ntfKeys, ntfCreds) <- liftIO $ mkNtfCreds a g smp (thParams smp,ntfKeys,) <$> createSMPQueue smp nm nonce_ rKeys dhKey auth subMode (queueReqData cqrd) ntfCreds - -- TODO [certs rcv] validate that serviceId is the same as in the client session, fail otherwise - -- possibly, it should allow returning Nothing - it would indicate incorrect old version + let sessServiceId = (\THClientService {serviceId = sId} -> sId) <$> (clientService =<< thAuth thParams') + when (isJust serviceId && serviceId /= sessServiceId) $ logError "incorrect service ID in NEW response" liftIO . logServer "<--" c srv NoEntity $ B.unwords ["IDS", logSecret rcvId, logSecret sndId] shortLink <- mkShortLinkCreds thParams' qik let rq = @@ -1463,7 +1470,7 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl sndId, queueMode, shortLink, - rcvServiceAssoc = isJust serviceId, + rcvServiceAssoc = isJust serviceId && serviceId == sessServiceId, status = New, enableNtfs, clientNoticeId = Nothing, @@ -1559,6 +1566,8 @@ temporaryAgentError :: AgentErrorType -> Bool temporaryAgentError = \case BROKER _ e -> tempBrokerError e SMP _ (SMP.PROXY (SMP.BROKER e)) -> tempBrokerError e + SMP _ (SMP.STORE _) -> True + NTF _ (SMP.STORE _) -> True XFTP _ XFTP.TIMEOUT -> True PROXY _ _ (ProxyProtocolError (SMP.PROXY (SMP.BROKER e))) -> tempBrokerError e PROXY _ _ (ProxyProtocolError (SMP.PROXY SMP.NO_SESSION)) -> True @@ -1569,6 +1578,7 @@ temporaryAgentError = \case tempBrokerError = \case NETWORK _ -> True TIMEOUT -> True + TRANSPORT (TEHandshake BAD_SERVICE) -> True -- this error is considered temporary because it is DB error _ -> False temporaryOrHostError :: AgentErrorType -> Bool @@ -1715,11 +1725,16 @@ processClientNotices c@AgentClient {presetServers} tSess notices = do notifySub' c "" $ ERR e resubscribeClientService :: AgentClient -> SMPTransportSession -> ServiceSub -> AM ServiceSubResult -resubscribeClientService c tSess serviceSub = - withServiceClient c tSess $ \smp _ -> subscribeClientService_ c True tSess smp serviceSub - -subscribeClientService :: AgentClient -> Bool -> UserId -> SMPServer -> Int64 -> IdsHash -> AM ServiceSubResult -subscribeClientService c withEvent userId srv n idsHash = +resubscribeClientService c tSess@(userId, srv, _) serviceSub = + withServiceClient c tSess (\smp _ -> subscribeClientService_ c True tSess smp serviceSub) `catchE` \e -> do + when (clientServiceError e) $ do + qs <- withStore' c $ \db -> unassocUserServerRcvQueueSubs db userId srv + void $ lift $ subscribeUserServerQueues c userId srv qs + throwE e + +-- TODO [certs rcv] update service in the database if it has different ID and re-associate queues, and send event +subscribeClientService :: AgentClient -> Bool -> UserId -> SMPServer -> ServiceSub -> AM ServiceSubResult +subscribeClientService c withEvent userId srv (ServiceSub _ n idsHash) = withServiceClient c tSess $ \smp smpServiceId -> do let serviceSub = ServiceSub smpServiceId n idsHash atomically $ SS.setPendingServiceSub tSess serviceSub $ currentSubs c @@ -1728,14 +1743,15 @@ subscribeClientService c withEvent userId srv n idsHash = tSess = (userId, srv, Nothing) withServiceClient :: AgentClient -> SMPTransportSession -> (SMPClient -> ServiceId -> ExceptT SMPClientError IO a) -> AM a -withServiceClient c tSess action = +withServiceClient c tSess subscribe = withLogClient c NRMBackground tSess B.empty "SUBS" $ \(SMPConnectedClient smp _) -> case (\THClientService {serviceId} -> serviceId) <$> smpClientService smp of - Just smpServiceId -> action smp smpServiceId + Just smpServiceId -> subscribe smp smpServiceId Nothing -> throwE PCEServiceUnavailable +-- TODO [certs rcv] send subscription error event? subscribeClientService_ :: AgentClient -> Bool -> SMPTransportSession -> SMPClient -> ServiceSub -> ExceptT SMPClientError IO ServiceSubResult -subscribeClientService_ c withEvent tSess@(_, srv, _) smp expected@(ServiceSub _ n idsHash) = do +subscribeClientService_ c withEvent tSess@(userId, srv, _) smp expected@(ServiceSub _ n idsHash) = do subscribed <- subscribeService smp SMP.SRecipientService n idsHash let sessId = sessionId $ thParams smp r = serviceSubResult expected subscribed diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index a732d28d4..0d0b2af70 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -37,7 +37,9 @@ module Simplex.Messaging.Agent.Store.AgentStore -- * Client services createClientService, - getClientService, + getClientServiceCredentials, + getSubscriptionServices, + getSubscriptionService, getClientServiceServers, setClientServiceId, deleteClientService, @@ -52,8 +54,10 @@ module Simplex.Messaging.Agent.Store.AgentStore updateClientNotices, getSubscriptionServers, getUserServerRcvQueueSubs, + unassocUserServerRcvQueueSubs, unsetQueuesToSubscribe, setRcvServiceAssocs, + removeRcvServiceAssocs, getConnIds, getConn, getDeletedConn, @@ -419,8 +423,8 @@ createClientService db userId srv (kh, (cert, pk)) = do |] (userId, host srv, port srv, serverKeyHash_, kh, cert, pk) -getClientService :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ((C.KeyHash, TLS.Credential), Maybe ServiceId)) -getClientService db userId srv = +getClientServiceCredentials :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ((C.KeyHash, TLS.Credential), Maybe ServiceId)) +getClientServiceCredentials db userId srv = maybeFirstRow toService $ DB.query db @@ -435,21 +439,41 @@ getClientService db userId srv = where toService (kh, cert, pk, serviceId_) = ((kh, (cert, pk)), serviceId_) -getClientServiceServers :: DB.Connection -> UserId -> IO [(SMPServer, ServiceSub)] -getClientServiceServers db userId = - map toServer - <$> DB.query +getSubscriptionServices :: DB.Connection -> IO [(UserId, (SMPServer, ServiceSub))] +getSubscriptionServices db = map toUserService <$> DB.query_ db clientServiceQuery + where + toUserService (Only userId :. serviceRow) = (userId, toServerService serviceRow) + +getSubscriptionService :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ServiceSub) +getSubscriptionService db userId (SMPServer h p kh) = + maybeFirstRow toService $ + DB.query db [sql| - SELECT c.host, c.port, s.key_hash, c.service_id, c.service_queue_count, c.service_queue_ids_hash + SELECT c.service_id, c.service_queue_count, c.service_queue_ids_hash FROM client_services c JOIN servers s ON s.host = c.host AND s.port = c.port - WHERE c.user_id = ? + WHERE c.user_id = ? AND c.host = ? AND c.port = ? AND COALESCE(c.server_key_hash, s.key_hash) = ? |] - (Only userId) + (userId, h, p, kh) where - toServer (host, port, kh, serviceId, n, Binary idsHash) = - (SMPServer host port kh, ServiceSub serviceId n (IdsHash idsHash)) + toService (serviceId, qCnt, idsHash) = ServiceSub serviceId qCnt idsHash + +getClientServiceServers :: DB.Connection -> UserId -> IO [(SMPServer, ServiceSub)] +getClientServiceServers db userId = + map toServerService <$> DB.query db (clientServiceQuery <> " WHERE c.user_id = ?") (Only userId) + +clientServiceQuery :: Query +clientServiceQuery = + [sql| + SELECT c.host, c.port, COALESCE(c.server_key_hash, s.key_hash), c.service_id, c.service_queue_count, c.service_queue_ids_hash + FROM client_services c + JOIN servers s ON s.host = c.host AND s.port = c.port + |] + +toServerService :: (NonEmpty TransportHost, ServiceName, C.KeyHash, ServiceId, Int64, Binary ByteString) -> (ProtocolServer 'PSMP, ServiceSub) +toServerService (host, port, kh, serviceId, n, Binary idsHash) = + (SMPServer host port kh, ServiceSub serviceId n (IdsHash idsHash)) setClientServiceId :: DB.Connection -> UserId -> SMPServer -> ServiceId -> IO () setClientServiceId db userId srv serviceId = @@ -473,7 +497,9 @@ deleteClientService db userId srv = (userId, host srv, port srv) deleteClientServices :: DB.Connection -> UserId -> IO () -deleteClientServices db userId = DB.execute db "DELETE FROM client_services WHERE user_id = ?" (Only userId) +deleteClientServices db userId = do + DB.execute db "DELETE FROM client_services WHERE user_id = ?" (Only userId) + removeUserRcvServiceAssocs db userId createConn_ :: TVar ChaChaDRG -> @@ -2236,17 +2262,36 @@ getSubscriptionServers db onlyNeeded = toUserServer :: (UserId, NonEmpty TransportHost, ServiceName, C.KeyHash) -> (UserId, SMPServer) toUserServer (userId, host, port, keyHash) = (userId, SMPServer host port keyHash) -getUserServerRcvQueueSubs :: DB.Connection -> UserId -> SMPServer -> Bool -> IO [RcvQueueSub] -getUserServerRcvQueueSubs db userId srv onlyNeeded = +-- TODO [certs rcv] check index for getting queues with service present +getUserServerRcvQueueSubs :: DB.Connection -> UserId -> SMPServer -> Bool -> ServiceAssoc -> IO [RcvQueueSub] +getUserServerRcvQueueSubs db userId srv onlyNeeded hasService = map toRcvQueueSub <$> DB.query db - (rcvQueueSubQuery <> toSubscribe <> " c.deleted = 0 AND q.deleted = 0 AND c.user_id = ? AND q.host = ? AND q.port = ?") + (rcvQueueSubQuery <> toSubscribe <> " c.deleted = 0 AND q.deleted = 0 AND c.user_id = ? AND q.host = ? AND q.port = ?" <> serviceCond) (userId, host srv, port srv) where toSubscribe | onlyNeeded = " WHERE q.to_subscribe = 1 AND " | otherwise = " WHERE " + serviceCond + | hasService = " AND q.rcv_service_assoc = 0" + | otherwise = "" + +unassocUserServerRcvQueueSubs :: DB.Connection -> UserId -> SMPServer -> IO [RcvQueueSub] +unassocUserServerRcvQueueSubs db userId (SMPServer h p kh) = + map toRcvQueueSub + <$> DB.query + db + (removeRcvAssocsQuery <> " " <> returningColums) + (h, p, userId, kh) + where + returningColums = + [sql| + RETURNING c.user_id, rcv_queues.conn_id, rcv_queues.host, rcv_queues.port, COALESCE(rcv_queues.server_key_hash, s.key_hash), + rcv_queues.rcv_id, rcv_queues.rcv_private_key, rcv_queues.status, c.enable_ntfs, rcv_queues.client_notice_id, + rcv_queues.rcv_queue_id, rcv_queues.rcv_primary, rcv_queues.replace_rcv_queue_id + |] unsetQueuesToSubscribe :: DB.Connection -> IO () unsetQueuesToSubscribe db = DB.execute_ db "UPDATE rcv_queues SET to_subscribe = 0 WHERE to_subscribe = 1" @@ -2259,6 +2304,36 @@ setRcvServiceAssocs db rqs = DB.executeMany db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id = ?" $ map (Only . queueId) rqs #endif +removeRcvServiceAssocs :: DB.Connection -> UserId -> SMPServer -> IO () +removeRcvServiceAssocs db userId (SMPServer h p kh) = DB.execute db removeRcvAssocsQuery (h, p, userId, kh) + +removeRcvAssocsQuery :: Query +removeRcvAssocsQuery = + [sql| + UPDATE rcv_queues + SET rcv_service_assoc = 0 + FROM connections c, servers s + WHERE rcv_queues.host = ? + AND rcv_queues.port = ? + AND c.conn_id = rcv_queues.conn_id + AND c.user_id = ? + AND s.host = rcv_queues.host + AND s.port = rcv_queues.port + AND COALESCE(rcv_queues.server_key_hash, s.key_hash) = ? + |] + +removeUserRcvServiceAssocs :: DB.Connection -> UserId -> IO () +removeUserRcvServiceAssocs db userId = + DB.execute + db + [sql| + UPDATE rcv_queues + SET rcv_service_assoc = 0 + FROM connections c + WHERE c.conn_id = rcv_queues.conn_id AND c.user_id = ? + |] + (Only userId) + -- * getConn helpers getConnIds :: DB.Connection -> IO [ConnId] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index d5b8f8290..a670dd3e2 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -67,10 +67,10 @@ import Simplex.Messaging.Agent.Store.Migrations (DBMigrate (..), sharedMigrateSc import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Agent.Store.SQLite.Common import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB -import Simplex.Messaging.Agent.Store.SQLite.Util (SQLiteFunc, createStaticFunction, mkSQLiteFunc) +import Simplex.Messaging.Agent.Store.SQLite.Util import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfig (..), MigrationError (..)) import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Util (ifM, safeDecodeUtf8) +import Simplex.Messaging.Util (ifM, packZipWith, safeDecodeUtf8) import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist) import System.FilePath (takeDirectory, takeFileName, ()) @@ -116,9 +116,7 @@ connectDB path functions key track = do -- _printPragmas db path pure db where - functions' = SQLiteFuncDef "simplex_xor_md5_combine" 2 True sqliteXorMd5CombinePtr : functions prepare db = do - let db' = SQL.connectionHandle $ DB.conn db unless (BA.null key) . SQLite3.exec db' $ "PRAGMA key = " <> keyString key <> ";" SQLite3.exec db' . fromQuery $ [sql| @@ -128,9 +126,14 @@ connectDB path functions key track = do PRAGMA secure_delete = ON; PRAGMA auto_vacuum = FULL; |] - forM_ functions' $ \SQLiteFuncDef {funcName, argCount, deterministic, funcPtr} -> - createStaticFunction db' funcName argCount deterministic funcPtr - >>= either (throwIO . userError . show) pure + mapM_ addFunction functions' + where + db' = SQL.connectionHandle $ DB.conn db + functions' = SQLiteFuncDef "simplex_xor_md5_combine" 2 (SQLiteFuncPtr True sqliteXorMd5CombinePtr) : functions + addFunction SQLiteFuncDef {funcName, argCount, funcPtrs} = + either (throwIO . userError . show) pure =<< case funcPtrs of + SQLiteFuncPtr isDet funcPtr -> createStaticFunction db' funcName argCount isDet funcPtr + SQLiteAggrPtrs stepPtr finalPtr -> createStaticAggregate db' funcName argCount stepPtr finalPtr foreign export ccall "simplex_xor_md5_combine" sqliteXorMd5Combine :: SQLiteFunc @@ -143,7 +146,8 @@ sqliteXorMd5Combine = mkSQLiteFunc $ \cxt args -> do SQLite3.funcResultBlob cxt $ xorMd5Combine idsHash rId xorMd5Combine :: ByteString -> ByteString -> ByteString -xorMd5Combine idsHash rId = B.packZipWith xor idsHash $ C.md5Hash rId +xorMd5Combine idsHash rId = packZipWith xor idsHash $ C.md5Hash rId +{-# INLINE xorMd5Combine #-} closeDBStore :: DBStore -> IO () closeDBStore st@DBStore {dbClosed} = diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs index 0634360a2..448c885f2 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs @@ -7,6 +7,7 @@ module Simplex.Messaging.Agent.Store.SQLite.Common ( DBStore (..), DBOpts (..), SQLiteFuncDef (..), + SQLiteFuncPtrs (..), withConnection, withConnection', withTransaction, @@ -55,14 +56,18 @@ data DBOpts = DBOpts track :: DB.TrackQueries } --- e.g. `SQLiteFuncDef "name" 2 True f` +-- e.g. `SQLiteFuncDef "func_name" 2 (SQLiteFuncPtr True func)` +-- or `SQLiteFuncDef "aggr_name" 3 (SQLiteAggrPtrs step final)` data SQLiteFuncDef = SQLiteFuncDef { funcName :: ByteString, argCount :: CArgCount, - deterministic :: Bool, - funcPtr :: FunPtr SQLiteFunc + funcPtrs :: SQLiteFuncPtrs } +data SQLiteFuncPtrs + = SQLiteFuncPtr {deterministic :: Bool, funcPtr :: FunPtr SQLiteFunc} + | SQLiteAggrPtrs {stepPtr :: FunPtr SQLiteFunc, finalPtr :: FunPtr SQLiteFuncFinal} + withConnectionPriority :: DBStore -> Bool -> (DB.Connection -> IO a) -> IO a withConnectionPriority DBStore {dbSem, dbConnection} priority action | priority = E.bracket_ signal release $ withMVar dbConnection action diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs index a3c3b94ac..2cbd7ecff 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Util.hs @@ -3,16 +3,20 @@ module Simplex.Messaging.Agent.Store.SQLite.Util where import Control.Exception (SomeException, catch, mask_) import Data.ByteString (ByteString) import qualified Data.ByteString as B +import Data.IORef import Database.SQLite3.Direct (Database (..), FuncArgs (..), FuncContext (..)) import Database.SQLite3.Bindings import Foreign.C.String import Foreign.Ptr import Foreign.StablePtr +import Foreign.Storable data CFuncPtrs = CFuncPtrs (FunPtr CFunc) (FunPtr CFunc) (FunPtr CFuncFinal) type SQLiteFunc = Ptr CContext -> CArgCount -> Ptr (Ptr CValue) -> IO () +type SQLiteFuncFinal = Ptr CContext -> IO () + mkSQLiteFunc :: (FuncContext -> FuncArgs -> IO ()) -> SQLiteFunc mkSQLiteFunc f cxt nArgs cvals = catchAsResultError cxt $ f (FuncContext cxt) (FuncArgs nArgs cvals) {-# INLINE mkSQLiteFunc #-} @@ -25,6 +29,50 @@ createStaticFunction (Database db) name nArgs isDet funPtr = mask_ $ do B.useAsCString name $ \namePtr -> toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs flags (castStablePtrToPtr u) funPtr nullFunPtr nullFunPtr nullFunPtr +mkSQLiteAggStep :: a -> (FuncContext -> FuncArgs -> a -> IO a) -> SQLiteFunc +mkSQLiteAggStep initSt xStep cxt nArgs cvals = catchAsResultError cxt $ do + -- we store the aggregate state in the buffer returned by + -- c_sqlite3_aggregate_context as a StablePtr pointing to an IORef that + -- contains the actual aggregate state + aggCtx <- getAggregateContext cxt + aggStPtr <- peek aggCtx + aggStRef <- + if castStablePtrToPtr aggStPtr /= nullPtr + then deRefStablePtr aggStPtr + else do + aggStRef <- newIORef initSt + aggStPtr' <- newStablePtr aggStRef + poke aggCtx aggStPtr' + return aggStRef + aggSt <- readIORef aggStRef + aggSt' <- xStep (FuncContext cxt) (FuncArgs nArgs cvals) aggSt + writeIORef aggStRef aggSt' + +mkSQLiteAggFinal :: a -> (FuncContext -> a -> IO ()) -> SQLiteFuncFinal +mkSQLiteAggFinal initSt xFinal cxt = do + aggCtx <- getAggregateContext cxt + aggStPtr <- peek aggCtx + if castStablePtrToPtr aggStPtr == nullPtr + then catchAsResultError cxt $ xFinal (FuncContext cxt) initSt + else do + catchAsResultError cxt $ do + aggStRef <- deRefStablePtr aggStPtr + aggSt <- readIORef aggStRef + xFinal (FuncContext cxt) aggSt + freeStablePtr aggStPtr + +getAggregateContext :: Ptr CContext -> IO (Ptr a) +getAggregateContext cxt = c_sqlite3_aggregate_context cxt stPtrSize + where + stPtrSize = fromIntegral $ sizeOf (undefined :: StablePtr ()) + +-- Based on createAggregate from Database.SQLite3.Direct, but uses static function pointers to avoid dynamic wrappers that trigger DCL. +createStaticAggregate :: Database -> ByteString -> CArgCount -> FunPtr SQLiteFunc -> FunPtr SQLiteFuncFinal -> IO (Either Error ()) +createStaticAggregate (Database db) name nArgs stepPtr finalPtr = mask_ $ do + u <- newStablePtr $ CFuncPtrs nullFunPtr stepPtr finalPtr + B.useAsCString name $ \namePtr -> + toResult () <$> c_sqlite3_create_function_v2 db namePtr nArgs 0 (castStablePtrToPtr u) nullFunPtr stepPtr finalPtr nullFunPtr + -- Convert a 'CError' to a 'Either Error', in the common case where -- SQLITE_OK signals success and anything else signals an error. -- diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 81e9820a2..ac2dc9a9d 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -778,10 +778,10 @@ temporaryClientError = \case _ -> False {-# INLINE temporaryClientError #-} +-- it is consistent with clientServiceError smpClientServiceError :: SMPClientError -> Bool smpClientServiceError = \case PCEServiceUnavailable -> True - PCETransportError (TEHandshake BAD_SERVICE) -> True -- TODO [certs rcv] this error may be temporary, so we should possibly resubscribe. PCEProtocolError SERVICE -> True PCEProtocolError (PROXY (BROKER NO_SERVICE)) -> True -- for completeness, it cannot happen. _ -> False diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index a5f94960e..6b232f12b 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -143,6 +143,7 @@ module Simplex.Messaging.Protocol IdsHash (..), ServiceSub (..), ServiceSubResult (..), + ServiceSubError (..), serviceSubResult, queueIdsHash, queueIdHash, diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index e9f37b1ae..83a911452 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE MonadComprehensions #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -15,6 +16,7 @@ import qualified Data.Aeson as J import Data.Bifunctor (first, second) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.ByteString.Internal (toForeignPtr, unsafeCreate) import qualified Data.ByteString.Lazy.Char8 as LB import Data.IORef import Data.Int (Int64) @@ -29,6 +31,9 @@ import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8With, encodeUtf8) import Data.Time (NominalDiffTime) import Data.Tuple (swap) +import Data.Word (Word8) +import Foreign.ForeignPtr (withForeignPtr) +import Foreign.Storable (peekByteOff, pokeByteOff) import GHC.Conc (labelThread, myThreadId, threadDelay) import UnliftIO hiding (atomicModifyIORef') import qualified UnliftIO.Exception as UE @@ -156,6 +161,27 @@ mapAccumLM_NonEmpty mapAccumLM_NonEmpty f s (x :| xs) = [(s2, x' :| xs') | (s1, x') <- f s x, (s2, xs') <- mapAccumLM_List f s1 xs] +-- | Optimized from bytestring package for GHC 8.10.7 compatibility +packZipWith :: (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> ByteString +packZipWith f s1 s2 = + unsafeCreate len $ \r -> + withForeignPtr fp1 $ \p1 -> + withForeignPtr fp2 $ \p2 -> zipWith_ p1 p2 r + where + zipWith_ p1 p2 r = go 0 + where + go :: Int -> IO () + go !n + | n >= len = pure () + | otherwise = do + x <- peekByteOff p1 (off1 + n) + y <- peekByteOff p2 (off2 + n) + pokeByteOff r n (f x y) + go (n + 1) + (fp1, off1, l1) = toForeignPtr s1 + (fp2, off2, l2) = toForeignPtr s2 + len = min l1 l2 + tryWriteTBQueue :: TBQueue a -> a -> STM Bool tryWriteTBQueue q a = do full <- isFullTBQueue q diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 2a62deb45..31967917a 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -66,7 +66,7 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Either (isRight) import Data.Int (Int64) -import Data.List (find, isSuffixOf, nub) +import Data.List (find, isPrefixOf, isSuffixOf, nub) import Data.List.NonEmpty (NonEmpty) import qualified Data.Map as M import Data.Maybe (isJust, isNothing) @@ -113,7 +113,7 @@ import Simplex.Messaging.Util (bshow, diffToMicroseconds) import Simplex.Messaging.Version (VersionRange (..)) import qualified Simplex.Messaging.Version as V import Simplex.Messaging.Version.Internal (Version (..)) -import System.Directory (copyFile, renameFile) +import System.Directory (copyFile, removeFile, renameFile) import Test.Hspec hiding (fit, it) import UnliftIO import Util @@ -124,12 +124,13 @@ import Fixtures #endif #if defined(dbServerPostgres) import qualified Database.PostgreSQL.Simple as PSQL -import Simplex.Messaging.Agent.Store (Connection' (..), StoredRcvQueue (..), SomeConn' (..)) -import Simplex.Messaging.Agent.Store.AgentStore (getConn) +import qualified Simplex.Messaging.Agent.Store.Postgres as Postgres +import qualified Simplex.Messaging.Agent.Store.Postgres.Common as Postgres import Simplex.Messaging.Server.MsgStore.Journal (JournalQueue) import Simplex.Messaging.Server.MsgStore.Postgres (PostgresQueue) import Simplex.Messaging.Server.MsgStore.Types (QSType (..)) import Simplex.Messaging.Server.QueueStore.Postgres +import Simplex.Messaging.Server.QueueStore.Postgres.Migrations import Simplex.Messaging.Server.QueueStore.Types (QueueStoreClass (..)) #endif @@ -478,6 +479,7 @@ functionalAPITests ps = do withSmpServer ps testTwoUsers describe "Client service certificates" $ do it "should connect, subscribe and reconnect as a service" $ testClientServiceConnection ps + it "should re-subscribe when service ID changed" $ testClientServiceIDChange ps describe "Connection switch" $ do describe "should switch delivery to the new queue" $ testServerMatrix2 ps testSwitchConnection @@ -3679,26 +3681,84 @@ testClientServiceConnection ps = do subscribeConnection user sId exchangeGreetingsMsgId 4 service uId user sId pure (conns, qIdHash) - withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do + (uId', sId') <- withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do - [(_, Right (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash')))] <- M.toList <$> subscribeClientServices service 1 - ("", "", SERVICE_ALL _) <- nGet service - liftIO $ qIdHash' `shouldBe` qIdHash + subscribeAllConnections service False Nothing + liftIO $ getInAnyOrder service + [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash')))) -> qIdHash' == qIdHash; _ -> False, + \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False + ] subscribeConnection user sId exchangeGreetingsMsgId 6 service uId user sId ("", "", DOWN _ [_]) <- nGet user ("", "", SERVICE_DOWN _ (SMP.ServiceSub _ 1 qIdHash')) <- nGet service qIdHash' `shouldBe` qIdHash -- TODO [certs rcv] how to integrate service counts into stats - -- r <- nGet service -- TODO [certs rcv] some event when service disconnects with count - -- print r withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do ("", "", UP _ [_]) <- nGet user - ("", "", SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash''))) <- nGet service - ("", "", SERVICE_ALL _) <- nGet service - liftIO $ qIdHash'' `shouldBe` qIdHash - -- r <- nGet service -- TODO [certs rcv] some event when service reconnects with count + -- Nothing in ServiceSubResult confirms that both counts and IDs hash match + -- SERVICE_ALL may be deliverd before SERVICE_UP event in case there are no messages to deliver + liftIO $ getInAnyOrder service + [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash'')))) -> qIdHash'' == qIdHash; _ -> False, + \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False + ] exchangeGreetingsMsgId 8 service uId user sId + conns'@(uId', sId') <- makeConnection user service -- opposite direction + exchangeGreetings user sId' service uId' + pure conns' + withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do + withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + subscribeAllConnections service False Nothing + liftIO $ getInAnyOrder service + [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 2 _)))) -> True; _ -> False, + \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False + ] + -- TODO [certs rcv] test message delivery during subscription + subscribeAllConnections user False Nothing + ("", "", UP _ [_, _]) <- nGet user + exchangeGreetingsMsgId 4 user sId' service uId' + exchangeGreetingsMsgId 10 service uId user sId + +testClientServiceIDChange :: HasCallStack => (ASrvTransport, AStoreType) -> IO () +testClientServiceIDChange ps@(_, ASType qs _) = do + (sId, uId) <- withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do + withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + conns@(sId, uId) <- makeConnection service user + exchangeGreetings service uId user sId + pure conns + _ :: () <- case qs of + SQSPostgres -> do +#if defined(dbServerPostgres) + st <- either (error . show) pure =<< Postgres.createDBStore testStoreDBOpts serverMigrations (MigrationConfig MCError Nothing) + void $ Postgres.withTransaction st (`PSQL.execute_` "DELETE FROM services") +#else + pure () +#endif + SQSMemory -> do + s <- readFile testStoreLogFile + removeFile testStoreLogFile + writeFile testStoreLogFile $ unlines $ filter (not . ("NEW_SERVICE" `isPrefixOf`)) $ lines s + withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do + withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + subscribeAllConnections service False Nothing + liftIO $ getInAnyOrder service + [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult (Just (SMP.SSErrorQueueCount 1 0)) (SMP.ServiceSub _ 0 _)))) -> True; _ -> False, + \case ("", "", AEvt SAENone (SERVICE_ALL _)) -> True; _ -> False, + \case ("", "", AEvt SAENone (UP _ _)) -> True; _ -> False + ] + subscribeAllConnections user False Nothing + ("", "", UP _ [_]) <- nGet user + exchangeGreetingsMsgId 4 service uId user sId + -- disable service in the client + -- The test uses True for non-existing user to make sure it's removed for user 1, + -- because if no users use services, then it won't be checking them to optimize for most clients. + withAgentClientsServers2 (agentCfg, initAgentServers {useServices = M.fromList [(100, True)]}) (agentCfg, initAgentServers) $ \notService user -> do + withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + subscribeAllConnections notService False Nothing + ("", "", UP _ [_]) <- nGet notService + subscribeAllConnections user False Nothing + ("", "", UP _ [_]) <- nGet user + exchangeGreetingsMsgId 6 notService uId user sId getSMPAgentClient' :: Int -> AgentConfig -> InitialAgentServers -> String -> IO AgentClient getSMPAgentClient' clientId cfg' initServers dbPath = do