116 points | by og_kalu10 hours ago
weight = attention(token_query, weight_keys, weight_values).
In other words, they query weight_keys to fetch the weight_values, and mix them to compute each weight on the spot.Increasing model size becomes a matter of adding more weight_keys and weight_values, and incrementally training them.
Simple, clever, and it seems to work well. Beautiful.
https://www.desmos.com/calculator/3rtqsyapxo
The above assumes the columns of K are normalised but bear with me. K and V together form a vector database. V are the payloads, each row containing a vector of data. K describes the position of these points in space, on the surface of a hypershpere. The query vector describes the query into the database: the vector direction describes the point in space that's being queried, the vector magnitude describes the radius of the query. The result is the weighted average of vectors from V, weighted by their distance from the query vector scaled by the query radius (which has a smooth Gaussian falloff). A recent paper from Nvidia I recommend, which derives a significant speedup by normalising vectors to a hypershpere: https://arxiv.org/abs/2410.01131v1
Can you explain the desmos plot in simple terms?
Disclaimer: these are from my memory, which can be wrong entirely.
Total aside, but imagining how many levels of functions are present in the calculation of each activation here, and thinking about how regular old differentiation and gradient descent actually work to train these nested parameters, is truly amazing, in my opinion.
One interesting thing to note: sounds like model scaling happens on the fly by adding key-value pairs as rows in the K and V matrices on the Pattention layer. That suggests that weights represented by tokens in the first rows may be more important than weights in later rows. There may be a lot you could do with that ordering of weights in terms of pruning and such.
Consider a case of two "experts" or two "value parameter tokens."
The mixture of experts has a "router" network that provides a weight to each expert (through a softmax) conditional on an input. The output is a (sparse) weighted sum of the outputs of the experts.
The TokenFormer has an "attention" layer combines the token and a key value to provide a weight to each "value parameter" token. A(B+C) = AB + AC definitionally, so this is like applying a weighted sum of distinct transformations.
I think the differences are: a) where the non-linearity hits (the above description doesn't consider an activation function), b) this attention softmax is not (necessarily) sparse, c) that "mixtral" networks only replace the feed-forward components of the layer, and d) that extending a "mixtral" approach would require re-training the "router" layers.
It seems like (d) is maybe the nicest feature here... my intuition would think (a) doesn't matter much, (b) is debatable (how close a sparse-MoE can approximate a dense-MoE), (c) has probably been tried (guessing the ffwd limitation was just "more-bang-for-buck-given-parameters" not an oversight)...
... I wonder, though, if there might be diminishing returns here (I believe that Mixture-of-Experts tends to struggle with imbalanced "winner-take-all" dynamics, since "early" winners get more gradient signal to improve their weights) and how different this would have been from going from 3x7B to a 8x7B to a 24x7B training approach (with a "retrain routing networks" step).