Notes on semantic search with Elixir and Bumblebee
NDREAN
Posted on December 13, 2023
Suppose you have a bunch of images and you want to find a specific image on a certain thema.
One way to solve this problem is to describe each image with a caption and perform a full-text search: find captions based on lexical match.
With Machine Learning and models, you can greatly improve this with semantic search: you look for images whose captions are close in terms of meaning to your search.
It remains to express what "close meaning" stands for and how to do this. Theses notes describes how this can be done with the Elixir
language and Nx, Axon
(Nx-powered Neural Network library) and Bumblebee which provides pre-trained Neural Network models.
An overview of the process:
Firstly, you upload images into a bucket. You analyse the image with a model to produce a description of it. This is a "captioning" process or Image-To-Text. You save the URL and caption ( a short descriptive text) into a DB.
Then you record an audio and run a Speech-To-Text process to produce a text transcription.
Now that we have a target text, we want to find the captions that approximates this text.
This is where embeddings come into play. We transcript a text into a well-thought vector space. We then use an approximation algorithm to find the closest neighbours.
For this, you can use HNSWLib
. You build incrementally an Index struct from your captions, and then run a knn_search
on this index with the audio transcription as an input.
Description of the models used and semantic search
Elixir has an already rich and growing ecosystem. We will use pre-trained models powered by Bumblebee
.
Choosing the models to use is a difficult task. We simply followed the following posts:
- https://dockyard.com/blog/2023/01/11/semantic-search-with-phoenix-axon-bumblebee-and-exfaiss
- https://dockyard.com/blog/2023/03/07/audio-speech-recognition-in-elixir-with-whisper-bumblebee
The corresponding models used per task are:
Image-To-Text
We used the model "Salesforce/blip-image-captioning-base" withBumblebee.Vision.image_to_text
. This produces a text caption that describes the image.Speech-To-Text
We used the model "openai/whisper-small" andBumblebee.Audio.speech_to_text_whisper
. To capture the audio, we used the MediaRecorder API. This produces a text that is a transcription of the audio.compute embeddings
We will run a symmetric semantic search since we expect the captions and the audio trasncriptions to have the same amount of content. This leads to which pre-trained model to choose to compute vector encodings from strings
To transform a text into a vector (a so-called embedding), we used the transformer "sentence-transformers/paraphrase-MiniLM-L6-v2" andBumblebee.Text.TextEmbedding.text_embedding
. You compute the embeddings for each image caption.Semantic search
We used the Elixir binding forHNSWLib
.
All the captions embeddings are used to build theHNSWLib.Index
.
You also compute the embedding of the audio transcription and use it as the input ofHNSWLib.Index.knn_query
to find the closest neighbour(s) of the audio transcription embedding among the set of the caption embeddings.
This process is dependant on the metric used. It returns the position(s) (indices) among the Index struct indices. This is where you need to save whether the index or the embedding to look-up for the corresponding image(s).
Some code
Firstly, the ML process are started as follows:
#Application.ex
def start(_type, _args) do
children = [
# Nx serving for image classifier
{Nx.Serving, serving: App.Image2text.serving(), name: ImageClassifier},
{Nx.Serving, serving: App.Whisper.serving(), name: Whisper},
App.TextEmbedding,
...
We instantiate the HNSWLib
index with a GenServer and also the tokenizing (which produces embeddings). The transformer used is a 384 dimensional vector space. Since this transformer is trained with a cosine
metric, we embed the vector space of embeddings with the same distance.
#Embedding computation and HNSWLib.Index instantiation
defmodule App.TextEmbedding do
use GenServer
@indexes "indexes.bin"
def start_link(_) do
GenServer.start_link(__MODULE__, {}, name: __MODULE__)
end
# upload or create a new index file
def init(_) do
upload_dir = Application.app_dir(:app, ["priv", "static", "uploads"])
File.mkdir_p!(upload_dir)
path = Path.join([upload_dir, @indexes])
space = :cosine
{:ok, index} =
case File.exists?(path) do
false ->
HNSWLib.Index.new(_space = space, _dim = 384, _max_elements = 200)
true ->
HNSWLib.Index.load_index(space, 384, path)
end
model_info = nil
tokenizer = nil
{:ok, {model_info, tokenizer, index}, {:continue, :load}}
end
def handle_continue(:load, {_, _, index}) do
transformer = "sentence-transformers/paraphrase-MiniLM-L6-v2"
{:ok, %{model: _model, params: _params} = model_info} =
Bumblebee.load_model({:hf, transformer})
{:ok, tokenizer} =
Bumblebee.load_tokenizer({:hf, transformer})
{:noreply, {model_info, tokenizer, index}}
end
# called in Liveview `mount`
def serve() do
GenServer.call(__MODULE__, :serve)
end
def handle_call(:serve, _, {model_info, tokenizer, index} = state) do
serving =
Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer)
{:reply, {serving, index}, state}
end
end
In a Liveview page, you can upload pictures and record audios from the browser. For this, use allow_upload
twice: this will open a Channel (WS) to upload data to the server.
def mount(_,_,socket) do
{serving, index} = App.TextEmbedding.serve()
...
socket
|> assign(serve_embedding: serving, index: index)
|> allow_upload(:image,
accept: ~w(image/*),
auto_upload: true,
progress: &handle_progress/3,
max_entries: 1,
chunk_size: 64_000,
max_file_size: 5_000_000
)
|> allow_upload(:speech,
accept: :any,
auto_upload: true,
progress: &handle_progress/3,
max_entries: 1
)
In the rendered HTML, you can have a simple form like below for the image:
<form id="upload-form" phx-change="noop" phx-submit="noop" disabled={!@display}>
<label class="cursor-pointer">
<.live_file_input upload={@uploads.image} class="hidden" />
<img src={"/uploads/#{@image.filename}"} />
</label>
</form>
The data will be sent to the server to run an Image-To-Text captioning process, and also uploaded into a bucket.
For the audio, we use a "hook". The data will used by a Speech-To-Text process. We use a button to start and stop the recording process.
<audio id="audio" controls></audio>
<form phx-change="noop" class="hidden">
<.live_file_input upload={@uploads.speech} class="hidden" />
</form>
<p>Please record a phrase. The search for matching images will run automatically</p>
<p>
<button
id="record"
class="bg-blue-500 hover:bg-blue-700 text-white font-bold px-4 rounded"
type="button"
phx-hook="Audio"
disabled={@micro_off}
>
<Heroicons.microphone outline class="w-6 h-6 text-white font-bold group-active:animate-pulse"/>
<span>Record</span>
</button>
The hook to capture the audio uses the MediaRecorder
API and sends to the server with the LiveviewJS primitive this.upload
. For example:
#micro.js
export default {
mounted() {
let mediaRecorder;
let audioChunks = [];
const recordButton = document.getElementById("record");
const audioElement = document.getElementById("audio");
const blue = ["bg-blue-500", "hover:bg-blue-700"];
const pulseGreen = ["bg-green-500", "hover:bg-green-700", "animate-pulse"];
_this = this;
recordButton.addEventListener("click", () => {
if (mediaRecorder && mediaRecorder.state === "recording") {
mediaRecorder.stop();
recordButton.textContent = "Record";
} else {
navigator.mediaDevices.getUserMedia({ audio: true }).then((stream) => {
mediaRecorder = new MediaRecorder(stream);
mediaRecorder.start();
recordButton.classList.remove(blue);
recordButton.classList.add(plusGreen);
recordButton.textContent = "Stop";
mediaRecorder.addEventListener("dataavailable", (event) => {
audioChunks.push(event.data);
});
mediaRecorder.addEventListener("stop", () => {
const audioBlob = new Blob(audioChunks);
console.log(audioBlob);
audioElement.src = URL.createObjectURL(audioBlob);
_this.upload("speech", [audioBlob]);
audioChunks = [];
recordButton.classList.remove(pluseGreen);
recordButton.classList.add(blue);
});
});
}
});
},
};
//app.js
...
import Audio from "./micro.js";
...
let liveSocket = new LiveSocket("/live", Socket, {
params: { _csrf_token: csrfToken },
hooks: { Audio },
});
The picture upload and the audio upload to the server are processed by a handle_progress
callback.
In these handle_progress
callbacks (one for :image
and one for :speech
):
- when you upload pictures, you run a Task in parallel to produce a caption from the image. We transform the temporary file that contains the binary of the image into a tensor with the library
Vix.Vips
(because we also want to transform any received image into a WEBP format to save space).
def handle_progress(:speech, entry, socket) when entry.done? do
socket
|> consume_uploaded_entry(entry, fn %{path: path} ->
:ok = File.cp!(path, @tmp_wav)
{:ok, @tmp_wav}
end)
audio_task =
Task.Supervisor.async(
App.TaskSupervisor,
fn ->
Nx.Serving.batched_run(Whisper, {:file, @tmp_wav})
end
)
{:noreply, assign(socket, audio_ref: audio_task.ref, micro_off: true, speech_spin: true)}
end
- when you receive an audio file, you run a transcription Task.
def handle_progress(:image, entry, socket) when entry.done? do
%{
tensor_img: tensor_img,
upload_task: upload_task,
filename: filename
} =
socket
|> consume_uploaded_entry(entry, fn %{path: path} ->
%{client_type: type} = entry
file_binary =
File.read!(path)
...
upload_task =
App.Upload.run(local_path, filename, type)
{:ok, image} = Vimage.new_from_file(path)
{:ok, %Vix.Tensor{data: data, shape: shape, names: names, type: tp}} =
Vimage.write_to_tensor(image)
tensor_img =
Nx.from_binary(data, tp)
|> Nx.reshape(shape, names: names)
{:ok,
%{
tensor_img: tensor_img,
upload_task: upload_task,
filename: filename,
type: type,
local_path: local_path
}}
end)
i2t_task =
Task.Supervisor.async(App.TaskSupervisor, fn ->
Nx.Serving.batched_run(ImageClassifier, tensor_img)
end)
{:noreply,
socket
|> assign(
running: true,
i2t_ref: i2t_task.ref,
upload_ref: upload_task.ref,
db_img: nil,
display: false
)
|> update(:image, fn img ->
Map.merge(img, %{
filename: filename,
type: type,
tensor: tensor_img,
local_path: local_path
})
end)}
end
def handle_progress(_name, _entry, socket), do: {:noreply, socket}
These tasks will emit a response that is captured in a handle_info
callback.
- when you receive a response from the upload task, you save the URL into the database.
def handle_info({ref, {:ok, msg}}, %{assigns: assigns} = socket)
when assigns.upload_ref == ref do
Process.demonitor(ref, [:flush])
%{body: %{location: location}} = msg
image = Map.merge(assigns.image, %{location: location})
%{headers: {h, w}, db_img: db_img} = assigns
# save to db
db_img = handle_image({h, w}, image, db_img)
# clean the file from the server since we provide a remote path
File.rm!(Path.join([@upload_dir, assigns.image.filename]))
{:noreply,
socket
|> assign(db_img: db_img, upload_ref: nil)
|> update(:display, &(!&1))
|> update(:image, fn img -> Map.merge(img, %{location: location, tensor_img: nil}) end)}
end
- when you receive the response from the Image-To-Text task, you compute an embedding and add it to the Index struct of
HNSWLib.Index
. In this case, we add an embedding once at at time into the Index struct, so we get a unique index from the Index struct so we save it to the database. We can later use this field in the database to find an image with a look up at the index.
def handle_info({ref, %{results: [%{text: caption}]}}, %{assigns: assigns} = socket)
when assigns.i2t_ref == ref do
Process.demonitor(ref, [:flush])
%{index: index, serve_embedding: serving, db_img: db_img} = assigns
saved_index = Path.expand("priv/indexes.bin")
with %{embedding: data} <- Nx.Serving.run(serving, caption),
# compute an normed embedding (cosine case only) on the text result
normed_data <- Nx.divide(data, Nx.LinAlg.norm(data)),
:ok <- HNSWLib.Index.add_items(index, normed_data),
{:ok, idx} <- HNSWLib.Index.get_current_count(index),
:ok <- HNSWLib.Index.save_index(index, saved_index) do
db_img =
case db_img do
nil ->
App.Image.insert!(%App.Image{}, %{idx: idx, caption: caption})
img ->
App.Image.update!(img, %{idx: idx, caption: caption})
end
{:noreply,
socket
|> assign(running: false, db_img: db_img, index: index, i2t_ref: nil)
|> update(:image, fn img -> Map.merge(img, %{caption: caption, tensor_img: nil}) end)}
end
end
- when you receive a response from the Speech-To-Text process, you also compute an embedding, but then run a
HNSWLib.Index.knn_query
.
def handle_info({ref, %{chunks: [%{text: text}]} = _result}, %{assigns: assigns} = socket)
when assigns.audio_ref == ref do
Process.demonitor(ref, [:flush])
File.rm!(@tmp_wav)
%{serve_embedding: serving, index: index} = assigns
# compute an normed embedding (cosine case only) on the text result
# and returns an App.Image{} as the result of a "knn_search"
with %{embedding: data} <- Nx.Serving.run(serving, text),
normed_data <- Nx.divide(data, Nx.LinAlg.norm(data)),
%App.Image{} = result <- handle_knn(normed_data, index) do
{:noreply,
assign(socket,
transcription: String.trim(text),
micro_off: false,
speech_spin: false,
search_result: result,
audio_ref: nil
)}
else
# record without entries
{:error, "no entries in index"} ->
{:noreply,
assign(socket,
micro_off: false,
search_result: nil,
speech_spin: false,
audio_ref: nil
)}
nil ->
{:noreply,
assign(socket,
transcription: String.trim(text),
micro_off: false,
search_result: nil,
speech_spin: false,
audio_ref: nil
)}
end
end
In the knn_search
below, we first check if the index is populated. The function HNSWLib.Index.knn_query(index, input, k)
returns a tuple {:ok, indices, distances}
where the two last elements are lists. The length of these lists is the value of the last paramter k
. From this indice, we look up in the database for the image with the corresponding index.
def handle_knn(_, nil), do: {:error, "no index found"}
def handle_knn(data, index) do
case HNSWLib.Index.get_current_count(index) do
{:ok, 0} ->
{:error, "no entries in index"}
{:ok, _c} ->
case HNSWLib.Index.knn_query(index, data, k: 1) do
{:ok, label, _distance} ->
label[0]
|> Nx.to_flat_list()
|> hd()
|> then(&App.Repo.get_by(App.Image, %{idx: &1 + 1}))
end
end
end
Note that we look for the closest neighbour (with k=1
). This will always give a response, the closest neighbour", but it might not be a "close" neighbour. We may need some cut-off distance to exclude unwanted responses.
Posted on December 13, 2023
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.