10 comments

  • cs7029 hours ago
    The authors factorize every weight matrix with an attention mechanism:

      weight = attention(token_query, weight_keys, weight_values).
    
    In other words, they query weight_keys to fetch the weight_values, and mix them to compute each weight on the spot.

    Increasing model size becomes a matter of adding more weight_keys and weight_values, and incrementally training them.

    Simple, clever, and it seems to work well. Beautiful.

    • szcs8 hours ago
      There is a particularly nice geometric interpretation of attention I just realised recently in a flash of enlightenment, best explained with an interactive Desmos plot (black dot is draggable):

      https://www.desmos.com/calculator/3rtqsyapxo

      The above assumes the columns of K are normalised but bear with me. K and V together form a vector database. V are the payloads, each row containing a vector of data. K describes the position of these points in space, on the surface of a hypershpere. The query vector describes the query into the database: the vector direction describes the point in space that's being queried, the vector magnitude describes the radius of the query. The result is the weighted average of vectors from V, weighted by their distance from the query vector scaled by the query radius (which has a smooth Gaussian falloff). A recent paper from Nvidia I recommend, which derives a significant speedup by normalising vectors to a hypershpere: https://arxiv.org/abs/2410.01131v1

      • jdthedisciple2 hours ago
        It looks fascinating, but i don't understand it. I'm haven't gone yet deeply into the theory of attention networks.

        Can you explain the desmos plot in simple terms?

      • liuliu7 hours ago
        Yeah, I believe this intuition first introduced by the Neural Turing Machine line-of-work and later simplified into AIAYN paper (NTM maintains "external memory" a.k.a. weight_keys, weight_values here).

        Disclaimer: these are from my memory, which can be wrong entirely.

    • anon2918 hours ago
      I believe there have been studies showing that the attention mechanism allows estimation of gradients for one-shot learning (i.e, based on what you tell the model you want in the input, it will use attention to 'update' the weights of the linear layers to 'learn' new information). This seems to be taking that one step further and just using attention for the weight estimations itself. The key insight here is that by adding more tokens to the weight estimation calculation, you can get more degrees of freedom.

      Total aside, but imagining how many levels of functions are present in the calculation of each activation here, and thinking about how regular old differentiation and gradient descent actually work to train these nested parameters, is truly amazing, in my opinion.

      • cs7028 hours ago
        Yeah. This thing is "assembling a different transformer" on the spot for each token.

        If one thinks about it for more than a moment, it's kind of incredible that it works.

        • 0-_-06 hours ago
          I think the same about regular neutral networks
  • goldenshale4 hours ago
    This is a great idea. Being able to dynamically scale up model sizes as datasets and use cases expand without needing to retrain from scratch could enable a Cambrian explosion of interesting stuff building on top of a Llama type model trained in this way.
  • valine5 hours ago
    I would like to see a comparison for the inference time compute between a regular transformer and this. I’m assuming token/s is lower since you need to compute the weights of the model for each token prior to the actual attention calculations for the sequence position.
    • logicchains4 hours ago
      Isn't that figure 5 in the paper? It's for training not inference, but presumably if training is faster then inference would be too. Because they don't increase the dimension of the text tokens when scaling up, which reduces the compute needed for attention. But potentially limits how well the text token attention can keep track of things, because it's got less space for passing things along.
  • davesque6 hours ago
    Seems like a big deal. I feel like this could enable a new level of modularity and compatibility between publicly available weight sets, assuming they use similar channel dimensions. Maybe it also provides a nice formalism for thinking about fine tuning, where you could adopt certain heuristics for adding/removing key-value pairs from the Pattention layers.

    One interesting thing to note: sounds like model scaling happens on the fly by adding key-value pairs as rows in the K and V matrices on the Pattention layer. That suggests that weights represented by tokens in the first rows may be more important than weights in later rows. There may be a lot you could do with that ordering of weights in terms of pruning and such.

    • valine5 hours ago
      Unless I’m reading it wrong I don’t think rows matter. Attention doesn’t take into account sequence position natively, that’s why positional encodings exist.
      • davesque4 hours ago
        I'm talking about the rows in the new K and V matrices introduced by the paper, not rows in the input sequence. The ordering of rows in the new K and V matrices does matter in the sense that rows that appear further down were added later in the training process to add new parameter tokens during scaling. So those newer parameters may represent knowledge that is less fundamental and more about fine tuning on the training set.
  • c0g6 hours ago
  • mentalically3 hours ago
    Eventually people will figure out how to nest neural networks in the nodes and edges of an arbitrary graph.
  • davesque4 hours ago
    Seems like a lot of existing models could be converted to this token parameter representation.
  • ml_thoughts4 hours ago
    This seems closely related to the "Mixtral" approach of a mixture-of-experts transformer [1]... I'm not claiming the approach is not original, it just helped me understand what was going on.

    Consider a case of two "experts" or two "value parameter tokens."

    The mixture of experts has a "router" network that provides a weight to each expert (through a softmax) conditional on an input. The output is a (sparse) weighted sum of the outputs of the experts.

    The TokenFormer has an "attention" layer combines the token and a key value to provide a weight to each "value parameter" token. A(B+C) = AB + AC definitionally, so this is like applying a weighted sum of distinct transformations.

    I think the differences are: a) where the non-linearity hits (the above description doesn't consider an activation function), b) this attention softmax is not (necessarily) sparse, c) that "mixtral" networks only replace the feed-forward components of the layer, and d) that extending a "mixtral" approach would require re-training the "router" layers.

    It seems like (d) is maybe the nicest feature here... my intuition would think (a) doesn't matter much, (b) is debatable (how close a sparse-MoE can approximate a dense-MoE), (c) has probably been tried (guessing the ffwd limitation was just "more-bang-for-buck-given-parameters" not an oversight)...

    ... I wonder, though, if there might be diminishing returns here (I believe that Mixture-of-Experts tends to struggle with imbalanced "winner-take-all" dynamics, since "early" winners get more gradient signal to improve their weights) and how different this would have been from going from 3x7B to a 8x7B to a 24x7B training approach (with a "retrain routing networks" step).

    [1] https://arxiv.org/abs/2401.04088

  • logicchains4 hours ago
    Seems this would naturally translate into a mixture of experts by using a "hard" attention function so that only a fixed amount of weight tokens get included in the calculation.
  • a_wild_dandan6 hours ago
    This could be revolutionary. The PPL/compute graphs are damning. If the Transformer is a function, then the TokenFormer feels like a higher-order function. Perhaps this approach is a natural direction for producing System Two reasoning? There's so much to digest here...