Elegy
A High Level API for Deep Learning in JAX
Main Features
- 😀 Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to use for most common tasks.
- 💪 Flexible: Elegy provides a Pytorch Lightning-like low-level API that offers maximum flexibility when needed.
- 🔌 Compatible: Elegy supports various frameworks and data sources including Flax & Haiku Modules, Optax Optimizers, TensorFlow Datasets, Pytorch DataLoaders, and more.
Elegy is built on top of Treex and Treeo and reexports their APIs for convenience.
Getting Started | Examples | Documentation
What is included?
- A
Modelclass with an Estimator-like API. - A
callbacksmodule with common Keras callbacks.
From Treex
- A
Moduleclass. - A
nnmodule for with common layers. - A
lossesmodule with common loss functions. - A
metricsmodule with common metrics.
Installation
Install using pip:
For Windows users, we recommend the Windows subsystem for Linux 2 WSL2 since jax does not support it yet.
Quick Start: High-level API
Elegy's high-level API provides a straightforward interface you can use by implementing the following steps:
1. Define the architecture inside a Module:
import jax import elegy as eg class MLP(eg.Module): @eg.compact def __call__(self, x): x = eg.Linear(300)(x) x = jax.nn.relu(x) x = eg.Linear(10)(x) return x
2. Create a Model from this module and specify additional things like losses, metrics, and optimizers:
import optax optax import elegy as eg model = eg.Model( module=MLP(), loss=[ eg.losses.Crossentropy(), eg.regularizers.L2(l=1e-5), ], metrics=eg.metrics.Accuracy(), optimizer=optax.rmsprop(1e-3), )
3. Train the model using the fit method:
model.fit( inputs=X_train, labels=y_train, epochs=100, steps_per_epoch=200, batch_size=64, validation_data=(X_test, y_test), shuffle=True, callbacks=[eg.callbacks.TensorBoard("summaries")] )
Using Flax
Show
To use Flax with Elegy just create a flax.linen.Module and pass it to Model.
import jax import elegy as eg import optax optax import flax.linen as nn class MLP(nn.Module): @nn.compact def __call__(self, x, training: bool): x = nn.Dense(300)(x) x = jax.nn.relu(x) x = nn.Dense(10)(x) return x model = eg.Model( module=MLP(), loss=[ eg.losses.Crossentropy(), eg.regularizers.L2(l=1e-5), ], metrics=eg.metrics.Accuracy(), optimizer=optax.rmsprop(1e-3), )
As shown here, Flax Modules can optionally request a training argument to __call__ which will be provided by Elegy / Treex.
Using Haiku
Show
To use Haiku with Elegy do the following:
- Create a
forwardfunction. - Create a
TransformedWithStateobject by feedingforwardtohk.transform_with_state. - Pass your
TransformedWithStatetoModel.
You can also optionally create your own hk.Module and use it in forward if needed. Putting everything together should look like this:
import jax import elegy as eg import optax optax import haiku as hk def forward(x, training: bool): x = hk.Linear(300)(x) x = jax.nn.relu(x) x = hk.Linear(10)(x) return x model = eg.Model( module=hk.transform_with_state(forward), loss=[ eg.losses.Crossentropy(), eg.regularizers.L2(l=1e-5), ], metrics=eg.metrics.Accuracy(), optimizer=optax.rmsprop(1e-3), )
As shown here, forward can optionally request a training argument which will be provided by Elegy / Treex.
Quick Start: Low-level API
Elegy's low-level API lets you explicitly define what goes on during training, testing, and inference. Let's define our own custom Model to implement a LinearClassifier with pure JAX:
1. Define a custom init_step method:
class LinearClassifier(eg.Model): # use treex's API to declare parameter nodes w: jnp.ndarray = eg.Parameter.node() b: jnp.ndarray = eg.Parameter.node() def init_step(self, key: jnp.ndarray, inputs: jnp.ndarray): self.w = jax.random.uniform( key=key, shape=[features_in, 10], ) self.b = jnp.zeros([10]) self.optimizer = self.optimizer.init(self) return self
Here we declared the parameters w and b using Treex's Parameter.node() for pedagogical reasons, however normally you don't have to do this since you typically use a sub-Module instead.
2. Define a custom test_step method:
def test_step(self, inputs, labels): # flatten + scale inputs = jnp.reshape(inputs, (inputs.shape[0], -1)) / 255 # forward logits = jnp.dot(inputs, self.w) + self.b # crossentropy loss target = jax.nn.one_hot(labels["target"], 10) loss = optax.softmax_cross_entropy(logits, target).mean() # metrics logs = dict( acc=jnp.mean(jnp.argmax(logits, axis=-1) == labels["target"]), loss=loss, ) return loss, logs, self
3. Instantiate our LinearClassifier with an optimizer:
model = LinearClassifier( optimizer=optax.rmsprop(1e-3), )
4. Train the model using the fit method:
model.fit( inputs=X_train, labels=y_train, epochs=100, steps_per_epoch=200, batch_size=64, validation_data=(X_test, y_test), shuffle=True, callbacks=[eg.callbacks.TensorBoard("summaries")] )
Using other JAX Frameworks
Show
It is straightforward to integrate other functional JAX libraries with this low-level API, here is an example with Flax:
import elegy as eg import flax.linen as nn class LinearClassifier(eg.Model): params: Mapping[str, Any] = eg.Parameter.node() batch_stats: Mapping[str, Any] = eg.BatchStat.node() next_key: eg.KeySeq def __init__(self, module: nn.Module, **kwargs): self.flax_module = module super().__init__(**kwargs) def init_step(self, key, inputs): self.next_key = eg.KeySeq(key) variables = self.flax_module.init( {"params": self.next_key(), "dropout": self.next_key()}, x ) self.params = variables["params"] self.batch_stats = variables["batch_stats"] self.optimizer = self.optimizer.init(self.parameters()) def test_step(self, inputs, labels): # forward variables = dict( params=self.params, batch_stats=self.batch_stats, ) logits, variables = self.flax_module.apply( variables, inputs, rngs={"dropout": self.next_key()}, mutable=True, ) self.batch_stats = variables["batch_stats"] # loss target = jax.nn.one_hot(labels["target"], 10) loss = optax.softmax_cross_entropy(logits, target).mean() # logs logs = dict( accuracy=accuracy, loss=loss, ) return loss, logs, self
Examples
Check out the /example directory for some inspiration. To run an example, first install some requirements:
pip install -r examples/requirements.txt
And the run it normally with python e.g.
python examples/flax/mnist_vae.py
Contributing
If your are interested in helping improve Elegy check out the Contributing Guide.
Sponsors 💚
- Quansight - paid development time
Citing Elegy
BibTeX
@software{elegy2020repository,
title = {Elegy: A High Level API for Deep Learning in JAX},
author = {PoetsAI},
year = 2021,
url = {https://github.com/poets-ai/elegy},
version = {0.8.1}
}