diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 18bc0afbb..18e9d0465 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -194,7 +194,7 @@ import Simplex.Messaging.Agent.Store.Entity import Simplex.Messaging.Agent.Store.Interface (closeDBStore, execSQL, getCurrentMigrations) import Simplex.Messaging.Agent.Store.Shared (UpMigration (..), upMigration) import qualified Simplex.Messaging.Agent.TSessionSubs as SS -import Simplex.Messaging.Client (NetworkRequestMode (..), SMPClientError, ServerTransmission (..), ServerTransmissionBatch, TransportSessionMode (..), nonBlockingWriteTBQueue, smpErrorClientNotice, temporaryClientError, unexpectedResponse) +import Simplex.Messaging.Client (NetworkRequestMode (..), ProtocolClientError (..), SMPClientError, ServerTransmission (..), ServerTransmissionBatch, TransportSessionMode (..), nonBlockingWriteTBQueue, smpErrorClientNotice, temporaryClientError, unexpectedResponse) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile, CryptoFileArgs) import Simplex.Messaging.Crypto.Ratchet (PQEncryption, PQSupport (..), pattern PQEncOff, pattern PQEncOn, pattern PQSupportOff, pattern PQSupportOn) @@ -222,6 +222,7 @@ import Simplex.Messaging.Protocol SParty (..), SProtocolType (..), ServiceSub (..), + ServiceSubResult, SndPublicAuthKey, SubscriptionMode (..), UserProtocol, @@ -232,7 +233,7 @@ import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.ServiceScheme (ServiceScheme (..)) import Simplex.Messaging.SystemTime import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (SMPVersion) +import Simplex.Messaging.Transport (SMPVersion, THClientService' (..), THandleAuth (..), THandleParams (..)) import Simplex.Messaging.Util import Simplex.Messaging.Version import Simplex.RemoteControl.Client @@ -502,7 +503,7 @@ resubscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either Agen resubscribeConnections c = withAgentEnv c . resubscribeConnections' c {-# INLINE resubscribeConnections #-} -subscribeClientServices :: AgentClient -> UserId -> AE (Map SMPServer (Either AgentErrorType ServiceSub)) +subscribeClientServices :: AgentClient -> UserId -> AE (Map SMPServer (Either AgentErrorType ServiceSubResult)) subscribeClientServices c = withAgentEnv c . subscribeClientServices' c {-# INLINE subscribeClientServices #-} @@ -1355,11 +1356,7 @@ toConnResult connId rs = case M.lookup connId rs of Just (Left e) -> throwE e _ -> throwE $ INTERNAL $ "no result for connection " <> B.unpack connId -type QCmdResult a = (QueueStatus, Either AgentErrorType a) - -type QDelResult = QCmdResult () - -type QSubResult = QCmdResult (Maybe SMP.ServiceId) +type QCmdResult = (QueueStatus, Either AgentErrorType ()) subscribeConnections' :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType ())) subscribeConnections' _ [] = pure M.empty @@ -1367,16 +1364,15 @@ subscribeConnections' c connIds = subscribeConnections_ c . zip connIds =<< with subscribeConnections_ :: AgentClient -> [(ConnId, Either StoreError SomeConnSub)] -> AM (Map ConnId (Either AgentErrorType ())) subscribeConnections_ c conns = do - -- TODO [certs rcv] - it should exclude connections already associated, and then if some don't deliver any response they may be unassociated let (subRs, cs) = foldr partitionResultsConns ([], []) conns resumeDelivery cs resumeConnCmds c $ map fst cs + -- queue/service association is handled in the client rcvRs <- lift $ connResults <$> subscribeQueues c False (concatMap rcvQueues cs) - rcvRs' <- storeClientServiceAssocs rcvRs ns <- asks ntfSupervisor - lift $ whenM (liftIO $ hasInstantNotifications ns) . void . forkIO . void $ sendNtfCreate ns rcvRs' cs + lift $ whenM (liftIO $ hasInstantNotifications ns) . void . forkIO . void $ sendNtfCreate ns rcvRs cs -- union is left-biased - let rs = rcvRs' `M.union` subRs + let rs = rcvRs `M.union` subRs notifyResultError rs pure rs where @@ -1400,24 +1396,21 @@ subscribeConnections_ c conns = do _ -> Left $ INTERNAL "unexpected queue status" rcvQueues :: (ConnId, SomeConnSub) -> [RcvQueueSub] rcvQueues (_, SomeConn _ conn) = connRcvQueues conn - connResults :: [(RcvQueueSub, Either AgentErrorType (Maybe SMP.ServiceId))] -> Map ConnId (Either AgentErrorType (Maybe SMP.ServiceId)) + connResults :: [(RcvQueueSub, Either AgentErrorType (Maybe SMP.ServiceId))] -> Map ConnId (Either AgentErrorType ()) connResults = M.map snd . foldl' addResult M.empty where -- collects results by connection ID - addResult :: Map ConnId QSubResult -> (RcvQueueSub, Either AgentErrorType (Maybe SMP.ServiceId)) -> Map ConnId QSubResult - addResult rs (RcvQueueSub {connId, status}, r) = M.alter (combineRes (status, r)) connId rs + addResult :: Map ConnId QCmdResult -> (RcvQueueSub, Either AgentErrorType (Maybe SMP.ServiceId)) -> Map ConnId QCmdResult + addResult rs (RcvQueueSub {connId, status}, r) = M.alter (combineRes (status, () <$ r)) connId rs -- combines two results for one connection, by using only Active queues (if there is at least one Active queue) - combineRes :: QSubResult -> Maybe QSubResult -> Maybe QSubResult + combineRes :: QCmdResult -> Maybe QCmdResult -> Maybe QCmdResult combineRes r' (Just r) = Just $ if order r <= order r' then r else r' combineRes r' _ = Just r' - order :: QSubResult -> Int + order :: QCmdResult -> Int order (Active, Right _) = 1 order (Active, _) = 2 order (_, Right _) = 3 order _ = 4 - -- TODO [certs rcv] store associations of queues with client service ID - storeClientServiceAssocs :: Map ConnId (Either AgentErrorType (Maybe SMP.ServiceId)) -> AM (Map ConnId (Either AgentErrorType ())) - storeClientServiceAssocs = pure . M.map (() <$) sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType ()) -> [(ConnId, SomeConnSub)] -> AM' () sendNtfCreate ns rcvRs cs = do let oks = M.keysSet $ M.filter (either temporaryAgentError $ const True) rcvRs @@ -1522,14 +1515,14 @@ resubscribeConnections' c connIds = do rqs' -> anyM $ map (atomically . hasActiveSubscription c) rqs' -- TODO [certs rcv] compare hash. possibly, it should return both expected and returned counts -subscribeClientServices' :: AgentClient -> UserId -> AM (Map SMPServer (Either AgentErrorType ServiceSub)) +subscribeClientServices' :: AgentClient -> UserId -> AM (Map SMPServer (Either AgentErrorType ServiceSubResult)) subscribeClientServices' c userId = ifM useService subscribe $ throwError $ CMD PROHIBITED "no user service allowed" where useService = liftIO $ (Just True ==) <$> TM.lookupIO userId (useClientServices c) subscribe = do srvs <- withStore' c (`getClientServiceServers` userId) - lift $ M.fromList <$> mapConcurrently (\(srv, ServiceSub _ n idsHash) -> fmap (srv,) $ tryAllErrors' $ subscribeClientService c userId srv n idsHash) srvs + lift $ M.fromList <$> mapConcurrently (\(srv, ServiceSub _ n idsHash) -> fmap (srv,) $ tryAllErrors' $ subscribeClientService c False userId srv n idsHash) srvs -- requesting messages sequentially, to reduce memory usage getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta))) @@ -2383,13 +2376,13 @@ deleteConnQueues c nm waitDelivery ntf rqs = do connResults = M.map snd . foldl' addResult M.empty where -- collects results by connection ID - addResult :: Map ConnId QDelResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QDelResult + addResult :: Map ConnId QCmdResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QCmdResult addResult rs (RcvQueue {connId, status}, r) = M.alter (combineRes (status, r)) connId rs -- combines two results for one connection, by prioritizing errors in Active queues - combineRes :: QDelResult -> Maybe QDelResult -> Maybe QDelResult + combineRes :: QCmdResult -> Maybe QCmdResult -> Maybe QCmdResult combineRes r' (Just r) = Just $ if order r <= order r' then r else r' combineRes r' _ = Just r' - order :: QDelResult -> Int + order :: QCmdResult -> Int order (Active, Left _) = 1 order (_, Left _) = 2 order _ = 3 @@ -2840,11 +2833,17 @@ data ACKd = ACKd | ACKPending -- It cannot be finally, as sometimes it needs to be ACK+DEL, -- and sometimes ACK has to be sent from the consumer. processSMPTransmissions :: AgentClient -> ServerTransmissionBatch SMPVersion ErrorType BrokerMsg -> AM' () -processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId, ts) = do +processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), THandleParams {thAuth, sessionId = sessId}, ts) = do upConnIds <- newTVarIO [] + serviceRQs <- newTVarIO ([] :: [RcvQueue]) forM_ ts $ \(entId, t) -> case t of STEvent msgOrErr - | entId == SMP.NoEntity -> pure () -- TODO [certs rcv] process SALL + | entId == SMP.NoEntity -> case msgOrErr of + Right msg -> case msg of + SMP.ALLS -> notifySub c $ SERVICE_ALL srv + SMP.ERR e -> notifyErr "" $ PCEProtocolError e + _ -> logError $ "unexpected event: " <> tshow msg + Left e -> notifyErr "" e | otherwise -> withRcvConn entId $ \rq@RcvQueue {connId} conn -> case msgOrErr of Right msg -> runProcessSMP rq conn (toConnData conn) msg Left e -> lift $ do @@ -2853,11 +2852,10 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId STResponse (Cmd SRecipient cmd) respOrErr -> withRcvConn entId $ \rq conn -> case cmd of SMP.SUB -> case respOrErr of - Right SMP.OK -> liftIO $ processSubOk rq upConnIds - -- TODO [certs rcv] associate queue with the service - Right (SMP.SOK _serviceId_) -> liftIO $ processSubOk rq upConnIds + Right SMP.OK -> liftIO $ processSubOk rq upConnIds serviceRQs Nothing + Right (SMP.SOK serviceId_) -> liftIO $ processSubOk rq upConnIds serviceRQs serviceId_ Right msg@SMP.MSG {} -> do - liftIO $ processSubOk rq upConnIds -- the connection is UP even when processing this particular message fails + liftIO $ processSubOk rq upConnIds serviceRQs Nothing -- the connection is UP even when processing this particular message fails runProcessSMP rq conn (toConnData conn) msg Right r -> lift $ processSubErr rq $ unexpectedResponse r Left e -> lift $ unless (temporaryClientError e) $ processSubErr rq e -- timeout/network was already reported @@ -2873,6 +2871,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId unless (null connIds) $ do notify' "" $ UP srv connIds atomically $ incSMPServerStat' c userId srv connSubscribed $ length connIds + readTVarIO serviceRQs >>= processRcvServiceAssocs c where withRcvConn :: SMP.RecipientId -> (forall c. RcvQueue -> Connection c -> AM ()) -> AM' () withRcvConn rId a = do @@ -2882,11 +2881,13 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId tryAllErrors' (a rq conn) >>= \case Left e -> notify' connId (ERR e) Right () -> pure () - processSubOk :: RcvQueue -> TVar [ConnId] -> IO () - processSubOk rq@RcvQueue {connId} upConnIds = + processSubOk :: RcvQueue -> TVar [ConnId] -> TVar [RcvQueue] -> Maybe SMP.ServiceId -> IO () + processSubOk rq@RcvQueue {connId} upConnIds serviceRQs serviceId_ = atomically . whenM (isPendingSub rq) $ do SS.addActiveSub tSess sessId rq $ currentSubs c modifyTVar' upConnIds (connId :) + when (isJust serviceId_ && serviceId_ == clientServiceId_) $ modifyTVar' serviceRQs (rq :) + clientServiceId_ = (\THClientService {serviceId} -> serviceId) <$> (clientService =<< thAuth) processSubErr :: RcvQueue -> SMPClientError -> AM' () processSubErr rq@RcvQueue {connId} e = do atomically . whenM (isPendingSub rq) $ diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index e4324e088..77d73027d 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -50,6 +50,7 @@ module Simplex.Messaging.Agent.Client subscribeQueues, subscribeUserServerQueues, subscribeClientService, + processRcvServiceAssocs, processClientNotices, getQueueMessage, decryptSMPMessage, @@ -280,6 +281,7 @@ import Simplex.Messaging.Protocol SMPMsgMeta (..), SProtocolType (..), ServiceSub (..), + ServiceSubResult (..), SndPublicAuthKey, SubscriptionMode (..), NewNtfCreds (..), @@ -292,6 +294,7 @@ import Simplex.Messaging.Protocol XFTPServerWithAuth, pattern NoEntity, senderCanSecure, + serviceSubResult, ) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Protocol.Types @@ -785,6 +788,7 @@ smpClientDisconnected c@AgentClient {active, smpClients, smpProxiedRelays} tSess serverDown (qs, conns, serviceSub_) = whenM (readTVarIO active) $ do notifySub c $ hostEvent' DISCONNECT client unless (null conns) $ notifySub c $ DOWN srv conns + mapM_ (notifySub c . SERVICE_DOWN srv) serviceSub_ unless (null qs && isNothing serviceSub_) $ do releaseGetLocksIO c qs mode <- getSessionModeIO c @@ -1514,7 +1518,7 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl newErr = throwE . BROKER (B.unpack $ strEncode srv) . UNEXPECTED . ("Create queue: " <>) processSubResults :: AgentClient -> SMPTransportSession -> SessionId -> Maybe ServiceId -> NonEmpty (RcvQueueSub, Either SMPClientError (Maybe ServiceId)) -> STM ([RcvQueueSub], [(RcvQueueSub, Maybe ClientNotice)]) -processSubResults c tSess@(userId, srv, _) sessId smpServiceId rs = do +processSubResults c tSess@(userId, srv, _) sessId serviceId_ rs = do pending <- SS.getPendingSubs tSess $ currentSubs c let (failed, subscribed@(qs, sQs), notices, ignored) = foldr (partitionResults pending) (M.empty, ([], []), [], 0) rs unless (M.null failed) $ do @@ -1541,10 +1545,10 @@ processSubResults c tSess@(userId, srv, _) sessId smpServiceId rs = do | otherwise -> (failed', subscribed, notices, ignored) where failed' = M.insert rcvId e failed - Right serviceId_ + Right serviceId_' | rcvId `M.member` pendingSubs -> - let subscribed' = case (smpServiceId, serviceId_, pendingSS) of - (Just sId, Just sId', Just ServiceSub {serviceId}) | sId == sId' && sId == serviceId -> (qs, rq : sQs) + let subscribed' = case (serviceId_, serviceId_', pendingSS) of + (Just sId, Just sId', Just ServiceSub {smpServiceId}) | sId == sId' && sId == smpServiceId -> (qs, rq : sQs) _ -> (rq : qs, sQs) in (failed, subscribed', notices', ignored) | otherwise -> (failed, subscribed, notices', ignored + 1) @@ -1692,7 +1696,8 @@ subscribeSessQueues_ c withEvents qs = sendClientBatch_ "SUB" False subscribe_ c sessId = sessionId $ thParams smp smpServiceId = (\THClientService {serviceId} -> serviceId) <$> smpClientService smp -processRcvServiceAssocs :: AgentClient -> [RcvQueueSub] -> AM' () +processRcvServiceAssocs :: SMPQueue q => AgentClient -> [q] -> AM' () +processRcvServiceAssocs _ [] = pure () processRcvServiceAssocs c serviceQs = withStore' c (`setRcvServiceAssocs` serviceQs) `catchAllErrors'` \e -> do logError $ "processClientNotices error: " <> tshow e @@ -1709,17 +1714,16 @@ processClientNotices c@AgentClient {presetServers} tSess notices = do logError $ "processClientNotices error: " <> tshow e notifySub' c "" $ ERR e -resubscribeClientService :: AgentClient -> SMPTransportSession -> ServiceSub -> AM ServiceSub -resubscribeClientService c tSess (ServiceSub _ n idsHash) = - withServiceClient c tSess $ \smp _ -> do - subscribeClientService_ c tSess smp n idsHash +resubscribeClientService :: AgentClient -> SMPTransportSession -> ServiceSub -> AM ServiceSubResult +resubscribeClientService c tSess serviceSub = + withServiceClient c tSess $ \smp _ -> subscribeClientService_ c True tSess smp serviceSub -subscribeClientService :: AgentClient -> UserId -> SMPServer -> Int64 -> IdsHash -> AM ServiceSub -subscribeClientService c userId srv n idsHash = +subscribeClientService :: AgentClient -> Bool -> UserId -> SMPServer -> Int64 -> IdsHash -> AM ServiceSubResult +subscribeClientService c withEvent userId srv n idsHash = withServiceClient c tSess $ \smp smpServiceId -> do let serviceSub = ServiceSub smpServiceId n idsHash atomically $ SS.setPendingServiceSub tSess serviceSub $ currentSubs c - subscribeClientService_ c tSess smp n idsHash + subscribeClientService_ c withEvent tSess smp serviceSub where tSess = (userId, srv, Nothing) @@ -1730,14 +1734,15 @@ withServiceClient c tSess action = Just smpServiceId -> action smp smpServiceId Nothing -> throwE PCEServiceUnavailable -subscribeClientService_ :: AgentClient -> SMPTransportSession -> SMPClient -> Int64 -> IdsHash -> ExceptT SMPClientError IO ServiceSub -subscribeClientService_ c tSess smp n idsHash = do - -- TODO [certs rcv] handle error - serviceSub' <- subscribeService smp SMP.SRecipientService n idsHash +subscribeClientService_ :: AgentClient -> Bool -> SMPTransportSession -> SMPClient -> ServiceSub -> ExceptT SMPClientError IO ServiceSubResult +subscribeClientService_ c withEvent tSess@(_, srv, _) smp expected@(ServiceSub _ n idsHash) = do + subscribed <- subscribeService smp SMP.SRecipientService n idsHash let sessId = sessionId $ thParams smp + r = serviceSubResult expected subscribed atomically $ whenM (activeClientSession c tSess sessId) $ - SS.setActiveServiceSub tSess sessId serviceSub' $ currentSubs c - pure serviceSub' + SS.setActiveServiceSub tSess sessId subscribed $ currentSubs c + when withEvent $ notifySub c $ SERVICE_UP srv r + pure r activeClientSession :: AgentClient -> SMPTransportSession -> SessionId -> STM Bool activeClientSession c tSess sessId = sameSess <$> tryReadSessVar tSess (smpClients c) diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 15d51aed9..d5b35611b 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -234,6 +234,8 @@ import Simplex.Messaging.Protocol NMsgMeta, ProtocolServer (..), QueueMode (..), + ServiceSub, + ServiceSubResult, SMPClientVersion, SMPServer, SMPServerWithAuth, @@ -388,6 +390,9 @@ data AEvent (e :: AEntity) where DISCONNECT :: AProtocolType -> TransportHost -> AEvent AENone DOWN :: SMPServer -> [ConnId] -> AEvent AENone UP :: SMPServer -> [ConnId] -> AEvent AENone + SERVICE_ALL :: SMPServer -> AEvent AENone -- all service messages are delivered + SERVICE_DOWN :: SMPServer -> ServiceSub -> AEvent AENone + SERVICE_UP :: SMPServer -> ServiceSubResult -> AEvent AENone SWITCH :: QueueDirection -> SwitchPhase -> ConnectionStats -> AEvent AEConn RSYNC :: RatchetSyncState -> Maybe AgentCryptoError -> ConnectionStats -> AEvent AEConn SENT :: AgentMsgId -> Maybe SMPServer -> AEvent AEConn @@ -459,6 +464,9 @@ data AEventTag (e :: AEntity) where DISCONNECT_ :: AEventTag AENone DOWN_ :: AEventTag AENone UP_ :: AEventTag AENone + SERVICE_ALL_ :: AEventTag AENone + SERVICE_DOWN_ :: AEventTag AENone + SERVICE_UP_ :: AEventTag AENone SWITCH_ :: AEventTag AEConn RSYNC_ :: AEventTag AEConn SENT_ :: AEventTag AEConn @@ -514,6 +522,9 @@ aEventTag = \case DISCONNECT {} -> DISCONNECT_ DOWN {} -> DOWN_ UP {} -> UP_ + SERVICE_ALL _ -> SERVICE_ALL_ + SERVICE_DOWN {} -> SERVICE_DOWN_ + SERVICE_UP {} -> SERVICE_UP_ SWITCH {} -> SWITCH_ RSYNC {} -> RSYNC_ SENT {} -> SENT_ diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index 6e42aac9d..a732d28d4 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -419,18 +419,19 @@ createClientService db userId srv (kh, (cert, pk)) = do |] (userId, host srv, port srv, serverKeyHash_, kh, cert, pk) --- TODO [certs rcv] get correct service based on key hash of the server getClientService :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ((C.KeyHash, TLS.Credential), Maybe ServiceId)) getClientService db userId srv = maybeFirstRow toService $ DB.query db [sql| - SELECT service_cert_hash, service_cert, service_priv_key, service_id - FROM client_services - WHERE user_id = ? AND host = ? AND port = ? + SELECT c.service_cert_hash, c.service_cert, c.service_priv_key, c.service_id + FROM client_services c + JOIN servers s ON c.host = s.host AND c.port = s.port + WHERE c.user_id = ? AND c.host = ? AND c.port = ? + AND COALESCE(c.server_key_hash, s.key_hash) = ? |] - (userId, host srv, port srv) + (userId, host srv, port srv, keyHash srv) where toService (kh, cert, pk, serviceId_) = ((kh, (cert, pk)), serviceId_) @@ -2250,12 +2251,12 @@ getUserServerRcvQueueSubs db userId srv onlyNeeded = unsetQueuesToSubscribe :: DB.Connection -> IO () unsetQueuesToSubscribe db = DB.execute_ db "UPDATE rcv_queues SET to_subscribe = 0 WHERE to_subscribe = 1" -setRcvServiceAssocs :: DB.Connection -> [RcvQueueSub] -> IO () +setRcvServiceAssocs :: SMPQueue q => DB.Connection -> [q] -> IO () setRcvServiceAssocs db rqs = #if defined(dbPostgres) DB.execute db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id IN " $ Only $ In (map queueId rqs) #else - DB.executeMany db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id = " $ map (Only . queueId) rqs + DB.executeMany db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id = ?" $ map (Only . queueId) rqs #endif -- * getConn helpers diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 4d4086cfd..81e9820a2 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -251,7 +251,7 @@ type ClientCommand msg = (EntityId, Maybe C.APrivateAuthKey, ProtoCommand msg) -- | Type synonym for transmission from SPM servers. -- Batch response is presented as a single `ServerTransmissionBatch` tuple. -type ServerTransmissionBatch v err msg = (TransportSession msg, Version v, SessionId, NonEmpty (EntityId, ServerTransmission err msg)) +type ServerTransmissionBatch v err msg = (TransportSession msg, THandleParams v 'TClient, NonEmpty (EntityId, ServerTransmission err msg)) data ServerTransmission err msg = STEvent (Either (ProtocolClientError err) msg) @@ -864,8 +864,7 @@ writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO () writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c [(rId, STEvent (Right msg))]) (msgQ $ client_ c) serverTransmission :: ProtocolClient v err msg -> NonEmpty (RecipientId, ServerTransmission err msg) -> ServerTransmissionBatch v err msg -serverTransmission ProtocolClient {thParams = THandleParams {thVersion, sessionId}, client_ = PClient {transportSession}} ts = - (transportSession, thVersion, sessionId, ts) +serverTransmission ProtocolClient {thParams, client_ = PClient {transportSession}} ts = (transportSession, thParams, ts) -- | Get message from SMP queue. The server returns ERR PROHIBITED if a client uses SUB and GET via the same transport connection for the same queue -- diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 143d417c6..67ed89d71 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -524,7 +524,7 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = NtfPushServer {pushQ} <- asks pushServer stats <- asks serverStats liftIO $ forever $ do - ((_, srv@(SMPServer (h :| _) _ _), _), _thVersion, sessionId, ts) <- atomically $ readTBQueue msgQ + ((_, srv@(SMPServer (h :| _) _ _), _), THandleParams {sessionId}, ts) <- atomically $ readTBQueue msgQ forM ts $ \(ntfId, t) -> case t of STUnexpectedError e -> logError $ "SMP client unexpected error: " <> tshow e -- uncorrelated response, should not happen STResponse {} -> pure () -- it was already reported as timeout error diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index c00899e1c..a5f94960e 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -142,6 +142,8 @@ module Simplex.Messaging.Protocol MsgBody, IdsHash (..), ServiceSub (..), + ServiceSubResult (..), + serviceSubResult, queueIdsHash, queueIdHash, MaxMessageLen, @@ -712,7 +714,7 @@ data BrokerMsg where -- v2: MsgId -> SystemTime -> MsgFlags -> MsgBody -> BrokerMsg MSG :: RcvMessage -> BrokerMsg -- sent once delivering messages to SUBS command is complete - SALL :: BrokerMsg + ALLS :: BrokerMsg NID :: NotifierId -> RcvNtfPublicDhKey -> BrokerMsg NMSG :: C.CbNonce -> EncNMsgMeta -> BrokerMsg -- Should include certificate chain @@ -949,7 +951,7 @@ data BrokerMsgTag | SOK_ | SOKS_ | MSG_ - | SALL_ + | ALLS_ | NID_ | NMSG_ | PKEY_ @@ -1042,7 +1044,7 @@ instance Encoding BrokerMsgTag where SOK_ -> "SOK" SOKS_ -> "SOKS" MSG_ -> "MSG" - SALL_ -> "SALL" + ALLS_ -> "ALLS" NID_ -> "NID" NMSG_ -> "NMSG" PKEY_ -> "PKEY" @@ -1064,7 +1066,7 @@ instance ProtocolMsgTag BrokerMsgTag where "SOK" -> Just SOK_ "SOKS" -> Just SOKS_ "MSG" -> Just MSG_ - "SALL" -> Just SALL_ + "ALLS" -> Just ALLS_ "NID" -> Just NID_ "NMSG" -> Just NMSG_ "PKEY" -> Just PKEY_ @@ -1468,10 +1470,29 @@ type MsgId = ByteString type MsgBody = ByteString data ServiceSub = ServiceSub - { serviceId :: ServiceId, + { smpServiceId :: ServiceId, smpQueueCount :: Int64, smpQueueIdsHash :: IdsHash } + deriving (Eq, Show) + +data ServiceSubResult = ServiceSubResult (Maybe ServiceSubError) ServiceSub + deriving (Eq, Show) + +data ServiceSubError + = SSErrorServiceId {expectedServiceId :: ServiceId, subscribedServiceId :: ServiceId} + | SSErrorQueueCount {expectedQueueCount :: Int64, subscribedQueueCount :: Int64} + | SSErrorQueueIdsHash {expectedQueueIdsHash :: IdsHash, subscribedQueueIdsHash :: IdsHash} + deriving (Eq, Show) + +serviceSubResult :: ServiceSub -> ServiceSub -> ServiceSubResult +serviceSubResult s s' = ServiceSubResult subError_ s' + where + subError_ + | smpServiceId s /= smpServiceId s' = Just $ SSErrorServiceId (smpServiceId s) (smpServiceId s') + | smpQueueCount s /= smpQueueCount s' = Just $ SSErrorQueueCount (smpQueueCount s) (smpQueueCount s') + | smpQueueIdsHash s /= smpQueueIdsHash s' = Just $ SSErrorQueueIdsHash (smpQueueIdsHash s) (smpQueueIdsHash s') + | otherwise = Nothing newtype IdsHash = IdsHash {unIdsHash :: BS.ByteString} deriving (Eq, Show) @@ -1897,7 +1918,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where | otherwise -> e (SOKS_, ' ', n) MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body} -> e (MSG_, ' ', msgId, Tail body) - SALL -> e SALL_ + ALLS -> e ALLS_ NID nId srvNtfDh -> e (NID_, ' ', nId, srvNtfDh) NMSG nmsgNonce encNMsgMeta -> e (NMSG_, ' ', nmsgNonce, encNMsgMeta) PKEY sid vr certKey -> e (PKEY_, ' ', sid, vr, certKey) @@ -1928,7 +1949,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where MSG . RcvMessage msgId <$> bodyP where bodyP = EncRcvMsgBody . unTail <$> smpP - SALL_ -> pure SALL + ALLS_ -> pure ALLS IDS_ | v >= newNtfCredsSMPVersion -> ids smpP smpP smpP smpP | v >= serviceCertsSMPVersion -> ids smpP smpP smpP nothing @@ -1981,7 +2002,7 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where PONG -> noEntityMsg PKEY {} -> noEntityMsg RRES _ -> noEntityMsg - SALL -> noEntityMsg + ALLS -> noEntityMsg -- other broker responses must have queue ID _ | B.null entId -> Left $ CMD NO_ENTITY diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 0598f3c53..0fc15b3e3 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -1806,7 +1806,7 @@ client where deliverServiceMessages expectedCnt = do (qCnt, _msgCnt, _dupCnt, _errCnt) <- foldRcvServiceMessages ms serviceId deliverQueueMsg (0, 0, 0, 0) - atomically $ writeTBQueue msgQ [(NoCorrId, NoEntity, SALL)] + atomically $ writeTBQueue msgQ [(NoCorrId, NoEntity, ALLS)] -- TODO [certs rcv] compare with expected logNote $ "Service subscriptions for " <> tshow serviceId <> " (" <> tshow qCnt <> " queues)" deliverQueueMsg :: (Int, Int, Int, Int) -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO (Int, Int, Int, Int) diff --git a/tests/AgentTests/EqInstances.hs b/tests/AgentTests/EqInstances.hs index e142c6177..63c493861 100644 --- a/tests/AgentTests/EqInstances.hs +++ b/tests/AgentTests/EqInstances.hs @@ -8,7 +8,6 @@ import Data.Type.Equality import Simplex.Messaging.Agent.Protocol (ConnLinkData (..), OwnerAuth (..), UserContactData (..), UserLinkData (..)) import Simplex.Messaging.Agent.Store import Simplex.Messaging.Client (ProxiedRelay (..)) -import Simplex.Messaging.Protocol (ServiceSub (..)) instance (Eq rq, Eq sq) => Eq (SomeConn' rq sq) where SomeConn d c == SomeConn d' c' = case testEquality d d' of @@ -48,7 +47,3 @@ deriving instance Eq OwnerAuth deriving instance Show ProxiedRelay deriving instance Eq ProxiedRelay - -deriving instance Show ServiceSub - -deriving instance Eq ServiceSub diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index f3f7e817c..cb74bc0b6 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -3668,27 +3668,35 @@ testTwoUsers = withAgentClients2 $ \a b -> do testClientServiceConnection :: HasCallStack => (ASrvTransport, AStoreType) -> IO () testClientServiceConnection ps = do - (sId, uId) <- withSmpServerStoreLogOn ps testPort $ \_ -> do + ((sId, uId), qIdHash) <- withSmpServerStoreLogOn ps testPort $ \_ -> do conns@(sId, uId) <- withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> runRight $ do conns@(sId, uId) <- makeConnection service user exchangeGreetings service uId user sId pure conns withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> runRight $ do - subscribeClientServices service 1 + [(_, Right (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash)))] <- M.toList <$> subscribeClientServices service 1 + ("", "", SERVICE_ALL _) <- nGet service subscribeConnection user sId exchangeGreetingsMsgId 4 service uId user sId - pure conns + pure (conns, qIdHash) withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do - subscribeClientServices service 1 + [(_, Right (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash')))] <- M.toList <$> subscribeClientServices service 1 + ("", "", SERVICE_ALL _) <- nGet service + liftIO $ qIdHash' `shouldBe` qIdHash subscribeConnection user sId exchangeGreetingsMsgId 6 service uId user sId ("", "", DOWN _ [_]) <- nGet user + ("", "", SERVICE_DOWN _ (SMP.ServiceSub _ 1 qIdHash')) <- nGet service + qIdHash' `shouldBe` qIdHash -- TODO [certs rcv] how to integrate service counts into stats -- r <- nGet service -- TODO [certs rcv] some event when service disconnects with count -- print r withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do ("", "", UP _ [_]) <- nGet user + ("", "", SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash''))) <- nGet service + ("", "", SERVICE_ALL _) <- nGet service + liftIO $ qIdHash'' `shouldBe` qIdHash -- r <- nGet service -- TODO [certs rcv] some event when service reconnects with count exchangeGreetingsMsgId 8 service uId user sId diff --git a/tests/SMPProxyTests.hs b/tests/SMPProxyTests.hs index 09f20c1dd..0d8ccdf89 100644 --- a/tests/SMPProxyTests.hs +++ b/tests/SMPProxyTests.hs @@ -188,7 +188,7 @@ deliverMessagesViaProxy proxyServ relayServ alg unsecuredMsgs securedMsgs = do runExceptT' (proxySMPMessage pc NRMInteractive sess Nothing sndId noMsgFlags msg) `shouldReturn` Right () runExceptT' (proxySMPMessage pc NRMInteractive sess {prSessionId = "bad session"} Nothing sndId noMsgFlags msg) `shouldReturn` Left (ProxyProtocolError $ SMP.PROXY SMP.NO_SESSION) -- receive 1 - (_tSess, _v, _sid, [(_entId, STEvent (Right (SMP.MSG RcvMessage {msgId, msgBody = EncRcvMsgBody encBody})))]) <- atomically $ readTBQueue msgQ + (_tSess, _, [(_entId, STEvent (Right (SMP.MSG RcvMessage {msgId, msgBody = EncRcvMsgBody encBody})))]) <- atomically $ readTBQueue msgQ dec msgId encBody `shouldBe` Right msg runExceptT' $ ackSMPMessage rc rPriv rcvId msgId -- secure queue @@ -200,7 +200,7 @@ deliverMessagesViaProxy proxyServ relayServ alg unsecuredMsgs securedMsgs = do runExceptT' (proxySMPMessage pc NRMInteractive sess (Just sPriv) sndId noMsgFlags msg') `shouldReturn` Right () ) ( forM_ securedMsgs $ \msg' -> do - (_tSess, _v, _sid, [(_entId, STEvent (Right (SMP.MSG RcvMessage {msgId = msgId', msgBody = EncRcvMsgBody encBody'})))]) <- atomically $ readTBQueue msgQ + (_tSess, _, [(_entId, STEvent (Right (SMP.MSG RcvMessage {msgId = msgId', msgBody = EncRcvMsgBody encBody'})))]) <- atomically $ readTBQueue msgQ dec msgId' encBody' `shouldBe` Right msg' runExceptT' $ ackSMPMessage rc rPriv rcvId msgId' ) diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index dd97781c2..82a39af39 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -733,7 +733,7 @@ testServiceDeliverSubscribe = pure $ Just $ Just mId3 _ -> pure Nothing ] - Resp "" NoEntity SALL <- tGet1 sh + Resp "" NoEntity ALLS <- tGet1 sh Resp "12" _ OK <- signSendRecv sh rKey ("12", rId, ACK mId3) Resp "14" _ OK <- signSendRecv h sKey ("14", sId, _SEND "hello 4") Resp "" _ (Msg mId4 msg4) <- tGet1 sh @@ -831,7 +831,7 @@ testServiceUpgradeAndDowngrade = pure $ Just $ Just (rKey2, rId2, mId3) _ -> pure Nothing ] - Resp "" NoEntity SALL <- tGet1 sh + Resp "" NoEntity ALLS <- tGet1 sh Resp "15" _ OK <- signSendRecv sh rKey3_1 ("15", rId3_1, ACK mId3_1) Resp "16" _ OK <- signSendRecv sh rKey3_2 ("16", rId3_2, ACK mId3_2) pure ()