Deploying Machine learning models on Mobile with Tensorflow Lite and Firebase M.L Kit

emmarex

Oluwafemi Tairu

Posted on April 2, 2020

Deploying Machine learning models on Mobile with Tensorflow Lite and Firebase M.L Kit

In my last article, I shared how to deploy Machine learning models via an A.P.I.

In this article, I will share with you on how to deploy models using Tensorflow Lite and Firebase M.L Kit with Mobile Apps.

Deploying models via A.P.I is fine but there are multiple reasons why that might not suit your need or that of your organisation. Some of them include:

  1. Poor internet connection or no access to an internet connection for users of your app.
  2. Data Privacy.
  3. Need for faster inference.

Whatever your reason might be, changing your model availability from A.P.I's to on-device won't cost you much time. Your previous codes can stay the same with only an addition of some few lines.

Convert your previously saved model

import tensorflow as tf

# Convert the model.
converter = tf.lite.TFLiteConverter.from_saved_model(MODEL_DIR)
tflite_model = converter.convert()
open("plant_ai.tflite", "wb").write(tflite_model)
Enter fullscreen mode Exit fullscreen mode

Convert new models

import tensorflow as tf

# saving your deep learning model
model.save('plant_ai_model.h5')
# Convert the model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("plant_ai.tflite", "wb").write(tflite_model)
Enter fullscreen mode Exit fullscreen mode

For this article, I will be using a deep learning model for plant disease detection. I wrote about it here.
The modified code can be found here

Integrating with a mobile app

Our model is now ready for deployment to a mobile app. Bundling a model (in a .tflite format ) means you can perform prediction without the use of the internet which is good for some solutions. However, bundling your .tflite models with your apps this will increase app size.

For this article, I will be using Flutter but you can do the same using Java or Kotlin following Google's Documentation or Swift or Objective-C with Google's Documentation. Flutter is Google’s UI toolkit for building beautiful, natively compiled applications for mobile, web, and desktop from a single codebase - Flutter.

If you are new to Flutter, you can follow this link to get started with your first Flutter application. Your Flutter application structure should look like this after creation.

Flutter app folder structure

Next step is to install the necessary packages. This packages can be simply installed in Flutter by putting them in the pubspec.yaml file.

pubspec.yaml

The following packages were used for this app:

  • mlkit: A Flutter plugin to use the Firebase ML Kit.
  • image_picker: A Flutter plugin for iOS and Android for picking images from the image library, and taking new pictures with the camera.
  • toast (Optional): A Flutter Toast plugin.

Next, you will have to add Firebase to your android or iOS project following the guide here for android and the guide here for iOS. Do make sure you add the "google-services.json" to your project folder.

Next, we would create a folder called assets in our project's root directory and then copy our .tflite model and label.text to that folder.
We would need to make this accessible in our Flutter project by modifying our pubspec.yaml file as shown below.

pubspec.yaml file showing the inclusion of assets folder

I will skip the interface design aspect has you can have any interface for your app.

Let's import the necessary packages (which we installed earlier)

import 'package:flutter/material.dart';
import 'dart:io';
import 'dart:typed_data';
import 'package:flutter/services.dart';
import 'package:image_picker/image_picker.dart';
import 'package:mlkit/mlkit.dart';
import 'package:toast/toast.dart';
import 'package:image/image.dart' as img;
Enter fullscreen mode Exit fullscreen mode

We will declare some variables and constants first at the beginning of our code.

FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.instance;
FirebaseModelManager manager = FirebaseModelManager.instance;

File selectedImageFile;
List<String> modelLabels = [];
Map<String, double> predictions = Map<String, double>();
int imageDim = 256;
List<int> inputDims = [1, 256, 256, 3];
List<int> outputDims = [1, 15];
Enter fullscreen mode Exit fullscreen mode

The first two lines show the instantiation of the FirebaseModel Interpreter and Manager. The imageDim (height and width 256x256) is the size of the image required by the model.
Then the inputDims is the required input shape of our model while outputDims is the output shape. This value may vary depending on the model used. If someone else built the model, you can easily check for these by running the Python code below.

Load TFLite model and allocate tensors.

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path='plant_ai_lite_model.tflite')
# OR interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
Enter fullscreen mode Exit fullscreen mode

Get input and output tensors.

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
Enter fullscreen mode Exit fullscreen mode

The value for input_details in this case is:

[{
    'name': 'conv2d_input',
    'index': 1,
    'shape': array([1, 256, 256, 3], dtype = int32),
    'dtype': <class 'numpy.float32'> ,
    'quantization': (0.0, 0)
}]
Enter fullscreen mode Exit fullscreen mode

while that of the output_details is:

[{
    'name': 'Identity',
    'index': 0,
    'shape': array([1, 15], dtype = int32),
    'dtype': <class 'numpy.float32'> ,
    'quantization': (0.0, 0)
}]
Enter fullscreen mode Exit fullscreen mode

Do take note of the value of 'shape' in the two results above. That is what determines the value for 'inputDims' and 'outputDims' above.

Next, we would load the model and labels using the FirebaseModel manager. Firebase M.L Kit gives us the option of either loading local model on the device or using one hosted in the Cloud. You can find samples and more information about this here or you can check my GitHub repo for a simple example of Firebase Vision API to recognise text.

For this example, I have used an on-device model loaded using FirebaseLocalModelSource.

manager.registerLocalModelSource(
    FirebaseLocalModelSource(
        assetFilePath: "assets/models/plant_ai_lite_model.tflite",
        modelName: "plant_ai_model"
    )
);
//     rootBundle.loadString('assets/models/labels_plant_ai.txt').then((string) {
    var _labels = string.split('\n');
    _labels.removeLast();
    modelLabels = _labels;
});
Enter fullscreen mode Exit fullscreen mode

Let's take a picture !!
Using the image_picker Flutter package, the function below allows us to pick a picture from the phone's gallery

_pickPicture() async{
    var image = await ImagePicker.pickImage(source: ImageSource.gallery);
    if(image != null){
      setState(() {
        selectedImageFile = image;
      });
      predict(selectedImageFile);
    }
}
Enter fullscreen mode Exit fullscreen mode

The predict function is where we convert our image file into a byte and run on the interpreter using the input and output options. The result will be a prediction has you will have it when running on your Jupyter notebook or any other platform.

var bytes = await imageToByteListFloat(imageFile, imageDim);
var results = await interpreter.run(
    localModelName: "plant_ai_model",
    inputOutputOptions: FirebaseModelInputOutputOptions(
        [
            FirebaseModelIOOption(
                FirebaseModelDataType.FLOAT32,
                inputDims
            )
        ],
        [
            FirebaseModelIOOption(
                FirebaseModelDataType.FLOAT32,
                outputDims
            )
        ]
    ),
    inputBytes: bytes
);
Enter fullscreen mode Exit fullscreen mode

The results of the prediction can then be converted to label text has specified in the 'labels.txt'.

if(results != null && results.length > 0){
    for (var i = 0; i < results[0][0].length; i++) {
        if (results[0][0][i] > 0) {
            var confidenceLevel = results[0][0][i] / 2.55 * 100;
            if(confidenceLevel > 0){
              predictions[modelLabels[i]] = confidenceLevel;
            }
        }
    }
    var predictionKeys = predictions.entries.toList();
    predictionKeys.sort((b,a)=>a.value.compareTo(b.value));
    predictions = Map<String, double>.fromEntries(predictionKeys);
}
Enter fullscreen mode Exit fullscreen mode

The prediction result are now avaialble for further use and/or display on your app. All without the use of an internet connection and without sending data from the client to a particular server.

The codes for this article can be found here.

Thanks for reading. 🤝

💖 💪 🙅 🚩
emmarex
Oluwafemi Tairu

Posted on April 2, 2020

Join Our Newsletter. No Spam, Only the good stuff.

Sign up to receive the latest update from our blog.

Related