EGGROLL stands for Evolution Guided General Optimization via Low-rank Learning. The method is illustrated in the topmost figure.
To explain each part of the acronym, we will first give a brief overview on how Evolution Strategies works; for a more complete description of the history and variants, we highly recommend Lilian Weng's blog post on ES.
Evolution Strategies work by directly optimizing the parameters of a neural network by sampling random perturbations and shifting the parameters towards the perturbations that give the best fitness. Mathematically, OpenAI's ES formulation is represented as:
\[\nabla_{\theta}\mathbb{E}_{\epsilon\sim N(0,I)} F(\theta+\sigma\epsilon) = \frac{1}{\sigma}\mathbb{E}_{\epsilon\sim N(0,I)}\{F(\theta+\sigma\epsilon)\epsilon\}\]
where \(F\) is the fitness function, which measures how good a specific set of parameters are at the task at hand similar to the reward function in RL, \(\theta\) are the parameters you are optimizing, and \(\sigma\) is the standard deviation of the noise to add to the parameters.
In OpenAI's Evolution Strategies, we sample from a normal distribution independently for each parameter. In jax, this can be represented as follows for a standard matrix multiplication, where thread_id is the index of the population member you are evaluating:
1: def forward(base_perturbation_key, sigma, parameter, x, thread_id): 2: key = jax.random.fold_in(base_perturbation_key, thread_id) 3: perturbation = jax.random.normal(key, parameter.shape) * sigma 4: return x @ (parameter + perturbation).T 5: 6: batch_forward = jax.vmap(forward, in_axes=(None, None, None, 0, 0))
Note that standard matrix multiplication has now turned into a batched matrix multiplication, which is extremely inefficient on GPUs for large matrices and large populations.2
Our approach is instead to structure our perturbations to explicitly be low-rank. This enables us to do a large standard matrix multiplication, alongside a batched vector-vector multiplication and batched scalar-vector multipliation at rank 1. This is extremely fast and scalable, giving us the throughput curves in the headline image (less than 10% slower than pure non-lora inference). In jax, this is represented as:
1: def forward(base_perturbation_key, sigma, parameter, x, thread_id, rank=1): 2: key = jax.random.fold_in(base_perturbation_key, thread_id) 3: a, b = parameter.shape 4: perturbation = jax.random.normal(key, (a+b, rank)) 5: B = lora_params[:b] # b x r 6: A = lora_params[b:] # a x r 7: return x @ parameter.T + x @ B @ A.T * sigma 8: 9: batch_forward = jax.vmap(forward, in_axes=(None, None, None, 0, 0))
In the limit as rank increases,3 the "gradient" estimate just standard ES with the structured matrix:
\[ \nabla_{\theta}\mathbb{E}_{\epsilon_1, \epsilon_2 \sim N(0,I_{d})} F(\theta+\sigma\epsilon_2 \epsilon_1^T) = \frac{1}{\sigma}\mathbb{E}_{\epsilon_1,\epsilon_2\sim N(0,I_{d})}\{F(\theta+\sigma\epsilon_2 \epsilon_1^T)\epsilon_2 \epsilon_1^T\} \]
Note that although individual perturbations are low-rank, the expression on the right side is actually high-rank, due to the properties of sums of low-rank matrices. We directly fuse this high rank update into the parameters at each update step. In comparison, using ES to directly optimize LoRA matrices will still be restricted to a low-rank update regardless of the population size, which may be enough for LLM reasoning but not sufficient for pretraining or supervised fine-tuning.