Running a JAX Program from Dart Using C++ FFI
Nik L.
Posted on November 21, 2024
🚀 Why Combine Dart and JAX for Machine Learning?
When building applications, selecting the right tools is crucial. You want high performance, easy development, and seamless cross-platform deployment. Popular frameworks offer trade-offs:
- C++ provides speed but can slow down development.
- Dart (with Flutter) is slower but simplifies memory management and cross-platform development.
But here’s the catch: most frameworks lack robust native machine learning (ML) support. This gap exists because these frameworks predate the AI boom. The question is:
How can we efficiently integrate ML into applications?
Common solutions like ONNX Runtime allow exporting ML models for application integration, but they aren’t optimized for CPUs or flexible enough for generalized algorithms.
Enter JAX, a Python library that:
- Enables writing optimized ML and general-purpose algorithms.
- Offers platform-agnostic execution on CPUs, GPUs, and TPUs.
- Supports cutting-edge features like autograd and JIT compilation.
In this article, we’ll show you how to:
- Write JAX programs in Python.
- Generate XLA specifications.
- Deploy optimized JAX code in Dart using C++ FFI.
🧠What is JAX?
JAX is like NumPy on steroids. Developed by Google, it’s a low-level, high-performance library that makes ML accessible yet powerful.
- Platform Agnostic: Code runs on CPUs, GPUs, and TPUs without modification.
- Speed: Powered by the XLA compiler, JAX optimizes and accelerates execution.
- Flexibility: Perfect for ML models and general algorithms alike.
Here’s an example comparing NumPy and JAX:
# NumPy version
import numpy as np
def assign_numpy():
a = np.empty(1000000)
a[:] = 1
return a
# JAX version
import jax.numpy as jnp
import jax
@jax.jit
def assign_jax():
a = jnp.empty(1000000)
return a.at[:].set(1)
Benchmarking in Google Colab reveals JAX’s performance edge:
- CPU & GPU: JAX is faster than NumPy.
- TPU: Speed-ups become noticeable for large models due to data transfer costs.
This flexibility and speed make JAX ideal for production environments where performance is key.
🛠️ Bringing JAX into Production
Cloud Microservices vs. Local Deployment
- Cloud: Containerized Python microservices are great for cloud-based compute.
- Local: Shipping a Python interpreter isn’t ideal for local apps.
Solution: Leverage JAX’s XLA Compilation
JAX translates Python code into HLO (High-Level Optimizer) specifications, which can be compiled and executed using C++ XLA libraries. This enables:
- Writing algorithms in Python.
- Running them natively via a C++ library.
- Integrating with Dart via FFI (Foreign Function Interface).
✍️ Step-by-Step Integration
1. Generate an HLO Proto
Write your JAX function and export its HLO representation. For example:
import jax.numpy as jnp
def fn(x, y, z):
return jnp.dot(x, y) / z
To generate the HLO, use the jax_to_ir.py script from the JAX repository:
python jax_to_ir.py \
--fn jax_example.prog.fn \
--input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2")]' \
--constants '{"z": 2.0}' \
--ir_format HLO \
--ir_human_dest /tmp/fn_hlo.txt \
--ir_dest /tmp/fn_hlo.pb
Place the resulting files (fn_hlo.txt
and fn_hlo.pb
) in your app’s assets directory.
2. Build a C++ Dynamic Library
Modify JAX’s C++ Example Code
Clone the JAX repository and navigate to jax/examples/jax_cpp.
- Add a
main.h
header file:
#ifndef MAIN_H
#define MAIN_H
extern "C" {
int bar(int foo);
}
#endif
- Update the BUILD file to create a shared library:
cc_shared_library(
name = "jax",
deps = [":main"],
visibility = ["//visibility:public"],
)
Compile with Bazel:
bazel build examples/jax_cpp:jax
You’ll find the compiled libjax.dylib
in the output directory.
3. Connect Dart with C++ Using FFI
Use Dart’s FFI package to communicate with the C++ library. Create a jax.dart
file:
import 'dart:ffi';
import 'package:dynamic_library/dynamic_library.dart';
typedef FooCFunc = Int32 Function(Int32 bar);
typedef FooDartFunc = int Function(int bar);
class JAX {
late final DynamicLibrary dylib;
JAX() {
dylib = loadDynamicLibrary(libraryName: 'jax');
}
Function get _bar => dylib.lookupFunction<FooCFunc, FooDartFunc>('bar');
int bar(int foo) {
return _bar(foo);
}
}
Include the dynamic library in your project directory. Test it with:
final jax = JAX();
print(jax.bar(42));
You’ll see the output from the C++ library in your console.
🎯 Next Steps
With this setup, you can:
- Optimize ML models with JAX and XLA.
- Run powerful algorithms locally.
Potential use cases include:
- Search algorithms (e.g., A*).
- Combinatorial optimization (e.g., scheduling).
- Image processing (e.g., edge detection).
JAX bridges the gap between Python-based development and production-level performance, letting ML engineers focus on algorithms without worrying about low-level C++ code.
We’re building a cutting-edge AI platform with unlimited chat tokens and long-term memory, ensuring seamless, context-aware interactions that evolve over time.
It's fully free, and you can try it inside your current IDE, too.
Posted on November 21, 2024
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.
Related
November 25, 2024