Skip to content

Commit 15b9666

Browse files
Merge pull request #311 from khanhlvg:sound_classification_task_lib
PiperOrigin-RevId: 373541129
2 parents 4bf9928 + 001cb02 commit 15b9666

File tree

10 files changed

+127
-482
lines changed

10 files changed

+127
-482
lines changed

lite/examples/sound_classification/android/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Sound Classifier Android sample.
22

3+
This Android application demonstrates how to classify sound on-device. It uses:
4+
* [TFLite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview)
5+
* [YAMNet](https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1), an audio event classification model.
36

47
## Requirements
58

@@ -36,4 +39,4 @@ Re-installing the app may require you to uninstall the previous installations.
3639
## Resources used:
3740

3841
* [TensorFlow Lite](https://www.tensorflow.org/lite)
39-
* [Teachable Machine Audio Project](https://teachablemachine.withgoogle.com/train/audio)
42+
* [YAMNet audio classification model](https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1)

lite/examples/sound_classification/android/app/build.gradle

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,16 @@ apply from: 'download_model.gradle'
4848
dependencies {
4949
implementation fileTree(dir: "libs", include: ["*.jar"])
5050
implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
51-
implementation "androidx.core:core-ktx:1.3.1"
51+
implementation "androidx.core:core-ktx:1.3.2"
5252
implementation "androidx.appcompat:appcompat:1.2.0"
53-
implementation "androidx.lifecycle:lifecycle-common-java8:2.2.0"
54-
implementation "androidx.constraintlayout:constraintlayout:2.0.1"
55-
implementation "androidx.recyclerview:recyclerview:1.1.0"
56-
implementation "com.google.android.material:material:1.2.1"
53+
implementation "androidx.lifecycle:lifecycle-common-java8:2.3.1"
54+
implementation "androidx.constraintlayout:constraintlayout:2.0.4"
55+
implementation "androidx.recyclerview:recyclerview:1.2.0"
56+
implementation "com.google.android.material:material:1.3.0"
5757

58-
implementation "org.tensorflow:tensorflow-lite:2.3.0"
59-
implementation "org.tensorflow:tensorflow-lite-select-tf-ops:2.3.0"
60-
implementation "org.tensorflow:tensorflow-lite-support:0.1.0"
61-
implementation "org.tensorflow:tensorflow-lite-metadata:0.1.0"
58+
implementation 'org.tensorflow:tensorflow-lite-task-audio:0.2.0-rc2'
6259

63-
testImplementation "junit:junit:4.13"
60+
testImplementation "junit:junit:4.13.2"
6461
androidTestImplementation "androidx.test.ext:junit:1.1.2"
6562
androidTestImplementation "androidx.test.espresso:espresso-core:3.3.0"
6663
}

lite/examples/sound_classification/android/app/download_model.gradle

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
task downloadSoundClassificationModelFile(type: Download) {
2-
src 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/sound_classification/snap_clap.tflite'
3-
dest project.ext.ASSET_DIR + '/sound_classifier.tflite'
2+
src 'https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1?lite-format=tflite'
3+
dest project.ext.ASSET_DIR + '/yamnet.tflite'
44
overwrite false
55
}
66

lite/examples/sound_classification/android/app/src/main/assets/labels.txt

Lines changed: 0 additions & 3 deletions
This file was deleted.

lite/examples/sound_classification/android/app/src/main/java/org/tensorflow/lite/examples/soundclassifier/MainActivity.kt

Lines changed: 100 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,84 +18,150 @@ package org.tensorflow.lite.examples.soundclassifier
1818

1919
import android.Manifest
2020
import android.content.pm.PackageManager
21+
import android.media.AudioRecord
2122
import android.os.Build
2223
import android.os.Bundle
24+
import android.os.Handler
25+
import android.os.HandlerThread
2326
import android.util.Log
2427
import android.view.WindowManager
2528
import androidx.annotation.RequiresApi
2629
import androidx.appcompat.app.AppCompatActivity
2730
import androidx.core.content.ContextCompat
31+
import androidx.core.os.HandlerCompat
2832
import org.tensorflow.lite.examples.soundclassifier.databinding.ActivityMainBinding
33+
import org.tensorflow.lite.task.audio.classifier.AudioClassifier
34+
2935

3036
class MainActivity : AppCompatActivity() {
3137
private val probabilitiesAdapter by lazy { ProbabilitiesAdapter() }
3238

33-
private lateinit var soundClassifier: SoundClassifier
39+
private var audioClassifier: AudioClassifier? = null
40+
private var audioRecord: AudioRecord? = null
41+
private var classificationInterval = 500L // how often should classification run in milli-secs
42+
private lateinit var handler: Handler // background thread handler to run classification
3443

3544
override fun onCreate(savedInstanceState: Bundle?) {
3645
super.onCreate(savedInstanceState)
3746

3847
val binding = ActivityMainBinding.inflate(layoutInflater)
3948
setContentView(binding.root)
4049

41-
soundClassifier = SoundClassifier(this, SoundClassifier.Options()).also {
42-
it.lifecycleOwner = this
43-
}
44-
4550
with(binding) {
4651
recyclerView.apply {
47-
setHasFixedSize(true)
48-
adapter = probabilitiesAdapter.apply {
49-
labelList = soundClassifier.labelList
50-
}
52+
setHasFixedSize(false)
53+
adapter = probabilitiesAdapter
5154
}
5255

56+
// Input switch to turn on/off classification
5357
keepScreenOn(inputSwitch.isChecked)
5458
inputSwitch.setOnCheckedChangeListener { _, isChecked ->
55-
soundClassifier.isPaused = !isChecked
59+
if (isChecked) startAudioClassification() else stopAudioClassification()
5660
keepScreenOn(isChecked)
5761
}
5862

59-
overlapFactorSlider.value = soundClassifier.overlapFactor
60-
overlapFactorSlider.addOnChangeListener { _, value, _ ->
61-
soundClassifier.overlapFactor = value
63+
// Slider which control how often the classification task should run
64+
classificationIntervalSlider.value = classificationInterval.toFloat()
65+
classificationIntervalSlider.setLabelFormatter { value: Float ->
66+
"${value.toInt()} ms"
6267
}
63-
}
64-
65-
soundClassifier.probabilities.observe(this) { resultMap ->
66-
if (resultMap.isEmpty() || resultMap.size > soundClassifier.labelList.size) {
67-
Log.w(TAG, "Invalid size of probability output! (size: ${resultMap.size})")
68-
return@observe
68+
classificationIntervalSlider.addOnChangeListener { _, value, _ ->
69+
classificationInterval = value.toLong()
70+
stopAudioClassification()
71+
startAudioClassification()
6972
}
70-
probabilitiesAdapter.probabilityMap = resultMap
71-
probabilitiesAdapter.notifyDataSetChanged()
7273
}
7374

75+
// Create a handler to run classification in a background thread
76+
val handlerThread = HandlerThread("backgroundThread")
77+
handlerThread.start()
78+
handler = HandlerCompat.createAsync(handlerThread.looper)
79+
80+
// Request microphone permission and start running classification
7481
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
7582
requestMicrophonePermission()
7683
} else {
77-
soundClassifier.start()
84+
startAudioClassification()
7885
}
86+
87+
}
88+
89+
private fun startAudioClassification() {
90+
// If the audio classifier is initialized and running, do nothing.
91+
if (audioClassifier != null) return;
92+
93+
// Initialize the audio classifier
94+
val classifier = AudioClassifier.createFromFile(this, MODEL_FILE)
95+
val audioTensor = classifier.createInputTensorAudio()
96+
97+
// Initialize the audio recorder
98+
val record = classifier.createAudioRecord()
99+
record.startRecording()
100+
101+
// Define the classification runnable
102+
val run = object : Runnable {
103+
override fun run() {
104+
val startTime = System.currentTimeMillis()
105+
106+
// Load the latest audio sample
107+
audioTensor.load(record)
108+
val output = classifier.classify(audioTensor)
109+
110+
// Filter out results above a certain threshold, and sort them descendingly
111+
val filteredModelOutput = output[0].categories.filter {
112+
it.score > MINIMUM_DISPLAY_THRESHOLD
113+
}.sortedBy {
114+
-it.score
115+
}
116+
117+
val finishTime = System.currentTimeMillis()
118+
119+
Log.d(TAG, "Latency = ${finishTime - startTime}ms")
120+
121+
// Updating the UI
122+
runOnUiThread {
123+
probabilitiesAdapter.categoryList = filteredModelOutput
124+
probabilitiesAdapter.notifyDataSetChanged()
125+
}
126+
127+
// Rerun the classification after a certain interval
128+
handler.postDelayed(this, classificationInterval)
129+
}
130+
}
131+
132+
// Start the classification process
133+
handler.post(run)
134+
135+
// Save the instances we just created for use later
136+
audioClassifier = classifier
137+
audioRecord = record
138+
}
139+
140+
private fun stopAudioClassification() {
141+
handler.removeCallbacksAndMessages(null)
142+
audioRecord?.stop()
143+
audioRecord = null
144+
audioClassifier = null
79145
}
80146

81147
override fun onTopResumedActivityChanged(isTopResumedActivity: Boolean) {
82148
// Handles "top" resumed event on multi-window environment
83149
if (isTopResumedActivity) {
84-
soundClassifier.start()
150+
startAudioClassification()
85151
} else {
86-
soundClassifier.stop()
152+
stopAudioClassification()
87153
}
88154
}
89155

90156
override fun onRequestPermissionsResult(
91-
requestCode: Int,
92-
permissions: Array<out String>,
93-
grantResults: IntArray
157+
requestCode: Int,
158+
permissions: Array<out String>,
159+
grantResults: IntArray
94160
) {
95161
if (requestCode == REQUEST_RECORD_AUDIO) {
96162
if (grantResults.isNotEmpty() && grantResults[0] == PackageManager.PERMISSION_GRANTED) {
97163
Log.i(TAG, "Audio permission granted :)")
98-
soundClassifier.start()
164+
startAudioClassification()
99165
} else {
100166
Log.e(TAG, "Audio permission not granted :(")
101167
}
@@ -105,11 +171,11 @@ class MainActivity : AppCompatActivity() {
105171
@RequiresApi(Build.VERSION_CODES.M)
106172
private fun requestMicrophonePermission() {
107173
if (ContextCompat.checkSelfPermission(
108-
this,
109-
Manifest.permission.RECORD_AUDIO
110-
) == PackageManager.PERMISSION_GRANTED
174+
this,
175+
Manifest.permission.RECORD_AUDIO
176+
) == PackageManager.PERMISSION_GRANTED
111177
) {
112-
soundClassifier.start()
178+
startAudioClassification()
113179
} else {
114180
requestPermissions(arrayOf(Manifest.permission.RECORD_AUDIO), REQUEST_RECORD_AUDIO)
115181
}
@@ -125,5 +191,7 @@ class MainActivity : AppCompatActivity() {
125191
companion object {
126192
const val REQUEST_RECORD_AUDIO = 1337
127193
private const val TAG = "AudioDemo"
194+
private const val MODEL_FILE = "yamnet.tflite"
195+
private const val MINIMUM_DISPLAY_THRESHOLD: Float = 0.3f
128196
}
129197
}

lite/examples/sound_classification/android/app/src/main/java/org/tensorflow/lite/examples/soundclassifier/ProbabilitiesAdapter.kt

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ import android.view.ViewGroup
2323
import android.view.animation.AccelerateDecelerateInterpolator
2424
import androidx.recyclerview.widget.RecyclerView
2525
import org.tensorflow.lite.examples.soundclassifier.databinding.ItemProbabilityBinding
26+
import org.tensorflow.lite.support.label.Category
2627

2728
internal class ProbabilitiesAdapter : RecyclerView.Adapter<ProbabilitiesAdapter.ViewHolder>() {
28-
var labelList = emptyList<String>()
29-
var probabilityMap = mapOf<String, Float>()
29+
var categoryList: List<Category> = emptyList()
3030

3131
override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): ViewHolder {
3232
val binding =
@@ -35,22 +35,21 @@ internal class ProbabilitiesAdapter : RecyclerView.Adapter<ProbabilitiesAdapter.
3535
}
3636

3737
override fun onBindViewHolder(holder: ViewHolder, position: Int) {
38-
val label = labelList[position]
39-
val probability = probabilityMap[label] ?: 0f
40-
holder.bind(position, label, probability)
38+
val category = categoryList[position]
39+
holder.bind(position, category.label, category.score)
4140
}
4241

43-
override fun getItemCount() = labelList.size
42+
override fun getItemCount() = categoryList.size
4443

4544
class ViewHolder(private val binding: ItemProbabilityBinding) :
4645
RecyclerView.ViewHolder(binding.root) {
47-
fun bind(position: Int, label: String, probability: Float) {
46+
fun bind(position: Int, label: String, score: Float) {
4847
with(binding) {
4948
labelTextView.text = label
5049
progressBar.progressBackgroundTintList = progressColorPairList[position % 3].first
5150
progressBar.progressTintList = progressColorPairList[position % 3].second
5251

53-
val newValue = (probability * 100).toInt()
52+
val newValue = (score * 100).toInt()
5453
// If you don't want to animate, you can write like `progressBar.progress = newValue`.
5554
val animation =
5655
ObjectAnimator.ofInt(progressBar, "progress", progressBar.progress, newValue)

0 commit comments

Comments
 (0)