Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions src/Simplex/Messaging/Agent.hs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ import Data.Bifunctor (bimap, first)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Composition
import Data.Either (isRight, partitionEithers, rights)
import Data.Either (fromRight, isRight, partitionEithers, rights)
import Data.Foldable (foldl', toList)
import Data.Functor (($>))
import Data.Functor.Identity
Expand Down Expand Up @@ -221,7 +221,6 @@ import Simplex.Messaging.Protocol
SMPMsgMeta,
SParty (..),
SProtocolType (..),
ServiceSub (..),
ServiceSubResult,
SndPublicAuthKey,
SubscriptionMode (..),
Expand Down Expand Up @@ -1451,7 +1450,23 @@ subscribeAllConnections' c onlyNeeded activeUserId_ = handleErr $ do
let userSrvs' = case activeUserId_ of
Just activeUserId -> sortOn (\(uId, _) -> if uId == activeUserId then 0 else 1 :: Int) userSrvs
Nothing -> userSrvs
rs <- lift $ mapConcurrently (subscribeUserServer maxPending currPending) userSrvs'
useServices <- readTVarIO $ useClientServices c
-- These options are possible below:
-- 1) services fully disabled:
-- No service subscriptions will be attempted, and existing services and association will remain in in the database,
-- but they will be ignored because of hasService parameter set to False.
-- This approach preserves performance for all clients that do not use services.
-- 2) at least one user ID has services enabled:
-- Service will be loaded for all user/server combinations:
-- a) service is enabled for and service record exists: subscription will be attempted,
-- b) service is disabled and record exists: service record and all associations will be removed,
-- c) service is disabled or no record: no subscription attempt.
-- On successful service subscription, only unassociated queues will be subscribed.
userSrvs'' <-
if any id useServices
then lift $ mapConcurrently (subscribeService useServices) userSrvs'
else pure $ map (,False) userSrvs'
rs <- lift $ mapConcurrently (subscribeUserServer maxPending currPending) userSrvs''
let (errs, oks) = partitionEithers rs
logInfo $ "subscribed " <> tshow (sum oks) <> " queues"
forM_ (L.nonEmpty errs) $ notifySub c . ERRS . L.map ("",)
Expand All @@ -1460,16 +1475,25 @@ subscribeAllConnections' c onlyNeeded activeUserId_ = handleErr $ do
resumeAllCommands c
where
handleErr = (`catchAllErrors` \e -> notifySub' c "" (ERR e) >> throwE e)
subscribeUserServer :: Int -> TVar Int -> (UserId, SMPServer) -> AM' (Either AgentErrorType Int)
subscribeUserServer maxPending currPending (userId, srv) = do
subscribeService :: Map UserId Bool -> (UserId, SMPServer) -> AM' ((UserId, SMPServer), ServiceAssoc)
subscribeService useServices us@(userId, srv) = fmap ((us,) . fromRight False) $ tryAllErrors' $ do
withStore' c (\db -> getSubscriptionService db userId srv) >>= \case
Just serviceSub -> case M.lookup userId useServices of
-- TODO [certs rcv] improve logic to differentiate between permanent and temporary service subscription errors,
-- as the current logic would fall back to per-queue subscriptions on ANY service subscription error (e.g., network connection error).
Just True -> isRight <$> tryAllErrors (subscribeClientService c True userId srv serviceSub)
_ -> False <$ withStore' c (\db -> unassocUserServerRcvQueueSubs db userId srv)
_ -> pure False
subscribeUserServer :: Int -> TVar Int -> ((UserId, SMPServer), ServiceAssoc) -> AM' (Either AgentErrorType Int)
subscribeUserServer maxPending currPending ((userId, srv), hasService) = do
atomically $ whenM ((maxPending <=) <$> readTVar currPending) retry
tryAllErrors' $ do
qs <- withStore' c $ \db -> do
qs <- getUserServerRcvQueueSubs db userId srv onlyNeeded
atomically $ modifyTVar' currPending (+ length qs) -- update before leaving transaction
qs <- getUserServerRcvQueueSubs db userId srv onlyNeeded hasService
unless (null qs) $ atomically $ modifyTVar' currPending (+ length qs) -- update before leaving transaction
pure qs
let n = length qs
lift $ subscribe qs `E.finally` atomically (modifyTVar' currPending $ subtract n)
unless (null qs) $ lift $ subscribe qs `E.finally` atomically (modifyTVar' currPending $ subtract n)
pure n
where
subscribe qs = do
Expand Down Expand Up @@ -1522,7 +1546,7 @@ subscribeClientServices' c userId =
useService = liftIO $ (Just True ==) <$> TM.lookupIO userId (useClientServices c)
subscribe = do
srvs <- withStore' c (`getClientServiceServers` userId)
lift $ M.fromList <$> mapConcurrently (\(srv, ServiceSub _ n idsHash) -> fmap (srv,) $ tryAllErrors' $ subscribeClientService c False userId srv n idsHash) srvs
lift $ M.fromList <$> mapConcurrently (\(srv, serviceSub) -> fmap (srv,) $ tryAllErrors' $ subscribeClientService c False userId srv serviceSub) srvs

-- requesting messages sequentially, to reduce memory usage
getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta)))
Expand Down
47 changes: 31 additions & 16 deletions src/Simplex/Messaging/Agent/Client.hs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ import Simplex.Messaging.Session
import Simplex.Messaging.SystemTime
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (SMPServiceRole (..), SMPVersion, ServiceCredentials (..), SessionId, THClientService' (..), THandleParams (sessionId, thVersion), TransportError (..), TransportPeer (..), sndAuthKeySMPVersion, shortLinksSMPVersion, newNtfCredsSMPVersion)
import Simplex.Messaging.Transport (HandshakeError (..), SMPServiceRole (..), SMPVersion, ServiceCredentials (..), SessionId, THClientService' (..), THandleAuth (..), THandleParams (sessionId, thAuth, thVersion), TransportError (..), TransportPeer (..), sndAuthKeySMPVersion, shortLinksSMPVersion, newNtfCredsSMPVersion)
import Simplex.Messaging.Transport.Client (TransportHost (..))
import Simplex.Messaging.Transport.Credentials
import Simplex.Messaging.Util
Expand Down Expand Up @@ -619,7 +619,7 @@ getServiceCredentials c userId srv =
let g = agentDRG c
((C.KeyHash kh, serviceCreds), serviceId_) <-
withStore' c $ \db ->
getClientService db userId srv >>= \case
getClientServiceCredentials db userId srv >>= \case
Just service -> pure service
Nothing -> do
cred <- genCredentials g Nothing (25, 24 * 999999) "simplex"
Expand Down Expand Up @@ -747,15 +747,13 @@ smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm
smp <- liftError (protocolClientError SMP $ B.unpack $ strEncode srv) $ do
ts <- readTVarIO proxySessTs
ExceptT $ getProtocolClient g nm tSess cfg' presetDomains (Just msgQ) ts $ smpClientDisconnected c tSess env v' prs
-- TODO [certs rcv] add service to SS, possibly combine with SS.setSessionId
atomically $ SS.setSessionId tSess (sessionId $ thParams smp) $ currentSubs c
updateClientService service smp
pure SMPConnectedClient {connectedClient = smp, proxiedRelays = prs}
-- TODO [certs rcv] this should differentiate between service ID just set and service ID changed, and in the latter case disassociate the queues
updateClientService service smp = case (service, smpClientService smp) of
(Just (_, serviceId_), Just THClientService {serviceId})
| serviceId_ /= Just serviceId -> withStore' c $ \db -> setClientServiceId db userId srv serviceId
| otherwise -> pure ()
(Just (_, serviceId_), Just THClientService {serviceId}) -> withStore' c $ \db -> do
setClientServiceId db userId srv serviceId
forM_ serviceId_ $ \sId -> when (sId /= serviceId) $ removeRcvServiceAssocs db userId srv
(Just _, Nothing) -> withStore' c $ \db -> deleteClientService db userId srv -- e.g., server version downgrade
(Nothing, Just _) -> logError "server returned serviceId without service credentials in request"
(Nothing, Nothing) -> pure ()
Expand Down Expand Up @@ -1258,6 +1256,15 @@ protocolClientError protocolError_ host = \case
PCEServiceUnavailable {} -> BROKER host NO_SERVICE
PCEIOError e -> BROKER host $ NETWORK $ NEConnectError $ E.displayException e

-- it is consistent with smpClientServiceError
clientServiceError :: AgentErrorType -> Bool
clientServiceError = \case
BROKER _ NO_SERVICE -> True
BROKER _ (TRANSPORT (TEHandshake BAD_SERVICE)) -> True -- TODO [certs rcv] this error may be temporary, so we should possibly resubscribe.
SMP _ SMP.SERVICE -> True
SMP _ (SMP.PROXY (SMP.BROKER NO_SERVICE)) -> True -- for completeness, it cannot happen.
_ -> False

data ProtocolTestStep
= TSConnect
| TSDisconnect
Expand Down Expand Up @@ -1446,8 +1453,8 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl
withClient c nm tSess $ \(SMPConnectedClient smp _) -> do
(ntfKeys, ntfCreds) <- liftIO $ mkNtfCreds a g smp
(thParams smp,ntfKeys,) <$> createSMPQueue smp nm nonce_ rKeys dhKey auth subMode (queueReqData cqrd) ntfCreds
-- TODO [certs rcv] validate that serviceId is the same as in the client session, fail otherwise
-- possibly, it should allow returning Nothing - it would indicate incorrect old version
let sessServiceId = (\THClientService {serviceId = sId} -> sId) <$> (clientService =<< thAuth thParams')
when (isJust serviceId && serviceId /= sessServiceId) $ logError "incorrect service ID in NEW response"
liftIO . logServer "<--" c srv NoEntity $ B.unwords ["IDS", logSecret rcvId, logSecret sndId]
shortLink <- mkShortLinkCreds thParams' qik
let rq =
Expand All @@ -1463,7 +1470,7 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl
sndId,
queueMode,
shortLink,
rcvServiceAssoc = isJust serviceId,
rcvServiceAssoc = isJust serviceId && serviceId == sessServiceId,
status = New,
enableNtfs,
clientNoticeId = Nothing,
Expand Down Expand Up @@ -1718,8 +1725,9 @@ resubscribeClientService :: AgentClient -> SMPTransportSession -> ServiceSub ->
resubscribeClientService c tSess serviceSub =
withServiceClient c tSess $ \smp _ -> subscribeClientService_ c True tSess smp serviceSub

subscribeClientService :: AgentClient -> Bool -> UserId -> SMPServer -> Int64 -> IdsHash -> AM ServiceSubResult
subscribeClientService c withEvent userId srv n idsHash =
-- TODO [certs rcv] update service in the database if it has different ID and re-associate queues, and send event
subscribeClientService :: AgentClient -> Bool -> UserId -> SMPServer -> ServiceSub -> AM ServiceSubResult
subscribeClientService c withEvent userId srv (ServiceSub _ n idsHash) =
withServiceClient c tSess $ \smp smpServiceId -> do
let serviceSub = ServiceSub smpServiceId n idsHash
atomically $ SS.setPendingServiceSub tSess serviceSub $ currentSubs c
Expand All @@ -1728,14 +1736,21 @@ subscribeClientService c withEvent userId srv n idsHash =
tSess = (userId, srv, Nothing)

withServiceClient :: AgentClient -> SMPTransportSession -> (SMPClient -> ServiceId -> ExceptT SMPClientError IO a) -> AM a
withServiceClient c tSess action =
withLogClient c NRMBackground tSess B.empty "SUBS" $ \(SMPConnectedClient smp _) ->
withServiceClient c tSess@(userId, srv, _) subscribe =
unassocOnError $ withLogClient c NRMBackground tSess B.empty "SUBS" $ \(SMPConnectedClient smp _) ->
case (\THClientService {serviceId} -> serviceId) <$> smpClientService smp of
Just smpServiceId -> action smp smpServiceId
Just smpServiceId -> subscribe smp smpServiceId
Nothing -> throwE PCEServiceUnavailable
where
unassocOnError a = a `catchE` \e -> do
when (clientServiceError e) $ do
qs <- withStore' c $ \db -> unassocUserServerRcvQueueSubs db userId srv
void $ lift $ subscribeUserServerQueues c userId srv qs
throwE e

-- TODO [certs rcv] send subscription error event?
subscribeClientService_ :: AgentClient -> Bool -> SMPTransportSession -> SMPClient -> ServiceSub -> ExceptT SMPClientError IO ServiceSubResult
subscribeClientService_ c withEvent tSess@(_, srv, _) smp expected@(ServiceSub _ n idsHash) = do
subscribeClientService_ c withEvent tSess@(userId, srv, _) smp expected@(ServiceSub _ n idsHash) = do
subscribed <- subscribeService smp SMP.SRecipientService n idsHash
let sessId = sessionId $ thParams smp
r = serviceSubResult expected subscribed
Expand Down
102 changes: 85 additions & 17 deletions src/Simplex/Messaging/Agent/Store/AgentStore.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ module Simplex.Messaging.Agent.Store.AgentStore

-- * Client services
createClientService,
getClientService,
getClientServiceCredentials,
getSubscriptionServices,
getSubscriptionService,
getClientServiceServers,
setClientServiceId,
deleteClientService,
Expand All @@ -52,8 +54,10 @@ module Simplex.Messaging.Agent.Store.AgentStore
updateClientNotices,
getSubscriptionServers,
getUserServerRcvQueueSubs,
unassocUserServerRcvQueueSubs,
unsetQueuesToSubscribe,
setRcvServiceAssocs,
removeRcvServiceAssocs,
getConnIds,
getConn,
getDeletedConn,
Expand Down Expand Up @@ -419,8 +423,8 @@ createClientService db userId srv (kh, (cert, pk)) = do
|]
(userId, host srv, port srv, serverKeyHash_, kh, cert, pk)

getClientService :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ((C.KeyHash, TLS.Credential), Maybe ServiceId))
getClientService db userId srv =
getClientServiceCredentials :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ((C.KeyHash, TLS.Credential), Maybe ServiceId))
getClientServiceCredentials db userId srv =
maybeFirstRow toService $
DB.query
db
Expand All @@ -435,21 +439,41 @@ getClientService db userId srv =
where
toService (kh, cert, pk, serviceId_) = ((kh, (cert, pk)), serviceId_)

getClientServiceServers :: DB.Connection -> UserId -> IO [(SMPServer, ServiceSub)]
getClientServiceServers db userId =
map toServer
<$> DB.query
getSubscriptionServices :: DB.Connection -> IO [(UserId, (SMPServer, ServiceSub))]
getSubscriptionServices db = map toUserService <$> DB.query_ db clientServiceQuery
where
toUserService (Only userId :. serviceRow) = (userId, toServerService serviceRow)

getSubscriptionService :: DB.Connection -> UserId -> SMPServer -> IO (Maybe ServiceSub)
getSubscriptionService db userId (SMPServer h p kh) =
maybeFirstRow toService $
DB.query
db
[sql|
SELECT c.host, c.port, s.key_hash, c.service_id, c.service_queue_count, c.service_queue_ids_hash
SELECT c.service_id, c.service_queue_count, c.service_queue_ids_hash
FROM client_services c
JOIN servers s ON s.host = c.host AND s.port = c.port
WHERE c.user_id = ?
WHERE c.user_id = ? AND c.host = ? AND c.port = ? AND COALESCE(c.server_key_hash, s.key_hash) = ?
|]
(Only userId)
(userId, h, p, kh)
where
toServer (host, port, kh, serviceId, n, Binary idsHash) =
(SMPServer host port kh, ServiceSub serviceId n (IdsHash idsHash))
toService (serviceId, qCnt, idsHash) = ServiceSub serviceId qCnt idsHash

getClientServiceServers :: DB.Connection -> UserId -> IO [(SMPServer, ServiceSub)]
getClientServiceServers db userId =
map toServerService <$> DB.query db (clientServiceQuery <> " WHERE c.user_id = ?") (Only userId)

clientServiceQuery :: Query
clientServiceQuery =
[sql|
SELECT c.host, c.port, COALESCE(c.server_key_hash, s.key_hash), c.service_id, c.service_queue_count, c.service_queue_ids_hash
FROM client_services c
JOIN servers s ON s.host = c.host AND s.port = c.port
|]

toServerService :: (NonEmpty TransportHost, ServiceName, C.KeyHash, ServiceId, Int64, Binary ByteString) -> (ProtocolServer 'PSMP, ServiceSub)
toServerService (host, port, kh, serviceId, n, Binary idsHash) =
(SMPServer host port kh, ServiceSub serviceId n (IdsHash idsHash))

setClientServiceId :: DB.Connection -> UserId -> SMPServer -> ServiceId -> IO ()
setClientServiceId db userId srv serviceId =
Expand All @@ -473,7 +497,9 @@ deleteClientService db userId srv =
(userId, host srv, port srv)

deleteClientServices :: DB.Connection -> UserId -> IO ()
deleteClientServices db userId = DB.execute db "DELETE FROM client_services WHERE user_id = ?" (Only userId)
deleteClientServices db userId = do
DB.execute db "DELETE FROM client_services WHERE user_id = ?" (Only userId)
removeUserRcvServiceAssocs db userId

createConn_ ::
TVar ChaChaDRG ->
Expand Down Expand Up @@ -2236,17 +2262,24 @@ getSubscriptionServers db onlyNeeded =
toUserServer :: (UserId, NonEmpty TransportHost, ServiceName, C.KeyHash) -> (UserId, SMPServer)
toUserServer (userId, host, port, keyHash) = (userId, SMPServer host port keyHash)

getUserServerRcvQueueSubs :: DB.Connection -> UserId -> SMPServer -> Bool -> IO [RcvQueueSub]
getUserServerRcvQueueSubs db userId srv onlyNeeded =
-- TODO [certs rcv] check index for getting queues with service present
getUserServerRcvQueueSubs :: DB.Connection -> UserId -> SMPServer -> Bool -> ServiceAssoc -> IO [RcvQueueSub]
getUserServerRcvQueueSubs db userId srv onlyNeeded hasService =
map toRcvQueueSub
<$> DB.query
db
(rcvQueueSubQuery <> toSubscribe <> " c.deleted = 0 AND q.deleted = 0 AND c.user_id = ? AND q.host = ? AND q.port = ?")
(rcvQueueSubQuery <> toSubscribe <> " c.deleted = 0 AND q.deleted = 0 AND c.user_id = ? AND q.host = ? AND q.port = ?" <> serviceCond)
(userId, host srv, port srv)
where
toSubscribe
| onlyNeeded = " WHERE q.to_subscribe = 1 AND "
| otherwise = " WHERE "
serviceCond
| hasService = " AND q.rcv_service_assoc = 0"
| otherwise = ""

unassocUserServerRcvQueueSubs :: DB.Connection -> UserId -> SMPServer -> IO [RcvQueueSub]
unassocUserServerRcvQueueSubs db userId srv = undefined

unsetQueuesToSubscribe :: DB.Connection -> IO ()
unsetQueuesToSubscribe db = DB.execute_ db "UPDATE rcv_queues SET to_subscribe = 0 WHERE to_subscribe = 1"
Expand All @@ -2256,9 +2289,44 @@ setRcvServiceAssocs db rqs =
#if defined(dbPostgres)
DB.execute db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id IN " $ Only $ In (map queueId rqs)
#else
DB.executeMany db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id = ?" $ map (Only . queueId) rqs
DB.executeMany db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_isd = ?" $ map (Only . queueId) rqs
#endif

removeRcvServiceAssocs :: DB.Connection -> UserId -> SMPServer -> IO ()
removeRcvServiceAssocs db userId (SMPServer h p kh) =
DB.execute
db
[sql|
UPDATE rcv_queues
SET rcv_service_assoc = 0
WHERE EXISTS (
SELECT 1
FROM connections c
JOIN servers s ON rcv_queues.host = s.host AND rcv_queues.port = s.port
WHERE c.conn_id = rcv_queues.conn_id
AND c.user_id = ?
AND rcv_queues.host = ?
AND rcv_queues.port = ?
AND COALESCE(rcv_queues.server_key_hash, s.key_hash) = ?
)
|]
(userId, h, p, kh)

removeUserRcvServiceAssocs :: DB.Connection -> UserId -> IO ()
removeUserRcvServiceAssocs db userId =
DB.execute
db
[sql|
UPDATE rcv_queues
SET rcv_service_assoc = 0
WHERE EXISTS (
SELECT 1
FROM connections c
WHERE c.conn_id = rcv_queues.conn_id AND c.user_id = ?
)
|]
(Only userId)

-- * getConn helpers

getConnIds :: DB.Connection -> IO [ConnId]
Expand Down
1 change: 1 addition & 0 deletions src/Simplex/Messaging/Client.hs
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,7 @@ temporaryClientError = \case
_ -> False
{-# INLINE temporaryClientError #-}

-- it is consistent with clientServiceError
smpClientServiceError :: SMPClientError -> Bool
smpClientServiceError = \case
PCEServiceUnavailable -> True
Expand Down
Loading
Loading