Transformers Energized: Energy-Based Transformers (EBTs) use gradient descent to gradually predict the next token
A new type of transformer can check its work. Instead of guessing the next output token in one shot like a typical transformer, it starts with a rough version of the token and improves it step by step.
What’s new: Alexi Gladstone and colleagues at University of Virginia, University of Illinois Urbana-Champaign, Amazon, Stanford, and Harvard proposed the Energy-Based Transformer (EBT). Early experiments show that it scales more efficiently than transformers at relatively small sizes.
Energy-based model basics: For a given input context paired with a candidate response (for example, a prompt and potential next token), an energy-based model produces a number called “energy” that represents how likely the potential next token would follow the prompt. During training, the model learns to assign low energy if a context/potential-response pair is very likely and high energy if it’s not.
Key insight: A typical transformer is trained to predict the next token directly, while an energy-based model learns how to score an input text. How would a researcher use an energy-based model to predict the next text token? A naive way would be to measure the energy of an input prompt with a random token, randomly modify the text token a number of times, and select the prompt-token combination with the lowest energy. Instead of random modification, a model can use gradient descent repeatedly to compute the change needed to decrease the token’s energy. This process enables the model to refine the token over several steps, ultimately producing a token with low energy (and high likelihood to follow the previous text).
How it works: Among other models, the authors trained a 44 million-parameter autoregressive EBT on the RedPajama-Data-v2 dataset of 32 billion text tokens scraped from the web. As input, EBT received a sequence of tokens and a probability vector (for the next token). It learned to output an energy score that measured the likelihood that the predicted next token would follow the context.
- During training, given a text prompt and a random guess for the probability vector, the model computed the energy. It refined the vector (leaving the model weights unchanged) by backpropagating to compute the change in the vector needed to decrease the predicted energy, and then it updated the vector. It repeated this process for a fixed number of steps, producing a predicted probability vector.
- The loss function encouraged the model to minimize the difference between the predicted probability vector and the ground-truth vector (1 for the right token, 0 for all others).
- At inference, given an input, the model predicted the next token by starting with a random probability vector and refining it through a fixed number of steps.
Results: The authors compared EBTs and transformers of the same sizes and trained on the same numbers of tokens by measuring perplexity (a measure of the likelihood that a model will predict the next word, lower is better) on several benchmarks including math problems, question answering, and reading comprehension. Overall, EBT proved to be better at generalization but worse at generating text that followed the training data’s distribution. EBTs in the sizes tested proved to be significantly less compute-efficient than transformers, but they scaled better, and larger versions may be more efficient than transformers.
- On three out of four popular benchmarks, the EBT achieved better perplexity than a vanilla transformer of the same size and trained on the same number of tokens. The EBT beat the transformer on GSM8K math problems (43.3 to 49.6), BIG-bench Elementary Math QA (72.6 to 79.8), and BIG-bench Dyck Languages, which tests closing brackets or parentheses accurately (125.3 to 131.5). On the SQuAD test of reading comprehension, EBT underperformed the transformer (53.1 to 52.3).
- On a held-out portion of the dataset, the EBT achieved slightly worse perplexity than the transformer (33.43 to 31.36).
- The authors trained several EBTs and transformers using model sizes and training-step counts dictated by transformer scaling laws and trained the models using roughly 1016 to 1020 FLOPs. The EBTs required about 10 times more FLOPs than transformers to reach the same perplexity. However, per additional FLOP, the EBTs’ perplexity improved 3 percent faster than the transformers’, so larger EBTs trained on more data for more steps may achieve higher perplexity using fewer FLOPs.
- The authors built autoregressive video models and vision transformers with similarly promising results.
Why it matters: This work offers intriguing possibilities for higher performance at larger scales. A typical transformer learns to predict the next token directly, but that locks it into a single forward pass per output token and provides no built-in measure of whether the prediction is good. In contrast, EBT learns to assign a score that it uses both to generate tokens (by iteratively lowering their energy) and to verify them (by checking if the energy is high). Work remains to learn whether larger EBTs can be more compute-efficient.
We’re thinking: When it comes to energy, AI research is a renewable resource!