Deploying Machine learning models on Mobile with Tensorflow Lite and Firebase M.L Kit
Oluwafemi Tairu
Posted on April 2, 2020
In my last article, I shared how to deploy Machine learning models via an A.P.I.
In this article, I will share with you on how to deploy models using Tensorflow Lite and Firebase M.L Kit with Mobile Apps.
Deploying models via A.P.I is fine but there are multiple reasons why that might not suit your need or that of your organisation. Some of them include:
- Poor internet connection or no access to an internet connection for users of your app.
- Data Privacy.
- Need for faster inference.
Whatever your reason might be, changing your model availability from A.P.I's to on-device won't cost you much time. Your previous codes can stay the same with only an addition of some few lines.
Convert your previously saved model
import tensorflow as tf
# Convert the model.
converter = tf.lite.TFLiteConverter.from_saved_model(MODEL_DIR)
tflite_model = converter.convert()
open("plant_ai.tflite", "wb").write(tflite_model)
Convert new models
import tensorflow as tf
# saving your deep learning model
model.save('plant_ai_model.h5')
# Convert the model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("plant_ai.tflite", "wb").write(tflite_model)
For this article, I will be using a deep learning model for plant disease detection. I wrote about it here.
The modified code can be found here
Integrating with a mobile app
Our model is now ready for deployment to a mobile app. Bundling a model (in a .tflite format ) means you can perform prediction without the use of the internet which is good for some solutions. However, bundling your .tflite models with your apps this will increase app size.
For this article, I will be using Flutter but you can do the same using Java or Kotlin following Google's Documentation or Swift or Objective-C with Google's Documentation. Flutter is Google’s UI toolkit for building beautiful, natively compiled applications for mobile, web, and desktop from a single codebase - Flutter.
If you are new to Flutter, you can follow this link to get started with your first Flutter application. Your Flutter application structure should look like this after creation.
Next step is to install the necessary packages. This packages can be simply installed in Flutter by putting them in the pubspec.yaml
file.
The following packages were used for this app:
- mlkit: A Flutter plugin to use the Firebase ML Kit.
- image_picker: A Flutter plugin for iOS and Android for picking images from the image library, and taking new pictures with the camera.
- toast (Optional): A Flutter Toast plugin.
Next, you will have to add Firebase to your android or iOS project following the guide here for android and the guide here for iOS. Do make sure you add the "google-services.json" to your project folder.
Next, we would create a folder called assets in our project's root directory and then copy our .tflite model and label.text to that folder.
We would need to make this accessible in our Flutter project by modifying our pubspec.yaml
file as shown below.
I will skip the interface design aspect has you can have any interface for your app.
Let's import the necessary packages (which we installed earlier)
import 'package:flutter/material.dart';
import 'dart:io';
import 'dart:typed_data';
import 'package:flutter/services.dart';
import 'package:image_picker/image_picker.dart';
import 'package:mlkit/mlkit.dart';
import 'package:toast/toast.dart';
import 'package:image/image.dart' as img;
We will declare some variables and constants first at the beginning of our code.
FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.instance;
FirebaseModelManager manager = FirebaseModelManager.instance;
File selectedImageFile;
List<String> modelLabels = [];
Map<String, double> predictions = Map<String, double>();
int imageDim = 256;
List<int> inputDims = [1, 256, 256, 3];
List<int> outputDims = [1, 15];
The first two lines show the instantiation of the FirebaseModel Interpreter and Manager. The imageDim
(height and width 256x256) is the size of the image required by the model.
Then the inputDims
is the required input shape of our model while outputDims
is the output shape. This value may vary depending on the model used. If someone else built the model, you can easily check for these by running the Python code below.
Load TFLite model and allocate tensors.
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path='plant_ai_lite_model.tflite')
# OR interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
The value for input_details
in this case is:
[{
'name': 'conv2d_input',
'index': 1,
'shape': array([1, 256, 256, 3], dtype = int32),
'dtype': <class 'numpy.float32'> ,
'quantization': (0.0, 0)
}]
while that of the output_details
is:
[{
'name': 'Identity',
'index': 0,
'shape': array([1, 15], dtype = int32),
'dtype': <class 'numpy.float32'> ,
'quantization': (0.0, 0)
}]
Do take note of the value of 'shape' in the two results above. That is what determines the value for 'inputDims' and 'outputDims' above.
Next, we would load the model and labels using the FirebaseModel manager. Firebase M.L Kit gives us the option of either loading local model on the device or using one hosted in the Cloud. You can find samples and more information about this here or you can check my GitHub repo for a simple example of Firebase Vision API to recognise text.
For this example, I have used an on-device model loaded using FirebaseLocalModelSource
.
manager.registerLocalModelSource(
FirebaseLocalModelSource(
assetFilePath: "assets/models/plant_ai_lite_model.tflite",
modelName: "plant_ai_model"
)
);
// rootBundle.loadString('assets/models/labels_plant_ai.txt').then((string) {
var _labels = string.split('\n');
_labels.removeLast();
modelLabels = _labels;
});
Let's take a picture !!
Using the image_picker Flutter package, the function below allows us to pick a picture from the phone's gallery
_pickPicture() async{
var image = await ImagePicker.pickImage(source: ImageSource.gallery);
if(image != null){
setState(() {
selectedImageFile = image;
});
predict(selectedImageFile);
}
}
The predict function is where we convert our image file into a byte and run on the interpreter using the input and output options. The result will be a prediction has you will have it when running on your Jupyter notebook or any other platform.
var bytes = await imageToByteListFloat(imageFile, imageDim);
var results = await interpreter.run(
localModelName: "plant_ai_model",
inputOutputOptions: FirebaseModelInputOutputOptions(
[
FirebaseModelIOOption(
FirebaseModelDataType.FLOAT32,
inputDims
)
],
[
FirebaseModelIOOption(
FirebaseModelDataType.FLOAT32,
outputDims
)
]
),
inputBytes: bytes
);
The results of the prediction can then be converted to label text has specified in the 'labels.txt'.
if(results != null && results.length > 0){
for (var i = 0; i < results[0][0].length; i++) {
if (results[0][0][i] > 0) {
var confidenceLevel = results[0][0][i] / 2.55 * 100;
if(confidenceLevel > 0){
predictions[modelLabels[i]] = confidenceLevel;
}
}
}
var predictionKeys = predictions.entries.toList();
predictionKeys.sort((b,a)=>a.value.compareTo(b.value));
predictions = Map<String, double>.fromEntries(predictionKeys);
}
The prediction result are now avaialble for further use and/or display on your app. All without the use of an internet connection and without sending data from the client to a particular server.
The codes for this article can be found here.
Thanks for reading. 🤝
Posted on April 2, 2020
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.
Related
November 17, 2024
November 14, 2024