Notes on semantic search with Elixir and Bumblebee

ndrean

NDREAN

Posted on December 13, 2023

Notes on semantic search with Elixir and Bumblebee

Suppose you have a bunch of images and you want to find a specific image on a certain thema.
One way to solve this problem is to describe each image with a caption and perform a full-text search: find captions based on lexical match.

With Machine Learning and models, you can greatly improve this with semantic search: you look for images whose captions are close in terms of meaning to your search.

It remains to express what "close meaning" stands for and how to do this. Theses notes describes how this can be done with the Elixir language and Nx, Axon (Nx-powered Neural Network library) and Bumblebee which provides pre-trained Neural Network models.

An overview of the process:

Firstly, you upload images into a bucket. You analyse the image with a model to produce a description of it. This is a "captioning" process or Image-To-Text. You save the URL and caption ( a short descriptive text) into a DB.

Then you record an audio and run a Speech-To-Text process to produce a text transcription.

Now that we have a target text, we want to find the captions that approximates this text.
This is where embeddings come into play. We transcript a text into a well-thought vector space. We then use an approximation algorithm to find the closest neighbours.

For this, you can use HNSWLib. You build incrementally an Index struct from your captions, and then run a knn_search on this index with the audio transcription as an input.

process audio semantic search<br>

Description of the models used and semantic search

Elixir has an already rich and growing ecosystem. We will use pre-trained models powered by Bumblebee.

Choosing the models to use is a difficult task. We simply followed the following posts:

The corresponding models used per task are:

  • Image-To-Text
    We used the model "Salesforce/blip-image-captioning-base" with Bumblebee.Vision.image_to_text. This produces a text caption that describes the image.

  • Speech-To-Text
    We used the model "openai/whisper-small" and Bumblebee.Audio.speech_to_text_whisper. To capture the audio, we used the MediaRecorder API. This produces a text that is a transcription of the audio.

  • compute embeddings
    We will run a symmetric semantic search since we expect the captions and the audio trasncriptions to have the same amount of content. This leads to which pre-trained model to choose to compute vector encodings from strings
    To transform a text into a vector (a so-called embedding), we used the transformer "sentence-transformers/paraphrase-MiniLM-L6-v2" and Bumblebee.Text.TextEmbedding.text_embedding. You compute the embeddings for each image caption.

  • Semantic search
    We used the Elixir binding for HNSWLib.
    All the captions embeddings are used to build the HNSWLib.Index.
    You also compute the embedding of the audio transcription and use it as the input of HNSWLib.Index.knn_query to find the closest neighbour(s) of the audio transcription embedding among the set of the caption embeddings.
    This process is dependant on the metric used. It returns the position(s) (indices) among the Index struct indices. This is where you need to save whether the index or the embedding to look-up for the corresponding image(s).

Some code

Firstly, the ML process are started as follows:

#Application.ex
def start(_type, _args) do
    children = [
      # Nx serving for image classifier
      {Nx.Serving, serving: App.Image2text.serving(), name: ImageClassifier},
      {Nx.Serving, serving: App.Whisper.serving(), name: Whisper},
      App.TextEmbedding,
      ...
Enter fullscreen mode Exit fullscreen mode

We instantiate the HNSWLib index with a GenServer and also the tokenizing (which produces embeddings). The transformer used is a 384 dimensional vector space. Since this transformer is trained with a cosine metric, we embed the vector space of embeddings with the same distance.

#Embedding computation and HNSWLib.Index instantiation
defmodule App.TextEmbedding do
  use GenServer
  @indexes "indexes.bin"

  def start_link(_) do
    GenServer.start_link(__MODULE__, {}, name: __MODULE__)
  end

  # upload or create a new index file
  def init(_) do
    upload_dir = Application.app_dir(:app, ["priv", "static", "uploads"])
    File.mkdir_p!(upload_dir)
    path = Path.join([upload_dir, @indexes])
    space = :cosine

    {:ok, index} =
      case File.exists?(path) do
        false ->
          HNSWLib.Index.new(_space = space, _dim = 384, _max_elements = 200)

        true ->
          HNSWLib.Index.load_index(space, 384, path)
      end

    model_info = nil
    tokenizer = nil
    {:ok, {model_info, tokenizer, index}, {:continue, :load}}
  end

  def handle_continue(:load, {_, _, index}) do
    transformer = "sentence-transformers/paraphrase-MiniLM-L6-v2"

    {:ok, %{model: _model, params: _params} = model_info} =
      Bumblebee.load_model({:hf, transformer})

    {:ok, tokenizer} =
      Bumblebee.load_tokenizer({:hf, transformer})

    {:noreply, {model_info, tokenizer, index}}
  end

  # called in Liveview `mount`
  def serve() do
    GenServer.call(__MODULE__, :serve)
  end

  def handle_call(:serve, _, {model_info, tokenizer, index} = state) do
    serving = 
Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer)
    {:reply, {serving, index}, state}
  end
end
Enter fullscreen mode Exit fullscreen mode

In a Liveview page, you can upload pictures and record audios from the browser. For this, use allow_upload twice: this will open a Channel (WS) to upload data to the server.

def mount(_,_,socket) do
  {serving, index} = App.TextEmbedding.serve()
  ...
  socket
  |> assign(serve_embedding: serving, index: index)
  |> allow_upload(:image,
     accept: ~w(image/*),
     auto_upload: true,
     progress: &handle_progress/3,
     max_entries: 1,
     chunk_size: 64_000,
     max_file_size: 5_000_000
   )
  |> allow_upload(:speech,
    accept: :any, 
    auto_upload: true,
    progress: &handle_progress/3,
    max_entries: 1
  )
Enter fullscreen mode Exit fullscreen mode

In the rendered HTML, you can have a simple form like below for the image:

<form id="upload-form" phx-change="noop" phx-submit="noop" disabled={!@display}>
  <label class="cursor-pointer">
    <.live_file_input upload={@uploads.image} class="hidden" />
    <img src={"/uploads/#{@image.filename}"} />
  </label>
</form>
Enter fullscreen mode Exit fullscreen mode

The data will be sent to the server to run an Image-To-Text captioning process, and also uploaded into a bucket.

For the audio, we use a "hook". The data will used by a Speech-To-Text process. We use a button to start and stop the recording process.

<audio id="audio" controls></audio>
<form phx-change="noop" class="hidden">
  <.live_file_input upload={@uploads.speech} class="hidden" />
</form>

<p>Please record a phrase. The search for matching images will run automatically</p>
<p>
<button
  id="record"
  class="bg-blue-500 hover:bg-blue-700 text-white font-bold px-4 rounded"
  type="button"
  phx-hook="Audio"
  disabled={@micro_off}
>
  <Heroicons.microphone outline class="w-6 h-6 text-white font-bold group-active:animate-pulse"/>
    <span>Record</span>
</button>
Enter fullscreen mode Exit fullscreen mode

The hook to capture the audio uses the MediaRecorder API and sends to the server with the LiveviewJS primitive this.upload. For example:

#micro.js
export default {
  mounted() {
    let mediaRecorder;
    let audioChunks = [];
    const recordButton = document.getElementById("record");
    const audioElement = document.getElementById("audio");
    const blue = ["bg-blue-500", "hover:bg-blue-700"];
    const pulseGreen = ["bg-green-500", "hover:bg-green-700", "animate-pulse"];

    _this = this;

    recordButton.addEventListener("click", () => {
      if (mediaRecorder && mediaRecorder.state === "recording") {
        mediaRecorder.stop();
        recordButton.textContent = "Record";
      } else {
        navigator.mediaDevices.getUserMedia({ audio: true }).then((stream) => {
          mediaRecorder = new MediaRecorder(stream);
          mediaRecorder.start();
          recordButton.classList.remove(blue);
          recordButton.classList.add(plusGreen);

          recordButton.textContent = "Stop";

          mediaRecorder.addEventListener("dataavailable", (event) => {
            audioChunks.push(event.data);
          });

          mediaRecorder.addEventListener("stop", () => {
            const audioBlob = new Blob(audioChunks);
            console.log(audioBlob);
            audioElement.src = URL.createObjectURL(audioBlob);

            _this.upload("speech", [audioBlob]);
            audioChunks = [];
            recordButton.classList.remove(pluseGreen);
            recordButton.classList.add(blue);
          });
        });
      }
    });
  },
};

//app.js
...
import Audio from "./micro.js";
...
let liveSocket = new LiveSocket("/live", Socket, {
  params: { _csrf_token: csrfToken },
  hooks: { Audio },
});
Enter fullscreen mode Exit fullscreen mode

The picture upload and the audio upload to the server are processed by a handle_progress callback.

In these handle_progresscallbacks (one for :image and one for :speech):

  • when you upload pictures, you run a Task in parallel to produce a caption from the image. We transform the temporary file that contains the binary of the image into a tensor with the library Vix.Vips (because we also want to transform any received image into a WEBP format to save space).
def handle_progress(:speech, entry, socket) when entry.done? do
    socket
    |> consume_uploaded_entry(entry, fn %{path: path} ->
      :ok = File.cp!(path, @tmp_wav)
      {:ok, @tmp_wav}
    end)

    audio_task =
      Task.Supervisor.async(
        App.TaskSupervisor,
        fn ->
          Nx.Serving.batched_run(Whisper, {:file, @tmp_wav})
        end
      )

    {:noreply, assign(socket, audio_ref: audio_task.ref, micro_off: true, speech_spin: true)}
end
Enter fullscreen mode Exit fullscreen mode
  • when you receive an audio file, you run a transcription Task.
def handle_progress(:image, entry, socket) when entry.done? do

    %{
      tensor_img: tensor_img,
      upload_task: upload_task,
      filename: filename
    } =
      socket
      |> consume_uploaded_entry(entry, fn %{path: path} ->
        %{client_type: type} = entry

        file_binary =
          File.read!(path)
        ...
        upload_task =
          App.Upload.run(local_path, filename, type)

        {:ok, image} = Vimage.new_from_file(path)

        {:ok, %Vix.Tensor{data: data, shape: shape, names: names, type: tp}} =
          Vimage.write_to_tensor(image)

        tensor_img =
          Nx.from_binary(data, tp)
          |> Nx.reshape(shape, names: names)

        {:ok,
         %{
           tensor_img: tensor_img,
           upload_task: upload_task,
           filename: filename,
           type: type,
           local_path: local_path
         }}
      end)

    i2t_task =
      Task.Supervisor.async(App.TaskSupervisor, fn ->
        Nx.Serving.batched_run(ImageClassifier, tensor_img)
      end)

    {:noreply,
     socket
     |> assign(
       running: true,
       i2t_ref: i2t_task.ref,
       upload_ref: upload_task.ref,
       db_img: nil,
       display: false
     )
     |> update(:image, fn img ->
       Map.merge(img, %{
         filename: filename,
         type: type,
         tensor: tensor_img,
         local_path: local_path
       })
     end)}
  end

def handle_progress(_name, _entry, socket), do: {:noreply, socket}
Enter fullscreen mode Exit fullscreen mode

These tasks will emit a response that is captured in a handle_info callback.

  • when you receive a response from the upload task, you save the URL into the database.
def handle_info({ref, {:ok, msg}}, %{assigns: assigns} = socket)
      when assigns.upload_ref == ref do
    Process.demonitor(ref, [:flush])
    %{body: %{location: location}} = msg
    image = Map.merge(assigns.image, %{location: location})
    %{headers: {h, w}, db_img: db_img} = assigns
    # save to db
    db_img = handle_image({h, w}, image, db_img)
    # clean the file from the server since we provide a remote path
    File.rm!(Path.join([@upload_dir, assigns.image.filename]))

    {:noreply,
     socket
     |> assign(db_img: db_img, upload_ref: nil)
     |> update(:display, &(!&1))
     |> update(:image, fn img -> Map.merge(img, %{location: location, tensor_img: nil}) end)}
end
Enter fullscreen mode Exit fullscreen mode
  • when you receive the response from the Image-To-Text task, you compute an embedding and add it to the Index struct of HNSWLib.Index. In this case, we add an embedding once at at time into the Index struct, so we get a unique index from the Index struct so we save it to the database. We can later use this field in the database to find an image with a look up at the index.
  def handle_info({ref, %{results: [%{text: caption}]}}, %{assigns: assigns} = socket)
      when assigns.i2t_ref == ref do
    Process.demonitor(ref, [:flush])

    %{index: index, serve_embedding: serving, db_img: db_img} = assigns
    saved_index = Path.expand("priv/indexes.bin")

    with %{embedding: data} <- Nx.Serving.run(serving, caption),
         # compute an normed embedding (cosine case only) on the text result
         normed_data <- Nx.divide(data, Nx.LinAlg.norm(data)),
         :ok <- HNSWLib.Index.add_items(index, normed_data),
         {:ok, idx} <- HNSWLib.Index.get_current_count(index),
         :ok <- HNSWLib.Index.save_index(index, saved_index) do
      db_img =
        case db_img do
          nil ->
            App.Image.insert!(%App.Image{}, %{idx: idx, caption: caption})

          img ->
            App.Image.update!(img, %{idx: idx, caption: caption})
        end

      {:noreply,
       socket
       |> assign(running: false, db_img: db_img, index: index, i2t_ref: nil)
       |> update(:image, fn img -> Map.merge(img, %{caption: caption, tensor_img: nil}) end)}
  end
end
Enter fullscreen mode Exit fullscreen mode
  • when you receive a response from the Speech-To-Text process, you also compute an embedding, but then run a HNSWLib.Index.knn_query.
def handle_info({ref, %{chunks: [%{text: text}]} = _result}, %{assigns: assigns} = socket)
      when assigns.audio_ref == ref do
    Process.demonitor(ref, [:flush])
    File.rm!(@tmp_wav)

    %{serve_embedding: serving, index: index} = assigns
    # compute an normed embedding (cosine case only) on the text result
    # and returns an App.Image{} as the result of a "knn_search"
    with %{embedding: data} <- Nx.Serving.run(serving, text),
         normed_data <- Nx.divide(data, Nx.LinAlg.norm(data)),
         %App.Image{} = result <- handle_knn(normed_data, index) do
      {:noreply,
       assign(socket,
         transcription: String.trim(text),
         micro_off: false,
         speech_spin: false,
         search_result: result,
         audio_ref: nil
       )}
    else
      # record without entries
      {:error, "no entries in index"} ->
        {:noreply,
         assign(socket,
           micro_off: false,
           search_result: nil,
           speech_spin: false,
           audio_ref: nil
         )}

      nil ->
        {:noreply,
         assign(socket,
           transcription: String.trim(text),
           micro_off: false,
           search_result: nil,
           speech_spin: false,
           audio_ref: nil
         )}
  end
end
Enter fullscreen mode Exit fullscreen mode

In the knn_search below, we first check if the index is populated. The function HNSWLib.Index.knn_query(index, input, k) returns a tuple {:ok, indices, distances} where the two last elements are lists. The length of these lists is the value of the last paramter k. From this indice, we look up in the database for the image with the corresponding index.

def handle_knn(_, nil), do: {:error, "no index found"}

def handle_knn(data, index) do
  case HNSWLib.Index.get_current_count(index) do
    {:ok, 0} ->
      {:error, "no entries in index"}

    {:ok, _c} ->
      case HNSWLib.Index.knn_query(index, data, k: 1) do
        {:ok, label, _distance} ->
          label[0]
          |> Nx.to_flat_list()
          |> hd()
          |> then(&App.Repo.get_by(App.Image, %{idx: &1 + 1}))
      end
  end
end
Enter fullscreen mode Exit fullscreen mode

Note that we look for the closest neighbour (with k=1). This will always give a response, the closest neighbour", but it might not be a "close" neighbour. We may need some cut-off distance to exclude unwanted responses.

💖 💪 🙅 🚩
ndrean
NDREAN

Posted on December 13, 2023

Join Our Newsletter. No Spam, Only the good stuff.

Sign up to receive the latest update from our blog.

Related