Introducing to a simple classification and create a neural network using Brainjs to do it
Nam Phạm
Posted on October 24, 2021
An introduction
I write this article especially for my students as many of them have heard about some topics such as machine learning, deep learning, classification, … but still haven't figured how to do it due to difficulties in learning about the ideas, the math, platform to run, languages, libraries usage, … It takes to learn about deep learning and it is in general a very broad topics so in this article, I want to show you how to do a classification task using a deep learning technique called neural network to give you a slice idea of how to do it in general.
So what is a classification? Classification is that you are given an input, and your job is to tell what type of the input is based on some known types. For example, in this article, you are given a measurement of an iris flower (its sepal length, sepal width, petal length, petal width) and you need to tell what variety of that iris flower is (it can be setosa, versicolor or virginica)
The ideas
How can we do that? Basically, you will build a function that takes the above parameters and outputs the type of the iris flower. We can see that it is not possible to generate such a function using classical programming techniques and that is where we resort to the neural network technique of deep learning. This neural network plays the role of the above function and we will train the neural network based on the measured parameter of gathered iris flowers data that we collected and with that the neural network can perform classification task by interpolation for an unknown measurement parameter. Each measurement parameter will be attached to the main label as the type of iris flower.
Thus we have the following:
- Collect data and corresponding labels
- Building a neural network
- Train neural network based on collected data set
- Verify the results of the neural network
- Using the above neural network in practice
This article uses the iris flower dataset at https://www.kaggle.com/arshid/iris-flower-dataset
How do we create neural network as said? In fact, there are libraries like tensorflow, pytorch, … dedicated to deep learning, but due to the use of python and high hardware requirements, it is not suitable for those who use javascript as the main programming language. and that's why this article uses brainjs, a library that allows creating a simple neural network using javascript and can fully use the power of the GPU to train through the GPU.js library as a foundation.
Before we get into using brainjs to create and train neural networks we need to take a look at our dataset.
sepal_length | sepal_width | petal_length | petal_width | species |
---|---|---|---|---|
5.1 | 3.5 | 1.4 | 0.2 | Iris-setosa |
4.9 | 3 | 1.4 | 0.2 | Iris-setosa |
4.7 | 3.2 | 1.3 | 0.2 | Iris-setosa |
4.6 | 3.1 | 1.5 | 0.2 | Iris-setosa |
5 | 3.6 | 1.4 | 0.2 | Iris-setosa |
7 | 3.2 | 4.7 | 1.4 | Iris-versicolor |
6.4 | 3.2 | 4.5 | 1.5 | Iris-versicolor |
6.9 | 3.1 | 4.9 | 1.5 | Iris-versicolor |
5.5 | 2.3 | 4 | 1.3 | Iris-versicolor |
6.5 | 2.8 | 4.6 | 1.5 | Iris-versicolor |
5.7 | 2.8 | 4.5 | 1.3 | Iris-versicolor |
6.3 | 3.3 | 6 | 2.5 | Iris-virginica |
5.8 | 2.7 | 5.1 | 1.9 | Iris-virginica |
7.1 | 3 | 5.9 | 2.1 | Iris-virginica |
6.3 | 2.9 | 5.6 | 1.8 | Iris-virginica |
6.5 | 3 | 5.8 | 2.2 | Iris-virginica |
7.6 | 3 | 6.6 | 2.1 | Iris-virginica |
4.9 | 2.5 | 4.5 | 1.7 | Iris-virginica |
7.3 | 2.9 | 6.3 | 1.8 | Iris-virginica |
As you can see an recorded tuple (5.1, 3.5, 1.4, 0.2)
is labeled Iris-setosa
while (7, 3.2, 4.7, 1.4)
is Iris-versicolor
and for (6.3, 3.3, 6, 2.5)
, it is Iris-virginica
. Our function, in this case is the neural network, should be able to tell what variety a iris flower is for an arbitrary given input tuple.
Before we dive in into how to create such network, we have to understand the form of the input we feed to the network, and the output we will get there. The input is easy to see that it must be an tuple of 4 numbers, but what's about our output? We first numbered the label Iris-setosa
, Iris-versicolor
, Iris-virginica
0
, 1
and 2
respectively. You may think that our function should output these values, but no. The number is actually the slot in tuple, which indicates the probabilities of the input being in each variety. So the input (5.1, 3.5, 1.4, 0.2)
should be mapped to the output of (1, 0, 0)
because it is 100%
the setosa iris and none for the others. Again, we will have to transform our data into something like this:
sepal_length | sepal_width | petal_length | petal_width | Iris-setosa | Iris-versicolor | Iris-virginica |
---|---|---|---|---|---|---|
5.1 | 3.5 | 1.4 | 0.2 | 1 | 0 | 0 |
4.9 | 3 | 1.4 | 0.2 | 1 | 0 | 0 |
4.7 | 3.2 | 1.3 | 0.2 | 1 | 0 | 0 |
4.6 | 3.1 | 1.5 | 0.2 | 1 | 0 | 0 |
5 | 3.6 | 1.4 | 0.2 | 1 | 0 | 0 |
7 | 3.2 | 4.7 | 1.4 | 0 | 1 | 0 |
6.4 | 3.2 | 4.5 | 1.5 | 0 | 1 | 0 |
6.9 | 3.1 | 4.9 | 1.5 | 0 | 1 | 0 |
5.5 | 2.3 | 4 | 1.3 | 0 | 1 | 0 |
6.5 | 2.8 | 4.6 | 1.5 | 0 | 1 | 0 |
5.7 | 2.8 | 4.5 | 1.3 | 0 | 1 | 0 |
6.3 | 3.3 | 6 | 2.5 | 0 | 0 | 1 |
5.8 | 2.7 | 5.1 | 1.9 | 0 | 0 | 1 |
7.1 | 3 | 5.9 | 2.1 | 0 | 0 | 1 |
6.3 | 2.9 | 5.6 | 1.8 | 0 | 0 | 1 |
6.5 | 3 | 5.8 | 2.2 | 0 | 0 | 1 |
7.6 | 3 | 6.6 | 2.1 | 0 | 0 | 1 |
4.9 | 2.5 | 4.5 | 1.7 | 0 | 0 | 1 |
7.3 | 2.9 | 6.3 | 1.8 | 0 | 0 | 1 |
And now, we can train our network
Brainjs
Brainjs is a js library that allow users to create, train and reuse the neurtal networks they created. Brainjs can be used in browser environment and this article focus on training a neural network in browser. You should have Firefox or Google Chrome installed to run the example.
Understand how to work with Brainjs
Prepare the data
The data is an js array whose elements are the rows from the dataset and each row must be in the form of
{
input: [inputNumber0, inputNumber1, inputNumber2, ..., inputNumberM],
output: [outputNumber0, outputNumber1, outputNumber2, ..., outputNumberN]
}
for example, the row
sepal_length | sepal_width | petal_length | petal_width | Iris-setosa | Iris-versicolor | Iris-virginica |
---|---|---|---|---|---|---|
5.1 | 3.5 | 1.4 | 0.2 | 1 | 0 | 0 |
will be
{
input: [5.1, 3.5, 1.4, 0.2],
output: [1, 0, 0]
}
Create a neural network
We create a neural network in Brainjs using the following code
let net = new brain.NeuralNetwork({
binaryThresh: 0.5,
hiddenLayers: [3, 3, 2],
activation: "sigmoid",
});
Here, hiddenLayers
parameter determine the number of layers in the neural network and number of neurons in each layers.
The activation
parameter determine the activation function being used at the last hidden layer before the output.
Train the network
After creating the network, we can train the network using the following code
net.train(trainingData, {
iterations: 1000,
learningRate: 0.3,
});
The iterations
determines how many round the net will run
The learningRate
determines how large the network parameters should be updated
Use the trained network to do classification task
You can use the network to do classification task by calling
net.run([value0, value1, value2, value3]);
The output is the probabilities of each type in the classification
Extract the trained network data
After training the network, you can extract the network data by running
let extracted = net.toJSON()
Reload trained network
With the extracted data, you can now recreate the network without training it by
net.fromJSON(extracted)
Provided example
User should have tool like http-server
, Vite
installed and know how to use the tool from the command line. I use Vite
here since I'm using it for other projects as well.
Steps
Create a directory for the project
You should be able to create a directory for a project
Download and convert the csv data to json
Download the data from the kaggle link I mentioned earlier and use tool like csv2json at https://csvjson.com/csv2json to convert data and download it to your directory. Name it data.json
Create index.html
In your directory, create a index.html
file with following code
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, minimum-scale=1, user-scalable=no, viewport-fit=cover" />
<meta name="apple-mobile-web-app-capable" content="yes" />
<title>Kaggle Iris dataset training</title>
<script src="https://unpkg.com/brain.js@2.0.0-beta.2/dist/brain-browser.min.js"></script>
</head>
<body>
<h1>Kaggle Iris dataset training using brainjs</h1>
<div>
<button onclick="handleClick()">Click to train</button>
</div>
<div>
<textarea id="output" rows="40" cols="80" readonly></textarea>
</div>
<script>
let running = false;
let trained = null;
async function handleClick() {
if (running) return;
running = true;
try {
let net = train(await getTrainingData());
trained = net;
document.getElementById("output").value = JSON.stringify(net.toJSON(), null, 4);
} finally {
running = false;
}
}
async function getTrainingData() {
return (await (await fetch("data.json")).json()).map((o) => ({
input: [o.sepal_length, o.sepal_width, o.petal_length, o.petal_width],
output: [o.species == "Iris-setosa" ? 1 : 0, o.species == "Iris-versicolor" ? 1 : 0, o.species == "Iris-virginica" ? 1 : 0],
}));
}
function train(trainingData) {
let net = new brain.NeuralNetwork({
binaryThresh: 0.5,
hiddenLayers: [3, 3, 2],
activation: "sigmoid",
});
net.train(trainingData, {
iterations: 1000,
learningRate: 0.3,
});
return net;
}
</script>
</body>
</html>
Run a web server from your directory
Fire up a web server by using http-server
or Vite
Click run to train
Go to your local web server and click the button. The code will download the data from data.json
file, transform it to Brainjs data form, create a neural network and feed the data to the network, train it and finally output the trained networked into the textarea element in the form of json
Sorry for not implementing the UI to run the classification but the trained network is stored in the global variable trained
. You can easily do the classificatoin by runing the trained.run
in the console
The article won't cover all the aspects of neural network and deep learning in general but I hope you know what to do with the network especially when you write js.
Have fun with Brainjs and have a good day.
Posted on October 24, 2021
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.
Related
October 24, 2021