|
12 | 12 | */ |
13 | 13 | package com.ibm.watson.developer_cloud.visual_recognition.v3; |
14 | 14 |
|
| 15 | +import com.google.common.io.Files; |
15 | 16 | import com.google.gson.Gson; |
16 | 17 | import com.google.gson.JsonObject; |
17 | 18 | import com.ibm.watson.developer_cloud.WatsonServiceUnitTest; |
| 19 | +import com.ibm.watson.developer_cloud.http.HttpMediaType; |
18 | 20 | import com.ibm.watson.developer_cloud.visual_recognition.v3.model.ClassifiedImages; |
19 | 21 | import com.ibm.watson.developer_cloud.visual_recognition.v3.model.Classifier; |
20 | 22 | import com.ibm.watson.developer_cloud.visual_recognition.v3.model.ClassifyOptions; |
|
23 | 25 | import com.ibm.watson.developer_cloud.visual_recognition.v3.model.DetectFacesOptions; |
24 | 26 | import com.ibm.watson.developer_cloud.visual_recognition.v3.model.DetectedFaces; |
25 | 27 | import com.ibm.watson.developer_cloud.visual_recognition.v3.model.GetClassifierOptions; |
| 28 | +import com.ibm.watson.developer_cloud.visual_recognition.v3.model.GetCoreMlModelOptions; |
26 | 29 | import com.ibm.watson.developer_cloud.visual_recognition.v3.model.ListClassifiersOptions; |
27 | 30 | import com.ibm.watson.developer_cloud.visual_recognition.v3.model.UpdateClassifierOptions; |
28 | 31 | import okhttp3.mockwebserver.MockResponse; |
29 | 32 | import okhttp3.mockwebserver.RecordedRequest; |
| 33 | +import okio.Buffer; |
30 | 34 | import org.junit.Before; |
31 | 35 | import org.junit.Test; |
32 | 36 |
|
@@ -58,6 +62,7 @@ public class VisualRecognitionTest extends WatsonServiceUnitTest { |
58 | 62 | private static final String PATH_CLASSIFIERS = "/v3/classifiers"; |
59 | 63 | private static final String PATH_CLASSIFIER = "/v3/classifiers/%s"; |
60 | 64 | private static final String PATH_DETECT_FACES = "/v3/detect_faces"; |
| 65 | + private static final String PATH_CORE_ML = "/v3/classifiers/%s/core_ml_model"; |
61 | 66 |
|
62 | 67 | private VisualRecognition service; |
63 | 68 |
|
@@ -328,4 +333,26 @@ public void testGetClassifiers() throws InterruptedException, IOException { |
328 | 333 | assertEquals("GET", request.getMethod()); |
329 | 334 | assertEquals(serviceResponse, classifiers); |
330 | 335 | } |
| 336 | + |
| 337 | + @Test |
| 338 | + public void testGetCoreMlModel() throws IOException, InterruptedException { |
| 339 | + final File model = new File("src/test/resources/visual_recognition/custom_model.mlmodel"); |
| 340 | + final Buffer buffer = new Buffer().write(Files.toByteArray(model)); |
| 341 | + |
| 342 | + server.enqueue(new MockResponse().addHeader(CONTENT_TYPE, HttpMediaType.APPLICATION_OCTET_STREAM).setBody(buffer)); |
| 343 | + |
| 344 | + String classifierId = "classifier_id"; |
| 345 | + GetCoreMlModelOptions options = new GetCoreMlModelOptions.Builder() |
| 346 | + .classifierId(classifierId) |
| 347 | + .build(); |
| 348 | + |
| 349 | + InputStream modelFile = service.getCoreMlModel(options).execute(); |
| 350 | + |
| 351 | + RecordedRequest request = server.takeRequest(); |
| 352 | + String path = String.format(PATH_CORE_ML, classifierId) + "?" + VERSION_DATE + "=2016-05-20&api_key=" + API_KEY; |
| 353 | + |
| 354 | + assertEquals(path, request.getPath()); |
| 355 | + assertEquals("GET", request.getMethod()); |
| 356 | + writeInputStreamToFile(modelFile, new File("build/model_result.mlmodel")); |
| 357 | + } |
331 | 358 | } |
0 commit comments