Skip to content

Commit e9a5999

Browse files
lu-wang-gtflite-support-robot
authored andcommitted
Cherry pick the Swift test from #709
PiperOrigin-RevId: 411642699
1 parent 03c2c84 commit e9a5999

File tree

2 files changed

+161
-0
lines changed

2 files changed

+161
-0
lines changed

tensorflow_lite_support/ios/test/task/vision/image_classifier/BUILD

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library")
12
load("@org_tensorflow//tensorflow/lite/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION")
23
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
34
load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner")
@@ -7,6 +8,30 @@ package(
78
licenses = ["notice"], # Apache 2.0
89
)
910

11+
swift_library(
12+
name = "TFLImageClassifierSwiftTestLibrary",
13+
testonly = 1,
14+
srcs = ["TFLImageClassifierTests.swift"],
15+
data = [
16+
"//tensorflow_lite_support/cc/test/testdata/task/vision:test_images",
17+
"//tensorflow_lite_support/cc/test/testdata/task/vision:test_models",
18+
],
19+
tags = TFL_DEFAULT_TAGS,
20+
deps = [
21+
"//tensorflow_lite_support/ios/task/vision:TFLImageClassifier",
22+
],
23+
)
24+
25+
ios_unit_test(
26+
name = "TFLImageClassifierSwiftTest",
27+
minimum_os_version = TFL_MINIMUM_OS_VERSION,
28+
runner = tflite_ios_lab_runner("IOS_LATEST"),
29+
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
30+
deps = [
31+
":TFLImageClassifierSwiftTestLibrary",
32+
],
33+
)
34+
1035
objc_library(
1136
name = "TFLImageClassifierObjcTestLibrary",
1237
testonly = 1,
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
import XCTest
16+
17+
@testable import TFLImageClassifier
18+
19+
class TFLImageClassifierTests: XCTestCase {
20+
21+
static let bundle = Bundle(for: TFLImageClassifierTests.self)
22+
static let modelPath = bundle.path(
23+
forResource: "mobilenet_v2_1.0_224",
24+
ofType: "tflite")!
25+
26+
func testSuccessfullInferenceOnMLImageWithUIImage() throws {
27+
28+
let modelPath = try XCTUnwrap(TFLImageClassifierTests.modelPath)
29+
30+
let imageClassifierOptions = TFLImageClassifierOptions(modelPath: modelPath)
31+
XCTAssertNotNil(imageClassifierOptions)
32+
33+
let imageClassifier =
34+
try TFLImageClassifier.imageClassifier(options: imageClassifierOptions!)
35+
36+
let gmlImage = try gmlImage(withName: "burger", ofType: "jpg")
37+
let classificationResults: TFLClassificationResult =
38+
try imageClassifier.classify(gmlImage: gmlImage)
39+
40+
XCTAssertNotNil(classificationResults)
41+
XCTAssertEqual(classificationResults.classifications.count, 1)
42+
XCTAssertGreaterThan(classificationResults.classifications[0].categories.count, 0)
43+
// TODO: match the score as image_classifier_test.cc
44+
let category = classificationResults.classifications[0].categories[0]
45+
XCTAssertEqual(category.label, "cheeseburger")
46+
XCTAssertEqual(category.score, 0.748976, accuracy: 0.001)
47+
}
48+
49+
func testModelOptionsWithMaxResults() throws {
50+
51+
let modelPath = try XCTUnwrap(TFLImageClassifierTests.modelPath)
52+
53+
let imageClassifierOptions = TFLImageClassifierOptions(modelPath: modelPath)
54+
XCTAssertNotNil(imageClassifierOptions)
55+
56+
let maxResults = 3
57+
imageClassifierOptions!.classificationOptions.maxResults = maxResults
58+
59+
let imageClassifier =
60+
try TFLImageClassifier.imageClassifier(options: imageClassifierOptions!)
61+
62+
let gmlImage = try gmlImage(withName: "burger", ofType: "jpg")
63+
64+
let classificationResults: TFLClassificationResult = try imageClassifier.classify(
65+
gmlImage: gmlImage)
66+
67+
XCTAssertNotNil(classificationResults)
68+
XCTAssertEqual(classificationResults.classifications.count, 1)
69+
XCTAssertLessThanOrEqual(classificationResults.classifications[0].categories.count, maxResults)
70+
71+
// TODO: match the score as image_classifier_test.cc
72+
let category = classificationResults.classifications[0].categories[0]
73+
XCTAssertEqual(category.label, "cheeseburger")
74+
XCTAssertEqual(category.score, 0.748976, accuracy: 0.001)
75+
}
76+
77+
func testInferenceWithBoundingBox() throws {
78+
79+
let modelPath = try XCTUnwrap(TFLImageClassifierTests.modelPath)
80+
81+
let imageClassifierOptions = TFLImageClassifierOptions(modelPath: modelPath)
82+
XCTAssertNotNil(imageClassifierOptions)
83+
84+
let imageClassifier =
85+
try TFLImageClassifier.imageClassifier(options: imageClassifierOptions!)
86+
87+
let gmlImage = try gmlImage(withName: "multi_objects", ofType: "jpg")
88+
89+
let roi = CGRect(x: 406, y: 110, width: 148, height: 153)
90+
let classificationResults =
91+
try imageClassifier.classify(gmlImage: gmlImage, regionOfInterest: roi)
92+
93+
XCTAssertNotNil(classificationResults)
94+
XCTAssertEqual(classificationResults.classifications.count, 1)
95+
XCTAssertGreaterThan(classificationResults.classifications[0].categories.count, 0)
96+
97+
// TODO: match the label and score as image_classifier_test.cc
98+
// let category = classificationResults.classifications[0].categories[0]
99+
// XCTAssertEqual(category.label, "soccer ball")
100+
// XCTAssertEqual(category.score, 0.256512, accuracy:0.001);
101+
}
102+
103+
func testInferenceWithRGBAImage() throws {
104+
105+
let modelPath = try XCTUnwrap(TFLImageClassifierTests.modelPath)
106+
107+
let imageClassifierOptions = TFLImageClassifierOptions(modelPath: modelPath)
108+
XCTAssertNotNil(imageClassifierOptions)
109+
110+
let imageClassifier =
111+
try TFLImageClassifier.imageClassifier(options: imageClassifierOptions!)
112+
113+
let gmlImage = try gmlImage(withName: "sparrow", ofType: "png")
114+
115+
let classificationResults =
116+
try imageClassifier.classify(gmlImage: gmlImage)
117+
118+
XCTAssertNotNil(classificationResults)
119+
XCTAssertEqual(classificationResults.classifications.count, 1)
120+
XCTAssertGreaterThan(classificationResults.classifications[0].categories.count, 0)
121+
122+
let category = classificationResults.classifications[0].categories[0]
123+
XCTAssertEqual(category.label, "junco")
124+
XCTAssertEqual(category.score, 0.253016, accuracy: 0.001)
125+
}
126+
127+
private func gmlImage(withName name: String, ofType type: String) throws -> MLImage {
128+
let imagePath =
129+
try XCTUnwrap(TFLImageClassifierTests.bundle.path(forResource: name, ofType: type))
130+
let image = UIImage(contentsOfFile: imagePath)
131+
let imageForInference = try XCTUnwrap(image)
132+
let gmlImage = try XCTUnwrap(MLImage(image: imageForInference))
133+
134+
return gmlImage
135+
}
136+
}

0 commit comments

Comments
 (0)