Skip to content

Commit 1499364

Browse files
authored
feat: add object detection (android) (#52)
## Description <!-- Provide a concise and descriptive summary of the changes implemented in this PR. --> ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [ ] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent 962b3df commit 1499364

File tree

13 files changed

+610
-13
lines changed

13 files changed

+610
-13
lines changed

android/build.gradle

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
buildscript {
2+
ext {
3+
agp_version = '8.4.2'
4+
}
5+
26
// Buildscript is evaluated before everything else so we can't use getExtOrDefault
37
def kotlin_version = rootProject.ext.has("kotlinVersion") ? rootProject.ext.get("kotlinVersion") : project.properties["RnExecutorch_kotlinVersion"]
48

@@ -9,7 +13,7 @@ buildscript {
913
}
1014

1115
dependencies {
12-
classpath "com.android.tools.build:gradle:7.2.1"
16+
classpath "com.android.tools.build:gradle:$agp_version"
1317
// noinspection DifferentKotlinGradleVersion
1418
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
1519
}
@@ -95,7 +99,8 @@ dependencies {
9599
// For < 0.71, this will be from the local maven repo
96100
// For > 0.71, this will be replaced by `com.facebook.react:react-android:$version` by react gradle plugin
97101
//noinspection GradleDynamicVersion
98-
implementation "com.facebook.react:react-native:+"
102+
implementation "com.facebook.react:react-android:+"
103+
implementation 'org.opencv:opencv:4.10.0'
99104
implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
100105
implementation 'com.github.software-mansion:react-native-executorch:main-SNAPSHOT'
101106
implementation 'org.opencv:opencv:4.10.0'
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package com.swmansion.rnexecutorch
2+
3+
import android.util.Log
4+
import com.facebook.react.bridge.Arguments
5+
import com.facebook.react.bridge.Promise
6+
import com.facebook.react.bridge.ReactApplicationContext
7+
import com.facebook.react.bridge.WritableArray
8+
import com.swmansion.rnexecutorch.models.BaseModel
9+
import com.swmansion.rnexecutorch.utils.ETError
10+
import com.swmansion.rnexecutorch.utils.ImageProcessor
11+
import org.opencv.android.OpenCVLoader
12+
import com.swmansion.rnexecutorch.models.objectdetection.SSDLiteLargeModel
13+
import org.opencv.core.Mat
14+
15+
class ObjectDetection(reactContext: ReactApplicationContext) :
16+
NativeObjectDetectionSpec(reactContext) {
17+
18+
private lateinit var ssdLiteLarge: SSDLiteLargeModel
19+
20+
companion object {
21+
const val NAME = "ObjectDetection"
22+
}
23+
24+
init {
25+
if(!OpenCVLoader.initLocal()){
26+
Log.d("rn_executorch", "OpenCV not loaded")
27+
} else {
28+
Log.d("rn_executorch", "OpenCV loaded")
29+
}
30+
}
31+
32+
override fun loadModule(modelSource: String, promise: Promise) {
33+
try {
34+
ssdLiteLarge = SSDLiteLargeModel(reactApplicationContext)
35+
ssdLiteLarge.loadModel(modelSource)
36+
promise.resolve(0)
37+
} catch (e: Exception) {
38+
promise.reject(e.message!!, ETError.InvalidModelPath.toString())
39+
}
40+
}
41+
42+
override fun forward(input: String, promise: Promise) {
43+
try {
44+
val inputImage = ImageProcessor.readImage(input)
45+
val output = ssdLiteLarge.runModel(inputImage)
46+
val outputWritableArray: WritableArray = Arguments.createArray()
47+
output.map { detection ->
48+
detection.toWritableMap()
49+
}.forEach { writableMap ->
50+
outputWritableArray.pushMap(writableMap)
51+
}
52+
promise.resolve(outputWritableArray)
53+
} catch(e: Exception){
54+
promise.reject(e.message!!, e.message)
55+
}
56+
}
57+
58+
override fun getName(): String {
59+
return NAME
60+
}
61+
}

android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ class RnExecutorchPackage : TurboReactPackage() {
2121
StyleTransfer(reactContext)
2222
} else if (name == Classification.NAME) {
2323
Classification(reactContext)
24-
}
24+
} else if (name == ObjectDetection.NAME) {
25+
ObjectDetection(reactContext)
26+
}
2527
else {
2628
null
2729
}
@@ -63,6 +65,15 @@ class RnExecutorchPackage : TurboReactPackage() {
6365
false, // isCxxModule
6466
true
6567
)
68+
69+
moduleInfos[ObjectDetection.NAME] = ReactModuleInfo(
70+
ObjectDetection.NAME,
71+
ObjectDetection.NAME,
72+
false, // canOverrideExistingModule
73+
false, // needsEagerInit
74+
false, // isCxxModule
75+
true
76+
)
6677
moduleInfos
6778
}
6879
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package com.swmansion.rnexecutorch.models.objectdetection
2+
3+
import com.facebook.react.bridge.ReactApplicationContext
4+
import com.swmansion.rnexecutorch.utils.ImageProcessor
5+
import org.opencv.core.Mat
6+
import org.opencv.core.Size
7+
import org.opencv.imgproc.Imgproc
8+
import com.swmansion.rnexecutorch.models.BaseModel
9+
import com.swmansion.rnexecutorch.utils.Bbox
10+
import com.swmansion.rnexecutorch.utils.CocoLabel
11+
import com.swmansion.rnexecutorch.utils.Detection
12+
import com.swmansion.rnexecutorch.utils.nms
13+
import org.pytorch.executorch.EValue
14+
15+
const val detectionScoreThreshold = .7f
16+
const val iouThreshold = .55f
17+
18+
class SSDLiteLargeModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Mat, Array<Detection>>(reactApplicationContext) {
19+
private var heightRatio: Float = 1.0f
20+
private var widthRatio: Float = 1.0f
21+
22+
private fun getModelImageSize(): Size {
23+
val inputShape = module.getInputShape(0)
24+
val width = inputShape[inputShape.lastIndex]
25+
val height = inputShape[inputShape.lastIndex - 1]
26+
27+
return Size(height.toDouble(), width.toDouble())
28+
}
29+
30+
override fun preprocess(input: Mat): EValue {
31+
this.widthRatio = (input.size().width / getModelImageSize().width).toFloat()
32+
this.heightRatio = (input.size().height / getModelImageSize().height).toFloat()
33+
Imgproc.resize(input, input, getModelImageSize())
34+
return ImageProcessor.matToEValue(input, module.getInputShape(0))
35+
}
36+
37+
override fun runModel(input: Mat): Array<Detection> {
38+
val modelInput = preprocess(input)
39+
val modelOutput = forward(modelInput)
40+
return postprocess(modelOutput)
41+
}
42+
43+
override fun postprocess(output: Array<EValue>): Array<Detection> {
44+
val scoresTensor = output[1].toTensor()
45+
val numel = scoresTensor.numel()
46+
val bboxes = output[0].toTensor().dataAsFloatArray
47+
val scores = scoresTensor.dataAsFloatArray
48+
val labels = output[2].toTensor().dataAsFloatArray
49+
50+
val detections: MutableList<Detection> = mutableListOf();
51+
for (idx in 0 until numel.toInt()) {
52+
val score = scores[idx]
53+
if (score < detectionScoreThreshold) {
54+
continue
55+
}
56+
val bbox = Bbox(
57+
bboxes[idx * 4 + 0] * this.widthRatio,
58+
bboxes[idx * 4 + 1] * this.heightRatio,
59+
bboxes[idx * 4 + 2] * this.widthRatio,
60+
bboxes[idx * 4 + 3] * this.heightRatio
61+
)
62+
val label = labels[idx]
63+
detections.add(
64+
Detection(bbox, score, CocoLabel.fromId(label.toInt())!!)
65+
)
66+
}
67+
68+
val detectionsPostNms = nms(detections, iouThreshold);
69+
return detectionsPostNms.toTypedArray()
70+
}
71+
}
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
package com.swmansion.rnexecutorch.utils
2+
3+
import com.facebook.react.bridge.Arguments
4+
import com.facebook.react.bridge.WritableMap
5+
6+
fun nms(
7+
detections: MutableList<Detection>,
8+
iouThreshold: Float
9+
): List<Detection> {
10+
if (detections.isEmpty()) {
11+
return emptyList()
12+
}
13+
14+
// Sort detections first by label, then by score (descending)
15+
val sortedDetections = detections.sortedWith(compareBy({ it.label }, { -it.score }))
16+
17+
val result = mutableListOf<Detection>()
18+
19+
// Process NMS for each label group
20+
var i = 0
21+
while (i < sortedDetections.size) {
22+
val currentLabel = sortedDetections[i].label
23+
24+
// Collect detections for the current label
25+
val labelDetections = mutableListOf<Detection>()
26+
while (i < sortedDetections.size && sortedDetections[i].label == currentLabel) {
27+
labelDetections.add(sortedDetections[i])
28+
i++
29+
}
30+
31+
// Filter out detections with high IoU
32+
val filteredLabelDetections = mutableListOf<Detection>()
33+
while (labelDetections.isNotEmpty()) {
34+
val current = labelDetections.removeAt(0)
35+
filteredLabelDetections.add(current)
36+
37+
// Remove detections that overlap with the current detection above the IoU threshold
38+
val iterator = labelDetections.iterator()
39+
while (iterator.hasNext()) {
40+
val other = iterator.next()
41+
if (calculateIoU(current.bbox, other.bbox) > iouThreshold) {
42+
iterator.remove() // Remove detection if IoU is above threshold
43+
}
44+
}
45+
}
46+
47+
// Add the filtered detections to the result
48+
result.addAll(filteredLabelDetections)
49+
}
50+
51+
return result
52+
}
53+
54+
fun calculateIoU(bbox1: Bbox, bbox2: Bbox): Float {
55+
val x1 = maxOf(bbox1.x1, bbox2.x1)
56+
val y1 = maxOf(bbox1.y1, bbox2.y1)
57+
val x2 = minOf(bbox1.x2, bbox2.x2)
58+
val y2 = minOf(bbox1.y2, bbox2.y2)
59+
60+
val intersectionArea = maxOf(0f, x2 - x1) * maxOf(0f, y2 - y1)
61+
val bbox1Area = (bbox1.x2 - bbox1.x1) * (bbox1.y2 - bbox1.y1)
62+
val bbox2Area = (bbox2.x2 - bbox2.x1) * (bbox2.y2 - bbox2.y1)
63+
64+
val unionArea = bbox1Area + bbox2Area - intersectionArea
65+
return if (unionArea == 0f) 0f else intersectionArea / unionArea
66+
}
67+
68+
69+
data class Bbox(
70+
val x1: Float,
71+
val y1: Float,
72+
val x2: Float,
73+
val y2: Float
74+
) {
75+
fun toWritableMap(): WritableMap {
76+
val map = Arguments.createMap()
77+
map.putDouble("x1", x1.toDouble())
78+
map.putDouble("x2", x2.toDouble())
79+
map.putDouble("y1", y1.toDouble())
80+
map.putDouble("y2", y2.toDouble())
81+
return map
82+
}
83+
}
84+
85+
86+
data class Detection(
87+
val bbox: Bbox,
88+
val score: Float,
89+
val label: CocoLabel,
90+
) {
91+
fun toWritableMap(): WritableMap {
92+
val map = Arguments.createMap()
93+
map.putMap("bbox", bbox.toWritableMap())
94+
map.putDouble("score", score.toDouble())
95+
map.putString("label", label.name)
96+
return map
97+
}
98+
}
99+
100+
enum class CocoLabel(val id: Int) {
101+
PERSON(1),
102+
BICYCLE(2),
103+
CAR(3),
104+
MOTORCYCLE(4),
105+
AIRPLANE(5),
106+
BUS(6),
107+
TRAIN(7),
108+
TRUCK(8),
109+
BOAT(9),
110+
TRAFFIC_LIGHT(10),
111+
FIRE_HYDRANT(11),
112+
STREET_SIGN(12),
113+
STOP_SIGN(13),
114+
PARKING(14),
115+
BENCH(15),
116+
BIRD(16),
117+
CAT(17),
118+
DOG(18),
119+
HORSE(19),
120+
SHEEP(20),
121+
COW(21),
122+
ELEPHANT(22),
123+
BEAR(23),
124+
ZEBRA(24),
125+
GIRAFFE(25),
126+
HAT(26),
127+
BACKPACK(27),
128+
UMBRELLA(28),
129+
SHOE(29),
130+
EYE(30),
131+
HANDBAG(31),
132+
TIE(32),
133+
SUITCASE(33),
134+
FRISBEE(34),
135+
SKIS(35),
136+
SNOWBOARD(36),
137+
SPORTS(37),
138+
KITE(38),
139+
BASEBALL(39),
140+
SKATEBOARD(41),
141+
SURFBOARD(42),
142+
TENNIS_RACKET(43),
143+
BOTTLE(44),
144+
PLATE(45),
145+
WINE_GLASS(46),
146+
CUP(47),
147+
FORK(48),
148+
KNIFE(49),
149+
SPOON(50),
150+
BOWL(51),
151+
BANANA(52),
152+
APPLE(53),
153+
SANDWICH(54),
154+
ORANGE(55),
155+
BROCCOLI(56),
156+
CARROT(57),
157+
HOT_DOG(58),
158+
PIZZA(59),
159+
DONUT(60),
160+
CAKE(61),
161+
CHAIR(62),
162+
COUCH(63),
163+
POTTED_PLANT(64),
164+
BED(65),
165+
MIRROR(66),
166+
DINING_TABLE(67),
167+
WINDOW(68),
168+
DESK(69),
169+
TOILET(70),
170+
DOOR(71),
171+
TV(72),
172+
LAPTOP(73),
173+
MOUSE(74),
174+
REMOTE(75),
175+
KEYBOARD(76),
176+
CELL_PHONE(77),
177+
MICROWAVE(78),
178+
OVEN(79),
179+
TOASTER(80),
180+
SINK(81),
181+
REFRIGERATOR(82),
182+
BLENDER(83),
183+
BOOK(84),
184+
CLOCK(85),
185+
VASE(86),
186+
SCISSORS(87),
187+
TEDDY_BEAR(88),
188+
HAIR_DRIER(89),
189+
TOOTHBRUSH(90),
190+
HAIR_BRUSH(91);
191+
192+
companion object {
193+
private val idToLabelMap = values().associateBy(CocoLabel::id)
194+
fun fromId(id: Int): CocoLabel? = idToLabelMap[id]
195+
}
196+
}

0 commit comments

Comments
 (0)