Implementation of Tensorflow Lite model on Android
Z. QIU
Posted on February 18, 2021
Recently in some interview I have been asked about experience of implementing trained tensorflow models in android platform. I have tried one android project cloned from github which embedded a tflite model in it. However, I have not yet tried implementing my own model in an Android application. Thus I did such an exercise today and I successfully made my CNN model work on my Redmi Note 8 pro.
CNN model
Here is the code for training a cnn model with mnist data set. This model then is converted as tflite model and shall be implemented in Android application for recognizing hand-write digits.
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics,models
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
print(x_train[0,:,:])
## x_train.shape => (60000, 28, 28), y_train.shape => (60000,)
## x_test.shape => (10000, 28, 28), y_test.shape => (60000,)
x_train = tf.expand_dims(x_train, -1)
x_test = tf.expand_dims(x_test, -1)
yt = tf.squeeze(y_train)
y_train = tf.squeeze(y_train)
y_test = tf.squeeze(y_test)
print("Dataset info: ", x_train.shape, y_train.shape, x_test.shape, y_test.shape)
batch_size = 128
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(10000).batch(batch_size)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.batch(batch_size)
train_iter = iter(train_db)
sample = next(train_iter)
print(sample[0].shape, sample[1].shape)
## build a standard cnn model
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10))
model.summary()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
train_history = model.fit(train_db, epochs=10, validation_data=test_db)
## once the model has been trained, convert it to tflite model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open('qiu_mnist_model.tflite', 'wb') as f:
f.write(tflite_model)
Implementation in Android app
I refereed to this post for obtaining the original android project. I imported the kotlin version into my Android Studio. However, there were some bugs initially when I loaded my model into it.
My own model is located to asset
repository:
The most important thing for this work is the following Gradle setting:
After about 15min of debugging and code modifications, I successfully made my model work.
Check out the video (there is still accuracy issue):
I will upload the android project src code to my github repo once I finish cleaning the code and improve the performance.
reference
Posted on February 18, 2021
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.