import mlx.core as mx
= mx.array([1,2,3,4])
a print(a)
array([1, 2, 3, 4], dtype=int32)
Vidyasagar Bhargava
December 24, 2023
MLX is available on PyPI. You need an Apple silicon based computer.
pip install mlx
1. Familiar APIs
MLX has a Python API that closely follows NumPy. MLX also has a fully featured C++ API, which closely mirrors the Python API. MLX has higher-level packages like mlx.nn and mlx.optimizers with APIs that closely follow PyTorch to simplify building more complex models.
2. Lazy computation
Computations in MLX are lazy. That means outputs of MLX operations are not computed untill they are needed.
array([2, 4, 6, 8], dtype=float32)
3. Composable function transformations & Dynamic graph construction
MLX supports composable function transformations for automatic differentiation, automatic vectorization, and computation graph optimization.
Computation graphs in MLX are constructed dynamically. Changing the shapes of function arguments does not trigger slow compilations, and debugging is simple and intuitive.
array(0, dtype=float32)
4. Unified memory Architecture
A notable difference from MLX and other frameworks is the unified memory model. Arrays in MLX live in shared memory. Operations on MLX arrays can be performed on any of the supported device types without transferring data.
Let’s see an example
both a
and b
lives in unified memory.
In MLX, you don’t need to move arrays between different memory locations for different devices (like CPU or GPU). Instead of moving data, you specify the device (like CPU or GPU) when you perform an operation on the arrays.
array([0.0103115, -1.6365, -0.34433, ..., 0.890102, 0.870465, -1.75593], dtype=float32)
array([0.0103115, -1.6365, -0.34433, ..., 0.890102, 0.870465, -1.75593], dtype=float32)
If you perform operations that don’t depend on each other (like adding ‘a’ and ‘b’ in example), MLX can run them in parallel. So, the CPU and GPU can both work on the same task simultaneously because there are no dependencies between them.
If there are dependencies (meaning one operation depends on the result of another), MLX takes care of managing them. For instance, if you add ‘a’ and ‘b’ on the CPU and then perform another addition on the GPU that depends on the result from the CPU, MLX ensures that the GPU operation waits for the CPU operation to finish before it starts.
Example
The first matmul operation is good fit for the GPU since it is more compute dense. The second sequence of operations are better fit for the CPU, since they are very small and would be probably overhead bound on GPU.
5. Multi-device
Operations can run on any of the supported devices (currently the CPU and the GPU). The framework is intended to be user-friendly, but still efficient to train and deploy models. The design of the framework itself is also conceptually simple.
Let’s implement simple linear regression example as starting point
Initialize parameters (w and b) and hyperparameter (learning_rate)
Let’s implement logistic regression now
import time
import mlx.core as mx
num_features = 100
num_examples = 1_000
num_iters = 10_000
lr = 0.1
# True parameters
w_star = mx.random.normal((num_features,))
# Input examples
X = mx.random.normal((num_examples, num_features))
# Labels
y = (X @ w_star) > 0
# Initialize random parameters
w = 1e-2 * mx.random.normal((num_features,))
def loss_fn(w):
logits = X @ w
return mx.mean(mx.logaddexp(0.0, logits) - y * logits)
grad_fn = mx.grad(loss_fn)
tic = time.time()
for _ in range(num_iters):
grad = grad_fn(w)
w = w - lr * grad
mx.eval(w)
toc = time.time()
loss = loss_fn(w)
final_preds = (X @ w) > 0
acc = mx.mean(final_preds == y)
throughput = num_iters / (toc - tic)
print(
f"Loss {loss.item():.5f}, Accuracy {acc.item():.5f} "
f"Throughput {throughput:.5f} (it/s)"
)
Loss 0.03430, Accuracy 0.99900 Throughput 2802.29167 (it/s)