Skip to content

Commit 3851556

Browse files
authored
Merge pull request #5593 from unisonweb/syncv2/sorted
Run downloading, unpacking, and saving in parallel for serialized syncs
2 parents a775871 + f14e7c7 commit 3851556

File tree

4 files changed

+142
-64
lines changed

4 files changed

+142
-64
lines changed

unison-cli/package.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ library:
3737
- containers >= 0.6.3
3838
- conduit
3939
- conduit-extra
40+
- stm-chans
4041
- cryptonite
4142
- either
4243
- errors

unison-cli/src/Unison/Cli/DownloadUtils.hs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,8 @@ downloadProjectBranchFromShare useSquashed branch =
7777
Cli.respond (Output.DownloadedEntities numDownloaded)
7878
SyncV2 -> do
7979
let branchRef = SyncV2.BranchRef (into @Text (ProjectAndBranch branch.projectName remoteProjectBranchName))
80-
let downloadedCallback = \_ -> pure ()
8180
let shouldValidate = not $ Codeserver.isCustomCodeserver Codeserver.defaultCodeserver
82-
result <- SyncV2.syncFromCodeserver shouldValidate Share.hardCodedBaseUrl branchRef causalHashJwt downloadedCallback
81+
result <- SyncV2.syncFromCodeserver shouldValidate Share.hardCodedBaseUrl branchRef causalHashJwt
8382
result & onLeft \err0 -> do
8483
done case err0 of
8584
Share.SyncError pullErr ->

unison-cli/src/Unison/Share/SyncV2.hs

Lines changed: 139 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ where
1212
import Codec.Serialise qualified as CBOR
1313
import Conduit (ConduitT)
1414
import Conduit qualified as C
15+
import Control.Concurrent.STM.TBMQueue qualified as STM
1516
import Control.Lens
1617
import Control.Monad.Except
1718
import Control.Monad.Reader (ask)
@@ -30,6 +31,7 @@ import Data.Graph qualified as Graph
3031
import Data.Map qualified as Map
3132
import Data.Proxy
3233
import Data.Set qualified as Set
34+
import Data.Text qualified as Text
3335
import Data.Text.IO qualified as Text
3436
import Data.Text.Lazy qualified as Text.Lazy
3537
import Data.Text.Lazy.Encoding qualified as Text.Lazy
@@ -65,9 +67,12 @@ import Unison.SyncV2.API (Routes (downloadEntitiesStream))
6567
import Unison.SyncV2.API qualified as SyncV2
6668
import Unison.SyncV2.Types (CBORBytes, CBORStream, DependencyType (..))
6769
import Unison.SyncV2.Types qualified as SyncV2
70+
import Unison.Util.Monoid qualified as Monoid
6871
import Unison.Util.Servant.CBOR qualified as CBOR
6972
import Unison.Util.Timing qualified as Timing
7073
import UnliftIO qualified as IO
74+
import UnliftIO.Async qualified as Async
75+
import UnliftIO.STM qualified as STM
7176

7277
type Stream i o = ConduitT i o StreamM ()
7378

@@ -76,12 +81,13 @@ type SyncErr = SyncError SyncV2.PullError
7681
-- The base monad we use within the conduit pipeline.
7782
type StreamM = (ExceptT SyncErr (C.ResourceT IO))
7883

79-
-- | The number of entities to process in a single transaction.
80-
--
81-
-- SQLite transactions have some fixed overhead, so setting this too low can really slow things down,
82-
-- but going too high here means we may be waiting on the network to get a full batch when we could be starting work.
83-
batchSize :: Int
84-
batchSize = 5000
84+
data ProgressCallbacks
85+
= ProgressCallbacks
86+
{ setTotal :: Int -> IO (),
87+
downloadCounter :: Int -> IO (),
88+
doneDownloading :: IO (),
89+
importCounter :: Int -> IO ()
90+
}
8591

8692
------------------------------------------------------------------------------------------------------------------------
8793
-- Main methods
@@ -101,7 +107,7 @@ syncToFile codebase rootHash mayBranchRef destFilePath = do
101107
liftIO $ Codebase.withConnection codebase \conn -> do
102108
C.runResourceT $
103109
withCodebaseEntityStream conn rootHash mayBranchRef \mayTotal stream -> do
104-
withStreamProgressCallback (Just mayTotal) \countC -> runExceptT do
110+
syncToFileProgress (Just mayTotal) \countC -> runExceptT do
105111
C.runConduit $
106112
stream
107113
C..| countC
@@ -120,15 +126,16 @@ syncFromFile shouldValidate syncFilePath = do
120126
-- Every insert into SQLite checks the temp entity tables, but syncv2 doesn't actually use them, so it's faster
121127
-- if we clear them out before starting a sync.
122128
Cli.runTransaction Q.clearTempEntityTables
123-
runExceptT do
124-
mapExceptT liftIO $ Timing.time "File Sync" $ do
125-
header <- mapExceptT C.runResourceT $ do
126-
let stream = C.sourceFile syncFilePath C..| C.ungzip C..| decodeUnframedEntities
127-
(header, rest) <- initializeStream stream
128-
streamIntoCodebase shouldValidate codebase header rest
129-
pure header
130-
afterSyncChecks codebase (SyncV2.rootCausalHash header)
131-
pure . hash32ToCausalHash $ SyncV2.rootCausalHash header
129+
liftIO $ withStreamProgress False \progressCounters -> do
130+
runExceptT do
131+
mapExceptT liftIO $ Timing.time "File Sync" $ do
132+
header <- mapExceptT C.runResourceT $ do
133+
let stream = C.sourceFile syncFilePath C..| C.ungzip C..| decodeUnframedEntities
134+
(header, rest) <- initializeStream (setTotal progressCounters) stream
135+
streamIntoCodebase progressCounters shouldValidate codebase header rest
136+
pure header
137+
afterSyncChecks codebase (SyncV2.rootCausalHash header)
138+
pure . hash32ToCausalHash $ SyncV2.rootCausalHash header
132139

133140
syncFromCodebase ::
134141
Bool ->
@@ -142,10 +149,11 @@ syncFromCodebase shouldValidate srcConn destCodebase causalHash = do
142149
-- Every insert into SQLite checks the temp entity tables, but syncv2 doesn't actually use them, so it's faster
143150
-- if we clear them out before starting a sync.
144151
Sqlite.runTransaction srcConn Q.clearTempEntityTables
145-
liftIO . C.runResourceT . runExceptT $ withCodebaseEntityStream srcConn causalHash Nothing \_total entityStream -> do
146-
(header, rest) <- initializeStream entityStream
147-
streamIntoCodebase shouldValidate destCodebase header rest
148-
mapExceptT liftIO (afterSyncChecks destCodebase (causalHashToHash32 causalHash))
152+
withStreamProgress False \progressCounters -> do
153+
liftIO . C.runResourceT . runExceptT $ withCodebaseEntityStream srcConn causalHash Nothing \_total entityStream -> do
154+
(header, rest) <- initializeStream (setTotal progressCounters) entityStream
155+
streamIntoCodebase progressCounters shouldValidate destCodebase header rest
156+
mapExceptT liftIO (afterSyncChecks destCodebase (causalHashToHash32 causalHash))
149157

150158
syncFromCodeserver ::
151159
Bool ->
@@ -155,10 +163,8 @@ syncFromCodeserver ::
155163
SyncV2.BranchRef ->
156164
-- | The hash to download.
157165
Share.HashJWT ->
158-
-- | Callback that's given a number of entities we just downloaded.
159-
(Int -> IO ()) ->
160166
Cli (Either (SyncError SyncV2.PullError) ())
161-
syncFromCodeserver shouldValidate unisonShareUrl branchRef hashJwt _downloadedCallback = do
167+
syncFromCodeserver shouldValidate unisonShareUrl branchRef hashJwt = do
162168
Cli.Env {authHTTPClient, codebase} <- ask
163169
-- Every insert into SQLite checks the temp entity tables, but syncv2 doesn't actually use them, so it's faster
164170
-- if we clear them out before starting a sync.
@@ -169,14 +175,16 @@ syncFromCodeserver shouldValidate unisonShareUrl branchRef hashJwt _downloadedCa
169175
ExceptT $ do
170176
(Cli.runTransaction (Q.entityLocation hash)) >>= \case
171177
Just Q.EntityInMainStorage -> pure $ Right ()
172-
_ -> do
178+
_ -> liftIO $ withStreamProgress True \progressCallbacks -> do
173179
Timing.time "Entity Download" $ do
174180
liftIO . C.runResourceT . runExceptT $ httpStreamEntities
181+
(setTotal progressCallbacks)
175182
authHTTPClient
176183
unisonShareUrl
177184
SyncV2.DownloadEntitiesRequest {branchRef, causalHash = hashJwt, knownHashes}
178185
\header stream -> do
179-
streamIntoCodebase shouldValidate codebase header stream
186+
whenJust (SyncV2.numEntities header) (liftIO . setTotal progressCallbacks . fromIntegral)
187+
streamIntoCodebase progressCallbacks shouldValidate codebase header stream
180188
mapExceptT liftIO (afterSyncChecks codebase hash)
181189

182190
------------------------------------------------------------------------------------------------------------------------
@@ -220,46 +228,74 @@ batchValidateEntities entities = do
220228

221229
-- | Syncs a stream which could send entities in any order.
222230
syncUnsortedStream ::
231+
ProgressCallbacks ->
223232
Bool ->
224233
(Codebase.Codebase IO v a) ->
225234
Stream () SyncV2.EntityChunk ->
226235
StreamM ()
227-
syncUnsortedStream shouldValidate codebase stream = do
236+
syncUnsortedStream (ProgressCallbacks {setTotal, downloadCounter, doneDownloading, importCounter}) shouldValidate codebase stream = do
228237
allEntities <-
229238
C.runConduit $
230239
stream
240+
C..| C.iterM (\_ -> liftIO $ downloadCounter 1)
231241
C..| CL.chunksOf batchSize
232242
C..| unpackChunks codebase
233243
C..| validateBatch
234244
C..| C.concat
235245
C..| C.sinkVector @Vector
246+
liftIO doneDownloading
247+
liftIO $ setTotal (Vector.length allEntities)
236248
let sortedEntities = sortDependencyFirst allEntities
237-
liftIO $ withEntitySavingCallback (Just $ Vector.length allEntities) \countC -> do
238-
Codebase.runTransaction codebase $ for_ sortedEntities \(hash, entity) -> do
239-
r <- Q.saveTempEntityInMain v2HashHandle hash entity
240-
Sqlite.unsafeIO $ countC 1
241-
pure r
249+
liftIO $ Codebase.runTransaction codebase $ for_ sortedEntities \(hash, entity) -> do
250+
r <- Q.saveTempEntityInMain v2HashHandle hash entity
251+
Sqlite.unsafeIO $ importCounter 1
252+
pure r
242253
where
254+
-- The number of entities to process in a single transaction.
255+
--
256+
-- SQLite transactions have some fixed overhead, so setting this too low can really slow things down,
257+
-- but going too high here means we may be waiting on the network to get a full batch when we could be starting work.
258+
batchSize :: Int
259+
batchSize = 5000
260+
243261
validateBatch :: Stream (Vector (Hash32, TempEntity)) (Vector (Hash32, TempEntity))
244262
validateBatch = C.iterM \entities -> do
245263
when shouldValidate (mapExceptT lift $ batchValidateEntities entities)
246264

247265
-- | Syncs a stream which sends entities which are already sorted in dependency order.
248266
-- This allows us to stream them directly into the codebase as they're received.
249267
syncSortedStream ::
268+
ProgressCallbacks ->
250269
Bool ->
251270
(Codebase.Codebase IO v a) ->
252271
Stream () SyncV2.EntityChunk ->
253272
StreamM ()
254-
syncSortedStream shouldValidate codebase stream = do
273+
syncSortedStream (ProgressCallbacks {downloadCounter, doneDownloading, importCounter}) shouldValidate codebase stream = do
274+
(downloaderSink, downloaderSource) <- parallelSinkAndSource (3 * batchSize) -- Allow downloading up to triple our current batch size in advance
275+
(unpackerSink, unpackerSource) <- parallelSinkAndSource 5 -- Buffer of up to 5 batches.
255276
let handler :: Stream (Vector (Hash32, TempEntity)) o
256277
handler = C.mapM_C \entityBatch -> do
257278
validateAndSave shouldValidate codebase entityBatch
258-
C.runConduit $
259-
stream
260-
C..| CL.chunksOf batchSize
261-
C..| unpackChunks codebase
262-
C..| handler
279+
liftIO $ importCounter (length entityBatch)
280+
let downloadC = stream C..| downloaderSink
281+
let saverC =
282+
downloaderSource
283+
C..| CL.chunksOf batchSize
284+
C..| unpackChunks codebase
285+
C..| C.iterM (liftIO . downloadCounter . length)
286+
C..| (unpackerSink *> liftIO doneDownloading)
287+
let handlerC =
288+
unpackerSource
289+
C..| handler
290+
291+
-- Run the three conduits concurrently, and wait for them all to finish, fail if any of them fail.
292+
ExceptT . Async.runConc $ do
293+
a <- Async.conc . runExceptT $ C.runConduit downloadC
294+
b <- Async.conc . runExceptT $ C.runConduit saverC
295+
c <- Async.conc . runExceptT $ C.runConduit handlerC
296+
pure (a >> b >> c)
297+
where
298+
batchSize = 1000
263299

264300
-- | Topologically sort entities based on their dependencies, returning a list in dependency-first order.
265301
sortDependencyFirst :: (Foldable f, Functor f) => f (Hash32, TempEntity) -> [(Hash32, TempEntity)]
@@ -292,23 +328,21 @@ unpackChunks codebase = C.mapM \xs -> ExceptT . lift . Codebase.runTransactionEx
292328

293329
-- | Stream entities from one codebase into another.
294330
streamIntoCodebase ::
331+
ProgressCallbacks ->
295332
-- | Whether to validate entities as they're imported.
296333
Bool ->
297334
Codebase.Codebase IO v a ->
298335
SyncV2.StreamInitInfo ->
299336
Stream () SyncV2.EntityChunk ->
300337
StreamM ()
301-
streamIntoCodebase shouldValidate codebase SyncV2.StreamInitInfo {version, entitySorting, numEntities = numEntities} stream = ExceptT do
302-
withStreamProgressCallback (fromIntegral <$> numEntities) \countC -> runExceptT do
303-
-- Add a counter to the stream to track how many entities we've processed.
304-
let stream' = stream C..| countC
305-
case version of
306-
(SyncV2.Version 1) -> pure ()
307-
v -> throwError . SyncError . SyncV2.PullError'Sync $ SyncV2.SyncErrorUnsupportedVersion v
308-
309-
case entitySorting of
310-
SyncV2.DependenciesFirst -> syncSortedStream shouldValidate codebase stream'
311-
SyncV2.Unsorted -> syncUnsortedStream shouldValidate codebase stream'
338+
streamIntoCodebase progressCounters shouldValidate codebase SyncV2.StreamInitInfo {version, entitySorting} stream = do
339+
case version of
340+
(SyncV2.Version 1) -> pure ()
341+
v -> throwError . SyncError . SyncV2.PullError'Sync $ SyncV2.SyncErrorUnsupportedVersion v
342+
343+
case entitySorting of
344+
SyncV2.DependenciesFirst -> syncSortedStream progressCounters shouldValidate codebase stream
345+
SyncV2.Unsorted -> syncUnsortedStream progressCounters shouldValidate codebase stream
312346

313347
-- | A sanity-check to verify that the hash we expected to import from the stream was successfully loaded into the codebase.
314348
afterSyncChecks :: Codebase.Codebase IO v a -> Hash32 -> ExceptT (SyncError SyncV2.PullError) IO ()
@@ -492,12 +526,13 @@ handleClientError clientEnv err =
492526

493527
-- | Stream entities from the codeserver.
494528
httpStreamEntities ::
529+
(Int -> IO ()) ->
495530
Auth.AuthenticatedHttpClient ->
496531
Servant.BaseUrl ->
497532
SyncV2.DownloadEntitiesRequest ->
498533
(SyncV2.StreamInitInfo -> Stream () SyncV2.EntityChunk -> StreamM ()) ->
499534
StreamM ()
500-
httpStreamEntities (Auth.AuthenticatedHttpClient httpClient) unisonShareUrl req callback = do
535+
httpStreamEntities setTotal (Auth.AuthenticatedHttpClient httpClient) unisonShareUrl req callback = do
501536
let clientEnv =
502537
(Servant.mkClientEnv httpClient unisonShareUrl)
503538
{ Servant.makeClientRequest = \url request ->
@@ -509,19 +544,20 @@ httpStreamEntities (Auth.AuthenticatedHttpClient httpClient) unisonShareUrl req
509544
}
510545
}
511546
(downloadEntitiesStreamClientM req) & withConduit clientEnv \stream -> do
512-
(init, entityStream) <- initializeStream stream
547+
(init, entityStream) <- initializeStream setTotal stream
513548
callback init entityStream
514549

515550
-- | Peel the header off the stream and parse the remaining entity chunks into EntityChunks
516-
initializeStream :: Stream () SyncV2.DownloadEntitiesChunk -> StreamM (SyncV2.StreamInitInfo, Stream () SyncV2.EntityChunk)
517-
initializeStream stream = do
551+
initializeStream :: (Int -> IO ()) -> Stream () SyncV2.DownloadEntitiesChunk -> StreamM (SyncV2.StreamInitInfo, Stream () SyncV2.EntityChunk)
552+
initializeStream setTotal stream = do
518553
(streamRemainder, init) <- stream C.$$+ C.headC
519554
case init of
520555
Nothing -> throwError . SyncError . SyncV2.PullError'Sync $ SyncV2.SyncErrorMissingInitialChunk
521556
Just chunk -> do
522557
case chunk of
523558
SyncV2.InitialC info -> do
524559
let entityStream = C.unsealConduitT streamRemainder C..| C.mapM parseEntity
560+
for (SyncV2.numEntities info) \t -> liftIO $ setTotal (fromIntegral t)
525561
pure $ (info, entityStream)
526562
SyncV2.EntityC _ -> do
527563
throwError . SyncError . SyncV2.PullError'Sync $ SyncV2.SyncErrorMissingInitialChunk
@@ -628,20 +664,61 @@ counterProgress msgBuilder action = do
628664
liftIO $ IO.atomically (IO.modifyTVar' counterVar (+ i))
629665

630666
-- | Track how many entities have been downloaded using a counter stream.
631-
withStreamProgressCallback :: (MonadIO m, MonadUnliftIO n) => Maybe Int -> (ConduitT i i m () -> n a) -> n a
632-
withStreamProgressCallback total action = do
633-
let msg n = "\n 📦 Unpacked " <> tShow n <> maybe "" (\total -> " / " <> tShow total) total <> " entities...\n\n"
667+
syncToFileProgress :: (MonadIO m, MonadUnliftIO n) => Maybe Int -> (ConduitT i i m () -> n a) -> n a
668+
syncToFileProgress total action = do
669+
let msg n = "\n Exported " <> tShow n <> maybe "" (\total -> " / " <> tShow total) total <> " entities 📦 \n\n"
634670
let action' f = action (C.iterM \_i -> f 1)
635671
counterProgress msg action'
636672

637-
-- | Track how many entities have been saved.
638-
withEntitySavingCallback :: (MonadUnliftIO m) => Maybe Int -> ((Int -> m ()) -> m a) -> m a
639-
withEntitySavingCallback total action = do
640-
let msg n = "\n 💾 Saved " <> tShow n <> maybe "" (\total -> " / " <> tShow total) total <> " new entities...\n\n"
641-
counterProgress msg action
642-
643673
-- | Track how many entities have been loaded.
644674
withEntityLoadingCallback :: (MonadUnliftIO m) => ((Int -> m ()) -> m a) -> m a
645675
withEntityLoadingCallback action = do
646-
let msg n = "\n 📦 Unpacked " <> tShow n <> " entities...\n\n"
676+
let msg n = "\n Loading entities from codebase: " <> tShow n <> " 📦\n\n"
647677
counterProgress msg action
678+
679+
withStreamProgress :: (MonadUnliftIO n) => Bool -> (ProgressCallbacks -> n a) -> n a
680+
withStreamProgress hasDownload action = do
681+
downloadedVar <- IO.newTVarIO 0
682+
doneUnpackingVar <- IO.newTVarIO False
683+
savedVar <- IO.newTVarIO (0 :: Int)
684+
totalVar <- IO.newTVarIO Nothing
685+
IO.withRunInIO \toIO -> do
686+
Console.Regions.displayConsoleRegions do
687+
Console.Regions.withConsoleRegion Console.Regions.Linear \region -> do
688+
Console.Regions.setConsoleRegion region do
689+
downloaded <- IO.readTVar downloadedVar
690+
doneUnpacking <- IO.readTVar doneUnpackingVar
691+
saved <- IO.readTVar savedVar
692+
total <- IO.readTVar totalVar
693+
pure $
694+
Text.unlines
695+
[ Monoid.whenM hasDownload $ "\n Downloaded: " <> tShow @Int downloaded <> maybe "" (\total -> " / " <> tShow @Int total) total <> Monoid.whenM doneUnpacking " 🏁",
696+
" Imported: " <> tShow @Int saved
697+
]
698+
toIO $
699+
action $
700+
ProgressCallbacks
701+
{ setTotal = \total -> do liftIO $ IO.atomically (IO.writeTVar totalVar (Just total)),
702+
downloadCounter = \i -> do liftIO $ IO.atomically (IO.modifyTVar' downloadedVar (+ i)),
703+
doneDownloading = do liftIO $ IO.atomically (IO.writeTVar doneUnpackingVar True),
704+
importCounter = \i -> do liftIO $ IO.atomically (IO.modifyTVar' savedVar (+ i))
705+
}
706+
707+
-- * Conduit helpers
708+
709+
parallelSinkAndSource :: (MonadIO m) => Int -> m (ConduitT i void1 m (), ConduitT void2 i m ())
710+
parallelSinkAndSource bufferSize = do
711+
q <- liftIO $ STM.newTBMQueueIO bufferSize
712+
let sink = do
713+
C.await >>= \case
714+
Nothing -> STM.atomically $ STM.closeTBMQueue q
715+
Just chunk -> do
716+
STM.atomically $ STM.writeTBMQueue q chunk
717+
sink
718+
let source = do
719+
STM.atomically (STM.readTBMQueue q) >>= \case
720+
Nothing -> pure ()
721+
Just chunk -> do
722+
C.yield chunk
723+
source
724+
pure (sink, source)

unison-cli/unison-cli.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ library
251251
, servant-client
252252
, servant-conduit
253253
, stm
254+
, stm-chans
254255
, temporary
255256
, text
256257
, text-ansi

0 commit comments

Comments
 (0)