Skip to content

Commit 6078db1

Browse files
committed
feat: download snapshots
1 parent 8cc8580 commit 6078db1

File tree

3 files changed

+114
-2
lines changed

3 files changed

+114
-2
lines changed

pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@
4040
<scope>test</scope>
4141
</dependency>
4242

43+
<dependency>
44+
<groupId>org.apache.httpcomponents</groupId>
45+
<artifactId>httpclient</artifactId>
46+
<version>4.5.13</version>
47+
</dependency>
4348
</dependencies>
4449

4550
<build>

src/main/java/io/qdrant/client/QdrantClient.java

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,25 @@
1212
import io.qdrant.client.grpc.QdrantOuterClass;
1313
import io.qdrant.client.grpc.SnapshotsGrpc;
1414
import io.qdrant.client.grpc.SnapshotsService;
15+
import io.qdrant.client.grpc.SnapshotsService.SnapshotDescription;
1516
import io.qdrant.client.utils.PointUtil;
17+
import java.io.IOException;
1618
import java.net.MalformedURLException;
1719
import java.net.URL;
20+
import java.nio.file.Files;
21+
import java.nio.file.Path;
22+
import java.nio.file.StandardOpenOption;
1823
import java.time.Duration;
1924
import java.util.List;
2025
import java.util.Map;
2126
import java.util.concurrent.TimeUnit;
2227
import javax.annotation.Nullable;
28+
import org.apache.http.HttpEntity;
29+
import org.apache.http.HttpResponse;
30+
import org.apache.http.client.HttpClient;
31+
import org.apache.http.client.methods.HttpGet;
32+
import org.apache.http.impl.client.HttpClients;
33+
import org.apache.http.util.EntityUtils;
2334

2435
/** Client for interfacing with the Qdrant service. */
2536
public class QdrantClient implements AutoCloseable {
@@ -1311,10 +1322,74 @@ public SnapshotsService.DeleteSnapshotResponse deleteFullSnapshot(String snapsho
13111322
return snapshotsStub.deleteFull(request);
13121323
}
13131324

1325+
/**
1326+
* Downloads a snapshot of a collection from the specified REST API URI and saves it to the given
1327+
* output path.
1328+
*
1329+
* @param outPath The path where the snapshot will be saved.
1330+
* @param collectionName The name of the collection.
1331+
* @param snapshotName The name of the snapshot. If null, the latest snapshot will be downloaded.
1332+
* @param restApiUri The URI of the REST API. If null, the default URI "http://localhost:6333"
1333+
* will be used.
1334+
* @throws RuntimeException If an error occurs while downloading the snapshot.
1335+
*/
1336+
public void downloadSnapshot(
1337+
Path outPath,
1338+
String collectionName,
1339+
@Nullable String snapshotName,
1340+
@Nullable String restApiUri) {
1341+
try {
1342+
String resolvedSnapshotName;
1343+
1344+
if (snapshotName != null) {
1345+
resolvedSnapshotName = snapshotName;
1346+
} else {
1347+
// Get the latest(0th) snapshot of the collection
1348+
List<SnapshotDescription> snapshots =
1349+
listSnapshots(collectionName).getSnapshotDescriptionsList();
1350+
if (snapshots.isEmpty()) {
1351+
throw new RuntimeException("No snapshots found");
1352+
}
1353+
resolvedSnapshotName =
1354+
listSnapshots(collectionName).getSnapshotDescriptionsList().get(0).getName();
1355+
}
1356+
1357+
String uri;
1358+
if (restApiUri != null) {
1359+
uri =
1360+
String.format(
1361+
"%s/collections/%s/snapshots/%s", restApiUri, collectionName, resolvedSnapshotName);
1362+
} else {
1363+
uri =
1364+
String.format(
1365+
"http://localhost:6333/collections/%s/snapshots/%s",
1366+
collectionName, resolvedSnapshotName);
1367+
}
1368+
1369+
HttpClient httpClient = HttpClients.createDefault();
1370+
HttpGet httpGet = new HttpGet(uri);
1371+
1372+
HttpResponse response = httpClient.execute(httpGet);
1373+
1374+
if (response.getStatusLine().getStatusCode() == 200) {
1375+
HttpEntity entity = response.getEntity();
1376+
if (entity != null) {
1377+
Files.write(outPath, EntityUtils.toByteArray(entity), StandardOpenOption.WRITE);
1378+
System.out.println("Downloaded successfully");
1379+
} else {
1380+
System.err.println("No response body");
1381+
}
1382+
} else {
1383+
System.err.println(
1384+
"Download failed. HTTP Status Code: " + response.getStatusLine().getStatusCode());
1385+
}
1386+
} catch (IOException e) {
1387+
throw new RuntimeException("Error downloading snapshot " + e.getMessage());
1388+
}
1389+
}
1390+
13141391
@Override
13151392
public void close() throws InterruptedException {
13161393
this.channel.shutdown().awaitTermination(10, TimeUnit.SECONDS);
13171394
}
1318-
1319-
// TODO: Download snapshots REST
13201395
}

src/test/java/io/qdrant/client/QdrantClientSnapshotsTest.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import io.qdrant.client.grpc.Collections.Distance;
77
import io.qdrant.client.grpc.SnapshotsService;
8+
import java.nio.file.FileSystems;
9+
import java.nio.file.Path;
810
import java.util.UUID;
911
import org.junit.jupiter.api.BeforeAll;
1012
import org.junit.jupiter.api.Test;
@@ -71,4 +73,34 @@ void testFullSnapshots() {
7173
assertEquals(qdrantClient.listFullSnapshots().getSnapshotDescriptionsList().size(), 0);
7274
});
7375
}
76+
77+
@Test
78+
void testDownloadSnapshot() {
79+
String collectionName = UUID.randomUUID().toString();
80+
81+
qdrantClient.createCollection(collectionName, 768, Distance.Cosine);
82+
83+
assertEquals(
84+
qdrantClient.listSnapshots(collectionName).getSnapshotDescriptionsList().size(), 0);
85+
86+
// Test with snapshot name
87+
assertDoesNotThrow(
88+
() -> {
89+
SnapshotsService.CreateSnapshotResponse response =
90+
qdrantClient.createSnapshot(collectionName);
91+
String snapshotName = response.getSnapshotDescription().getName();
92+
93+
Path path = FileSystems.getDefault().getPath("./test.snapshot");
94+
qdrantClient.downloadSnapshot(path, collectionName, snapshotName, null);
95+
});
96+
97+
// Test without snapshot name
98+
assertDoesNotThrow(
99+
() -> {
100+
qdrantClient.createSnapshot(collectionName);
101+
102+
Path path = FileSystems.getDefault().getPath("./test_2.snapshot");
103+
qdrantClient.downloadSnapshot(path, collectionName, null, null);
104+
});
105+
}
74106
}

0 commit comments

Comments
 (0)