Introduction to Apple’s Machine learning Framework- MLX

Apple’s machine learning research team recently released a Machine Learning framework called MLX, a NumPy-like array framework designed for efficient and flexible machine learning on Apple silicon.
machine learning framework
deep learning
Author

Vidyasagar Bhargava

Published

December 24, 2023

Installation

MLX is available on PyPI. You need an Apple silicon based computer.

pip install mlx

Key Features of MLX

1. Familiar APIs

MLX has a Python API that closely follows NumPy. MLX also has a fully featured C++ API, which closely mirrors the Python API. MLX has higher-level packages like mlx.nn and mlx.optimizers with APIs that closely follow PyTorch to simplify building more complex models.

import mlx.core as mx
a = mx.array([1,2,3,4])
print(a)
array([1, 2, 3, 4], dtype=int32)
print(a.dtype)
mlx.core.int32
b = mx.array([1.0, 2.0, 3.0, 4.0])
print(b.dtype)
mlx.core.float32

2. Lazy computation

Computations in MLX are lazy. That means outputs of MLX operations are not computed untill they are needed.

c = a + b   # c not yet evaluated
mx.eval(c)  # evaluates c
c = a + b
print(c)    # also evaluates c
array([2, 4, 6, 8], dtype=float32)

3. Composable function transformations & Dynamic graph construction

MLX supports composable function transformations for automatic differentiation, automatic vectorization, and computation graph optimization.

Computation graphs in MLX are constructed dynamically. Changing the shapes of function arguments does not trigger slow compilations, and debugging is simple and intuitive.

# MLX has standard function transformations like grad() and vmap()

x = mx.array(0.0)
mx.sin(x)
array(0, dtype=float32)
mx.grad(mx.sin)(x)
array(1, dtype=float32)
mx.grad(mx.grad(mx.sin))(x)
array(-0, dtype=float32)

4. Unified memory Architecture

A notable difference from MLX and other frameworks is the unified memory model. Arrays in MLX live in shared memory. Operations on MLX arrays can be performed on any of the supported device types without transferring data.

Let’s see an example

a = mx.random.normal((100,))
b = mx.random.normal((100,))

both a and b lives in unified memory.

In MLX, you don’t need to move arrays between different memory locations for different devices (like CPU or GPU). Instead of moving data, you specify the device (like CPU or GPU) when you perform an operation on the arrays.

mx.add(a, b, stream=mx.cpu)
array([0.0103115, -1.6365, -0.34433, ..., 0.890102, 0.870465, -1.75593], dtype=float32)
mx.add(a, b, stream=mx.gpu)
array([0.0103115, -1.6365, -0.34433, ..., 0.890102, 0.870465, -1.75593], dtype=float32)

If you perform operations that don’t depend on each other (like adding ‘a’ and ‘b’ in example), MLX can run them in parallel. So, the CPU and GPU can both work on the same task simultaneously because there are no dependencies between them.

c = mx.add(a, b, stream=mx.cpu)
d = mx.add(a, c, stream=mx.gpu)

If there are dependencies (meaning one operation depends on the result of another), MLX takes care of managing them. For instance, if you add ‘a’ and ‘b’ on the CPU and then perform another addition on the GPU that depends on the result from the CPU, MLX ensures that the GPU operation waits for the CPU operation to finish before it starts.

Example

def fun(a, b, d1, d2):
  x = mx.matmul(a, b, stream=d1)
  for _ in range(500):
      b = mx.exp(b, stream=d2)
  return x, b
a = mx.random.uniform(shape=(4096, 512))
b = mx.random.uniform(shape=(512, 4))

The first matmul operation is good fit for the GPU since it is more compute dense. The second sequence of operations are better fit for the CPU, since they are very small and would be probably overhead bound on GPU.

5. Multi-device

Operations can run on any of the supported devices (currently the CPU and the GPU). The framework is intended to be user-friendly, but still efficient to train and deploy models. The design of the framework itself is also conceptually simple.

mx.default_stream(mx.default_device())
Stream(Device(gpu, 0), 0)

Linear Regression implementation

Let’s implement simple linear regression example as starting point

import mlx.core as mx

num_features = 100
num_examples = 1_000
num_iters = 10_000  # iterations of SGD
lr = 0.01  # learning rate for SGD

Initialize parameters (w and b) and hyperparameter (learning_rate)

# True parameters
w_star = mx.random.normal((num_features,))

# Input examples (design matrix)
X = mx.random.normal((num_examples, num_features))

# Noisy labels
eps = 1e-2 * mx.random.normal((num_examples,))
y = X @ w_star + eps
def loss_fn(w):
    return 0.5 * mx.mean(mx.square(X @ w - y))

grad_fn = mx.grad(loss_fn)
w = 1e-2 * mx.random.normal((num_features,))

for _ in range(num_iters):
    grad = grad_fn(w)
    w = w - lr * grad
    mx.eval(w)
loss = loss_fn(w)
error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5

print(
    f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
)
Loss 0.00004, |w-w*| = 0.00317, 

Logistic Regression

Let’s implement logistic regression now

import time

import mlx.core as mx

num_features = 100
num_examples = 1_000
num_iters = 10_000
lr = 0.1

# True parameters
w_star = mx.random.normal((num_features,))

# Input examples
X = mx.random.normal((num_examples, num_features))

# Labels
y = (X @ w_star) > 0


# Initialize random parameters
w = 1e-2 * mx.random.normal((num_features,))


def loss_fn(w):
    logits = X @ w
    return mx.mean(mx.logaddexp(0.0, logits) - y * logits)


grad_fn = mx.grad(loss_fn)

tic = time.time()
for _ in range(num_iters):
    grad = grad_fn(w)
    w = w - lr * grad
    mx.eval(w)

toc = time.time()

loss = loss_fn(w)
final_preds = (X @ w) > 0
acc = mx.mean(final_preds == y)

throughput = num_iters / (toc - tic)
print(
    f"Loss {loss.item():.5f}, Accuracy {acc.item():.5f} "
    f"Throughput {throughput:.5f} (it/s)"
)
Loss 0.03430, Accuracy 0.99900 Throughput 2802.29167 (it/s)
Back to top