diff --git a/simplexmq.cabal b/simplexmq.cabal index 3f9d1f61d..13759a05a 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -275,7 +275,6 @@ library Simplex.Messaging.Notifications.Server.Store.Migrations Simplex.Messaging.Notifications.Server.Store.Postgres Simplex.Messaging.Notifications.Server.Store.Types - Simplex.Messaging.Notifications.Server.StoreLog Simplex.Messaging.Server.MsgStore.Postgres Simplex.Messaging.Server.QueueStore.Postgres Simplex.Messaging.Server.QueueStore.Postgres.Migrations diff --git a/src/Simplex/FileTransfer/Client.hs b/src/Simplex/FileTransfer/Client.hs index 62f06b7d3..a425138e5 100644 --- a/src/Simplex/FileTransfer/Client.hs +++ b/src/Simplex/FileTransfer/Client.hs @@ -11,6 +11,7 @@ module Simplex.FileTransfer.Client where +import qualified Control.Exception as E import Control.Logger.Simple import Control.Monad import Control.Monad.Except @@ -264,7 +265,7 @@ downloadXFTPChunk g c@XFTPClient {config} rpKey fId chunkSpec@XFTPRcvChunkSpec { where errors = [ Handler $ \(e :: H.HTTP2Error) -> pure $ Left $ PCENetworkError $ NEConnectError $ displayException e, - Handler $ \(e :: IOException) -> pure $ Left $ PCEIOError e, + Handler $ \(e :: IOException) -> pure $ Left $ PCEIOError $ E.displayException e, Handler $ \(e :: SomeException) -> pure $ Left $ PCENetworkError $ toNetworkError e ] download cbState = diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 9bf1afd8d..4fd9eb175 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -751,8 +751,8 @@ smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm atomically $ SS.setSessionId tSess (sessionId $ thParams smp) $ currentSubs c updateClientService service smp pure SMPConnectedClient {connectedClient = smp, proxiedRelays = prs} - updateClientService service smp = case (service, smpClientService smp) of - (Just (_, serviceId_), Just THClientService {serviceId}) -> withStore' c $ \db -> do + updateClientService service smp = case (service, smpClientServiceId smp) of + (Just (_, serviceId_), Just serviceId) -> withStore' c $ \db -> do setClientServiceId db userId srv serviceId forM_ serviceId_ $ \sId -> when (sId /= serviceId) $ removeRcvServiceAssocs db userId srv (Just _, Nothing) -> withStore' c $ \db -> deleteClientService db userId srv -- e.g., server version downgrade @@ -1255,7 +1255,7 @@ protocolClientError protocolError_ host = \case PCETransportError e -> BROKER host $ TRANSPORT e e@PCECryptoError {} -> INTERNAL $ show e PCEServiceUnavailable {} -> BROKER host NO_SERVICE - PCEIOError e -> BROKER host $ NETWORK $ NEConnectError $ E.displayException e + PCEIOError e -> BROKER host $ NETWORK $ NEConnectError e -- it is consistent with smpClientServiceError clientServiceError :: AgentErrorType -> Bool @@ -1546,6 +1546,7 @@ processSubResults c tSess@(userId, srv, _) sessId serviceId_ rs = do Left e -> case smpErrorClientNotice e of Just notice_ -> (failed', subscribed, (rq, notice_) : notices, ignored) where + -- TODO [certs rcv] not used? notices' = if isJust notice_ || isJust clientNoticeId then (rq, notice_) : notices else notices Nothing | temporaryClientError e -> acc @@ -1678,7 +1679,7 @@ subscribeSessQueues_ c withEvents qs = sendClientBatch_ "SUB" False subscribe_ c (active, (serviceQs, notices)) <- atomically $ do r@(_, (_, notices)) <- ifM (activeClientSession c tSess sessId) - ((True,) <$> processSubResults c tSess sessId smpServiceId rs) + ((True,) <$> processSubResults c tSess sessId (smpClientServiceId smp) rs) ((False, ([], [])) <$ incSMPServerStat' c userId srv connSubIgnored (length rs)) unless (null notices) $ takeTMVar $ clientNoticesLock c pure r @@ -1704,7 +1705,6 @@ subscribeSessQueues_ c withEvents qs = sendClientBatch_ "SUB" False subscribe_ c where tSess = transportSession' smp sessId = sessionId $ thParams smp - smpServiceId = (\THClientService {serviceId} -> serviceId) <$> smpClientService smp processRcvServiceAssocs :: SMPQueue q => AgentClient -> [q] -> AM' () processRcvServiceAssocs _ [] = pure () @@ -1752,7 +1752,7 @@ subscribeClientService c withEvent userId srv (ServiceSub _ n idsHash) = withServiceClient :: AgentClient -> SMPTransportSession -> (SMPClient -> ServiceId -> ExceptT SMPClientError IO a) -> AM a withServiceClient c tSess subscribe = withLogClient c NRMBackground tSess B.empty "SUBS" $ \(SMPConnectedClient smp _) -> - case (\THClientService {serviceId} -> serviceId) <$> smpClientService smp of + case smpClientServiceId smp of Just smpServiceId -> subscribe smp smpServiceId Nothing -> throwE PCEServiceUnavailable diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index 853a76908..2dcb76327 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -472,15 +472,21 @@ toServerService (host, port, kh, serviceId, n, Binary idsHash) = (SMPServer host port kh, ServiceSub serviceId n (IdsHash idsHash)) setClientServiceId :: DB.Connection -> UserId -> SMPServer -> ServiceId -> IO () -setClientServiceId db userId srv serviceId = +setClientServiceId db userId (SMPServer h p kh) serviceId = DB.execute db [sql| UPDATE client_services SET service_id = ? - WHERE user_id = ? AND host = ? AND port = ? + FROM servers s + WHERE client_services.user_id = ? + AND client_services.host = ? + AND client_services.port = ? + AND s.host = client_services.host + AND s.port = client_services.port + AND COALESCE(client_services.server_key_hash, s.key_hash) = ? |] - (serviceId, userId, host srv, port srv) + (serviceId, userId, h, p, kh) deleteClientService :: DB.Connection -> UserId -> SMPServer -> IO () deleteClientService db userId (SMPServer h p kh) = @@ -2307,7 +2313,7 @@ unsetQueuesToSubscribe db = DB.execute_ db "UPDATE rcv_queues SET to_subscribe = setRcvServiceAssocs :: SMPQueue q => DB.Connection -> [q] -> IO () setRcvServiceAssocs db rqs = #if defined(dbPostgres) - DB.execute db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id IN " $ Only $ In (map queueId rqs) + DB.execute db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id IN ?" $ Only $ In (map queueId rqs) #else DB.executeMany db "UPDATE rcv_queues SET rcv_service_assoc = 1 WHERE rcv_id = ?" $ map (Only . queueId) rqs #endif diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index ac2dc9a9d..ebc458c0e 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -52,6 +52,7 @@ module Simplex.Messaging.Client subscribeSMPQueuesNtfs, subscribeService, smpClientService, + smpClientServiceId, secureSMPQueue, secureSndSMPQueue, proxySecureSndSMPQueue, @@ -128,7 +129,8 @@ import Control.Applicative ((<|>)) import Control.Concurrent (ThreadId, forkFinally, forkIO, killThread, mkWeakThreadId) import Control.Concurrent.Async import Control.Concurrent.STM -import Control.Exception +import Control.Exception (Exception, SomeException) +import qualified Control.Exception as E import Control.Logger.Simple import Control.Monad import Control.Monad.Except @@ -565,7 +567,7 @@ getProtocolClient g nm transportSession@(_, srv, _) cfg@ProtocolClientConfig {qS case chooseTransportHost networkConfig (host srv) of Right useHost -> (getCurrentTime >>= mkProtocolClient useHost >>= runClient useTransport useHost) - `catch` \(e :: IOException) -> pure . Left $ PCEIOError e + `E.catch` \(e :: SomeException) -> pure $ Left $ PCEIOError $ E.displayException e Left e -> pure $ Left e where NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig @@ -638,7 +640,7 @@ getProtocolClient g nm transportSession@(_, srv, _) cfg@ProtocolClientConfig {qS writeTVar (connected c) True putTMVar cVar $ Right c' raceAny_ ([send c' th, process c', receive c' th] <> [monitor c' | smpPingInterval > 0]) - `finally` disconnected c' + `E.finally` disconnected c' send :: Transport c => ProtocolClient v err msg -> THandle v c 'TClient -> IO () send ProtocolClient {client_ = PClient {sndQ}} h = forever $ atomically (readTBQueue sndQ) >>= sendPending @@ -765,7 +767,7 @@ data ProtocolClientError err | -- | Error when cryptographically "signing" the command or when initializing crypto_box. PCECryptoError C.CryptoError | -- | IO Error - PCEIOError IOException + PCEIOError String deriving (Eq, Show, Exception) type SMPClientError = ProtocolClientError ErrorType @@ -926,6 +928,10 @@ smpClientService :: SMPClient -> Maybe THClientService smpClientService = thAuth . thParams >=> clientService {-# INLINE smpClientService #-} +smpClientServiceId :: SMPClient -> Maybe ServiceId +smpClientServiceId = fmap (\THClientService {serviceId} -> serviceId) . smpClientService +{-# INLINE smpClientServiceId #-} + enablePings :: SMPClient -> IO () enablePings ProtocolClient {client_ = PClient {sendPings}} = atomically $ writeTVar sendPings True {-# INLINE enablePings #-} diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 45d747d21..9739c19c7 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -15,6 +15,7 @@ module Simplex.Messaging.Client.Agent ( SMPClientAgent (..), SMPClientAgentConfig (..), SMPClientAgentEvent (..), + DBService (..), OwnServer, defaultSMPClientAgentConfig, newSMPClientAgent, @@ -133,6 +134,7 @@ defaultSMPClientAgentConfig = data SMPClientAgent p = SMPClientAgent { agentCfg :: SMPClientAgentConfig, agentParty :: SParty p, + dbService :: Maybe DBService, active :: TVar Bool, startedAt :: UTCTime, msgQ :: TBQueue (ServerTransmissionBatch SMPVersion ErrorType BrokerMsg), @@ -155,8 +157,8 @@ data SMPClientAgent p = SMPClientAgent type OwnServer = Bool -newSMPClientAgent :: SParty p -> SMPClientAgentConfig -> TVar ChaChaDRG -> IO (SMPClientAgent p) -newSMPClientAgent agentParty agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} randomDrg = do +newSMPClientAgent :: SParty p -> SMPClientAgentConfig -> Maybe DBService -> TVar ChaChaDRG -> IO (SMPClientAgent p) +newSMPClientAgent agentParty agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} dbService randomDrg = do active <- newTVarIO True startedAt <- getCurrentTime msgQ <- newTBQueueIO msgQSize @@ -173,6 +175,7 @@ newSMPClientAgent agentParty agentCfg@SMPClientAgentConfig {msgQSize, agentQSize SMPClientAgent { agentCfg, agentParty, + dbService, active, startedAt, msgQ, @@ -188,6 +191,11 @@ newSMPClientAgent agentParty agentCfg@SMPClientAgentConfig {msgQSize, agentQSize workerSeq } +data DBService = DBService + { getCredentials :: SMPServer -> IO (Either SMPClientError ServiceCredentials), + updateServiceId :: SMPServer -> Maybe ServiceId -> IO (Either SMPClientError ()) + } + -- | Get or create SMP client for SMPServer getSMPServerClient' :: SMPClientAgent p -> SMPServer -> ExceptT SMPClientError IO SMPClient getSMPServerClient' ca srv = snd <$> getSMPServerClient'' ca srv @@ -218,7 +226,7 @@ getSMPServerClient'' ca@SMPClientAgent {agentCfg, smpClients, smpSessions, worke newSMPClient :: SMPClientVar -> IO (Either SMPClientError (OwnServer, SMPClient)) newSMPClient v = do - r <- connectClient ca srv v `E.catch` (pure . Left . PCEIOError) + r <- connectClient ca srv v `E.catch` \(e :: E.SomeException) -> pure $ Left $ PCEIOError $ E.displayException e case r of Right smp -> do logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv @@ -227,8 +235,7 @@ getSMPServerClient'' ca@SMPClientAgent {agentCfg, smpClients, smpSessions, worke atomically $ do putTMVar (sessionVar v) (Right c) TM.insert (sessionId $ thParams smp) c smpSessions - let serviceId_ = (\THClientService {serviceId} -> serviceId) <$> smpClientService smp - notify ca $ CAConnected srv serviceId_ + notify ca $ CAConnected srv $ smpClientServiceId smp pure $ Right c Left e -> do let ei = persistErrorInterval agentCfg @@ -249,9 +256,18 @@ isOwnServer SMPClientAgent {agentCfg} ProtocolServer {host} = -- | Run an SMP client for SMPClientVar connectClient :: SMPClientAgent p -> SMPServer -> SMPClientVar -> IO (Either SMPClientError SMPClient) -connectClient ca@SMPClientAgent {agentCfg, smpClients, smpSessions, msgQ, randomDrg, startedAt} srv v = - getProtocolClient randomDrg NRMBackground (1, srv, Nothing) (smpCfg agentCfg) [] (Just msgQ) startedAt clientDisconnected +connectClient ca@SMPClientAgent {agentCfg, dbService, smpClients, smpSessions, msgQ, randomDrg, startedAt} srv v = case dbService of + Just dbs -> runExceptT $ do + creds <- ExceptT $ getCredentials dbs srv + smp <- ExceptT $ getClient cfg {serviceCredentials = Just creds} + whenM (atomically $ activeClientSession ca smp srv) $ + ExceptT $ updateServiceId dbs srv $ smpClientServiceId smp + pure smp + Nothing -> getClient cfg where + cfg = smpCfg agentCfg + getClient cfg' = getProtocolClient randomDrg NRMBackground (1, srv, Nothing) cfg' [] (Just msgQ) startedAt clientDisconnected + clientDisconnected :: SMPClient -> IO () clientDisconnected smp = do removeClientAndSubs smp >>= serverDown @@ -435,7 +451,7 @@ smpSubscribeQueues ca smp srv subs = do unless (null notPending) $ removePendingSubs ca srv notPending pure acc sessId = sessionId $ thParams smp - smpServiceId = (\THClientService {serviceId} -> serviceId) <$> smpClientService smp + smpServiceId = smpClientServiceId smp groupSub :: Map QueueId C.APrivateAuthKey -> ((QueueId, C.APrivateAuthKey), Either SMPClientError (Maybe ServiceId)) -> diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index e7c1ca5f9..7d9e36c99 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -588,7 +588,7 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = logError $ "SMP server service subscription error " <> showService srv serviceSub <> ": " <> tshow e CAServiceUnavailable srv serviceSub -> do logError $ "SMP server service unavailable: " <> showService srv serviceSub - removeServiceAssociation st srv >>= \case + removeServiceAndAssociations st srv >>= \case Right (srvId, updated) -> do logSubStatus srv "removed service association" updated updated void $ subscribeSrvSubs ca st batchSize (srv, srvId, Nothing) diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index b0eafbc63..9ac89a12d 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -4,13 +4,14 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} module Simplex.Messaging.Notifications.Server.Env where import Control.Concurrent (ThreadId) -import Control.Logger.Simple -import Control.Monad +import Control.Monad.Except +import Control.Monad.Trans.Except import Crypto.Random import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) @@ -21,28 +22,26 @@ import qualified Data.X509.Validation as XV import Network.Socket import qualified Network.TLS as TLS import Numeric.Natural -import Simplex.Messaging.Client (ProtocolClientConfig (..)) +import Simplex.Messaging.Client (ProtocolClientError (..), SMPClientError) import Simplex.Messaging.Client.Agent import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Push.APNS import Simplex.Messaging.Notifications.Server.Stats -import Simplex.Messaging.Notifications.Server.Store (newNtfSTMStore) import Simplex.Messaging.Notifications.Server.Store.Postgres import Simplex.Messaging.Notifications.Server.Store.Types -import Simplex.Messaging.Notifications.Server.StoreLog (readWriteNtfSTMStore) import Simplex.Messaging.Notifications.Transport (NTFVersion, VersionRangeNTF) -import Simplex.Messaging.Protocol (BasicAuth, CorrId, Party (..), SMPServer, SParty (..), Transmission) +import Simplex.Messaging.Protocol (BasicAuth, CorrId, Party (..), SMPServer, SParty (..), ServiceId, Transmission) import Simplex.Messaging.Server.Env.STM (StartOptions (..)) import Simplex.Messaging.Server.Expiration import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) -import Simplex.Messaging.Server.StoreLog (closeStoreLog) import Simplex.Messaging.Session import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ASrvTransport, SMPServiceRole (..), ServiceCredentials (..), THandleParams, TransportPeer (..)) +import Simplex.Messaging.Transport.Credentials (genCredentials, tlsCredentials) import Simplex.Messaging.Transport.Server (AddHTTP, ServerCredentials, TransportServerConfig, loadFingerprint, loadServerCredential) -import System.Exit (exitFailure) +import Simplex.Messaging.Util (liftEitherWith) import System.Mem.Weak (Weak) import UnliftIO.STM @@ -96,33 +95,35 @@ data NtfEnv = NtfEnv } newNtfServerEnv :: NtfServerConfig -> IO NtfEnv -newNtfServerEnv config@NtfServerConfig {pushQSize, smpAgentCfg, apnsConfig, dbStoreConfig, ntfCredentials, useServiceCreds, startOptions} = do - when (compactLog startOptions) $ compactDbStoreLog $ dbStoreLogPath dbStoreConfig +newNtfServerEnv config@NtfServerConfig {pushQSize, smpAgentCfg, apnsConfig, dbStoreConfig, ntfCredentials, useServiceCreds} = do random <- C.newRandom store <- newNtfDbStore dbStoreConfig tlsServerCreds <- loadServerCredential ntfCredentials - serviceCertHash@(XV.Fingerprint fp) <- loadFingerprint ntfCredentials - smpAgentCfg' <- - if useServiceCreds - then do - serviceSignKey <- case C.x509ToPrivate' $ snd tlsServerCreds of - Right pk -> pure pk - Left e -> putStrLn ("Server has no valid key: " <> show e) >> exitFailure - let service = ServiceCredentials {serviceRole = SRNotifier, serviceCreds = tlsServerCreds, serviceCertHash, serviceSignKey} - pure smpAgentCfg {smpCfg = (smpCfg smpAgentCfg) {serviceCredentials = Just service}} - else pure smpAgentCfg - subscriber <- newNtfSubscriber smpAgentCfg' random + XV.Fingerprint fp <- loadFingerprint ntfCredentials + let dbService = if useServiceCreds then Just $ mkDbService random store else Nothing + subscriber <- newNtfSubscriber smpAgentCfg dbService random pushServer <- newNtfPushServer pushQSize apnsConfig serverStats <- newNtfServerStats =<< getCurrentTime pure NtfEnv {config, subscriber, pushServer, store, random, tlsServerCreds, serverIdentity = C.KeyHash fp, serverStats} where - compactDbStoreLog = \case - Just f -> do - logNote $ "compacting store log " <> T.pack f - newNtfSTMStore >>= readWriteNtfSTMStore False f >>= closeStoreLog - Nothing -> do - logError "Error: `--compact-log` used without `enable: on` option in STORE_LOG section of INI file" - exitFailure + mkDbService g st = DBService {getCredentials, updateServiceId} + where + getCredentials :: SMPServer -> IO (Either SMPClientError ServiceCredentials) + getCredentials srv = runExceptT $ do + ExceptT (withClientDB "" st $ \db -> getNtfServiceCredentials db srv >>= mapM (mkServiceCreds db)) >>= \case + Just (C.KeyHash kh, serviceCreds) -> do + serviceSignKey <- liftEitherWith PCEIOError $ C.x509ToPrivate' $ snd serviceCreds + pure ServiceCredentials {serviceRole = SRNotifier, serviceCreds, serviceCertHash = XV.Fingerprint kh, serviceSignKey} + Nothing -> throwE PCEServiceUnavailable -- this error cannot happen, as clients never connect to unknown servers + mkServiceCreds db = \case + (_, Just tlsCreds) -> pure tlsCreds + (srvId, Nothing) -> do + cred <- genCredentials g Nothing (25, 24 * 999999) "simplex" + let tlsCreds = tlsCredentials [cred] + setNtfServiceCredentials db srvId tlsCreds + pure tlsCreds + updateServiceId :: SMPServer -> Maybe ServiceId -> IO (Either SMPClientError ()) + updateServiceId srv serviceId_ = withClientDB "" st $ \db -> updateNtfServiceId db srv serviceId_ data NtfSubscriber = NtfSubscriber { smpSubscribers :: TMap SMPServer SMPSubscriberVar, @@ -132,11 +133,11 @@ data NtfSubscriber = NtfSubscriber type SMPSubscriberVar = SessionVar SMPSubscriber -newNtfSubscriber :: SMPClientAgentConfig -> TVar ChaChaDRG -> IO NtfSubscriber -newNtfSubscriber smpAgentCfg random = do +newNtfSubscriber :: SMPClientAgentConfig -> Maybe DBService -> TVar ChaChaDRG -> IO NtfSubscriber +newNtfSubscriber smpAgentCfg dbService random = do smpSubscribers <- TM.emptyIO subscriberSeq <- newTVarIO 0 - smpAgent <- newSMPClientAgent SNotifierService smpAgentCfg random + smpAgent <- newSMPClientAgent SNotifierService smpAgentCfg dbService random pure NtfSubscriber {smpSubscribers, subscriberSeq, smpAgent} data SMPSubscriber = SMPSubscriber diff --git a/src/Simplex/Messaging/Notifications/Server/Main.hs b/src/Simplex/Messaging/Notifications/Server/Main.hs index de12c33f8..e855c84d4 100644 --- a/src/Simplex/Messaging/Notifications/Server/Main.hs +++ b/src/Simplex/Messaging/Notifications/Server/Main.hs @@ -17,42 +17,32 @@ import Data.Functor (($>)) import Data.Ini (lookupValue, readIniFile) import Data.Int (Int64) import Data.Maybe (fromMaybe) -import Data.Set (Set) -import qualified Data.Set as S import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import qualified Data.Text.IO as T import Network.Socket (HostName, ServiceName) import Options.Applicative -import Simplex.Messaging.Agent.Store.Postgres (checkSchemaExists) import Simplex.Messaging.Agent.Store.Postgres.Options (DBOpts (..)) import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..)) import Simplex.Messaging.Client (HostMode (..), NetworkConfig (..), ProtocolClientConfig (..), SMPWebPortServers (..), SocksMode (..), defaultNetworkConfig, textToHostMode) import Simplex.Messaging.Client.Agent (SMPClientAgentConfig (..), defaultSMPClientAgentConfig) import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Notifications.Protocol (NtfTokenId) -import Simplex.Messaging.Notifications.Server (runNtfServer, restoreServerLastNtfs) +import Simplex.Messaging.Notifications.Server (runNtfServer) import Simplex.Messaging.Notifications.Server.Env (NtfServerConfig (..), defaultInactiveClientExpiration) import Simplex.Messaging.Notifications.Server.Push.APNS (defaultAPNSPushClientConfig) -import Simplex.Messaging.Notifications.Server.Store (newNtfSTMStore) -import Simplex.Messaging.Notifications.Server.Store.Postgres (exportNtfDbStore, importNtfSTMStore, newNtfDbStore) -import Simplex.Messaging.Notifications.Server.StoreLog (readWriteNtfSTMStore) import Simplex.Messaging.Notifications.Transport (alpnSupportedNTFHandshakes, supportedServerNTFVRange) import Simplex.Messaging.Protocol (ProtoServerWithAuth (..), pattern NtfServer) import Simplex.Messaging.Server.CLI import Simplex.Messaging.Server.Env.STM (StartOptions (..)) import Simplex.Messaging.Server.Expiration -import Simplex.Messaging.Server.Main (strParse) import Simplex.Messaging.Server.Main.Init (iniDbOpts) import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) -import Simplex.Messaging.Server.StoreLog (closeStoreLog) import Simplex.Messaging.Transport (ASrvTransport) import Simplex.Messaging.Transport.Client (TransportHost (..)) import Simplex.Messaging.Transport.HTTP2 (httpALPN) import Simplex.Messaging.Transport.Server (AddHTTP, ServerCredentials (..), mkTransportServerConfig) -import Simplex.Messaging.Util (eitherToMaybe, ifM, tshow) -import System.Directory (createDirectoryIfMissing, doesFileExist, renameFile) -import System.Exit (exitFailure) +import Simplex.Messaging.Util (eitherToMaybe, tshow) +import System.Directory (createDirectoryIfMissing, doesFileExist) import System.FilePath (combine) import System.IO (BufferMode (..), hSetBuffering, stderr, stdout) import Text.Read (readMaybe) @@ -73,69 +63,11 @@ ntfServerCLI cfgPath logPath = deleteDirIfExists cfgPath deleteDirIfExists logPath putStrLn "Deleted configuration and log files" - Database cmd dbOpts@DBOpts {connstr, schema} -> withIniFile $ \ini -> do - schemaExists <- checkSchemaExists connstr schema - storeLogExists <- doesFileExist storeLogFilePath - lastNtfsExists <- doesFileExist defaultLastNtfsFile - case cmd of - SCImport skipTokens - | schemaExists && (storeLogExists || lastNtfsExists) -> exitConfigureNtfStore connstr schema - | schemaExists -> do - putStrLn $ "Schema " <> B.unpack schema <> " already exists in PostrgreSQL database: " <> B.unpack connstr - exitFailure - | not storeLogExists -> do - putStrLn $ storeLogFilePath <> " file does not exist." - exitFailure - | not lastNtfsExists -> do - putStrLn $ defaultLastNtfsFile <> " file does not exist." - exitFailure - | otherwise -> do - storeLogFile <- getRequiredStoreLogFile ini - confirmOrExit - ("WARNING: store log file " <> storeLogFile <> " will be compacted and imported to PostrgreSQL database: " <> B.unpack connstr <> ", schema: " <> B.unpack schema) - "Notification server store not imported" - stmStore <- newNtfSTMStore - sl <- readWriteNtfSTMStore True storeLogFile stmStore - closeStoreLog sl - restoreServerLastNtfs stmStore defaultLastNtfsFile - let storeCfg = PostgresStoreCfg {dbOpts = dbOpts {createSchema = True}, dbStoreLogPath = Nothing, confirmMigrations = MCConsole, deletedTTL = iniDeletedTTL ini} - ps <- newNtfDbStore storeCfg - (tCnt, sCnt, nCnt, serviceCnt) <- importNtfSTMStore ps stmStore skipTokens - renameFile storeLogFile $ storeLogFile <> ".bak" - putStrLn $ "Import completed: " <> show tCnt <> " tokens, " <> show sCnt <> " subscriptions, " <> show serviceCnt <> " service associations, " <> show nCnt <> " last token notifications." - putStrLn "Configure database options in INI file." - SCExport - | schemaExists && storeLogExists -> exitConfigureNtfStore connstr schema - | not schemaExists -> do - putStrLn $ "Schema " <> B.unpack schema <> " does not exist in PostrgreSQL database: " <> B.unpack connstr - exitFailure - | storeLogExists -> do - putStrLn $ storeLogFilePath <> " file already exists." - exitFailure - | lastNtfsExists -> do - putStrLn $ defaultLastNtfsFile <> " file already exists." - exitFailure - | otherwise -> do - confirmOrExit - ("WARNING: PostrgreSQL database schema " <> B.unpack schema <> " (database: " <> B.unpack connstr <> ") will be exported to store log file " <> storeLogFilePath) - "Notification server store not imported" - let storeCfg = PostgresStoreCfg {dbOpts, dbStoreLogPath = Just storeLogFilePath, confirmMigrations = MCConsole, deletedTTL = iniDeletedTTL ini} - st <- newNtfDbStore storeCfg - (tCnt, sCnt, nCnt) <- exportNtfDbStore st defaultLastNtfsFile - putStrLn $ "Export completed: " <> show tCnt <> " tokens, " <> show sCnt <> " subscriptions, " <> show nCnt <> " last token notifications." where withIniFile a = doesFileExist iniFile >>= \case True -> readIniFile iniFile >>= either exitError a _ -> exitError $ "Error: server is not initialized (" <> iniFile <> " does not exist).\nRun `" <> executableName <> " init`." - getRequiredStoreLogFile ini = do - case enableStoreLog' ini $> storeLogFilePath of - Just storeLogFile -> do - ifM - (doesFileExist storeLogFile) - (pure storeLogFile) - (putStrLn ("Store log file " <> storeLogFile <> " not found") >> exitFailure) - Nothing -> putStrLn "Store log disabled, see `[STORE_LOG] enable`" >> exitFailure iniFile = combine cfgPath "ntf-server.ini" serverVersion = "SMP notifications server v" <> simplexmqVersionCommit defaultServerPort = "443" @@ -289,11 +221,6 @@ ntfServerCLI cfgPath logPath = startOptions } iniDeletedTTL ini = readIniDefault (86400 * defaultDeletedTTL) "STORE_LOG" "db_deleted_ttl" ini - defaultLastNtfsFile = combine logPath "ntf-server-last-notifications.log" - exitConfigureNtfStore connstr schema = do - putStrLn $ "Error: both " <> storeLogFilePath <> " file and " <> B.unpack schema <> " schema are present (database: " <> B.unpack connstr <> ")." - putStrLn "Configure notification server storage." - exitFailure printNtfServerConfig :: [(ServiceName, ASrvTransport, AddHTTP)] -> PostgresStoreCfg -> IO () printNtfServerConfig transports PostgresStoreCfg {dbOpts = DBOpts {connstr, schema}, dbStoreLogPath} = do @@ -305,9 +232,6 @@ data CliCommand | OnlineCert CertOptions | Start StartOptions | Delete - | Database StoreCmd DBOpts - -data StoreCmd = SCImport (Set NtfTokenId) | SCExport data InitOptions = InitOptions { enableStoreLog :: Bool, @@ -338,22 +262,8 @@ cliCommandP cfgPath logPath iniFile = <> command "cert" (info (OnlineCert <$> certOptionsP) (progDesc $ "Generate new online TLS server credentials (configuration: " <> iniFile <> ")")) <> command "start" (info (Start <$> startOptionsP) (progDesc $ "Start server (configuration: " <> iniFile <> ")")) <> command "delete" (info (pure Delete) (progDesc "Delete configuration and log files")) - <> command "database" (info (Database <$> databaseCmdP <*> dbOptsP defaultNtfDBOpts) (progDesc "Import/export notifications server store to/from PostgreSQL database")) ) where - databaseCmdP = - hsubparser - ( command "import" (info (SCImport <$> skipTokensP) (progDesc $ "Import store logs into a new PostgreSQL database schema")) - <> command "export" (info (pure SCExport) (progDesc $ "Export PostgreSQL database schema to store logs")) - ) - skipTokensP :: Parser (Set NtfTokenId) - skipTokensP = - option - strParse - ( long "skip-tokens" - <> help "Skip tokens during import" - <> value S.empty - ) initP :: Parser InitOptions initP = do enableStoreLog <- diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs b/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs index 8c0da7c07..87e89ac8d 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs @@ -14,7 +14,8 @@ ntfServerSchemaMigrations :: [(String, Text, Maybe Text)] ntfServerSchemaMigrations = [ ("20250417_initial", m20250417_initial, Nothing), ("20250517_service_cert", m20250517_service_cert, Just down_m20250517_service_cert), - ("20250830_queue_ids_hash", m20250830_queue_ids_hash, Just down_m20250830_queue_ids_hash) + ("20250830_queue_ids_hash", m20250830_queue_ids_hash, Just down_m20250830_queue_ids_hash), + ("20251219_service_cert_per_server", m20251219_service_cert_per_server, Just down_m20251219_service_cert_per_server) ] -- | The list of migrations in ascending order by date @@ -225,3 +226,36 @@ ALTER TABLE smp_servers DROP COLUMN smp_notifier_ids_hash; |] <> dropXorHashFuncs + +m20251219_service_cert_per_server :: Text +m20251219_service_cert_per_server = + [r| +ALTER TABLE smp_servers + ADD COLUMN ntf_service_cert BYTEA, + ADD COLUMN ntf_service_cert_hash BYTEA, + ADD COLUMN ntf_service_priv_key BYTEA; + |] + <> resetNtfServices + +down_m20251219_service_cert_per_server :: Text +down_m20251219_service_cert_per_server = + [r| +ALTER TABLE smp_servers + DROP COLUMN ntf_service_cert, + DROP COLUMN ntf_service_cert_hash, + DROP COLUMN ntf_service_priv_key; + |] + <> resetNtfServices + +resetNtfServices :: Text +resetNtfServices = + [r| +ALTER TABLE subscriptions DISABLE TRIGGER tr_subscriptions_update; +UPDATE subscriptions SET ntf_service_assoc = FALSE; +ALTER TABLE subscriptions ENABLE TRIGGER tr_subscriptions_update; + +UPDATE smp_servers +SET ntf_service_id = NULL, + smp_notifier_count = 0, + smp_notifier_ids_hash = DEFAULT; + |] diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs index 60e81a68b..80ab45ca1 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs @@ -18,7 +18,6 @@ module Simplex.Messaging.Notifications.Server.Store.Postgres where -import Control.Concurrent.STM import qualified Control.Exception as E import Control.Logger.Simple import Control.Monad @@ -26,19 +25,13 @@ import Control.Monad.Except import Control.Monad.IO.Class import Control.Monad.Trans.Except import Data.Bitraversable (bimapM) -import qualified Data.ByteString.Base64.URL as B64 import Data.ByteString.Char8 (ByteString) -import qualified Data.ByteString.Char8 as B -import Data.Containers.ListUtils (nubOrd) import Data.Either (fromRight) import Data.Functor (($>)) import Data.Int (Int64) -import Data.List (findIndex, foldl') import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L -import qualified Data.Map.Strict as M import Data.Maybe (fromMaybe, isJust, mapMaybe) -import qualified Data.Set as S import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1, encodeUtf8) @@ -51,31 +44,30 @@ import Database.PostgreSQL.Simple.FromField (FromField (..)) import Database.PostgreSQL.Simple.SqlQQ (sql) import Database.PostgreSQL.Simple.ToField (ToField (..)) import Network.Socket (ServiceName) +import qualified Network.TLS as TLS import Simplex.Messaging.Agent.Store.AgentStore () import Simplex.Messaging.Agent.Store.Postgres (closeDBStore, createDBStore) import Simplex.Messaging.Agent.Store.Postgres.Common import Simplex.Messaging.Agent.Store.Postgres.DB (fromTextField_) import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..)) +import Simplex.Messaging.Client (ProtocolClientError (..), SMPClientError) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol -import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore (..), NtfSubData (..), NtfTknData (..), TokenNtfMessageRecord (..), ntfSubServer) import Simplex.Messaging.Notifications.Server.Store.Migrations import Simplex.Messaging.Notifications.Server.Store.Types -import Simplex.Messaging.Notifications.Server.StoreLog -import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), IdsHash (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId, ServiceSub (..), pattern SMPServer) -import Simplex.Messaging.Server.QueueStore.Postgres (handleDuplicate, withLog_) +import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), IdsHash (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, ProtocolServer (..), SMPServer, ServiceId, ServiceSub (..), pattern SMPServer) +import Simplex.Messaging.Server.QueueStore.Postgres (handleDuplicate) import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) -import Simplex.Messaging.Server.StoreLog (openWriteStoreLog) import Simplex.Messaging.SystemTime import Simplex.Messaging.Transport.Client (TransportHost) -import Simplex.Messaging.Util (anyM, firstRow, maybeFirstRow, toChunks, tshow) +import Simplex.Messaging.Util (firstRow, maybeFirstRow, tshow) import System.Exit (exitFailure) -import System.IO (IOMode (..), hFlush, stdout, withFile) import Text.Hex (decodeHex) #if !defined(dbPostgres) +import qualified Data.X509 as X import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder) import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Util (eitherToMaybe) @@ -83,7 +75,6 @@ import Simplex.Messaging.Util (eitherToMaybe) data NtfPostgresStore = NtfPostgresStore { dbStore :: DBStore, - dbStoreLog :: Maybe (StoreLog 'WriteMode), deletedTTL :: Int64 } @@ -99,25 +90,22 @@ data NtfEntityRec (e :: NtfEntity) where NtfSub :: NtfSubRec -> NtfEntityRec 'Subscription newNtfDbStore :: PostgresStoreCfg -> IO NtfPostgresStore -newNtfDbStore PostgresStoreCfg {dbOpts, dbStoreLogPath, confirmMigrations, deletedTTL} = do +newNtfDbStore PostgresStoreCfg {dbOpts, confirmMigrations, deletedTTL} = do dbStore <- either err pure =<< createDBStore dbOpts ntfServerMigrations (MigrationConfig confirmMigrations Nothing) - dbStoreLog <- mapM (openWriteStoreLog True) dbStoreLogPath - pure NtfPostgresStore {dbStore, dbStoreLog, deletedTTL} + pure NtfPostgresStore {dbStore, deletedTTL} where err e = do logError $ "STORE: newNtfStore, error opening PostgreSQL database, " <> tshow e exitFailure closeNtfDbStore :: NtfPostgresStore -> IO () -closeNtfDbStore NtfPostgresStore {dbStore, dbStoreLog} = do - closeDBStore dbStore - mapM_ closeStoreLog dbStoreLog +closeNtfDbStore NtfPostgresStore {dbStore} = closeDBStore dbStore addNtfToken :: NtfPostgresStore -> NtfTknRec -> IO (Either ErrorType ()) addNtfToken st tkn = withFastDB "addNtfToken" st $ \db -> - E.try (DB.execute db insertNtfTknQuery $ ntfTknToRow tkn) - >>= bimapM handleDuplicate (\_ -> withLog "addNtfToken" st (`logCreateToken` tkn)) + E.try (void $ DB.execute db insertNtfTknQuery $ ntfTknToRow tkn) + >>= bimapM handleDuplicate pure insertNtfTknQuery :: Query insertNtfTknQuery = @@ -128,7 +116,7 @@ insertNtfTknQuery = |] replaceNtfToken :: NtfPostgresStore -> NtfTknRec -> IO (Either ErrorType ()) -replaceNtfToken st NtfTknRec {ntfTknId, token = token@(DeviceToken pp ppToken), tknStatus, tknRegCode = code@(NtfRegCode regCode)} = +replaceNtfToken st NtfTknRec {ntfTknId, token = DeviceToken pp ppToken, tknStatus, tknRegCode = NtfRegCode regCode} = withFastDB "replaceNtfToken" st $ \db -> runExceptT $ do ExceptT $ assertUpdated <$> DB.execute @@ -139,7 +127,6 @@ replaceNtfToken st NtfTknRec {ntfTknId, token = token@(DeviceToken pp ppToken), WHERE token_id = ? |] (pp, Binary ppToken, tknStatus, Binary regCode, ntfTknId) - withLog "replaceNtfToken" st $ \sl -> logUpdateToken sl ntfTknId token code ntfTknToRow :: NtfTknRec -> NtfTknRow ntfTknToRow NtfTknRec {ntfTknId, token, tknStatus, tknVerifyKey, tknDhPrivKey, tknDhSecret, tknRegCode, tknCronInterval, tknUpdatedAt} = @@ -160,15 +147,14 @@ getNtfToken_ :: ToRow q => NtfPostgresStore -> Query -> q -> IO (Either ErrorTyp getNtfToken_ st cond params = withFastDB' "getNtfToken" st $ \db -> do tkn_ <- maybeFirstRow rowToNtfTkn $ DB.query db (ntfTknQuery <> cond) params - mapM_ (updateTokenDate st db) tkn_ + mapM_ (updateTokenDate db) tkn_ pure tkn_ -updateTokenDate :: NtfPostgresStore -> DB.Connection -> NtfTknRec -> IO () -updateTokenDate st db NtfTknRec {ntfTknId, tknUpdatedAt} = do +updateTokenDate :: DB.Connection -> NtfTknRec -> IO () +updateTokenDate db NtfTknRec {ntfTknId, tknUpdatedAt} = do ts <- getSystemDate when (maybe True (ts /=) tknUpdatedAt) $ do void $ DB.execute db "UPDATE tokens SET updated_at = ? WHERE token_id = ?" (ts, ntfTknId) - withLog "updateTokenDate" st $ \sl -> logUpdateTokenTime sl ntfTknId ts type NtfTknRow = (NtfTokenId, PushProvider, Binary ByteString, NtfTknStatus, NtfPublicAuthKey, C.PrivateKeyX25519, C.DhSecretX25519, Binary ByteString, Word16, Maybe SystemDate) @@ -206,7 +192,6 @@ deleteNtfToken st tknId = |] (Only tknId) liftIO $ void $ DB.execute db "DELETE FROM tokens WHERE token_id = ?" (Only tknId) - withLog "deleteNtfToken" st (`logDeleteToken` tknId) pure subs where toServerSubs :: SMPServerRow :. Only Text -> (SMPServer, [NotifierId]) @@ -235,7 +220,6 @@ updateTknCronInterval st tknId cronInt = withFastDB "updateTknCronInterval" st $ \db -> runExceptT $ do ExceptT $ assertUpdated <$> DB.execute db "UPDATE tokens SET cron_interval = ? WHERE token_id = ?" (cronInt, tknId) - withLog "updateTknCronInterval" st $ \sl -> logTokenCron sl tknId 0 -- Reads servers that have subscriptions that need subscribing. -- It is executed on server start, and it is supposed to crash on database error @@ -259,6 +243,73 @@ getUsedSMPServers st = let service_ = (\serviceId -> ServiceSub serviceId n idsHash) <$> serviceId_ in (SMPServer host port kh, srvId, service_) +getNtfServiceCredentials :: DB.Connection -> SMPServer -> IO (Maybe (Int64, Maybe (C.KeyHash, TLS.Credential))) +getNtfServiceCredentials db srv = + maybeFirstRow toService $ + DB.query + db + [sql| + SELECT smp_server_id, ntf_service_cert_hash, ntf_service_cert, ntf_service_priv_key + FROM smp_servers + WHERE smp_host = ? AND smp_port = ? AND smp_keyhash = ? + FOR UPDATE + |] + (host srv, port srv, keyHash srv) + where + toService (Only srvId :. creds) = (srvId, toCredentials creds) + toCredentials = \case + (Just kh, Just cert, Just pk) -> Just (kh, (cert, pk)) + _ -> Nothing + +setNtfServiceCredentials :: DB.Connection -> Int64 -> (C.KeyHash, TLS.Credential) -> IO () +setNtfServiceCredentials db srvId (kh, (cert, pk)) = + void $ DB.execute + db + [sql| + UPDATE smp_servers + SET ntf_service_cert_hash = ?, ntf_service_cert = ?, ntf_service_priv_key = ? + WHERE smp_server_id = ? + |] + (kh, cert, pk, srvId) + +updateNtfServiceId :: DB.Connection -> SMPServer -> Maybe ServiceId -> IO () +updateNtfServiceId db srv newServiceId_ = do + maybeFirstRow id (getSMPServiceForUpdate_ db srv) >>= mapM_ updateService + where + updateService (srvId, currServiceId_) = unless (currServiceId_ == newServiceId_) $ do + when (isJust currServiceId_) $ do + void $ removeServiceAssociation_ db srvId + logError $ "STORE: service ID for " <> enc (host srv) <> toServiceId <> ", removed sub associations" + void $ case newServiceId_ of + Just newServiceId -> + DB.execute + db + [sql| + UPDATE smp_servers + SET ntf_service_id = ?, + smp_notifier_count = 0, + smp_notifier_ids_hash = DEFAULT + WHERE smp_server_id = ? + |] + (newServiceId, srvId) + Nothing -> + DB.execute + db + [sql| + UPDATE smp_servers + SET ntf_service_id = NULL, + ntf_service_cert = NULL, + ntf_service_cert_hash = NULL, + ntf_service_priv_key = NULL, + smp_notifier_count = 0, + smp_notifier_ids_hash = DEFAULT + WHERE smp_server_id = ? + |] + (Only srvId) + toServiceId = maybe " removed" ((" changed to " <>) . enc) newServiceId_ + enc :: StrEncoding a => a -> Text + enc = decodeLatin1 . strEncode + getServerNtfSubscriptions :: NtfPostgresStore -> Int64 -> Maybe NtfSubscriptionId -> Int -> IO (Either ErrorType [ServerNtfSub]) getServerNtfSubscriptions st srvId afterSubId_ count = withDB' "getServerNtfSubscriptions" st $ \db -> do @@ -297,7 +348,7 @@ findNtfSubscription st tknId q = withFastDB "findNtfSubscription" st $ \db -> runExceptT $ do tkn@NtfTknRec {ntfTknId, tknStatus} <- ExceptT $ getNtfToken st tknId unless (allowNtfSubCommands tknStatus) $ throwE AUTH - liftIO $ updateTokenDate st db tkn + liftIO $ updateTokenDate db tkn sub_ <- liftIO $ maybeFirstRow (rowToNtfSub q) $ DB.query @@ -330,7 +381,7 @@ getNtfSubscription st subId = WHERE s.subscription_id = ? |] (Only subId) - liftIO $ updateTokenDate st db tkn + liftIO $ updateTokenDate db tkn unless (allowNtfSubCommands tknStatus) $ throwE AUTH pure r @@ -352,36 +403,30 @@ mkNtfSubRec ntfSubId (NewNtfSub tokenId smpQueue notifierKey) = updateTknStatus :: NtfPostgresStore -> NtfTknRec -> NtfTknStatus -> IO (Either ErrorType ()) updateTknStatus st tkn status = - withFastDB' "updateTknStatus" st $ \db -> updateTknStatus_ st db tkn status + withFastDB' "updateTknStatus" st $ \db -> updateTknStatus_ db tkn status -updateTknStatus_ :: NtfPostgresStore -> DB.Connection -> NtfTknRec -> NtfTknStatus -> IO () -updateTknStatus_ st db NtfTknRec {ntfTknId} status = do - updated <- DB.execute db "UPDATE tokens SET status = ? WHERE token_id = ? AND status != ?" (status, ntfTknId, status) - when (updated > 0) $ withLog "updateTknStatus" st $ \sl -> logTokenStatus sl ntfTknId status +updateTknStatus_ :: DB.Connection -> NtfTknRec -> NtfTknStatus -> IO () +updateTknStatus_ db NtfTknRec {ntfTknId} status = + void $ DB.execute db "UPDATE tokens SET status = ? WHERE token_id = ? AND status != ?" (status, ntfTknId, status) -- unless it was already active setTknStatusConfirmed :: NtfPostgresStore -> NtfTknRec -> IO (Either ErrorType ()) setTknStatusConfirmed st NtfTknRec {ntfTknId} = - withFastDB' "updateTknStatus" st $ \db -> do - updated <- DB.execute db "UPDATE tokens SET status = ? WHERE token_id = ? AND status != ? AND status != ?" (NTConfirmed, ntfTknId, NTConfirmed, NTActive) - when (updated > 0) $ withLog "updateTknStatus" st $ \sl -> logTokenStatus sl ntfTknId NTConfirmed + withFastDB' "updateTknStatus" st $ \db -> + void $ DB.execute db "UPDATE tokens SET status = ? WHERE token_id = ? AND status != ? AND status != ?" (NTConfirmed, ntfTknId, NTConfirmed, NTActive) setTokenActive :: NtfPostgresStore -> NtfTknRec -> IO (Either ErrorType ()) setTokenActive st tkn@NtfTknRec {ntfTknId, token = DeviceToken pp ppToken} = withFastDB' "setTokenActive" st $ \db -> do - updateTknStatus_ st db tkn NTActive + updateTknStatus_ db tkn NTActive -- this removes other instances of the same token, e.g. because of repeated token registration attempts - tknIds <- - liftIO $ map fromOnly <$> - DB.query - db - [sql| - DELETE FROM tokens - WHERE push_provider = ? AND push_provider_token = ? AND token_id != ? - RETURNING token_id - |] - (pp, Binary ppToken, ntfTknId) - withLog "deleteNtfToken" st $ \sl -> mapM_ (logDeleteToken sl) tknIds + void $ DB.execute + db + [sql| + DELETE FROM tokens + WHERE push_provider = ? AND push_provider_token = ? AND token_id != ? + |] + (pp, Binary ppToken, ntfTknId) withPeriodicNtfTokens :: NtfPostgresStore -> Int64 -> (NtfTknRec -> IO ()) -> IO Int withPeriodicNtfTokens st now notify = @@ -399,7 +444,6 @@ addNtfSubscription st sub = withFastDB "addNtfSubscription" st $ \db -> runExceptT $ do srvId :: Int64 <- ExceptT $ upsertServer db $ ntfSubServer' sub n <- liftIO $ DB.execute db insertNtfSubQuery $ ntfSubToRow srvId sub - withLog "addNtfSubscription" st (`logCreateSubscription` sub) pure (srvId, n > 0) where -- It is possible to combine these two statements into one with CTEs, @@ -442,76 +486,66 @@ ntfSubToRow srvId NtfSubRec {ntfSubId, tokenId, smpQueue = SMPQueueNtf _ nId, no deleteNtfSubscription :: NtfPostgresStore -> NtfSubscriptionId -> IO (Either ErrorType ()) deleteNtfSubscription st subId = - withFastDB "deleteNtfSubscription" st $ \db -> runExceptT $ do - ExceptT $ assertUpdated <$> + withFastDB "deleteNtfSubscription" st $ \db -> + assertUpdated <$> DB.execute db "DELETE FROM subscriptions WHERE subscription_id = ?" (Only subId) - withLog "deleteNtfSubscription" st (`logDeleteSubscription` subId) updateSubStatus :: NtfPostgresStore -> Int64 -> NotifierId -> NtfSubStatus -> IO (Either ErrorType ()) updateSubStatus st srvId nId status = withFastDB' "updateSubStatus" st $ \db -> do - sub_ :: Maybe (NtfSubscriptionId, NtfAssociatedService) <- - maybeFirstRow id $ - DB.query - db - [sql| - UPDATE subscriptions SET status = ? - WHERE smp_server_id = ? AND smp_notifier_id = ? AND status != ? - RETURNING subscription_id, ntf_service_assoc - |] - (status, srvId, nId, status) - forM_ sub_ $ \(subId, serviceAssoc) -> - withLog "updateSubStatus" st $ \sl -> logSubscriptionStatus sl (subId, status, serviceAssoc) + void $ + DB.execute + db + [sql| + UPDATE subscriptions SET status = ? + WHERE smp_server_id = ? AND smp_notifier_id = ? AND status != ? + |] + (status, srvId, nId, status) updateSrvSubStatus :: NtfPostgresStore -> SMPQueueNtf -> NtfSubStatus -> IO (Either ErrorType ()) updateSrvSubStatus st q status = - withFastDB' "updateSrvSubStatus" st $ \db -> do - sub_ :: Maybe (NtfSubscriptionId, NtfAssociatedService) <- - maybeFirstRow id $ - DB.query - db - [sql| - UPDATE subscriptions s - SET status = ? - FROM smp_servers p - WHERE p.smp_server_id = s.smp_server_id - AND p.smp_host = ? AND p.smp_port = ? AND p.smp_keyhash = ? AND s.smp_notifier_id = ? - AND s.status != ? - RETURNING s.subscription_id, s.ntf_service_assoc - |] - (Only status :. smpQueueToRow q :. Only status) - forM_ sub_ $ \(subId, serviceAssoc) -> - withLog "updateSrvSubStatus" st $ \sl -> logSubscriptionStatus sl (subId, status, serviceAssoc) + withFastDB' "updateSrvSubStatus" st $ \db -> + void $ + DB.execute + db + [sql| + UPDATE subscriptions s + SET status = ? + FROM smp_servers p + WHERE p.smp_server_id = s.smp_server_id + AND p.smp_host = ? AND p.smp_port = ? AND p.smp_keyhash = ? AND s.smp_notifier_id = ? + AND s.status != ? + |] + (Only status :. smpQueueToRow q :. Only status) batchUpdateSrvSubStatus :: NtfPostgresStore -> SMPServer -> Maybe ServiceId -> NonEmpty NotifierId -> NtfSubStatus -> IO Int batchUpdateSrvSubStatus st srv newServiceId nIds status = fmap (fromRight (-1)) $ withDB "batchUpdateSrvSubStatus" st $ \db -> runExceptT $ do - (srvId :: Int64, currServiceId) <- ExceptT $ getSMPServerService db + (srvId, currServiceId) <- ExceptT $ firstRow id AUTH $ getSMPServiceForUpdate_ db srv + -- TODO [certs rcv] should this remove associations/credentials when newServiceId is Nothing or different unless (currServiceId == newServiceId) $ liftIO $ void $ DB.execute db "UPDATE smp_servers SET ntf_service_id = ? WHERE smp_server_id = ?" (newServiceId, srvId) let params = L.toList $ L.map (srvId,isJust newServiceId,status,) nIds liftIO $ fromIntegral <$> DB.executeMany db updateSubStatusQuery params - where - getSMPServerService db = - firstRow id AUTH $ - DB.query - db - [sql| - SELECT smp_server_id, ntf_service_id - FROM smp_servers - WHERE smp_host = ? AND smp_port = ? AND smp_keyhash = ? - FOR UPDATE - |] - (srvToRow srv) + +getSMPServiceForUpdate_ :: DB.Connection -> SMPServer -> IO [(Int64, Maybe ServiceId)] +getSMPServiceForUpdate_ db srv = + DB.query + db + [sql| + SELECT smp_server_id, ntf_service_id + FROM smp_servers + WHERE smp_host = ? AND smp_port = ? AND smp_keyhash = ? + FOR UPDATE + |] + (srvToRow srv) batchUpdateSrvSubErrors :: NtfPostgresStore -> SMPServer -> NonEmpty (NotifierId, NtfSubStatus) -> IO Int batchUpdateSrvSubErrors st srv subs = fmap (fromRight (-1)) $ withDB "batchUpdateSrvSubErrors" st $ \db -> runExceptT $ do srvId :: Int64 <- ExceptT $ getSMPServerId db let params = map (\(nId, status) -> (srvId, False, status, nId)) $ L.toList subs - subs' <- liftIO $ DB.returning db (updateSubStatusQuery <> " RETURNING s.subscription_id, s.status, s.ntf_service_assoc") params - withLog "batchUpdateStatus_" st $ forM_ subs' . logSubscriptionStatus - pure $ length subs' + liftIO $ fromIntegral <$> DB.executeMany db updateSubStatusQuery params where getSMPServerId db = firstRow fromOnly AUTH $ @@ -535,36 +569,51 @@ updateSubStatusQuery = AND (s.status != upd.status OR s.ntf_service_assoc != upd.ntf_service_assoc) |] -removeServiceAssociation :: NtfPostgresStore -> SMPServer -> IO (Either ErrorType (Int64, Int)) -removeServiceAssociation st srv = do - withDB "removeServiceAssociation" st $ \db -> runExceptT $ do - srvId <- ExceptT $ removeServerService db - subs <- - liftIO $ - DB.query - db - [sql| - UPDATE subscriptions s - SET status = ?, ntf_service_assoc = FALSE - WHERE smp_server_id = ? - AND (s.status != ? OR s.ntf_service_assoc != FALSE) - RETURNING s.subscription_id, s.status, s.ntf_service_assoc - |] - (NSInactive, srvId, NSInactive) - withLog "removeServiceAssociation" st $ forM_ subs . logSubscriptionStatus - pure (srvId, length subs) +removeServiceAssociation_ :: DB.Connection -> Int64 -> IO Int64 +removeServiceAssociation_ db srvId = + DB.execute + db + [sql| + UPDATE subscriptions s + SET status = ?, ntf_service_assoc = FALSE + WHERE smp_server_id = ? + AND (s.status != ? OR s.ntf_service_assoc != FALSE) + |] + (NSInactive, srvId, NSInactive) + +removeServiceAndAssociations :: NtfPostgresStore -> SMPServer -> IO (Either ErrorType (Int64, Int)) +removeServiceAndAssociations st srv = do + withDB "removeServiceAndAssociations" st $ \db -> runExceptT $ do + srvId <- ExceptT $ getServerId db + subsCount <- liftIO $ removeServiceAssociation_ db srvId + liftIO $ removeServerService db srvId + pure (srvId, fromIntegral subsCount) where - removeServerService db = + getServerId db = firstRow fromOnly AUTH $ DB.query db [sql| - UPDATE smp_servers - SET ntf_service_id = NULL + SELECT smp_server_id + FROM smp_servers WHERE smp_host = ? AND smp_port = ? AND smp_keyhash = ? - RETURNING smp_server_id + FOR UPDATE |] (srvToRow srv) + removeServerService db srvId = + DB.execute + db + [sql| + UPDATE smp_servers + SET ntf_service_id = NULL, + ntf_service_cert = NULL, + ntf_service_cert_hash = NULL, + ntf_service_priv_key = NULL, + smp_notifier_count = 0, + smp_notifier_ids_hash = DEFAULT + WHERE smp_server_id = ? + |] + (Only srvId) addTokenLastNtf :: NtfPostgresStore -> PNMessageData -> IO (Either ErrorType (NtfTknRec, NonEmpty PNMessageData)) addTokenLastNtf st newNtf = @@ -646,216 +695,6 @@ getEntityCounts st = count (Only n : _) = n count [] = 0 -importNtfSTMStore :: NtfPostgresStore -> NtfSTMStore -> S.Set NtfTokenId -> IO (Int64, Int64, Int64, Int64) -importNtfSTMStore NtfPostgresStore {dbStore = s} stmStore skipTokens = do - (tIds, tCnt) <- importTokens - subLookup <- readTVarIO $ subscriptionLookup stmStore - sCnt <- importSubscriptions tIds subLookup - nCnt <- importLastNtfs tIds subLookup - serviceCnt <- importNtfServiceIds - pure (tCnt, sCnt, nCnt, serviceCnt) - where - importTokens = do - allTokens <- M.elems <$> readTVarIO (tokens stmStore) - tokens <- filterTokens allTokens - let skipped = length allTokens - length tokens - when (skipped /= 0) $ putStrLn $ "Total skipped tokens " <> show skipped - -- uncomment this line instead of the next two to import tokens one by one. - -- tCnt <- withConnection s $ \db -> foldM (importTkn db) 0 tokens - -- token interval is reset to 0 to only send notifications to devices with periodic mode, - -- and before clients are upgraded - to all active devices. - tRows <- mapM (fmap (ntfTknToRow . (\t -> t {tknCronInterval = 0} :: NtfTknRec)) . mkTknRec) tokens - tCnt <- withConnection s $ \db -> DB.executeMany db insertNtfTknQuery tRows - let tokenIds = S.fromList $ map (\NtfTknData {ntfTknId} -> ntfTknId) tokens - (tokenIds,) <$> checkCount "token" (length tokens) tCnt - where - filterTokens tokens = do - let deviceTokens = foldl' (\m t -> M.alter (Just . (t :) . fromMaybe []) (tokenKey t) m) M.empty tokens - tokenSubs <- readTVarIO (tokenSubscriptions stmStore) - filterM (keepTokenRegistration deviceTokens tokenSubs) tokens - tokenKey NtfTknData {token, tknVerifyKey} = strEncode token <> ":" <> C.toPubKey C.pubKeyBytes tknVerifyKey - keepTokenRegistration deviceTokens tokenSubs tkn@NtfTknData {ntfTknId, tknStatus} = - case M.lookup (tokenKey tkn) deviceTokens of - Just ts - | length ts < 2 -> pure True - | ntfTknId `S.member` skipTokens -> False <$ putStrLn ("Skipped token " <> enc ntfTknId <> " from --skip-tokens") - | otherwise -> - readTVarIO tknStatus >>= \case - NTConfirmed -> do - hasSubs <- maybe (pure False) (\v -> not . S.null <$> readTVarIO v) $ M.lookup ntfTknId tokenSubs - if hasSubs - then pure True - else do - anyBetterToken <- anyM $ map (\NtfTknData {tknStatus = tknStatus'} -> activeOrInvalid <$> readTVarIO tknStatus') ts - if anyBetterToken - then False <$ putStrLn ("Skipped duplicate inactive token " <> enc ntfTknId) - else case findIndex (\NtfTknData {ntfTknId = tId} -> tId == ntfTknId) ts of - Just 0 -> pure True -- keeping the first token - Just _ -> False <$ putStrLn ("Skipped duplicate inactive token " <> enc ntfTknId <> " (no active token)") - Nothing -> True <$ putStrLn "Error: no device token in the list" - _ -> pure True - Nothing -> True <$ putStrLn "Error: no device token in lookup map" - activeOrInvalid = \case - NTActive -> True - NTInvalid _ -> True - _ -> False - -- importTkn db !n tkn@NtfTknData {ntfTknId} = do - -- tknRow <- ntfTknToRow <$> mkTknRec tkn - -- (DB.execute db insertNtfTknQuery tknRow >>= pure . (n + )) `E.catch` \(e :: E.SomeException) -> - -- putStrLn ("Error inserting token " <> enc ntfTknId <> " " <> show e) $> n - importSubscriptions :: S.Set NtfTokenId -> M.Map SMPQueueNtf NtfSubscriptionId -> IO Int64 - importSubscriptions tIds subLookup = do - subs <- filterSubs . M.elems =<< readTVarIO (subscriptions stmStore) - srvIds <- importServers subs - putStrLn $ "Importing " <> show (length subs) <> " subscriptions..." - -- uncomment this line instead of the next to import subs one by one. - -- (sCnt, errTkns) <- withConnection s $ \db -> foldM (importSub db srvIds) (0, M.empty) subs - sCnt <- foldM (importSubs srvIds) 0 $ toChunks 500000 subs - checkCount "subscription" (length subs) sCnt - where - filterSubs allSubs = do - let subs = filter (\NtfSubData {tokenId} -> S.member tokenId tIds) allSubs - skipped = length allSubs - length subs - when (skipped /= 0) $ putStrLn $ "Skipped " <> show skipped <> " subscriptions of missing tokens" - let (removedSubTokens, removeSubs, dupQueues) = foldl' addSubToken (S.empty, S.empty, S.empty) subs - unless (null removeSubs) $ putStrLn $ "Skipped " <> show (S.size removeSubs) <> " duplicate subscriptions of " <> show (S.size removedSubTokens) <> " tokens for " <> show (S.size dupQueues) <> " queues" - pure $ filter (\NtfSubData {ntfSubId} -> S.notMember ntfSubId removeSubs) subs - where - addSubToken acc@(!stIds, !sIds, !qs) NtfSubData {ntfSubId, smpQueue, tokenId} = - case M.lookup smpQueue subLookup of - Just sId | sId /= ntfSubId -> - (S.insert tokenId stIds, S.insert ntfSubId sIds, S.insert smpQueue qs) - _ -> acc - importSubs srvIds !n subs = do - rows <- mapM (ntfSubRow srvIds) subs - cnt <- withConnection s $ \db -> DB.executeMany db insertNtfSubQuery $ L.toList rows - let n' = n + cnt - putStr $ "Imported " <> show n' <> " subscriptions" <> "\r" - hFlush stdout - pure n' - -- importSub db srvIds (!n, !errTkns) sub@NtfSubData {ntfSubId = sId, tokenId} = do - -- subRow <- ntfSubRow srvIds sub - -- E.try (DB.execute db insertNtfSubQuery subRow) >>= \case - -- Right i -> do - -- let n' = n + i - -- when (n' `mod` 100000 == 0) $ do - -- putStr $ "Imported " <> show n' <> " subscriptions" <> "\r" - -- hFlush stdout - -- pure (n', errTkns) - -- Left (e :: E.SomeException) -> do - -- when (n `mod` 100000 == 0) $ putStrLn "" - -- putStrLn $ "Error inserting subscription " <> enc sId <> " for token " <> enc tokenId <> " " <> show e - -- pure (n, M.alter (Just . maybe [sId] (sId :)) tokenId errTkns) - ntfSubRow srvIds sub = case M.lookup srv srvIds of - Just sId -> ntfSubToRow sId <$> mkSubRec sub - Nothing -> E.throwIO $ userError $ "no matching server ID for server " <> show srv - where - srv = ntfSubServer sub - importServers subs = do - sIds <- withConnection s $ \db -> map fromOnly <$> DB.returning db srvQuery (map srvToRow srvs) - void $ checkCount "server" (length srvs) (length sIds) - pure $ M.fromList $ zip srvs sIds - where - srvQuery = "INSERT INTO smp_servers (smp_host, smp_port, smp_keyhash) VALUES (?, ?, ?) RETURNING smp_server_id" - srvs = nubOrd $ map ntfSubServer subs - importLastNtfs :: S.Set NtfTokenId -> M.Map SMPQueueNtf NtfSubscriptionId -> IO Int64 - importLastNtfs tIds subLookup = do - ntfs <- readTVarIO (tokenLastNtfs stmStore) - ntfRows <- filterLastNtfRows ntfs - nCnt <- withConnection s $ \db -> DB.executeMany db lastNtfQuery ntfRows - checkCount "last notification" (length ntfRows) nCnt - where - lastNtfQuery = "INSERT INTO last_notifications(token_id, subscription_id, sent_at, nmsg_nonce, nmsg_data) VALUES (?,?,?,?,?)" - filterLastNtfRows ntfs = do - (skippedTkns, ntfCnt, (skippedQueues, ntfRows)) <- foldM lastNtfRows (S.empty, 0, (S.empty, [])) $ M.assocs ntfs - let skipped = ntfCnt - length ntfRows - when (skipped /= 0) $ putStrLn $ "Skipped last notifications " <> show skipped <> " for " <> show (S.size skippedTkns) <> " missing tokens and " <> show (S.size skippedQueues) <> " missing subscriptions with token present" - pure ntfRows - lastNtfRows (!stIds, !cnt, !acc) (tId, ntfVar) = do - ntfs <- L.toList <$> readTVarIO ntfVar - let cnt' = cnt + length ntfs - pure $ - if S.member tId tIds - then (stIds, cnt', foldl' ntfRow acc ntfs) - else (S.insert tId stIds, cnt', acc) - where - ntfRow (!qs, !rows) PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta} = case M.lookup smpQueue subLookup of - Just ntfSubId -> - let row = (tId, ntfSubId, systemToUTCTime ntfTs, nmsgNonce, Binary encNMsgMeta) - in (qs, row : rows) - Nothing -> (S.insert smpQueue qs, rows) - importNtfServiceIds = do - ss <- M.assocs <$> readTVarIO (ntfServices stmStore) - withConnection s $ \db -> DB.executeMany db serviceQuery $ map serviceToRow ss - where - serviceQuery = - [sql| - INSERT INTO smp_servers (smp_host, smp_port, smp_keyhash, ntf_service_id) - VALUES (?, ?, ?, ?) - ON CONFLICT (smp_host, smp_port, smp_keyhash) - DO UPDATE SET ntf_service_id = EXCLUDED.ntf_service_id - |] - serviceToRow (srv, serviceId) = srvToRow srv :. Only serviceId - checkCount name expected inserted - | fromIntegral expected == inserted = do - putStrLn $ "Imported " <> show inserted <> " " <> name <> "s." - pure inserted - | otherwise = do - putStrLn $ "Incorrect " <> name <> " count: expected " <> show expected <> ", imported " <> show inserted - putStrLn "Import aborted, fix data and repeat" - exitFailure - enc = B.unpack . B64.encode . unEntityId - -exportNtfDbStore :: NtfPostgresStore -> FilePath -> IO (Int, Int, Int) -exportNtfDbStore NtfPostgresStore {dbStoreLog = Nothing} _ = - putStrLn "Internal error: export requires store log" >> exitFailure -exportNtfDbStore NtfPostgresStore {dbStore = s, dbStoreLog = Just sl} lastNtfsFile = - (,,) <$> exportTokens <*> exportSubscriptions <*> exportLastNtfs - where - exportTokens = do - tCnt <- withConnection s $ \db -> DB.fold_ db ntfTknQuery 0 $ \ !i tkn -> - logCreateToken sl (rowToNtfTkn tkn) $> (i + 1) - putStrLn $ "Exported " <> show tCnt <> " tokens" - pure tCnt - exportSubscriptions = do - sCnt <- withConnection s $ \db -> DB.fold_ db ntfSubQuery 0 $ \ !i sub -> do - let i' = i + 1 - logCreateSubscription sl (toNtfSub sub) - when (i' `mod` 500000 == 0) $ do - putStr $ "Exported " <> show i' <> " subscriptions" <> "\r" - hFlush stdout - pure i' - putStrLn $ "Exported " <> show sCnt <> " subscriptions" - pure sCnt - where - ntfSubQuery = - [sql| - SELECT s.token_id, s.subscription_id, s.smp_notifier_key, s.status, s.ntf_service_assoc, - p.smp_host, p.smp_port, p.smp_keyhash, s.smp_notifier_id - FROM subscriptions s - JOIN smp_servers p ON p.smp_server_id = s.smp_server_id - |] - toNtfSub :: Only NtfTokenId :. NtfSubRow :. SMPQueueNtfRow -> NtfSubRec - toNtfSub (Only tokenId :. (ntfSubId, notifierKey, subStatus, ntfServiceAssoc) :. qRow) = - let smpQueue = rowToSMPQueue qRow - in NtfSubRec {ntfSubId, tokenId, smpQueue, notifierKey, subStatus, ntfServiceAssoc} - exportLastNtfs = - withFile lastNtfsFile WriteMode $ \h -> - withConnection s $ \db -> DB.fold_ db lastNtfsQuery 0 $ \ !i (Only tknId :. ntfRow) -> - B.hPutStr h (encodeLastNtf tknId $ toLastNtf ntfRow) $> (i + 1) - where - -- Note that the order here is ascending, to be compatible with how it is imported - lastNtfsQuery = - [sql| - SELECT s.token_id, p.smp_host, p.smp_port, p.smp_keyhash, s.smp_notifier_id, - n.sent_at, n.nmsg_nonce, n.nmsg_data - FROM last_notifications n - JOIN subscriptions s ON s.subscription_id = n.subscription_id - JOIN smp_servers p ON p.smp_server_id = s.smp_server_id - ORDER BY token_ntf_id ASC - |] - encodeLastNtf tknId ntf = strEncode (TNMRv1 tknId ntf) `B.snoc` '\n' - withFastDB' :: Text -> NtfPostgresStore -> (DB.Connection -> IO a) -> IO (Either ErrorType a) withFastDB' op st action = withFastDB op st $ fmap Right . action {-# INLINE withFastDB' #-} @@ -881,9 +720,12 @@ withDB_ op st priority action = where err = op <> ", withDB, " <> tshow e -withLog :: MonadIO m => Text -> NtfPostgresStore -> (StoreLog 'WriteMode -> IO ()) -> m () -withLog op NtfPostgresStore {dbStoreLog} = withLog_ op dbStoreLog -{-# INLINE withLog #-} +withClientDB :: Text -> NtfPostgresStore -> (DB.Connection -> IO a) -> IO (Either SMPClientError a) +withClientDB op st action = + E.uninterruptibleMask_ $ E.try (withTransaction (dbStore st) action) >>= bimapM logErr pure + where + logErr :: E.SomeException -> IO SMPClientError + logErr e = logError ("STORE: " <> op <> ", withDB, " <> tshow e) $> PCEIOError (E.displayException e) assertUpdated :: Int64 -> Either ErrorType () assertUpdated 0 = Left AUTH @@ -921,4 +763,9 @@ instance ToField C.KeyHash where toField = toField . Binary . strEncode instance FromField C.CbNonce where fromField = blobFieldDecoder $ parseAll smpP instance ToField C.CbNonce where toField = toField . Binary . smpEncode + +instance ToField X.PrivKey where toField = toField . Binary . C.encodeASNObj + +instance FromField X.PrivKey where + fromField = blobFieldDecoder $ C.decodeASNKey >=> \case (pk, []) -> Right pk; r -> C.asnKeyError r #endif diff --git a/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql b/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql index b73995684..801208aaa 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql +++ b/src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql @@ -172,7 +172,10 @@ CREATE TABLE ntf_server.smp_servers ( smp_keyhash bytea NOT NULL, ntf_service_id bytea, smp_notifier_count bigint DEFAULT 0 NOT NULL, - smp_notifier_ids_hash bytea DEFAULT '\x00000000000000000000000000000000'::bytea NOT NULL + smp_notifier_ids_hash bytea DEFAULT '\x00000000000000000000000000000000'::bytea NOT NULL, + ntf_service_cert bytea, + ntf_service_cert_hash bytea, + ntf_service_priv_key bytea ); diff --git a/src/Simplex/Messaging/Notifications/Server/StoreLog.hs b/src/Simplex/Messaging/Notifications/Server/StoreLog.hs deleted file mode 100644 index 7c71ddb08..000000000 --- a/src/Simplex/Messaging/Notifications/Server/StoreLog.hs +++ /dev/null @@ -1,177 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DuplicateRecordFields #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE StrictData #-} -{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} - -module Simplex.Messaging.Notifications.Server.StoreLog - ( StoreLog, - NtfStoreLogRecord (..), - readWriteNtfSTMStore, - logCreateToken, - logTokenStatus, - logUpdateToken, - logTokenCron, - logDeleteToken, - logUpdateTokenTime, - logCreateSubscription, - logSubscriptionStatus, - logDeleteSubscription, - closeStoreLog, - ) -where - -import Control.Applicative (optional, (<|>)) -import Control.Concurrent.STM -import Control.Monad -import qualified Data.Attoparsec.ByteString.Char8 as A -import qualified Data.ByteString.Base64.URL as B64 -import qualified Data.ByteString.Char8 as B -import Data.Functor (($>)) -import qualified Data.Map.Strict as M -import Data.Maybe (fromMaybe) -import Data.Word (Word16) -import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Notifications.Protocol -import Simplex.Messaging.Notifications.Server.Store -import Simplex.Messaging.Notifications.Server.Store.Types -import Simplex.Messaging.Protocol (EntityId (..), SMPServer, ServiceId) -import Simplex.Messaging.Server.StoreLog -import Simplex.Messaging.SystemTime -import System.IO - -data NtfStoreLogRecord - = CreateToken NtfTknRec - | TokenStatus NtfTokenId NtfTknStatus - | UpdateToken NtfTokenId DeviceToken NtfRegCode - | TokenCron NtfTokenId Word16 - | DeleteToken NtfTokenId - | UpdateTokenTime NtfTokenId SystemDate - | CreateSubscription NtfSubRec - | SubscriptionStatus NtfSubscriptionId NtfSubStatus NtfAssociatedService - | DeleteSubscription NtfSubscriptionId - | SetNtfService SMPServer (Maybe ServiceId) - deriving (Show) - -instance StrEncoding NtfStoreLogRecord where - strEncode = \case - CreateToken tknRec -> strEncode (Str "TCREATE", tknRec) - TokenStatus tknId tknStatus -> strEncode (Str "TSTATUS", tknId, tknStatus) - UpdateToken tknId token regCode -> strEncode (Str "TUPDATE", tknId, token, regCode) - TokenCron tknId cronInt -> strEncode (Str "TCRON", tknId, cronInt) - DeleteToken tknId -> strEncode (Str "TDELETE", tknId) - UpdateTokenTime tknId ts -> strEncode (Str "TTIME", tknId, ts) - CreateSubscription subRec -> strEncode (Str "SCREATE", subRec) - SubscriptionStatus subId subStatus serviceAssoc -> strEncode (Str "SSTATUS", subId, subStatus) <> serviceStr - where - serviceStr = if serviceAssoc then " service=" <> strEncode True else "" - DeleteSubscription subId -> strEncode (Str "SDELETE", subId) - SetNtfService srv serviceId -> strEncode (Str "SERVICE", srv) <> " service=" <> maybe "off" strEncode serviceId - strP = - A.choice - [ "TCREATE " *> (CreateToken <$> strP), - "TSTATUS " *> (TokenStatus <$> strP_ <*> strP), - "TUPDATE " *> (UpdateToken <$> strP_ <*> strP_ <*> strP), - "TCRON " *> (TokenCron <$> strP_ <*> strP), - "TDELETE " *> (DeleteToken <$> strP), - "TTIME " *> (UpdateTokenTime <$> strP_ <*> strP), - "SCREATE " *> (CreateSubscription <$> strP), - "SSTATUS " *> (SubscriptionStatus <$> strP_ <*> strP <*> (fromMaybe False <$> optional (" service=" *> strP))), - "SDELETE " *> (DeleteSubscription <$> strP), - "SERVICE " *> (SetNtfService <$> strP <* " service=" <*> ("off" $> Nothing <|> strP)) - ] - -logNtfStoreRecord :: StoreLog 'WriteMode -> NtfStoreLogRecord -> IO () -logNtfStoreRecord = writeStoreLogRecord -{-# INLINE logNtfStoreRecord #-} - -logCreateToken :: StoreLog 'WriteMode -> NtfTknRec -> IO () -logCreateToken s = logNtfStoreRecord s . CreateToken - -logTokenStatus :: StoreLog 'WriteMode -> NtfTokenId -> NtfTknStatus -> IO () -logTokenStatus s tknId tknStatus = logNtfStoreRecord s $ TokenStatus tknId tknStatus - -logUpdateToken :: StoreLog 'WriteMode -> NtfTokenId -> DeviceToken -> NtfRegCode -> IO () -logUpdateToken s tknId token regCode = logNtfStoreRecord s $ UpdateToken tknId token regCode - -logTokenCron :: StoreLog 'WriteMode -> NtfTokenId -> Word16 -> IO () -logTokenCron s tknId cronInt = logNtfStoreRecord s $ TokenCron tknId cronInt - -logDeleteToken :: StoreLog 'WriteMode -> NtfTokenId -> IO () -logDeleteToken s tknId = logNtfStoreRecord s $ DeleteToken tknId - -logUpdateTokenTime :: StoreLog 'WriteMode -> NtfTokenId -> SystemDate -> IO () -logUpdateTokenTime s tknId t = logNtfStoreRecord s $ UpdateTokenTime tknId t - -logCreateSubscription :: StoreLog 'WriteMode -> NtfSubRec -> IO () -logCreateSubscription s = logNtfStoreRecord s . CreateSubscription - -logSubscriptionStatus :: StoreLog 'WriteMode -> (NtfSubscriptionId, NtfSubStatus, NtfAssociatedService) -> IO () -logSubscriptionStatus s (subId, subStatus, serviceAssoc) = logNtfStoreRecord s $ SubscriptionStatus subId subStatus serviceAssoc - -logDeleteSubscription :: StoreLog 'WriteMode -> NtfSubscriptionId -> IO () -logDeleteSubscription s subId = logNtfStoreRecord s $ DeleteSubscription subId - -logSetNtfService :: StoreLog 'WriteMode -> SMPServer -> Maybe ServiceId -> IO () -logSetNtfService s srv serviceId = logNtfStoreRecord s $ SetNtfService srv serviceId - -readWriteNtfSTMStore :: Bool -> FilePath -> NtfSTMStore -> IO (StoreLog 'WriteMode) -readWriteNtfSTMStore tty = readWriteStoreLog (readNtfStore tty) writeNtfStore - -readNtfStore :: Bool -> FilePath -> NtfSTMStore -> IO () -readNtfStore tty f st = readLogLines tty f $ \_ -> processLine - where - processLine s = either printError procNtfLogRecord (strDecode s) - where - printError e = B.putStrLn $ "Error parsing log: " <> B.pack e <> " - " <> B.take 100 s - procNtfLogRecord = \case - CreateToken r@NtfTknRec {ntfTknId} -> do - tkn <- mkTknData r - atomically $ stmAddNtfToken st ntfTknId tkn - TokenStatus tknId status -> do - tkn_ <- stmGetNtfTokenIO st tknId - forM_ tkn_ $ \tkn@NtfTknData {tknStatus} -> do - atomically $ writeTVar tknStatus status - when (status == NTActive) $ void $ atomically $ stmRemoveInactiveTokenRegistrations st tkn - UpdateToken tknId token' tknRegCode -> do - stmGetNtfTokenIO st tknId - >>= mapM_ - ( \tkn@NtfTknData {tknStatus} -> do - atomically $ stmRemoveTokenRegistration st tkn - atomically $ writeTVar tknStatus NTRegistered - atomically $ stmAddNtfToken st tknId tkn {token = token', tknRegCode} - ) - TokenCron tknId cronInt -> - stmGetNtfTokenIO st tknId - >>= mapM_ (\NtfTknData {tknCronInterval} -> atomically $ writeTVar tknCronInterval cronInt) - DeleteToken tknId -> - atomically $ void $ stmDeleteNtfToken st tknId - UpdateTokenTime tknId t -> - stmGetNtfTokenIO st tknId - >>= mapM_ (\NtfTknData {tknUpdatedAt} -> atomically $ writeTVar tknUpdatedAt $ Just t) - CreateSubscription r@NtfSubRec {tokenId, ntfSubId} -> do - sub <- mkSubData r - atomically (stmAddNtfSubscription st ntfSubId sub) >>= \case - Just () -> pure () - Nothing -> B.putStrLn $ "Warning: no token " <> enc tokenId <> ", subscription " <> enc ntfSubId - where - enc = B64.encode . unEntityId - SubscriptionStatus subId status serviceAssoc -> do - stmGetNtfSubscriptionIO st subId >>= mapM_ update - where - update NtfSubData {subStatus, ntfServiceAssoc} = atomically $ do - writeTVar subStatus status - writeTVar ntfServiceAssoc serviceAssoc - DeleteSubscription subId -> - atomically $ stmDeleteNtfSubscription st subId - SetNtfService srv serviceId -> - atomically $ stmSetNtfService st srv serviceId - -writeNtfStore :: StoreLog 'WriteMode -> NtfSTMStore -> IO () -writeNtfStore s NtfSTMStore {tokens, subscriptions, ntfServices} = do - mapM_ (logCreateToken s <=< mkTknRec) =<< readTVarIO tokens - mapM_ (logCreateSubscription s <=< mkSubRec) =<< readTVarIO subscriptions - mapM_ (\(srv, serviceId) -> logSetNtfService s srv $ Just serviceId) . M.assocs =<< readTVarIO ntfServices diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 24247e781..21b03f3cf 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -46,6 +46,7 @@ module Simplex.Messaging.Server where import Control.Concurrent.STM (throwSTM) +import qualified Control.Exception as E import Control.Logger.Simple import Control.Monad import Control.Monad.Except @@ -1385,7 +1386,7 @@ client Just r -> Just <$> proxyServerResponse a r Nothing -> forkProxiedCmd $ - liftIO (runExceptT (getSMPServerClient'' a srv) `catch` (pure . Left . PCEIOError)) + liftIO (runExceptT (getSMPServerClient'' a srv) `E.catch` (\(e :: SomeException) -> pure $ Left $ PCEIOError $ E.displayException e)) >>= proxyServerResponse a proxyServerResponse :: SMPClientAgent 'Sender -> Either SMPClientError (OwnServer, SMPClient) -> M s BrokerMsg proxyServerResponse a smp_ = do @@ -1422,7 +1423,7 @@ client inc own pRequests if v >= sendingProxySMPVersion then forkProxiedCmd $ do - liftIO (runExceptT (forwardSMPTransmission smp corrId fwdV pubKey encBlock) `catch` (pure . Left . PCEIOError)) >>= \case + liftIO (runExceptT (forwardSMPTransmission smp corrId fwdV pubKey encBlock) `E.catch` (\(e :: SomeException) -> pure $ Left $ PCEIOError $ E.displayException e)) >>= \case Right r -> PRES r <$ inc own pSuccesses Left e -> ERR (smpProxyError e) <$ case e of PCEProtocolError {} -> inc own pSuccesses diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index e59cd5c0b..574111c15 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -706,7 +706,7 @@ mkJournalStoreConfig queueStoreCfg storePath msgQueueQuota maxJournalMsgCount ma newSMPProxyAgent :: SMPClientAgentConfig -> TVar ChaChaDRG -> IO ProxyAgent newSMPProxyAgent smpAgentCfg random = do - smpAgent <- newSMPClientAgent SSender smpAgentCfg random + smpAgent <- newSMPClientAgent SSender smpAgentCfg Nothing random pure ProxyAgent {smpAgent} readWriteQueueStore :: forall q. StoreQueueClass q => Bool -> (RecipientId -> QueueRec -> IO q) -> FilePath -> STMQueueStore q -> IO (StoreLog 'WriteMode) diff --git a/src/Simplex/Messaging/Transport/HTTP2/Client.hs b/src/Simplex/Messaging/Transport/HTTP2/Client.hs index 91a8bf0e5..e805fa86c 100644 --- a/src/Simplex/Messaging/Transport/HTTP2/Client.hs +++ b/src/Simplex/Messaging/Transport/HTTP2/Client.hs @@ -11,7 +11,6 @@ module Simplex.Messaging.Transport.HTTP2.Client where import Control.Concurrent.Async -import Control.Exception (IOException, try) import qualified Control.Exception as E import Control.Monad import Data.Functor (($>)) @@ -90,7 +89,7 @@ defaultHTTP2ClientConfig = suportedTLSParams = http2TLSParams } -data HTTP2ClientError = HCResponseTimeout | HCNetworkError NetworkError | HCIOError IOException +data HTTP2ClientError = HCResponseTimeout | HCNetworkError NetworkError | HCIOError String deriving (Show) getHTTP2Client :: HostName -> ServiceName -> Maybe XS.CertificateStore -> HTTP2ClientConfig -> IO () -> IO (Either HTTP2ClientError HTTP2Client) @@ -111,7 +110,7 @@ attachHTTP2Client config host port disconnected bufferSize tls = getVerifiedHTTP getVerifiedHTTP2ClientWith :: forall p. TransportPeerI p => HTTP2ClientConfig -> TransportHost -> ServiceName -> IO () -> ((TLS p -> H.Client HTTP2Response) -> IO HTTP2Response) -> IO (Either HTTP2ClientError HTTP2Client) getVerifiedHTTP2ClientWith config host port disconnected setup = (mkHTTPS2Client >>= runClient) - `E.catch` \(e :: IOException) -> pure . Left $ HCIOError e + `E.catch` \(e :: E.SomeException) -> pure $ Left $ HCIOError $ E.displayException e where mkHTTPS2Client :: IO HClient mkHTTPS2Client = do @@ -177,9 +176,9 @@ sendRequest HTTP2Client {client_ = HClient {config, reqQ}} req reqTimeout_ = do sendRequestDirect :: HTTP2Client -> Request -> Maybe Int -> IO (Either HTTP2ClientError HTTP2Response) sendRequestDirect HTTP2Client {client_ = HClient {config, disconnected}, sendReq} req reqTimeout_ = do let reqTimeout = http2RequestTimeout config reqTimeout_ - reqTimeout `timeout` try (sendReq req process) >>= \case + reqTimeout `timeout` E.try (sendReq req process) >>= \case Just (Right r) -> pure $ Right r - Just (Left e) -> disconnected $> Left (HCIOError e) + Just (Left (e :: E.SomeException)) -> disconnected $> Left (HCIOError $ E.displayException e) Nothing -> pure $ Left HCResponseTimeout where process r = do diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 34448fc10..18cdfd1fa 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -3677,6 +3677,7 @@ testClientServiceConnection ps = do exchangeGreetings service uId user sId pure conns withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> runRight $ do + liftIO $ threadDelay 250000 [(_, Right (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash)))] <- M.toList <$> subscribeClientServices service 1 ("", "", SERVICE_ALL _) <- nGet service subscribeConnection user sId @@ -3684,6 +3685,7 @@ testClientServiceConnection ps = do pure (conns, qIdHash) (uId', sId') <- withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + liftIO $ threadDelay 250000 subscribeAllConnections service False Nothing liftIO $ getInAnyOrder service [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 1 qIdHash')))) -> qIdHash' == qIdHash; _ -> False, @@ -3708,6 +3710,7 @@ testClientServiceConnection ps = do pure conns' withAgentClientsServers2 (agentCfg, initAgentServersClientService) (agentCfg, initAgentServers) $ \service user -> do withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do + liftIO $ threadDelay 250000 subscribeAllConnections service False Nothing liftIO $ getInAnyOrder service [ \case ("", "", AEvt SAENone (SERVICE_UP _ (SMP.ServiceSubResult Nothing (SMP.ServiceSub _ 2 _)))) -> True; _ -> False,