Object detection on Android
If you can run object detection AI on an Android device, users around the world can use the convenient object detection function without server communication.
Yolov8 is a popular object detection AI.
Android is the mobile OS with the most users in the world.
This article describes how to perform yolov8 object detection on an android device.
I am referring to the code from the repository below.
↑ This is the code that performs real-time detection with the camera of an android device.
I also created a simple sample, so please refer to it if you like.
This is a sample of only the detection part without camera functions etc.
Step 1: Convert from Pytorch format to tflite format
YOLOv8 is built in pytorch format. Convert
this to tflite so that it can be used on android.
Installing YOLOv8
Install a framework called Ultralytics.
Yolov8 is included in this framework.
pip install ultralytics
convert to tflite
Convert with conversion code.
The code below will download the weights of the pre-trained model.
If you have a weight checkpoint file for a model trained with your own custom data, replace the yolov8s.pt part.
from ultralytics import YOLO
model = YOLO('yolov8s.pt')
model.export(format="tflite")
yolov8s_saved_model/yolov8s_float16.tflite will be generated, so use this.
If a conversion error occurs. . .
If the following error occurs, it is due to the version of tensorflow, so install the compatible version.
ImportError: generic_type: cannot initialize type “StatusCode”: an object with that name is already defined
For example, change tensorflow to the following version.
pip install tensorflow==2.13.0
run tflite file on android
From here on, we will run the yolov8 tflite file in the android studio project.
Add tflite file to project
Create an assets directory in the app directory of the android studio project (File → New → Folder → Asset Folder) and add the tflite file (yolov8s_float32.tflite) and labels.txt. You can add it by copying and pasting.
labels.txt is a text file in which the class names of the YOLOv8 model are described as follows.
If you have set up a custom class, write that class.
The default YOLOv8 pre-trained model is as follows.
labels.txt
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
Installing tflite
Add the following to dependencies in app/build.gradle.kts to install the tflite framework.
app/build.gradle.kts
implementation("org.tensorflow:tensorflow-lite:2.14.0")
implementation("org.tensorflow:tensorflow-lite-support:0.4.4")
After adding the above, press Sync Now to install.
Importing the required modules
import org.tensorflow.lite.DataType
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.gpu.CompatibilityList
import org.tensorflow.lite.gpu.GpuDelegate
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.common.ops.CastOp
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.io.BufferedReader
import java.io.IOException
import java.io.InputStream
import java.io.InputStreamReader
Required class properties
private val modelPath = "yolov8s_float32.tflite"
private val labelPath = "labels.txt"
private var interpreter: Interpreter? = null
private var tensorWidth = 0
private var tensorHeight = 0
private var numChannel = 0
private var numElements = 0
private var labels = mutableListOf<String>()
private val imageProcessor = ImageProcessor.Builder()
.add(NormalizeOp(INPUT_MEAN, INPUT_STANDARD_DEVIATION))
.add(CastOp(INPUT_IMAGE_TYPE))
.build() // preprocess input
companion object {
private const val INPUT_MEAN = 0f
private const val INPUT_STANDARD_DEVIATION = 255f
private val INPUT_IMAGE_TYPE = DataType.FLOAT32
private val OUTPUT_IMAGE_TYPE = DataType.FLOAT32
private const val CONFIDENCE_THRESHOLD = 0.3F
private const val IOU_THRESHOLD = 0.5F
}
Initializing the model
Initialize the tflite model. Get the model file and pass it to tflite’s Interpreter
. Optionally passes the number of threads to use.
If you use it in a class other than Activity, you need to pass the context to the class.
val model = FileUtil.loadMappedFile(context, modelPath)
val options = Interpreter.Options()
options.numThreads = 4
interpreter = Interpreter(model, options)
Get yolov8s input and output shapes from Interpreter.
val inputShape = interpreter.getInputTensor(0).shape()
val outputShape = interpreter.getOutputTensor(0).shape()
tensorWidth = inputShape[1]
tensorHeight = inputShape[2]
numChannel = outputShape[1]
numElements = outputShape[2]
Read the class name from the label.txt file.
InputStream and InputStreamReader must be explicitly closed.
try {
val inputStream: InputStream = context.assets.open(labelPath)
val reader = BufferedReader(InputStreamReader(inputStream))
var line: String? = reader.readLine()
while (line != null && line != "") {
labels.add(line)
line = reader.readLine()
}
reader.close()
inputStream.close()
} catch (e: IOException) {
e.printStackTrace()
}
Input image and execute
The input is a bitmap, but the following preprocessing is performed according to the input format of the model.
1. Resize to match the input shape of the model
2. Make it a tensor
3. Normalize the pixel value by dividing it by 255 (make it a value in the range of 0 to 1)
4. Cast to the input type of the model
5. Input get imageBuffer for
val resizedBitmap = Bitmap.createScaledBitmap(bitmap, tensorWidth, tensorHeight, false)
val tensorImage = TensorImage(DataType.FLOAT32)
tensorImage.load(resizedBitmap)
val processedImage = imageProcessor.process(tensorImage)
val imageBuffer = processedImage.buffer
I will do it.
Create an output tensor buffer that matches the output shape of the model, and
pass it to the interpreter along with the input imageBuffer above for execution.
val output = TensorBuffer.createFixedSize(intArrayOf(1 , numChannel, numElements), OUTPUT_IMAGE_TYPE)
interpreter.run(imageBuffer, output.buffer)
Post-process the output
The output box is treated as a BoudingBox class.
It’s a class with a class, a box, and a confidence level.
x1 y1 is the starting point.
x2 y2 is the end point.
cx cy is center.
w is the width.
h is the height.
is.
data class BoundingBox(
val x1: Float,
val y1: Float,
val x2: Float,
val y2: Float,
val cx: Float,
val cy: Float,
val w: Float,
val h: Float,
val cnf: Float,
val cls: Int,
val clsName: String
)
The following process selects one with high reliability from among the many output box candidates.
1. Extract boxes with higher confidence than the confidence threshold.
2. Among the overlapping boxes, leave the box with the highest reliability. (nms)
private fun bestBox(array: FloatArray) : List<BoundingBox>? {
val boundingBoxes = mutableListOf<BoundingBox>()
for (c in 0 until numElements) {
var maxConf = -1.0f
var maxIdx = -1
var j = 4
var arrayIdx = c + numElements * j
while (j < numChannel){
if (array[arrayIdx] > maxConf) {
maxConf = array[arrayIdx]
maxIdx = j - 4
}
j++
arrayIdx += numElements
}
if (maxConf > CONFIDENCE_THRESHOLD) {
val clsName = labels[maxIdx]
val cx = array[c] // 0
val cy = array[c + numElements] // 1
val w = array[c + numElements * 2]
val h = array[c + numElements * 3]
val x1 = cx - (w/2F)
val y1 = cy - (h/2F)
val x2 = cx + (w/2F)
val y2 = cy + (h/2F)
if (x1 < 0F || x1 > 1F) continue
if (y1 < 0F || y1 > 1F) continue
if (x2 < 0F || x2 > 1F) continue
if (y2 < 0F || y2 > 1F) continue
boundingBoxes.add(
BoundingBox(
x1 = x1, y1 = y1, x2 = x2, y2 = y2,
cx = cx, cy = cy, w = w, h = h,
cnf = maxConf, cls = maxIdx, clsName = clsName
)
)
}
}
if (boundingBoxes.isEmpty()) return null
return applyNMS(boundingBoxes)
}
private fun applyNMS(boxes: List<BoundingBox>) : MutableList<BoundingBox> {
val sortedBoxes = boxes.sortedByDescending { it.cnf }.toMutableList()
val selectedBoxes = mutableListOf<BoundingBox>()
while(sortedBoxes.isNotEmpty()) {
val first = sortedBoxes.first()
selectedBoxes.add(first)
sortedBoxes.remove(first)
val iterator = sortedBoxes.iterator()
while (iterator.hasNext()) {
val nextBox = iterator.next()
val iou = calculateIoU(first, nextBox)
if (iou >= IOU_THRESHOLD) {
iterator.remove()
}
}
}
return selectedBoxes
}
private fun calculateIoU(box1: BoundingBox, box2: BoundingBox): Float {
val x1 = maxOf(box1.x1, box2.x1)
val y1 = maxOf(box1.y1, box2.y1)
val x2 = minOf(box1.x2, box2.x2)
val y2 = minOf(box1.y2, box2.y2)
val intersectionArea = maxOf(0F, x2 - x1) * maxOf(0F, y2 - y1)
val box1Area = box1.w * box1.h
val box2Area = box2.w * box2.h
return intersectionArea / (box1Area + box2Area - intersectionArea)
}
At this point, you will get the output of yolov8.
val bestBoxes = bestBox(output.floatArray)
Draw the output box to the image
fun drawBoundingBoxes(bitmap: Bitmap, boxes: List<BoundingBox>): Bitmap {
val mutableBitmap = bitmap.copy(Bitmap.Config.ARGB_8888, true)
val canvas = Canvas(mutableBitmap)
val paint = Paint().apply {
color = Color.RED
style = Paint.Style.STROKE
strokeWidth = 8f
}
val textPaint = Paint().apply {
color = Color.WHITE
textSize = 40f
typeface = Typeface.DEFAULT_BOLD
}
for (box in boxes) {
val rect = RectF(
box.x1 * mutableBitmap.width,
box.y1 * mutableBitmap.height,
box.x2 * mutableBitmap.width,
box.y2 * mutableBitmap.height
)
canvas.drawRect(rect, paint)
canvas.drawText(box.clsName, rect.left, rect.bottom, textPaint)
}
return mutableBitmap
}
When things don’t go well
There were many cases where the model path was incorrect and the interpreter was null, so it might be a good idea to check that.
🐣
I’m a freelance engineer.
Work consultation
Please feel free to contact us with a brief development description.
rockyshikoku@gmail.com
I am creating applications using machine learning and AR technology.
I send machine learning / AR related information.