Settings

Theme

The Training Example Lie Bracket

pbement.com

33 points by pb1729 3 months ago · 17 comments

Reader

Majromax 3 months ago

Wait a second, they define the induced vector field (and consequently Lie bracket) in terms of batch-size 1 SGD:

> In particular, if x is a training example and L(x) is the per-example loss for the training example x, then this vector field is: v^(x)(θ) = -∇_θ L(x). In other words, for a specific training example, the arrows of the resulting vector field point in the direction that the parameters should be updated.

but for the MXResNet example:

> The optimizer is Adam, with the following parameters: lr = 5e-3, betas = (0.8, 0.999)

This changes the direction of the updates, such that I'm not completely sure the intuitive equivalence holds.

If it were just SGD with momentum, then the measured update directions would be a combination of the momentum vector and v1/v2, so {M + v1, M + v2} = {v1, M} + {M, v2} + {v1, v2}. The Lie bracket is no longer "just" a function of the model parameters and the training examples; it's now inherently path dependent.

For Adam, the parameter-wise normalization by the second norm will also slightly change the directions of the updates in a nonlinear way (thanks to the β2 term).

The interpretation is also strained with fancier optimizers like Muon; this uses both momentum and (approximate) SVD normalization, so I'm really not sure what to expect.

  • pb1729OP 3 months ago

    Yeah, this is a good point. IIRC, I wasn't able to get the network to train very well at all with standard SGD. I don't think I thought to try Adam with β1 = 0, I will try it (& recompute brackets) if I get some time.

    If we have built up a momentum M, then the two orderings are:

    M' = M + εv1

    θ' = θ + M' = θ + M + εv1

    M'' = M' + εv2(θ') = M + εv1 + ε(v2 + (M + εv1)⋅∇v2)

    M' = M + εv2

    θ' = θ + M' = θ + M + εv2

    M'' = M' + εv1(θ') = M + εv2 + ε(v1 + (M + εv2)⋅∇v1)

    Then the resulting difference in momenta M'' is:

    ε^2*[v1, v2] + ε(M⋅∇)(v2 - v1)

    So there is an extra term which is not actually a Lie bracket itself. I think the bracket can still be informative on its own, but it's definitely no longer the sole component of what happens when order is swapped.

    One other inconsistency that is a little less bad is BatchNorm. Since it needs a whole batch to work, and we're just comparing individual examples, I computed the Lie brackets with the BatchNorm layers in eval mode, not train mode.

    I don't know if there is any relevance of this to Muon, even if so, it would likely be very messy to compute.

    • Majromax 3 months ago

      Well, the "vector field defined by the update attributable to this training sample" is a well-defined thing (even if it's not just the gradient of loss with respect to parameters), so that part translates.

      However, what's harder to interpret is how this field transports with respect to θ, since the momentum vector and θ are themselves inextricably linked. If you somehow arrived at a different θ, then you'd have a different momentum. (On the gripping hand, the bracket is a construct of infinitesimals, maybe that doesn't matter.)

thaumasiotes 3 months ago

> An ideal machine learning model would not care what order training examples appeared in its training process. From a Bayesian perspective, the training dataset is unordered data and all updates based on seeing one additional example should commute with each other.

One of Andrew Gelman's favorite points to make about science 'as practiced' is that researchers fail to behave this way. There's a gigantic bias in favor of whatever information is published first.

  • Ifkaluva 3 months ago

    I think most ML models don’t have this property. Usually it’s assumed that the training samples are “independently identically distributed”.

    This is the key insight that causes the DQN algorithm to maintain a replay buffer, and randomly sample from that buffer, rather than feed in the training examples as they come, since they would have strong temporal correlation and destabilize learning.

    An easy way to wreck most ML models is to feed the examples in a way that they are correlated. For example in a vision system to distinguish cats and dogs, first plan to feed in all the cats. Even worse, order the cats so there are minimal changes from one to the next, all the white cats first, and every time finding the most similar cat to the previous one. That model will fail

measurablefunc 3 months ago

Eventually ML folks will discover fiber bundles.

  • Y_Y 3 months ago

    But what bastard "new" name will they give them?

  • esafak 3 months ago

    Sooner if you explain why.

    • senderista 3 months ago

      Off the top of my head, connections on fiber bundles (which define a notion of "parallel transport" of points in the total space, allowing you to "lift" curves from the base space to the total space) are more general than Riemannian metrics, so maybe there are some ML concepts that can be naturally represented by a connection on a principal bundle but not by a metric on a Riemannian manifold? At least this approach has been useful in gauge theory; there must be enough theoretical physicists working in ML that someone would have tried to apply fiber bundle concepts.

eden-u4 3 months ago

I don't understand the RMS table, shouldn't it be non commutative? Eg "example 0 vs 1"'s RMS != "example 1 vs 0"'s RMS? Which doesn't seem the case for the checkpoints I checked.

  • pb1729OP 3 months ago

    If I understand your question correctly, the answer is that not only are the Lie brackets non-commutative, they're anti commutative (swapping the order negates the bracket). But this ironically means they end up having the same RMS, because the squaring part of the RMS gets rid of the sign.

avmich 3 months ago

https://en.wikipedia.org/wiki/Leigh_Brackett has nothing to do with it...

willrshansen 3 months ago

Was hoping for a tournament bracket of best lies found in training data :(

E-Reverance 3 months ago

Could this be used for batch filtering?

  • measurablefunc 3 months ago

    Lie brackets are bi-linear so whatever you do per example automatically carries over to sums, the bracket for the batch is just the pairwise brackets for elements in the batch, i.e. {a + b + c, d} = {a, d} + {b, d} + {c, d}. Similarly for the second component.

    • thaumasiotes 3 months ago

      > Similarly for the second component.

      Hmm.

      {a + b, c + d} = {a, c + d} + {b, c + d} = {a, c} + {a, d} + {b, c} + {b, d}.

      {a + b + c, x + y + z} = {a, x + y + z} + {b, x + y + z} + {c, x + y + z} = (a sum of nine direct brackets).

      This doesn't look like it will scale well.

Keyboard Shortcuts

j
Next item
k
Previous item
o / Enter
Open selected item
?
Show this help
Esc
Close modal / clear selection