Paper deep dive
Uncovering Mesa-Optimization Algorithms in Transformers
Johannes von Oswald, Maximilian Schlegel, Alexander Meulemans, Seijin Kobayashi, Eyvind Niklasson, Nicolas Zucchet, Nino Scherrer, Nolan Miller, Mark Sandler, Blaise Agüera y Arcas, Max Vladymyrov, Razvan Pascanu, João Sacramento
Models: Hybrid-mesa Transformer, Linear self-attention Transformer, Softmax Transformer
Intelligence
Status: succeeded | Model: google/gemini-3.1-flash-lite-preview | Prompt: intel-v1 | Confidence: 94%
Last extracted: 3/12/2026, 7:10:34 PM
Summary
The paper investigates the phenomenon of in-context learning in Transformer models, proposing that it arises from a 'mesa-optimization' process. The authors demonstrate that training Transformers on next-token prediction tasks implicitly installs a gradient-based optimization algorithm within the model's forward pass, allowing it to adapt to new sequences without parameter updates.
Entities (5)
Relation Signals (3)
Transformer → exhibits → In-context learning
confidence 95% · Some autoregressive models exhibit in-context learning capabilities
Next-token prediction → enables → Mesa-optimization
confidence 90% · standard next-token prediction error minimization gives rise to a subsidiary learning algorithm
Mesa-optimization → implements → Gradient-based optimization
confidence 90% · this process corresponds to gradient-based optimization of a principled objective function
Cypher Suggestions (2)
Find all mechanisms that explain in-context learning · confidence 90% · unvalidated
MATCH (m:Mechanism)-[:EXPLAINS]->(p:Phenomenon {name: 'In-context learning'}) RETURN m.nameIdentify the relationship between training objectives and learning mechanisms · confidence 85% · unvalidated
MATCH (t:Objective)-[r:ENABLES]->(m:Mechanism) RETURN t.name, m.name
Abstract
Abstract:Some autoregressive models exhibit in-context learning capabilities: being able to learn as an input sequence is processed, without undergoing any parameter changes, and without being explicitly trained to do so. The origins of this phenomenon are still poorly understood. Here we analyze a series of Transformer models trained to perform synthetic sequence prediction tasks, and discover that standard next-token prediction error minimization gives rise to a subsidiary learning algorithm that adjusts the model as new inputs are revealed. We show that this process corresponds to gradient-based optimization of a principled objective function, which leads to strong generalization performance on unseen sequences. Our findings explain in-context learning as a product of autoregressive loss minimization and inform the design of new optimization-based Transformer layers.
Tags
Links
- Source: https://arxiv.org/abs/2309.05858
- Canonical: https://arxiv.org/abs/2309.05858
Full Text
352,578 characters extracted from source content.
Expand or collapse full text
jvoswald@google.com, joaosacramento@google.com Uncovering mesa-optimization algorithms in Transformers Johannes von Oswald Google, Paradigms of Intelligence Team ETH Zürich Contributed equally to this work. Maximilian Schlegel Google, Paradigms of Intelligence Team ETH Zürich Contributed equally to this work. Alexander Meulemans Google, Paradigms of Intelligence Team ETH Zürich Seijin Kobayashi Google, Paradigms of Intelligence Team ETH Zürich Eyvind Niklasson Google, Paradigms of Intelligence Team Nicolas Zucchet ETH Zürich Nino Scherrer Google, Paradigms of Intelligence Team Nolan Miller Google Research Mark Sandler Google Research Blaise Agüera y Arcas Google, Paradigms of Intelligence Team Max Vladymyrov Google Research Razvan Pascanu Google DeepMind João Sacramento Google, Paradigms of Intelligence Team ETH Zürich Contributed equally to this work. Abstract Some autoregressive models exhibit in-context learning capabilities: being able to learn as an input sequence is processed, without undergoing any parameter changes, and without being explicitly trained to do so. The origins of this phenomenon are still poorly understood. Here we analyze a series of Transformer models trained to perform synthetic sequence prediction tasks, and discover that standard next-token prediction error minimization gives rise to a subsidiary learning algorithm that adjusts the model as new inputs are revealed. We show that this process corresponds to gradient-based optimization of a principled objective function, which leads to strong generalization performance on unseen sequences. Our findings explain in-context learning as a product of autoregressive loss minimization and inform the design of new optimization-based Transformer layers. We are currently witnessing a paradigm shift in machine learning. Specialized models trained on large labeled data sets are being replaced by generalist foundation models trained with self-supervision [1]. There is increasing evidence that these models can adapt to a wide range of tasks after a brief period of supervised learning (‘fine-tuning’). Intriguingly, some foundation models are capable of learning directly from contextual input data, without having been explicitly designed or trained to do so. In this way, parameter fine-tuning can often be sidestepped altogether, making on-the-fly adaptation to new tasks possible simply by providing examples in context. This powerful yet puzzling phenomenon, known as in-context learning [2], was first observed in autoregressive large language models (LLMs). A number of recent theoretical studies have begun to shed light on how in-context learning works, and why it arises. A seminal analysis of Transformers, the backbone architecture [3] of the majority of LLMs, identified a two-layer circuit mechanism called ‘induction head’ responsible for in-context learning in shallow networks, and provided evidence for its likely involvement in deeper and more complex networks [4, 5]. A complementary line of work has shown that in-context learning can emerge in small-scale models, as long as the data distribution displays certain properties [6, 7], and that it can vanish under long training times [8]. Building on prior recurrent neural network studies [9, 10, 11, 12], yet another line of investigation has studied the metalearning abilities of Transformers, explicitly training the models to solve supervised learning problems in-context [13, 14, 15, 16]. In such a setup, in-context learning is no longer an emergent phenomenon, but is ‘forced’ by the training regime, simplifying the analysis. Our previous work showing that Transformers solve linear regression tasks by gradient descent [17], later followed by a series of refined studies and mathematical analyses [18, 19, 20, 21, 22, 23, 24, 25, 26], falls under the same category as it also relies on explicit metalearning. In this paper, we continue to analyze the in-context learning abilities of Transformers, but shift our focus to autoregressive sequence prediction tasks. Like LLMs—and most sequence models—we train Transformers in a self-supervised manner by minimizing a next-token prediction error objective. Based on our previous results on metalearned Transformers [17], we then investigate whether the prediction algorithm learned by autoregressive Transformers can be interpreted as gradient-based learning on a suitable contextual objective function. We find that this holds true for a range of synthetic sequence modeling tasks. In such a controlled synthetic data setting, we identify a gradient-based learning mechanism spanning multiple Transformer layers. We refer to this mechanism as a ‘mesa-optimizer’ to emphasize that it is acquired through training, as opposed to being inherent to the model [see 27]. The mesa-optimizer adapts the model as new contextual information becomes available, enabling it to improve its predictions with near-optimal sample efficiency. Moreover, the same mechanism enables learning downstream tasks from contextual demonstrations only. Taken together, our results explain, at least in the settings we have considered, the emergence of in-context learning in Transformers trained only to predict the next token. Results We study autoregressive sequence modeling tasks where the goal is to causally predict, at every time step t=1,…,T−11…1t=1,…,T-1t = 1 , … , T - 1, the next element et+1subscript1e_t+1eitalic_t + 1 in a sequence of tokens e=(et)t=1Tsuperscriptsubscriptsubscript1e=(e_t)_t=1^Te = ( eitalic_t )t = 1T, given the past (et′)t′=1tsuperscriptsubscriptsubscriptsuperscript′1(e_t )_t =1^t( eitalic_t′ )t′ = 1t as context. We examine a range of causally masked Transformer models [3] trained to solve such problems, from simple attention-only models to full-fledged deep Transformers comprising multiple attention layers, layer normalization [LayerNorm; 28], and nonlinear multi-layer perceptron (MLP) blocks, cf. Materials and Methods. The objective of training is to find a set of parameters θ that minimize the cumulative next-token prediction error ℒ(θ)=e∼p(e)[12∑t=1T−1‖et+1−ft(e1:t,θ)‖2],ℒsubscriptsimilar-todelimited-[]12superscriptsubscript11superscriptnormsubscript1subscriptsubscript:12L(θ)=E_e p(e)\! [ 12 _t=1^T-1% \|e_t+1-f_t(e_1:t,θ)\|^2 ],L ( θ ) = blackboard_Ee ∼ p ( e ) [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑t = 1T - 1 ∥ eitalic_t + 1 - fitalic_t ( e1 : t , θ ) ∥2 ] , (1) where ft(e1:t,θ)subscriptsubscript:1f_t(e_1:t,θ)fitalic_t ( e1 : t , θ ) denotes the Transformer output conditioned on the context e1:tsubscript:1e_1:te1 : t, and the expectation is taken over the sequence distribution p(e)p(e)p ( e ), which we describe next. We focus on continuous-state problems, with et∈ℝnesubscriptsuperscriptℝsubscripte_t ^n_eeitalic_t ∈ blackboard_Rnitalic_e, and take the squared error as the per-time-step loss, the standard objective for autoregressive problems with continuous outputs. Our Transformers are trained on synthetic sequences (st)t=1Tsuperscriptsubscriptsubscript1(s_t)_t=1^T( sitalic_t )t = 1T of observations st∈ℝnssubscriptsuperscriptℝsubscripts_t ^n_ssitalic_t ∈ blackboard_Rnitalic_s generated by discrete-time dynamical systems. As we detail in Materials and Methods, we consider a range of sequence generators described in state-space representation: (i) linear systems with full observability (st=htsubscriptsubscriptℎs_t=h_tsitalic_t = hitalic_t) of the internal system state ht∈ℝnhsubscriptℎsuperscriptℝsubscriptℎh_t ^n_hhitalic_t ∈ blackboard_Rnitalic_h; (i) partially-observed linear systems, where we only allow access to a low-dimensional state projection, st=C∗htsubscriptsuperscriptsubscriptℎs_t=C^*h_tsitalic_t = C∗ hitalic_t; (i) nonlinear dynamics, with state transitions governed by nonlinear neural networks . Typically, each token corresponds to one observation, et=stsubscriptsubscripte_t=s_teitalic_t = sitalic_t, but we also study tokenization schemes that aggregate several observations within one token. These aggregate token representations play an important role in the theory we develop below. Next-token prediction by mesa-optimization In this paper, we hypothesize that training Transformers on next-token prediction tasks as described above installs a gradient-based, in-context optimization algorithm in the forward pass of the model. Following the terminology of Hubinger et al. [27], we refer to this hypothetical acquired process as mesa-optimization, to distinguish it from the base-optimization of Eq. 1, over which we have explicit control. More concretely, we hypothesize that generating the future-token prediction ft(e1:t,θ)subscriptsubscript:1f_t(e_1:t,θ)fitalic_t ( e1 : t , θ ) involves using the current and past tokens e1:tsubscript:1e_1:te1 : t to build a sequence-specific latent model on the fly. We focus on the case where this model is linear in its parameters, which we denote by Φ Φ. According to our mesa-optimization hypothesis, trained Transformers successively learn a sequence of such parameters ΦtsubscriptΦ _tΦitalic_t as input tokens are gradually revealed, by minimizing an in-context objective function Lt(e1:t,Φ)subscriptsubscript:1ΦL_t(e_1:t, )Litalic_t ( e1 : t , Φ ) using gradient information ∇ΦLt(e1:t,Φ)subscript∇Φsubscriptsubscript:1Φ _ L_t(e_1:t, )∇Φ Litalic_t ( e1 : t , Φ ). The resulting in-context models are then used to generate the Transformer predictions. It is important to appreciate that these in-context latent models and their learning rules are not explicitly hardwired in the Transformer design, but are instead a by-product of base-optimization. The parameters Φ Φ may thus be thought of as an implicit type of fast (i.e., sequence-specific) weights [29, 30] which live in the short-term memory of a Transformer model, not in its learned parameters. Before verifying whether our hypothesis holds for trained models, we first show that in theory, autoregressive linear Transformers are capable of optimizing quadratic loss functions in-context. We show this constructively, by providing a set of parameters θ such that a linear Transformer implements a mesa-optimizer. This construction will then guide our analyses of trained models. Theory of self-attention mesa-optimizers Our first theoretical result concerns a single layer of causally-masked self-attention, the architectural component at the heart of an autoregressive Transformer; we will later consider deeper, more complex architectures. Given an input sequence (et)t=1Tsuperscriptsubscriptsubscript1(e_t)_t=1^T( eitalic_t )t = 1T, one such layer with H heads updates each token et←et+Δetsa←subscriptsubscriptΔsuperscriptsubscriptsae_t← e_t+ e_t^saeitalic_t ← eitalic_t + Δ eitalic_tsa following the rule Δetsa=∑h=1HPhVh,tα(Kh,t⊤qh,t),Δsuperscriptsubscriptsasuperscriptsubscriptℎ1subscriptℎsubscriptℎsuperscriptsubscriptℎtopsubscriptℎ e_t^sa= _h=1^HP_hV_h,t\,α(K_h,t q_h% ,t),Δ eitalic_tsa = ∑h = 1H Pitalic_h Vitalic_h , t α ( Kitalic_h , t⊤ qitalic_h , t ) , (2) where qh,t=Wh,qet∈ℝnasubscriptℎsubscriptℎsubscriptsuperscriptℝsubscriptq_h,t=W_h,qe_t ^n_aqitalic_h , t = Witalic_h , q eitalic_t ∈ blackboard_Rnitalic_a is referred to as a query, each column kh,t′=Wh,ket′∈ℝnasubscriptℎsuperscript′subscriptℎsubscriptsuperscript′ℝsubscriptk_h,t =W_h,ke_t ^n_akitalic_h , t′ = Witalic_h , k eitalic_t′ ∈ blackboard_Rnitalic_a of matrix Kh,t∈ℝna×tsubscriptℎsuperscriptℝsubscriptK_h,t ^n_a× tKitalic_h , t ∈ blackboard_Rnitalic_a × t as a key, and each column vh,t′=Wh,vet′∈ℝnvsubscriptℎsuperscript′subscriptℎsubscriptsuperscript′ℝsubscriptv_h,t =W_h,ve_t ^n_vvitalic_h , t′ = Witalic_h , v eitalic_t′ ∈ blackboard_Rnitalic_v of matrix Vh,t∈ℝnv×tsubscriptℎsuperscriptℝsubscriptV_h,t ^n_v× tVitalic_h , t ∈ blackboard_Rnitalic_v × t as a value. The parameters of this layer are the projection matrices (Ph,Wh,q,Wh,k,Wh,v)h=1Hsuperscriptsubscriptsubscriptℎsubscriptℎsubscriptℎsubscriptℎ1\(P_h,W_h,q,W_h,k,W_h,v)\_h=1^H ( Pitalic_h , Witalic_h , q , Witalic_h , k , Witalic_h , v ) h = 1H for all heads; we absorb bias terms, and assume here for conciseness that all heads are equally sized. The function α applied to vector a∈ℝtsuperscriptℝa ^ta ∈ blackboard_Rt returns an attention weight vector. For the theoretical results presented below, we focus on the case where α is the identity function, which yields the linear self-attention layer, the main building block of linear Transformers [e.g., 31, 32, 33, 34, 35]. In our experimental analyses, we also study standard (softmax) self-attention layers, where α(a)i=softmax(a)i:=(∑t′=1texp(at′))−1exp(ai)subscriptsoftmaxsubscriptassignsuperscriptsuperscriptsubscriptsuperscript′1subscriptsuperscript′1subscriptα(a)_i=softmax(a)_i:=( _t =1^t (a_t^% ))^-1 (a_i)α ( a )i = softmax ( a )i := ( ∑t′ = 1t exp ( aitalic_t′ ) )- 1 exp ( aitalic_i ), present in the original and still most popular Transformer architecture [3]. Figure 1: Illustration of mesa-optimization in autoregressive Transformers. The neural dynamics implements an optimization-based in-context learning algorithm, which optimizes the parameters Φ Φ of a linear model over a series of causally-masked attention layers. Taking as inputs an initial set of parameters Φ0subscriptΦ0 _0Φ0 and a training set of input-target pairs (st′,st′+1)t′=1t−1superscriptsubscriptsubscriptsuperscript′subscriptsuperscript′1superscript′11\(s_t ,s_t +1)\_t =1^t-1 ( sitalic_t′ , sitalic_t′ + 1 ) t′ = 1t - 1 constructed from context, this process returns a prediction Φ^tstsubscript^Φsubscript _ts_tover start_ARG Φ end_ARGt sitalic_t obtained by applying the learned model to the current input. Early layers implement a copy operation which binds multiple consecutive tokens together, in agreement with previous in-context learning analyses [4, 5]. This aggregate-token representation enables the implementation of gradient-based optimizers in subsequent attention layers, cf. Propositions 1 and 2. Consider the cumulative squared-error loss function Lt(Φ)=∑t′=1t−112‖st′+1−Φst′‖2,subscriptΦsuperscriptsubscriptsuperscript′1112superscriptnormsubscriptsuperscript′1Φsubscriptsuperscript′2L_t( )= _t =1^t-1 12\|s_t +1- s_t^% \|^2,Litalic_t ( Φ ) = ∑t′ = 1t - 1 divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ sitalic_t′ + 1 - Φ sitalic_t′ ∥2 , (3) where Φ∈ℝns×nsΦsuperscriptℝsubscriptsubscript ^n_s× n_sΦ ∈ blackboard_Rnitalic_s × nitalic_s parametrizes a first-order linear autoregressive model which predicts st+1subscript1s_t+1sitalic_t + 1 from stsubscripts_tsitalic_t. We show that one linear self-attention layer can implicitly represent such a model in its activations, with mesa-parameters Φ Φ learned by a step of gradient descent on the mesa-objective Lt(Φ)subscriptΦL_t( )Litalic_t ( Φ ). Proposition 1 (1-step attention-based gradient descent). Given tokens of the form et=[Φ0st,st,st−1]subscriptsubscriptΦ0subscriptsubscriptsubscript1e_t=[ _0s_t,s_t,s_t-1]eitalic_t = [ Φ0 sitalic_t , sitalic_t , sitalic_t - 1 ], for t=2,…,T2…t=2,...,Tt = 2 , … , T, if the projection matrices WksubscriptW_kWitalic_k, WqsubscriptW_qWitalic_q, WvsubscriptW_vWitalic_v, P are such that PWv=[0ηIs−ηΦ0000000],Wk⊤Wq=[0000000Is0],formulae-sequencesubscriptmatrix0subscriptsubscriptΦ0000000superscriptsubscripttopsubscriptmatrix0000000subscript0PW_v= bmatrix0&η I_s&-η _0\\ 0&0&0\\ 0&0&0\\ bmatrix,\;W_k W_q= bmatrix0&0&0\\ 0&0&0\\ 0&I_s&0\\ bmatrix,P Witalic_v = [ start_ARG start_ROW start_CELL 0 end_CELL start_CELL η Iitalic_s end_CELL start_CELL - η Φ0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL 0 end_CELL end_ROW end_ARG ] , Witalic_k⊤ Witalic_q = [ start_ARG start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL Iitalic_s end_CELL start_CELL 0 end_CELL end_ROW end_ARG ] , with IssubscriptI_sIitalic_s the identity matrix of size ns×nssubscriptsubscriptn_s× n_snitalic_s × nitalic_s, then the transformation of every token etsubscripte_teitalic_t by one causally-masked linear self-attention head is identical to the gradient-induced update et←[(Φ0−η∇Lt(Φ0))st,st,st−1]←subscriptmatrixsubscriptΦ0∇subscriptsubscriptΦ0subscriptsubscriptsubscript1e_t← bmatrix( _0-η∇ L_t( _0))s_t,s_t% ,s_t-1 bmatrixeitalic_t ← [ start_ARG start_ROW start_CELL ( Φ0 - η ∇ Litalic_t ( Φ0 ) ) sitalic_t , sitalic_t , sitalic_t - 1 end_CELL end_ROW end_ARG ]. Proposition 1 (proven in Materials and Methods) is an immediate extension of the main result of von Oswald et al. [17] to the autoregressive sequence modeling setting, where T loss functions (Lt)tsubscriptsubscript(L_t)_t( Litalic_t )t must be optimized in sequence, see Materials and Methods for details. Since LtsubscriptL_tLitalic_t is the cumulative squared error up to time t, Proposition 1 implements a ‘full-batch’ gradient step. Notably, the self-attention layer executes this step in all T problems in parallel. We remark that our construction assumes a special three-channel tokenization, where a single token encodes the current input stsubscripts_tsitalic_t, the previous input st−1subscript1s_t-1sitalic_t - 1, and an initial prediction Φ0stsubscriptΦ0subscript _0s_tΦ0 sitalic_t. As illustrated in Fig. 1, we will later show that trained Transformers learn to internally produce such encodings when driven by a standard-format (et=stsubscriptsubscripte_t=s_teitalic_t = sitalic_t) sequence, but for now we proceed under the assumption that the tokens are structured in such a way. We now turn to multi-layer, self-attention-only models. Here, we find that causally-masked autoregressive modeling complicates the problem, in the sense that stacking k layers following Proposition 1 yields an unconventional biased algorithm that is expected to be slower than k-step gradient descent, as analyzed in [23]. There exists, however, an alternative unbiased mesa-optimizer for multi-layer models, which introduces an additional layerwise operation for improving the preconditioning of mesa-optimization. This algorithm again makes use of self-attention layers, now employed to transform the input data. In the limit of many such layers, a single gradient descent step then yields the optimal (least-squares) mesa-optimization solution. Proposition 2 (Multi-attention-layer mesa-optimizer). Assume we are given for every time step t=2,…,T2…t=2,…,Tt = 2 , … , T a sequence of suitably-constructed input tokens (et′)t′=1tsuperscriptsubscriptsubscriptsuperscript′1(e_t )_t =1^t( eitalic_t′ )t′ = 1t, and a regularized mesa-objective we wish to minimize L¯t(Φ)=∑t′=1t−112‖st′+1−Φst′‖2+12λ‖Φ‖F2subscript¯Φsuperscriptsubscriptsuperscript′1112superscriptnormsubscriptsuperscript′1Φsubscriptsuperscript′212superscriptsubscriptnormΦF2 L_t( )= _t =1^t-1 12\|s_t +1- s% _t \|^2+ 12λ|| ||_F^2over¯ start_ARG L end_ARGt ( Φ ) = ∑t′ = 1t - 1 divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ sitalic_t′ + 1 - Φ sitalic_t′ ∥2 + divide start_ARG 1 end_ARG start_ARG 2 λ end_ARG | | Φ | |F2 where λ−1∈ℝsuperscript1ℝλ^-1 λ- 1 ∈ blackboard_R is a regularization hyperparameter and StsubscriptS_tSitalic_t is the data matrix whose columns are (st′)t′=1tsuperscriptsubscriptsubscriptsuperscript′1(s_t )_t =1^t( sitalic_t′ )t′ = 1t. Then, there exists a set of linear Transformer parameters θ that yield an approximation to the vectors Ht∗st:=(St−1St−1⊤+1/λI)−1stassignsuperscriptsubscriptsubscriptsuperscriptsubscript1superscriptsubscript1top11subscriptH_t^*s_t:=(S_t-1S_t-1 +1/λ I)^-1s_tHitalic_t∗ sitalic_t := ( Sitalic_t - 1 Sitalic_t - 1⊤ + 1 / λ I )- 1 sitalic_t in parallel for all t in their forward pass, with approximation error decreasing with the number of linear self-attention layers k. As a consequence, in the many-layer limit the Transformer can minimize the regularized mesa-objective. A concrete parameter construction and proof are provided in the Materials and Methods. Propositions 1 and 2 show that simplified Transformers can, at least in theory, minimize cumulative squared-error objectives in-context, without any actual parameter (‘in-weights’) learning taking place. As we shall see in our experimental section below, these ideal constructions yield solutions to our synthetic tasks, and they generate testable hypotheses that inform our experiments with trained models. Before proceeding to our empirical analyses, we present one last theoretical result motivated by the constructions above: a novel self-attention layer designed for efficient least-squares in-context learning. An attention layer for optimal least-squares learning The mesa-optimizers discussed so far require in general many layers to reach a desired error. This observation leads us to develop the mesa-layer, a self-attention layer derived from in-context optimization first principles. More concretely, we show that an appropriately modified attention layer yields autoregressive least-squares solutions in sequence and in a single step, a computation that would otherwise require infinitely many linear self-attention layers under Proposition 2. Thus, if the mesa-optimization hypothesis advanced in this paper describes actual trained standard Transformers, it should be possible to improve their performance by introducing such a layer in their architecture. The mesa-layer therefore provides one additional way of verifying the mesa-optimization hypothesis in experiments. The mesa-layer changes a sequence of input tokens according to the update ΔetmesaΔsuperscriptsubscriptmesa e_t^mesaΔ eitalic_tmesa =∑h=1HPhΦ^h,tmesaqh,t,absentsuperscriptsubscriptℎ1subscriptℎsuperscriptsubscript^Φℎmesasubscriptℎ = _h=1^HP_h _h,t^mesaq_h,t,= ∑h = 1H Pitalic_h over start_ARG Φ end_ARGh , tmesa qitalic_h , t , (4) with Φ^h,tmesa=argminΦ12∑t′=1t‖vh,t′−Φkh,t′‖2+‖Φ‖F22λh.superscriptsubscript^ΦℎmesasubscriptargminΦ12superscriptsubscriptsuperscript′1superscriptnormsubscriptℎsuperscript′Φsubscriptℎsuperscript′2superscriptsubscriptnormΦF22subscriptℎ _h,t^mesa= *arg\,min_ \,% 12 _t =1^t||v_h,t - k_h,t ||^% 2+ || ||_F^22 _h.over start_ARG Φ end_ARGh , tmesa = start_OPERATOR arg min end_OPERATORΦ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑t′ = 1t | | vitalic_h , t′ - Φ kitalic_h , t′ | |2 + divide start_ARG | | Φ | |F2 end_ARG start_ARG 2 λitalic_h end_ARG . (5) Above, the (learnable) scalar λh−1>0superscriptsubscriptℎ10 _h^-1>0λitalic_h- 1 > 0 controls the strength of a regularizer added to improve generalization, and key, value and query vectors are the usual learned head-specific affine transformations of the tokens, as in Eq. 2. However, through Eq. 5 these vectors are now assigned a precise, interpretable role: value vectors specify targets to which an internal model with parameters Φ Φ should map training and test inputs, represented by keys and queries, respectively. We note that the minimizer of a regularized squared-error objective can be mapped to Eq. 5 under an appropriate tokenization (such as the one of Proposition 1) by appropriately setting the projection matrices Wh,vsubscriptℎW_h,vWitalic_h , v, Wh,ksubscriptℎW_h,kWitalic_h , k and Wh,qsubscriptℎW_h,qWitalic_h , q. At any given time step t=1,…,T1…t=1,…,Tt = 1 , … , T computing ΔetmesaΔsuperscriptsubscriptmesa e_t^mesaΔ eitalic_tmesa requires solving a regularized least-squares problem per attention head. To efficiently solve this sequence of T optimization problems, we leverage the recursive dependency of the T solutions, which can be expressed in closed-form as Φ^h,tmesasuperscriptsubscript^Φℎmesa _h,t^mesaover start_ARG Φ end_ARGh , tmesa =Vh,tKh,t⊤Rh,tabsentsubscriptℎsuperscriptsubscriptℎtopsubscriptℎ =V_h,tK_h,t R_h,t= Vitalic_h , t Kitalic_h , t⊤ Ritalic_h , t =∑t′=1tvh,t′kh,t′⊤(∑t′=1tkh,t′kh,t′⊤+1/λhI)−1.absentsuperscriptsubscriptsuperscript′1subscriptℎsuperscript′subscriptℎsuperscript′topsuperscriptsuperscriptsubscriptsuperscript′1subscriptℎsuperscript′subscriptℎsuperscript′top1subscriptℎ1 = _t =1^tv_h,t k_h,t \!% ( _t =1^tk_h,t k_h,t +1/λ% _h\,I )^\!\!-1.= ∑t′ = 1t vitalic_h , t′ kitalic_h , t′⊤ ( ∑t′ = 1t kitalic_h , t′ kitalic_h , t′⊤ + 1 / λitalic_h I )- 1 . (6) As λh→0→subscriptℎ0 _h→ 0λitalic_h → 0, we recover a standard linear self-attention layer. Thus, the mesa-layer strictly generalizes the latter. We now use the Sherman-Morrison formula [36] to obtain the inverse at time t from the inverse at the previous time step t−11t-1t - 1. This iterative update is possible because we only change the inverse by a rank-one update. The following solution scheme is known as recursive least squares [37]: Rh,t=Rh,t−1−Rh,t−1kh,tkh,t⊤Rh,t−11+kh,t⊤Rh,t−1kh,tsubscriptℎsubscriptℎ1subscriptℎ1subscriptℎsuperscriptsubscriptℎtopsubscriptℎ11superscriptsubscriptℎtopsubscriptℎ1subscriptℎR_h,t=R_h,t-1- R_h,t-1k_h,tk_h,t R_h,t-11+k_h,t^% R_h,t-1k_h,tRitalic_h , t = Ritalic_h , t - 1 - divide start_ARG Ritalic_h , t - 1 kitalic_h , t kitalic_h , t⊤ Ritalic_h , t - 1 end_ARG start_ARG 1 + kitalic_h , t⊤ Ritalic_h , t - 1 kitalic_h , t end_ARG (7) with Rh,0=λhIsubscriptℎ0subscriptℎR_h,0= _h\,IRitalic_h , 0 = λitalic_h I. We can then (causally in time) compute Δetmesa=∑h=1HPhVh,tKh,t⊤Rh,tqh,t,Δsuperscriptsubscriptmesasuperscriptsubscriptℎ1subscriptℎsubscriptℎsuperscriptsubscriptℎtopsubscriptℎsubscriptℎ e_t^mesa= _h=1^HP_hV_h,tK_h,t R_h,tq_h% ,t,Δ eitalic_tmesa = ∑h = 1H Pitalic_h Vitalic_h , t Kitalic_h , t⊤ Ritalic_h , t qitalic_h , t , (8) which requires 2 additional vector-matrix and 2 vector-vector multiplications per step compared to standard self-attention. Naive backward gradient computation requires storing matrices of dimension na×nasubscriptsubscriptn_a× n_anitalic_a × nitalic_a in memory across time. However, this memory overhead can be avoided using the Sherman-Morrison formula in reverse during backpropagation, as we show in the SI Appendix, enabling memory-efficient gradient computation of the output of the mesa-layer w.r.t. its inputs. We note that while the implementation described here has a desirable (1)1O(1)O ( 1 ) inference memory cost, it is not parallelizable across time. This is a disadvantage for training on contemporary hardware shared with nonlinear recurrent neural networks, but not with standard self-attention layers. The mesa-layer is closely related to the Delta-Net model of Schlag et al. [33], which is hardwired to do one gradient descent step per time point. It can also be seen as an adaptation of the intention layer proposed by Garnelo & Czarnecki [38] to the sequential, autoregressive case. The latter corresponds exactly to a non-causally-masked version of Eq. 6. Here, we focus on the autoregressive setting, which leads us to develop recursive forward and backward updates, in order to achieve efficient sequential inference and training. Aggregate internal token representations develop through training Figure 2: Early layers of trained autoregressive Transformers (blue lines) produce internal token representations that support mesa-optimization by subsequent layers. Similar results are obtained for standard deep Transformers and new compact, two-layer model variants which feature the mesa-layer (red lines). (A) After training, the past token et′subscriptsuperscript′e_t eitalic_t′ (t′=49superscript′49t =49t′ = 49) can be almost perfectly linearly decoded from the current (t=5050t=50t = 50) output of the first Transformer layer. The decoding horizon t−t′-t t - t′ increases when the Transformer is trained to solve partially-observed tasks (dashed lines; notice low probing error for t′∈49,48,47,46superscript′49484746t ∈\49,48,47,46\t′ ∈ 49 , 48 , 47 , 46 ). (B) Same analysis, now for the groundtruth hidden state MLP∗(et′)superscriptMLPsuperscriptsubscript′MLP^*(e_t )MLP∗ ( eitalic_t′ ) of a nonlinear sequence generator and for varying layer depth. Current (t′=tsuperscript′t =t′ = t) and preceding (t′=t−1superscript′1t =t-1t′ = t - 1) states can be linearly decoded from early Transformer layers (depicted with lighter color tones) after training on nonlinear tasks. We begin our empirical analysis of Transformer models trained by autoregressive loss (Eq. 1) minimization by searching for evidence of an internal token binding mechanism. Recall that Propositions 1 and 2 required a non-standard token format, in which consecutive observations were aggregated within a single token etsubscripte_teitalic_t. In our first set of experiments, we adopt a standard token format and provide only the current observation stsubscripts_tsitalic_t as the input etsubscripte_teitalic_t to the model. The first prediction of our theory is that training should install a token binding mechanism, responsible for aggregating multi-time-step observation information within a single token. We now show that this indeed occurs in actual trained models. In Fig. 2, we report the performance of linear decoders [probes; 39] trained to predict previous tokens from the output of the first attention layer of a deep Transformer model. We consider both standard Transformer models featuring seven softmax self-attention layers, MLPs and LayerNorm, as well as a novel compact Transformer which combines one layer of softmax self-attention and one mesa-layer, described in detail in Materials and Methods. We relegate architectures solely built out of mesa-layer models to the SI Appendix, as we found that these were generally outperformed by hybrid softmax-mesa architectures. We see that after training it becomes possible to decode past tokens from the present token (see also Fig. 1A), with decoding horizon increasing for partially-observed problems, for both standard softmax Transformers and the novel hybrid softmax-mesa Transformers introduced in this paper. For the fully observed setting, the probe error increases quickly when predicting more than one step in the past, aligned with our token construction that binds together only consecutive tokens. Moreover, when the models are trained on systems with nonlinear dynamics, the performance of linear probes that decode the hidden state of the sequence generator system from the outputs of MLP layers improves, in particular for early MLP layers. These results can be explained by analyzing the tasks the Transformers are trained on. When the input data is generated by a fully-observed linear dynamical system, the maximum likelihood estimator of the groundtruth parameters W∗superscriptW^*W∗ corresponds to the least-squares solution argminΦ∑t′=1t−112‖st′+1−Φst′‖2subscriptargminΦsuperscriptsubscriptsuperscript′1112superscriptnormsubscriptsuperscript′1Φsubscriptsuperscript′2 *arg\,min_ _t =1^t-1 12\|s_t^% +1- s_t \|^2start_OPERATOR arg min end_OPERATORΦ ∑t′ = 1t - 1 divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ sitalic_t′ + 1 - Φ sitalic_t′ ∥2. The mesa-optimizers described in Propositions 1 and 2, as well as the mesa-layer, can be readily applied to solve this problem (or a regularized variant, corresponding to maximum a posteriori estimation under a Gaussian prior on Φ Φ), as long as inputs and targets are both encoded within a single token etsubscripte_teitalic_t. This is what we observe in Fig. 2A. The nonlinear case can be approached similarly, by performing least-squares estimation of W∗superscriptW^*W∗ after an appropriate nonlinear feature transformation. For a Transformer model, the MLP layers are perfectly placed to implement this transformation. In line with this, we find that early layers develop a set of basis functions that align with those of the nonlinear sequence generator, Fig. 2B, followed by a token binding step (cf. SI Appendix). We find that there is a longer dependence on the past under partial observability (Fig. 2A), where next-token prediction is complicated due to the presence of latent variables. This behavior can again be explained in the light of our mesa-optimization hypothesis. First, we note that the task our Transformers face is harder than classical Kalman filtering [40], where knowledge of groundtruth system parameters is assumed. Methods such as subspace identification [41] or variational expectation maximization [42] are applicable to this harder setting, but we found these standard methods difficult to map to a Transformer. We identified however a less-orthodox algorithm, mathematically related to data-driven control techniques [43, 44], which runs online as a sequence is unveiled, and that is based on least-squares estimation. This algorithm can therefore be implemented by a Transformer through Propositions 1 and 2, or by a mesa-layer. The key step is to encode k past observations st−k+1,…,stsubscript1…subscripts_t-k+1,…,s_tsitalic_t - k + 1 , … , sitalic_t in a single augmented variable zt∈ℝknssubscriptsuperscriptℝsubscriptz_t ^kn_szitalic_t ∈ blackboard_Rk nitalic_s of large enough dimensionality; it can then be shown that maximum likelihood estimation of the next token st+1subscript1s_t+1sitalic_t + 1 can be achieved by solving a least-squares problem involving the augmented variables ztsubscript\z_t\ zitalic_t . We provide a full derivation and analysis in Materials and Methods, where we show that the optimal value of k depends on the compression ratio nh/nssubscriptℎsubscriptn_h/n_snitalic_h / nitalic_s. According to this theory, we would expect to see a higher-order (k>11k>1k > 1) dependency on past inputs for the case of partially-observed dynamics. This corresponds to what we find in trained Transformers, cf. Fig. 2. Figure 3: Visualization of activations and parameters for trained models. (A): Additional evidence for a token binding mechanism on a 7-layer Transformer complementing Fig. 2, shown by plotting first-layer attention scores averaged over a batch of 2048 sequences. Clear data-independent attention on the previous and current token is shown resp. by high sub-diagonal and main diagonal attention, with zero everywhere else. (B): One trained layer of linear self-attention implements one step of gradient descent, compare with Proposition 1. We thus conclude that training robustly installs a token binding mechanism in the first Transformer layers across a range of next-token prediction tasks and network architectures. Interestingly, this mechanism exactly coincides with the first layer of the induction head circuit [4, 5, 8], which has inspired the design of new neural architectures [45, 46, 47, 48]. Through their analysis of Transformers trained on natural language modeling, Olsson et al. provide compelling evidence that the appearance of this mechanism during training is strongly correlated with improvements in in-context learning performance. Here, we interpret this phenomenon as part of a multi-layer circuit for mesa-optimization. In light of our theory, token binding can be understood as constructing an in-context training set of appropriate input-output associations. Once this step is concluded, the mesa-objective function is defined, and in-context optimization can take place. Evidence for mesa-optimizers in linear attention-only models We proceed with our analysis of trained Transformers, focusing in this section on simplified linear-attention-only architectures, that can in theory be explained by Propositions 1 and 2. Having shown that a token binding mechanism can be learned, and aiming for the simplest deep Transformer setup, in this section we directly feed our models with aggregate token inputs, et=[0,st,st−1]subscriptmatrix0subscriptsubscript1e_t= bmatrix0,s_t,s_t-1 bmatrixeitalic_t = [ start_ARG start_ROW start_CELL 0 , sitalic_t , sitalic_t - 1 end_CELL end_ROW end_ARG ], as assumed by our theory. Moreover, we focus on fully-observable linear tasks. Figure 4: Evidence for mesa-optimization in linear self-attention networks. (A) As training proceeds, the test loss of a single layer of linear self-attention (Linear-SA-1, green lines) converges to the loss achieved by 1-step gradient descent (Proposition-1, gray line) with optimized learning rate and initial parameters. A single mesa-layer (red lines) strongly outperforms a single linear self-attention layer, consistent with the fact that it yields recursively the optimal (least-squares) solution at every time step. (B) Same analysis, now for a 6-layer linear self-attention model. The increase in the number of attention layers reduces the gap towards the mesa-layer. The test loss of this model converges to that of the CompressedAlg-6 expression (black line), which comprises a small fraction (0.5%) of parameters of the original model, reflecting the highly-structured parameters obtained after training. (C) At convergence, trained models exhibit the same in-context learning performance (measured as the loss as a function of sequence length) as 1 step of gradient descent (dashed line). (D) Similarly for 6-layer models, which can be almost perfectly described by the multi-layer mesa-optimizer of Proposition 2 (dashed line). (E) Linear probing of next-token targets st+1subscript1s_t+1sitalic_t + 1 from the internal Transformer activations improves with depth and context length, consistent with mesa-optimization for next-token prediction. (F) Linear probing of preconditioned inputs (St−1St−1⊤+1/λI)−1stsuperscriptsubscript1superscriptsubscript1top11subscript(S_t-1S_t-1 +1/λ I)^-1s_t( Sitalic_t - 1 Sitalic_t - 1⊤ + 1 / λ I )- 1 sitalic_t improves with depth and context length, consistent with the mesa-optimizer of Proposition 2. The results for single-layer networks are strikingly clear. After next-token prediction training, these networks implement the one-step gradient descent algorithm of Proposition 1 in a near-exact fashion. This can be seen by visual inspection, Fig. 3, or quantitatively by comparing the loss reached by the trained layer with that of a linear autoregressive model learned through one step of gradient descent, cf. Fig. 4A-C. We find that we can perfectly fit the outputs of our trained layer when using all degrees of freedom of our theory, including not only a learned learning rate η, but also a learned set of initial weights Φ0subscriptΦ0 _0Φ0. Next-token prediction therefore installs in the Transformer an in-context variant of the model-agnostic metalearning algorithm due to Finn et al. [49]. Deep linear attention networks correspond to high-degree polynomial functions with a large number of terms. Despite their complexity, for such deep networks training once again leads to highly-structured sparse model parameters θ; we provide visual examples in the SI Appendix. This allows us to construct an expression (CompressedAlg-d, where d denotes model depth) comprising only 16 parameters (instead of 3200) per layer head. We find that this compressed, albeit convoluted, expression can describe a trained deep linear Transformer. In particular, it allows interpolating between actual Transformer and CompressedAlg-d weights (or Proposition 1, for the single-layer case) in an almost lossless fashion, cf. Figure 4C. Further details can be found in Materials and Methods. While the CompressedAlg-d expression explains a trained deep linear self-attention model with a small number of free parameters, it is difficult to interpret it from the lens of mesa-optimization and connect it exactly to the theoretical construction of Proposition 2. We therefore resort to a linear probing analysis [39] to look for signatures of our hypothesized mesa-optimization algorithms. Based on Propositions 1 and 2 we design (i) target probes measuring optimization progress, regressing k-th layer representations et(k)superscriptsubscripte_t^(k)eitalic_t( k ) against the next token st+1subscript1s_t+1sitalic_t + 1 to be predicted, where we could expect multiple steps of gradient descent gradually approaching the target; and (i) preconditioning probes regressing against preconditioned inputs (St−1St−1⊤+1/λI)−1stsuperscriptsubscript1superscriptsubscript1top11subscript(S_t-1S_t-1 +1/λ I)^-1s_t( Sitalic_t - 1 Sitalic_t - 1⊤ + 1 / λ I )- 1 sitalic_t , cf. Materials and Methods. As shown in Fig. 4D-E we see that both probes succeed, with linear decoding performance increasing with sequence length and network depth. Base-optimization has therefore discovered a hybrid algorithm that descends over layers the mesa-objective Lt(Φ)subscriptΦL_t( )Litalic_t ( Φ ) while simultaneously improving the condition number of the mesa-optimization problem. This leads to a fast descent of the mesa-objective. Examining next-token prediction error, we find that it decreases quickly with depth, cf. Figure 4C, with a 6-layer model coming close to but still not matching a single mesa-layer. The high performance of the mesa-layer in this setup can be explained by the fact that it yields the optimal (least-squares) predictor provided that the correct query, key and value inputs are fed to it. Moreover, prediction error decreases monotonically with sequence length both for the mesa-layer as well as for multi-layer linear Transformers. This improvement with context size matches the operational definition of in-context learning proposed by Kaplan et al. [50]; in this sense, the models are strong in-context learners, behaving similarly to regularized least-squares. Notably, we see in Fig. 4 that performance-wise a deep model with k linear attention layers can be almost perfectly explained by k steps of the multi-layer mesa-optimizer described in Proposition 2, with appropriately tuned hyperparameters (cf. Materials and Methods). Importantly, these hyperparameters are tuned for maximal k-step performance and not to reproduce Transformer behavior. This is one additional point of evidence that our theoretical mesa-optimizers describe the computations performed by Transformers trained by next-token prediction error minimization. Trained softmax self-attention layers behave like linear attention We return to standard Transformer models, which feature MLPs, LayerNorm and softmax self-attention layers. We train multi-layer versions of such networks on fully-observed linear tasks, under a standard tokenization scheme (et=stsubscriptsubscripte_t=s_teitalic_t = sitalic_t). Recalling that our theoretical mesa-optimizers (Propositions 1 and 2) rely on linear self-attention operations, we now ask whether base-optimization renders the softmax attention nonlinearity in an approximately linear regime, when driven by sequences such as those seen during training. Figure 5: Linearization analysis of softmax Transformers. (A) The test loss achieved by a linearized Transformer, where one attention layer at a given depth d (intensity color-coded) is linearized, normalized relative to reference model loss. As the input dimension nssubscriptn_snitalic_s grows, the linear approximation improves for all layers except for the first. The highly-nonlinear behavior exhibited by this layer is consistent with its special role in implementing a token binding mechanism (Figs. 1 and 5). (B) The test loss of an autoregressive linear model learned by regularized least-squares (LSQ, yellow line), the algorithm we hypothesize that a trained Transformer implements, does not suffer from the curse of dimensionality, whereas a generic interpolation algorithm (red line) that can be implemented in softmax attention layers does. In Fig. 5A, we analyze the test set loss achieved by a Transformer after replacing a softmax self-attention layer by its linear counterpart at a given depth, keeping the architecture otherwise intact. We obtain this control model through a process known as distillation [51]: we first record the outputs produced by the to-be replaced softmax attention layer, when the Transformer is applied to a set of training sequences, and then train a linear attention layer to reproduce these outputs by squared error minimization. As we observe in Fig. 5, for sufficiently large input dimension nssubscriptn_snitalic_s, from the second layer onwards the linear attention models behave as their reference counterparts to a very good approximation. We further observe that the first attention layer behaves in an entirely different, nonlinear manner. This is consistent with the fact that softmax self-attention can implement near-exact token copying [4], as required by our token binding mechanism (cf. Figs. 1 and 2). On induction heads The low linearization error achieved at high enough data dimension seen in Fig. 5 is at odds with previous theories explaining in-context learning as best-match (or nearest neighbor) pattern retrieval, which rely on the softmax nonlinearity [4, 5]. To better understand this phenomenon, let us compare the scaling behavior of two competing mechanistic explanations for in-context learning in Transformers, as we let the input dimension nssubscriptn_snitalic_s grow: the theory studied here, where a linear model is learned by regularized least-squares, and nonparametric regression under a softmax kernel. The latter is in fact the algorithm implemented by the full (two-layer) induction head mechanism [4, 5]. While we have seen previously that the first token-binding layer of an induction head circuit is precisely what Propositions 1 and 2 require, the subsequent layers differ in the two theories, as we briefly review next. In basic terms, an induction head predicts the next token by first retrieving the most similar past inputs, and then outputting a similarity-weighted combination of the tokens that appeared afterwards. This yields the next-token prediction s^t+1n=∑t′=1t−1st′+1softmax(βst′⊤st)superscriptsubscript^1nnsuperscriptsubscriptsuperscript′11subscriptsuperscript′1softmaxsuperscriptsubscriptsuperscript′topsubscript s_t+1^n= _t =1^t-1s_t +1% softmax(β\,s_t s_t)over start_ARG s end_ARGt + 1n = ∑t′ = 1t - 1 sitalic_t′ + 1 softmax ( β sitalic_t′⊤ sitalic_t ). Unlike least-squares mesa-optimizers, this method operates on the highly nonlinear regime of the softmax attention function, with the scalar β∈ℝ+subscriptℝβ _+β ∈ blackboard_R+ set large enough so as to approximate single-pattern retrieval (β→∞→β→∞β → ∞). Thus, with regards to the linearity of the attention function, an induction head and the mesa-optimizers studied here sit on two opposite extremes. The theory of nonparametric regression has sought to characterize such interpolants, revealing that generalization error scales in general exponentially with input dimension [52]. By contrast, it can be shown analytically in the simpler non-autoregressive case that the generalization error is independent of input dimension for optimally-regularized linear regression [53, 54], assuming that the task difficulty (measured as the context size per dimension T/nssubscriptT/n_sT / nitalic_s) is conserved, which is the regime we study here. These theoretical considerations are reflected in the experiments with fully-observed linear dynamics reported in Fig. 5B, where we report the scaling of cumulative next-token prediction mean-squared error loss for softmax kernel regression with optimally-tuned β (per dimension) against an autoregressive linear model learned by optimally-regularized least-squares (LSQ). We see that next-token prediction performance is always best and only weakly depends on nssubscriptn_snitalic_s for the latter, whereas it degrades for the former. The findings presented in Fig. 5B highlight the merits of performing proper latent variable inference under the correct generative model, over applying a generic interpolation algorithm. This is the curse of dimensionality [55], here unveiled at the level of in-context learning. One strategy to deal with this problem is to embed the data in an appropriate learned space before applying a nearest-neighbor-type method [56]. For the synthetic autoregressive tasks considered in this paper, the curse of dimensionality can be defeated if base-optimization discovers the multi-layer mesa-optimizer of Proposition 2. Below, we provide further evidence that this actually occurs in trained Transformers. Figure 6: Evidence for mesa-optimization in standard (softmax) Transformers. (A) Linear probes decode next-token target st+1subscript1s_t+1sitalic_t + 1 from internal Transformer activations, with decoding performance improving with depth (intensity color-coded) and context length, consistent with gradual optimization of an internal next-token prediction model. (B) Likewise for preconditioned input (St−1St−1⊤+1/λI)−1stsuperscriptsubscript1superscriptsubscript1top11subscript(S_t-1S_t-1 +1/λ I)^-1s_t( Sitalic_t - 1 Sitalic_t - 1⊤ + 1 / λ I )- 1 sitalic_t probing, consistent with the mesa-optimizer of Proposition 2. (C) Next-token prediction error of a 3-layer and a 7-layer Transformer (light and dark blue lines) decrease with context length in almost exactly the same way as 3 or respectively 7 steps of Proposition 2 (light and dark dashed yellow lines), with hyperparameters of the latter set for best performance, not to match Transformer behavior. Mesa-optimization theory describes complete Transformers We continue studying complete Transformers, repeating the analyses carried out for deep linear attention models paired with special input tokens, as required by Propositions 1 and 2. Moreover, we examine all three task types — linear systems with either full or partial observability, as well as nonlinear systems — now always using conventional input token formatting (et=stsubscriptsubscripte_t=s_teitalic_t = sitalic_t). Figure 7: Comparison of the next-token prediction error of 7-layer softmax Transformers (blue lines) and 2-layer softmax-mesa Transformers (red lines) on three families of tasks: fully-observed linear systems (A), partially-observed linear systems (B), and nonlinear systems (C). To validate the mesa-optimization theory developed here, we also report the performance achieved after applying 7 steps of the mesa-optimizer of Proposition 2 to learn the parameters of a linear model (Proposition-2-linear; yellow lines). For partially-observed and nonlinear tasks, we further report the loss achieved when the Proposition 2 is used to train a linear model applied to the groundtruth feature transformation, given by an optimal number of concatenated past tokens to resolve partial observability, or the MLP∗(st)superscriptMLPsubscriptMLP^*(s_t)MLP∗ ( sitalic_t ) used by the nonlinear sequence generator, respectively (Proposition-2-nonlinear; light blue lines). These two control models accurately describe the behavior of actual trained standard Transformers. Moreover and also in accordance with the theory developed here, the hybrid-mesa architecture serves as a strong baseline for all three tasks. In short, our main findings on simplified linear attention-only models translate to standard Transformers. We have already seen in Fig. 2 that these models learn appropriate MLP basis functions when faced with nonlinear tasks, and that they construct internal training sets by binding tokens together. Repeating our probing analysis, we now confirm that subsequent layers execute an algorithm that simultaneously improves next-token predictions and mesa-optimization conditioning (Fig. 6), as it was the case for linear attention-only models. In terms of next-token prediction performance, we see that k steps of Proposition 2 can essentially describe the performance of k-attention-layer Transformers trained on all three task types considered here (Figs. 6C and 7), once again in line with our previous findings on simplified linear attention-only models. Moreover, we find that a hybrid two-attention-layer architecture, stacking one mesa-layer after a standard softmax attention layer, is the strongest of all models considered here despite its low parameter count and shallow depth. This hybrid architecture design is directly inspired by our mesa-optimization theory, leveraging the fact that softmax attention layers can easily implement a token binding operation, and that mesa-layers implement efficient in-context least-squares solvers. The fact that a fixed-depth, 2-layer softmax-mesa Transformer provides a performance upper bound approached as the depth of standard softmax Transformers increases provides additional evidence that such models are well described by the mesa-optimization theory developed here. Autoregressive Transformers are few-shot learners Brown et al. [2] established in-context learning in large autoregressive language models, showing that LLMs can solve new tasks when provided with a small number of (‘few-shot’) labeled examples in-context. Here, we investigate whether a similar phenomenon occurs in the autoregressive models studied thus far. To that end, we take the Transformers analyzed above and present them post-training with in-context linear regression tasks (cf. Materials and Methods). Despite the fact that the models were trained to predict the evolution of linear dynamical systems, and not to perform supervised in-context learning, we observe that regression loss decreases with sequence length (Fig. 8A). The models can thus use additional in-context training data to improve predictions. Our results therefore show that training Transformers on simple autoregressive tasks can give rise to in-context few-shot learning, complementing previous evidence for this phenomenon in large-scale models [2]. As a control, we report the performance reached by autoregressive least-squares on the same dataset, which yields a similar error curve. Figure 8: Autoregressive Transformers display in-context few-shot learning capabilities. After training a standard 7-layer Transformer on autoregressive sequence prediction problems, we measure its ability to solve linear regression tasks in-context, without further parameter fine-tuning. The task training set is presented to the model in sequence, with each token corresponding either to an input or to its corresponding label. A final test input is provided and the loss is measured after completing the sequence using the autoregressive Transformer. (A) The mesa-optimizers installed by autoregressive pretraining can be leveraged off-the-shelf to solve in-context supervised regression tasks, but yield sub-optimal regression performance (lightest blue lines). In-context learning performance can be improved following the standard strategies of prompt (TF+EOS, light blue lines) and prefix fine-tuning (TF+EOS+P, dark blue lines). For comparison, we provide the loss achieved by an autoregressive linear model learned by least-squares (LSQ, yellow lines) (B) Same analysis, now presenting two tasks in a row. The autoregressive models develop some in-context continual learning capabilities. We note that the autoregressive in-context learning algorithm uncovered above is sub-optimal with respect to linear regression. Close inspection reveals that the origin of its sub-optimality lies in the learned token binding mechanism (analyzed in Fig. 2) that binds every consecutive pair of tokens, in an overlapping fashion. In a training set of size n, this introduces n−11n-1n - 1 spurious associations, where a regression target yisubscripty_iyitalic_i is incorrectly associated to the next independent input xi+1subscript1x_i+1xitalic_i + 1, whereas only inputs xisubscriptx_ixitalic_i should be associated with their respective targets yisubscripty_iyitalic_i. Interestingly, this gives rise not only to convergence to a sub-optimal solution, but also to the early ascent phenomenon present in LLMs [57]. This refers to in-context learning performance first undergoing a brief but statistically significant period of loss increase, before actual improvements start taking place. Note that early ascent is not specific to autoregressive Transformers; we can observe it on the autoregressive linear least-squares control model as well (LSQ; Fig. 8A). We therefore identify one cause for this poorly-understood phenomenon, tracing it back to the internal mechanics of mesa-optimization for next-token prediction. To mitigate this effect, we investigate a common approach, known as prompt-tuning, which can lead to significant performance improvements when applied to large language models [58, 59]. Concretely, we fine-tune a single token, which we refer to as the EOS token, on the linear regression objective. When presenting data sequentially as [x1,y1,EOS,x2,y2,…,EOS,xN,yN]subscript1subscript1EOSsubscript2subscript2…EOSsubscriptsubscript[x_1,y_1, EOS,x_2,y_2,…, EOS,x_N,y_N][ x1 , y1 , EOS , x2 , y2 , … , EOS , xitalic_N , yitalic_N ], where xisubscriptx_ixitalic_i and yisubscripty_iyitalic_i resp. denote regression inputs and labels, we observe a considerable performance improvement after prompt-tuning, see Fig. 8A. Furthermore, to instruct the model to perform few-shot tasks, we learn a single prefix-prompt P which we append at the beginning of a sequence with EOS tokens. This appears to further improve the few-shot performance for early data-pairs. Additional experimental details can be found in Materials and Methods. Lastly, we demonstrate the capability of autoregressive Transformers to learn multiple tasks in a row. We study the minimal setup where the model has to learn two tasks, generated from two distinct groundtruth linear models, resulting in a sequence of data of the form [x11,y11,…,xN1,yN1,x12,y12,…,xN2,yN2]subscriptsuperscript11subscriptsuperscript11…subscriptsuperscript1subscriptsuperscript1subscriptsuperscript21subscriptsuperscript21…subscriptsuperscript2subscriptsuperscript2[x^1_1,y^1_1,…,x^1_N,y^1_N,x^2_1,y^2_1,…,x^2% _N,y^2_N][ x11 , y11 , … , x1italic_N , y1italic_N , x21 , y21 , … , x2italic_N , y2italic_N ]. In Fig. 8B, we see that the trained Transformer can learn a second task in-context, even though it was never explicitly trained to solve such sequential learning problems. This behavior is expected, given the autoregressive linear model optimizer uncovered in the preceding sections. This finding suggests further characterizing the continual in-context learning abilities of Transformers, as Irie et al. [60] have begun to investigate. Discussion We’ve presented evidence that Transformer models develop gradient-based learning algorithms when trained on sequence prediction tasks under a standard autoregressive objective. Moreover, we have seen that the resulting prediction algorithms can be repurposed without retraining to solve supervised in-context learning tasks, capturing LLM phenomena such as early ascent or the effectiveness of prompt fine-tuning techniques in improving in-context learning. The fact that we were able to reproduce these findings in our synthetic data setup is surprising, given that the state-space sequence generators studied here are far from language models—most notably, they operate in continuous space, and lack deep hierarchical structure. Our results serve as a case-in-point that autoregressive Transformers can exhibit in-context learning capabilities outside language modeling, and point towards the universality of certain properties of these acquired learning algorithms. There has been significant debate on whether LLMs, and learned next-token predictors more generally, are limited to memorizing correlations present in the training set [having been called stochastic parrots; 61]. This view has been challenged by a number of studies, analyzing for example autoregressive models trained to predict legal moves in board games [62, 63, 64]. In a purely observational manner and without any a priori game knowledge, self-supervised next-token prediction models learn latent representations of the board state and track the moves of each opponent. Our findings provide complementary evidence that next-token prediction objectives can lead to the discovery of algorithms that correctly infer the hidden state of the world: the in-context learning algorithm we identified can be precisely cast as maximum a posteriori inference under the correct Bayesian prior and likelihood function. Moreover, the multi-layer mesa-optimizers installed by next-token prediction objectives are highly efficient (i.e., achieve significant loss reduction in only a few layers) thanks to precise tuning of their hyperparameters to the sequence generative model. The idea that a Transformer generates its predictions by solving internal optimization problems has ties to many different lines of thought in machine learning. One closely related line of work explores the concept of a declarative node: a differentiable layer whose output is defined implicitly as the solution of an optimization problem [65, 66, 67]. We note that subsuming an entire chain of layers by a single declarative node is not only potentially more efficient, but also more interpretable. The mesa-layer is an example of such a node, adding to recent studies exploring the advantages of including declarative nodes within attention-based models [68, 69, 38, 70]. Our analysis of trained models revealed that stochastic gradient descent in effect discovered a declarative node, preferring to pick an optimization algorithm among alternative solutions in the configuration space of autoregressive Transformers. This can be partly explained by the fact that recursive least-squares can be leveraged to solve the tasks considered here, and by the fact that Transformers can efficiently approximate this algorithm through Proposition 2. Our results complement the theoretical work of Hubinger et al. [27], by providing a concrete toy model where mesa-optimization occurs. However, more work is still needed to characterize this phenomenon outside the controlled experimental setting considered in this paper. The mesa-layer developed here can also be seen as a locally optimal fast weight programmer [29]. In his seminal work [29], Schimidhuber proposed to dynamically reprogram the weights of a feedforward neural network using a Hebbian rule. As pointed out by Schlag et al. [33], this is precisely what a linear self-attention layer does: it generates predictions using an effective weight matrix that is learned by taking outer products of values and keys, a Hebbian associative rule [71]. In this work, we instead frame fast weight learning as an optimization problem that is efficiently solved at every moment in time by the mesa-layer. This form of optimal fast learning is strictly superior to Hebb’s rule, both in terms of generalization and memory capacity [72]. The mesa-layer is therefore also closely related to the Delta-Net [33], which uses the delta rule [73] for fast weight learning. Unlike the mesa-layer, which is optimal at every step, the delta rule requires multiple steps to converge, though it is cheaper to implement. The strong performance of the mesa-layer observed here on synthetic tasks suggests investigating its application to natural data at larger scales, for which we provide preliminary language modeling results in Appendix G. Our work has an unexpected connection to research on local learning rules, a question of great interest in theoretical neuroscience [74]. Decomposing a global learning problem into a series of local quadratic optimization problems, like the objective functions of the mesa-optimizers studied here, is at the heart of the target propagation [75], predictive coding [76] and control-based [77] theories of learning in the brain. Moreover, previous studies have proposed greedy layerwise learning algorithms that do not require global error information [78, 79, 80, 81, 82]. Much in the same vein but now on the fast timescale of inference, the mesa-optimizers uncovered here implement greedy, local learning algorithms which only use bottom-up information. We conclude by discussing our findings in the light of predictive processing theories of intelligence, where learning predictive models is presumed to underwrite intelligent behavior [83, 84]. A number of influential predictive processing models have adopted a Bayesian approach, starting from the assumption that the world obeys a certain generative model, and then hand-designing approximate inference algorithms for the assumed model [85, 86, 87, 88, 89]. Here, directly inspired by LLMs, we took a powerful neural sequence model and trained it to maximize the likelihood of upcoming inputs given the past, without making explicit probabilistic assumptions about the latent structure of the world. The network was nonetheless able to discover the correct underlying model of the data, and appropriately exploit its knowledge to generate predictions. This finding provides further evidence that direct maximization of future prediction performance by simple gradient-based methods — as opposed to hierarchical probabilistic methods, and the typically intractable inference problems that they bring — might be sufficient to build the predictive processing backbone of an intelligent system. Methods Transformer architectures The Transformer models studied here follow the widely-used GPT-2 specification [90]. This architecture comprises multiple identical blocks, with one block consisting of the softmax self-attention layer defined in equation 2 followed by a one-hidden layer MLP. The inputs of both layers are normalized: et←et+Δetsa(LN(et))←subscriptsubscriptΔsuperscriptsubscriptsaLNsubscript e_t← e_t+ e_t^sa(LN(e_t))eitalic_t ← eitalic_t + Δ eitalic_tsa ( LN ( eitalic_t ) ) et←et+Δetmlp(LN(et)),←subscriptsubscriptΔsuperscriptsubscriptmlpLNsubscript e_t← e_t+ e_t^mlp(LN(e_t)),eitalic_t ← eitalic_t + Δ eitalic_tmlp ( LN ( eitalic_t ) ) , where LN(⋅)LN⋅LN(·)LN ( ⋅ ) denotes the LayerNorm operation [28], and Δetmlp(et)=W2GELU(W1e)Δsuperscriptsubscriptmlpsubscriptsubscript2GELUsubscript1 e_t^mlp(e_t)=W_2GELU(W_1e)Δ eitalic_tmlp ( eitalic_t ) = W2 GELU ( W1 e ) with GELU(e):=eG(e)assignGELUGGELU(e):=e\,G(e)GELU ( e ) := e G ( e ), and G(⋅)G⋅G(·)G ( ⋅ ) the Gaussian cumulative distribution function, applied elementwise [91]. We set W1subscript1W_1W1 such that W1esubscript1W_1eW1 e has four times more neurons compared to e, which itself is four times larger than s. Additional architectural details are provided in the SI Appendix. The predictions are read-out directly from the first dimensions of last-layer token outputs, and we add a positional encoding to every input following the original method of Vaswani et al. [3]. If not explicitly stated otherwise, for models that incorporate the mesa-layer, we leave the architecture configuration unchanged but replace ΔetsaΔsuperscriptsubscriptsa e_t^saΔ eitalic_tsa with ΔetmesaΔsuperscriptsubscriptmesa e_t^mesaΔ eitalic_tmesa in the appropriate places. A hybrid-mesa Transformer features two self-attention layers, with the first being standard softmax self-attention, and the second a mesa-layer. Base optimizers All models are trained by online autoregressive loss (Eq. 1) minimization using the AdamW [92] optimizer with learning rate warm-up followed by a cosine decay. Statistics All numerical results are averaged across five random seeds, with shaded areas representing standard deviation. Synthetic sequence generators The tasks considered in this paper involve predicting the next observation st+1∈ℝnssubscript1superscriptℝsubscripts_t+1 ^n_ssitalic_t + 1 ∈ blackboard_Rnitalic_s from a sequences of past observations (st′)t′=1tsuperscriptsubscriptsubscriptsuperscript′1(s_t )_t =1^t( sitalic_t′ )t′ = 1t generated by discrete-time dynamical systems, whose state is denoted by ht∈ℝnhsubscriptℎsuperscriptℝsubscriptℎh_t ^n_hhitalic_t ∈ blackboard_Rnitalic_h. Starting from a random initial state h1∼(0,1)similar-tosubscriptℎ101h_1 (0,1)h1 ∼ N ( 0 , 1 ), we generate observations by letting a groundtruth system evolve according to ht+1subscriptℎ1 h_t+1hitalic_t + 1 =W∗f∗(ht)+ϵh,tabsentsuperscriptsuperscriptsubscriptℎsubscriptitalic-ϵℎ =W^*f^*(h_t)+ _h,t= W∗ f∗ ( hitalic_t ) + ϵitalic_h , t stsubscript s_tsitalic_t =C∗ht+ϵs,t,absentsuperscriptsubscriptℎsubscriptitalic-ϵ =C^*h_t+ _s,t,= C∗ hitalic_t + ϵitalic_s , t , where ϵh,t∼(0,σh2)similar-tosubscriptitalic-ϵℎ0subscriptsuperscript2ℎ _h,t (0,σ^2_h)ϵitalic_h , t ∼ N ( 0 , σ2italic_h ) is a noise input and ϵs,t∼(0,σs2)similar-tosubscriptitalic-ϵ0subscriptsuperscript2 _s,t (0,σ^2_s)ϵitalic_s , t ∼ N ( 0 , σ2italic_s ) is an observation noise term. We set the transition matrix W∗∈ℝnh×nhsuperscriptsuperscriptℝsubscriptℎsubscriptℎW^* ^n_h× n_hW∗ ∈ blackboard_Rnitalic_h × nitalic_h to a random orthogonal matrix, and we consider both fully-observed (C∗=IsuperscriptC^*=IC∗ = I) and partially-observed tasks, where ns<nhsubscriptsubscriptℎn_s<n_hnitalic_s < nitalic_h, and Cij∗∼(0,0.5)similar-tosubscriptsuperscript00.5C^*_ij (0,0.5)C∗italic_i j ∼ N ( 0 , 0.5 ). Our tasks can be further categorized as linear (by setting f∗superscriptf^*f∗ to the identity function) or nonlinear. For the nonlinear case, we always take C∗=IsuperscriptC^*=IC∗ = I and introduce a nonlinear transformation MLP∗(⋅)superscriptMLP⋅MLP^*(·)MLP∗ ( ⋅ ) in state-space, ht+1=W∗MLP∗(ht)+ϵtsubscriptℎ1superscriptsuperscriptMLPsubscriptℎsubscriptitalic-ϵh_t+1=W^*\,MLP^*(h_t)+ _thitalic_t + 1 = W∗ MLP∗ ( hitalic_t ) + ϵitalic_t. The MLP computation is described by MLP∗(ht)=B⋅GELU(A⋅ht)superscriptMLPsubscriptℎ⋅GELU⋅subscriptℎMLP^*(h_t)=B·GELU(A· h_t)MLP∗ ( hitalic_t ) = B ⋅ GELU ( A ⋅ hitalic_t ), where A∈ℝnm×nh∼(0,1.1)superscriptℝsubscriptsubscriptℎsimilar-to01.1A ^n_m× n_h (0,1.1)A ∈ blackboard_Rnitalic_m × nitalic_h ∼ N ( 0 , 1.1 ) and B∈ℝnh×nm∼(0,1.1)superscriptℝsubscriptℎsubscriptsimilar-to01.1B ^n_h× n_m (0,1.1)B ∈ blackboard_Rnitalic_h × nitalic_m ∼ N ( 0 , 1.1 ). Importantly, we draw new transition and readout matrices W∗superscriptW^*W∗ and C∗superscriptC^*C∗ for every sequence. These parameters are analogous to task-specific variables in multi-task learning [93], adapted to the problem of unsupervised sequence modeling. We introduce sequence-specific variables to reflect the high degree of variability that is observed in large datasets of real-world data, such as in LLM training corpora [90]. Under such a generative model, rote memorization solutions are excluded from the global minimizers of Eq. 1: a trained Transformer cannot achieve minimal loss by memorizing a single set of W∗superscriptW^*W∗ and C∗superscriptC^*C∗ in its parameters θ. Instead, it must deal with inherent uncertainty in every sequence, and infer in-context a set of latent variables whose values vary from sequence to sequence. The main goal of this paper is to characterize this in-context inference process. Proof of Proposition 1 Starting with the token construction et=[Φ0st,st,st−1]subscriptmatrixsubscriptΦ0subscriptsubscriptsubscript1e_t= bmatrix _0s_t,s_t,s_t-1 bmatrixeitalic_t = [ start_ARG start_ROW start_CELL Φ0 sitalic_t , sitalic_t , sitalic_t - 1 end_CELL end_ROW end_ARG ], we now show that the parameter construction of Proposition 1 induces the following gradient-based change to all tokens in parallel et←[(Φ0−η∇Lt(Φ0))st,st,st−1]←subscriptmatrixsubscriptΦ0∇subscriptsubscriptΦ0subscriptsubscriptsubscript1e_t← bmatrix( _0-η∇ L_t( _0))s_t,s_t% ,s_t-1 bmatrixeitalic_t ← [ start_ARG start_ROW start_CELL ( Φ0 - η ∇ Litalic_t ( Φ0 ) ) sitalic_t , sitalic_t , sitalic_t - 1 end_CELL end_ROW end_ARG ]. When plugging in the proposed weights into a linear self-attention layer head we obtain [ΔΦ^tst00]matrixΔsubscript^Φsubscript00 bmatrix _ts_t\\ 0\\ 0\\ bmatrix[ start_ARG start_ROW start_CELL Δ over start_ARG Φ end_ARGt sitalic_t end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ] =PWV∑t′=1t[0st′st′−1][0st′st′−1]⊺[0000000Is0][0stst−1]absentsubscriptsuperscriptsubscriptsuperscript′1matrix0subscriptsuperscript′subscriptsuperscript′1superscriptmatrix0subscriptsuperscript′subscriptsuperscript′1⊺matrix0000000subscript0matrix0subscriptsubscript1 =PW_V _t =1^t bmatrix0\\ s_t \\ s_t -1 bmatrix bmatrix0\\ s_t \\ s_t -1 bmatrix bmatrix0&0&0\\ 0&0&0\\ 0&I_s&0\\ bmatrix bmatrix0\\ s_t\\ s_t-1\\ bmatrix= P Witalic_V ∑t′ = 1t [ start_ARG start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL sitalic_t′ end_CELL end_ROW start_ROW start_CELL sitalic_t′ - 1 end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL sitalic_t′ end_CELL end_ROW start_ROW start_CELL sitalic_t′ - 1 end_CELL end_ROW end_ARG ]⊺ [ start_ARG start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL Iitalic_s end_CELL start_CELL 0 end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL sitalic_t end_CELL end_ROW start_ROW start_CELL sitalic_t - 1 end_CELL end_ROW end_ARG ] =PWV∑t′=1t[0st′st′−1⊤stst′−1st′−1⊤st]absentsubscriptsuperscriptsubscriptsuperscript′1matrix0subscriptsuperscript′subscriptsuperscript′1topsubscriptsubscriptsuperscript′1superscriptsubscriptsuperscript′1topsubscript =PW_V _t =1^t bmatrix0\\ s_t s_t -1 s_t\\ s_t -1s_t -1 s_t bmatrix= P Witalic_V ∑t′ = 1t [ start_ARG start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL sitalic_t′ sitalic_t′ - 1⊤ sitalic_t end_CELL end_ROW start_ROW start_CELL sitalic_t′ - 1 sitalic_t′ - 1⊤ sitalic_t end_CELL end_ROW end_ARG ] =[0ηIs−ηΦ0000000]∑t′=1t[0st′st′−1⊤stst′−1st′−1⊤st]absentmatrix0subscriptsubscriptΦ0000000superscriptsubscriptsuperscript′1matrix0subscriptsuperscript′subscriptsuperscript′1topsubscriptsubscriptsuperscript′1superscriptsubscriptsuperscript′1topsubscript = bmatrix0&η I_s&-η _0\\ 0&0&0\\ 0&0&0\\ bmatrix _t =1^t bmatrix0\\ s_t s_t -1 s_t\\ s_t -1s_t -1 s_t bmatrix= [ start_ARG start_ROW start_CELL 0 end_CELL start_CELL η Iitalic_s end_CELL start_CELL - η Φ0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL 0 end_CELL end_ROW end_ARG ] ∑t′ = 1t [ start_ARG start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL sitalic_t′ sitalic_t′ - 1⊤ sitalic_t end_CELL end_ROW start_ROW start_CELL sitalic_t′ - 1 sitalic_t′ - 1⊤ sitalic_t end_CELL end_ROW end_ARG ] =−η∑t′=1t[(Φ0st′−1−st′)st′−1⊤st00]=[−η∇Lt(Φ0)00].absentsuperscriptsubscriptsuperscript′1matrixsubscriptΦ0subscriptsuperscript′1subscriptsuperscript′subscriptsuperscript′1topsubscript00matrix∇subscriptsubscriptΦ000 =-η _t =1^t bmatrix( _0s_t % -1-s_t )s_t -1 s_t\\ 0\\ 0\\ bmatrix= bmatrix-η∇ L_t( _0)\\ 0\\ 0\\ bmatrix.= - η ∑t′ = 1t [ start_ARG start_ROW start_CELL ( Φ0 sitalic_t′ - 1 - sitalic_t′ ) sitalic_t′ - 1⊤ sitalic_t end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ] = [ start_ARG start_ROW start_CELL - η ∇ Litalic_t ( Φ0 ) end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ] . Adding the above result to the layer input, an operation that is supported in Transformers by a residual connection or a second attention head, yields the desired output. Full statement and proof of Proposition 2 We present here the linear self-attention parameter construction which supports the claim of Proposition 2. First, we restate the goal of the autoregressive Transformer, namely, to solve a regularized least-squares problem: minΦ∑t′=1t−112‖st′+1−Φst′‖2+12λ‖Φ‖F2,subscriptΦsuperscriptsubscriptsuperscript′1112superscriptnormsubscriptsuperscript′1Φsubscriptsuperscript′212superscriptsubscriptnormΦF2 _ _t =1^t-1 12\|s_t +1- s_t^% \|^2+ 12λ|| ||_F^2,minroman_Φ ∑t′ = 1t - 1 divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ sitalic_t′ + 1 - Φ sitalic_t′ ∥2 + divide start_ARG 1 end_ARG start_ARG 2 λ end_ARG | | Φ | |F2 , for all time steps simultaneously. This amounts to computing a (recursive) least squares solution, where time-shifted (by one) sequence elements play the role of inputs and desired outputs in a dataset, with inputs St−1subscript1S_t-1Sitalic_t - 1, targets StsubscriptS_tSitalic_t, and test input stsubscripts_tsitalic_t. With the limited expressivity of one layer, we have already established that Transformers can, and do, implement a single gradient step on the corresponding regression problems ∑t′=1t−1‖st′+1−Φst′‖2∀tsuperscriptsubscriptsuperscript′11superscriptnormsubscriptsuperscript′1Φsubscriptsuperscript′2for-all _t =1^t-1\|s_t +1- s_t \|^2 ∀ t∑t′ = 1t - 1 ∥ sitalic_t′ + 1 - Φ sitalic_t′ ∥2 ∀ t in parallel both in theory and in practice. The key observation here is that given a preconditioning matrix Ht∗=(St−1St−1⊤+1λI)−1subscriptsuperscriptsuperscriptsubscript1superscriptsubscript1top11H^*_t=(S_t-1S_t-1 + 1λI)^-1H∗italic_t = ( Sitalic_t - 1 Sitalic_t - 1⊤ + divide start_ARG 1 end_ARG start_ARG λ end_ARG I )- 1 which changes the loss to ∑t′=1t−1‖st′+1−ΦHt∗st′‖2superscriptsubscriptsuperscript′11superscriptnormsubscriptsuperscript′1Φsubscriptsuperscriptsubscriptsuperscript′2 _t =1^t-1\|s_t +1- H^*_ts_t \|^2∑t′ = 1t - 1 ∥ sitalic_t′ + 1 - Φ H∗italic_t sitalic_t′ ∥2, a single gradient descent step immediately yields the desired regularized least-squares solution. Based on this simple observation, we provide a theoretical construction that shows how Transformers can approximate (St−1St−1⊤+1λI)−1stsuperscriptsubscript1superscriptsubscript1top11subscript(S_t-1S_t-1 + 1λI)^-1s_t( Sitalic_t - 1 Sitalic_t - 1⊤ + divide start_ARG 1 end_ARG start_ARG λ end_ARG I )- 1 sitalic_t layer by layer in their forward pass, leading to improved single-step gradient descent performance. Note that this is equivalent to iteratively solving the systems of linear equations (St′−1St′−1⊤+1λI)x=st′t′=1tsuperscriptsubscriptsubscriptsuperscript′1superscriptsubscriptsuperscript′1top1subscriptsuperscript′1\(S_t -1S_t -1 + 1λI)x=s_t % \_t =1^t ( Sitalic_t′ - 1 Sitalic_t′ - 1⊤ + divide start_ARG 1 end_ARG start_ARG λ end_ARG I ) x = sitalic_t′ = 1t. Let us now approximate the above expression with a truncated Neumann series: Ht∗stsubscriptsuperscriptsubscript H^*_ts_tH∗italic_t sitalic_t ≈s~tK=1=∑k=0K(I−(St−1St−1⊤+1λI))kstabsentsubscriptsuperscript~1superscriptsubscript0superscriptsubscript1superscriptsubscript1top1subscript ≈ s^K=1_t= _k=0^K(I-(S_t-1S_t-1 % + 1λI))^ks_t≈ over~ start_ARG s end_ARGK = 1t = ∑k = 0K ( I - ( Sitalic_t - 1 Sitalic_t - 1⊤ + divide start_ARG 1 end_ARG start_ARG λ end_ARG I ) )k sitalic_t =∑k=0K((1−1λ)I−St−1St−1⊤)kstabsentsuperscriptsubscript0superscript11subscript1superscriptsubscript1topsubscript = _k=0^K((1- 1λ)I-S_t-1S_t-1 )^k% s_t= ∑k = 0K ( ( 1 - divide start_ARG 1 end_ARG start_ARG λ end_ARG ) I - Sitalic_t - 1 Sitalic_t - 1⊤ )k sitalic_t =s~tK+((1−1λ)I−St−1St−1⊤)s~tK=s~tK+H~t∗s~tKabsentsubscriptsuperscript~11subscript1superscriptsubscript1topsubscriptsuperscript~subscriptsuperscript~superscriptsubscript~subscriptsuperscript~ = s^K_t+((1- 1λ)I-S_t-1S_t-1 )% s^K_t= s^K_t+ H_t^* s^K_t= over~ start_ARG s end_ARGKitalic_t + ( ( 1 - divide start_ARG 1 end_ARG start_ARG λ end_ARG ) I - Sitalic_t - 1 Sitalic_t - 1⊤ ) over~ start_ARG s end_ARGKitalic_t = over~ start_ARG s end_ARGKitalic_t + over~ start_ARG H end_ARGt∗ over~ start_ARG s end_ARGKitalic_t with H~t∗:=((1−1λ)I−St−1St−1⊤)assignsuperscriptsubscript~11subscript1superscriptsubscript1top H_t^*:=((1- 1λ)I-S_t-1S_t-1 )over~ start_ARG H end_ARGt∗ := ( ( 1 - divide start_ARG 1 end_ARG start_ARG λ end_ARG ) I - Sitalic_t - 1 Sitalic_t - 1⊤ ). This corresponds to the Richardson iteration [94] method for solving linear systems iteratively, which can be augmented with a stepwise parameter (or learning rate) αKsubscript _Kαitalic_K and an additional term adding the difference between former approximations, resembling a momentum term. This variant is termed second-order Richardson or Chebyshev [95] iteration, and it can speed up convergence: s~tK+1=s~tK−αKH~t∗s~tK−βK(s~tK−s~tK−1).subscriptsuperscript~1subscriptsuperscript~subscriptsuperscriptsubscript~subscriptsuperscript~subscriptsubscriptsuperscript~subscriptsuperscript~1 s^K+1_t= s^K_t- _K H_t^* s^K% _t- _K( s^K_t- s^K-1_t).over~ start_ARG s end_ARGK + 1t = over~ start_ARG s end_ARGKitalic_t - αitalic_K over~ start_ARG H end_ARGt∗ over~ start_ARG s end_ARGKitalic_t - βitalic_K ( over~ start_ARG s end_ARGKitalic_t - over~ start_ARG s end_ARGK - 1t ) . (9) We now show that a single step of these iteration methods can be mapped to a single layer of linear self-attention, allowing deep Transformers to solve the aforementioned set of linear equations efficiently in parallel. Starting with a token construction similar to the one of Proposition 1, i.e., with aggregate tokens [s~t′K,s~t′K−1,st′−1]matrixsuperscriptsubscript~superscript′subscript~superscript′1subscriptsuperscript′1 bmatrix s_t ^K, s_t ^K-1,s_t^% -1 bmatrix[ start_ARG start_ROW start_CELL over~ start_ARG s end_ARGt′italic_K , over~ start_ARG s end_ARGt′italic_K - 1 , sitalic_t′ - 1 end_CELL end_ROW end_ARG ] with s~t0=stsuperscriptsubscript~0subscript s_t^0=s_tover~ start_ARG s end_ARGt0 = sitalic_t, we can compute s~tK+1superscriptsubscript~1 s_t^K+1over~ start_ARG s end_ARGtitalic_K + 1 with a single causally masked linear self-attention, in parallel for ∀tfor-all∀ t∀ t. Indeed, with Wk⊤Wq=[00−αKIs000000]superscriptsubscripttopsubscriptmatrix00subscriptsubscript000000W_k W_q= bmatrix0&0&- _KI_s\\ 0&0&0\\ 0&0&0\\ bmatrixWitalic_k⊤ Witalic_q = [ start_ARG start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL - αitalic_K Iitalic_s end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL 0 end_CELL end_ROW end_ARG ] and PWv=[000000−αKIs00]subscriptmatrix000000subscriptsubscript00PW_v= bmatrix0&0&0\\ 0&0&0\\ - _KI_s&0&0\\ bmatrixP Witalic_v = [ start_ARG start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL - αitalic_K Iitalic_s end_CELL start_CELL 0 end_CELL start_CELL 0 end_CELL end_ROW end_ARG ] the linear self-attention equation, similar to the derivation above, results in PWv∑t′=1t[s~t′Ks~t′K−1st′−1][s~t′Ks~t′K−1st′−1]⊺Wk⊤Wq[s~tKs~tK−1st−1]=[−αKSt−1St−1⊤s~tK00]subscriptsuperscriptsubscriptsuperscript′1matrixsuperscriptsubscript~superscript′subscript~superscript′1subscriptsuperscript′1superscriptmatrixsuperscriptsubscript~superscript′subscript~superscript′1subscriptsuperscript′1⊺superscriptsubscripttopsubscriptmatrixsuperscriptsubscript~superscriptsubscript~1subscript1matrixsubscriptsubscript1superscriptsubscript1topsuperscriptsubscript~00PW_v _t =1^t bmatrix s_t ^K\\ s_t ^K-1\\ s_t -1 bmatrix bmatrix s_t ^K\\ s_t ^K-1\\ s_t -1 bmatrix W_k W_q % bmatrix s_t^K\\ s_t^K-1\\ s_t-1 bmatrix= bmatrix- _KS_t-1S_t-1 s_% t^K\\ 0\\ 0 bmatrixP Witalic_v ∑t′ = 1t [ start_ARG start_ROW start_CELL over~ start_ARG s end_ARGt′italic_K end_CELL end_ROW start_ROW start_CELL over~ start_ARG s end_ARGt′italic_K - 1 end_CELL end_ROW start_ROW start_CELL sitalic_t′ - 1 end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL over~ start_ARG s end_ARGt′italic_K end_CELL end_ROW start_ROW start_CELL over~ start_ARG s end_ARGt′italic_K - 1 end_CELL end_ROW start_ROW start_CELL sitalic_t′ - 1 end_CELL end_ROW end_ARG ]⊺ Witalic_k⊤ Witalic_q [ start_ARG start_ROW start_CELL over~ start_ARG s end_ARGtitalic_K end_CELL end_ROW start_ROW start_CELL over~ start_ARG s end_ARGtitalic_K - 1 end_CELL end_ROW start_ROW start_CELL sitalic_t - 1 end_CELL end_ROW end_ARG ] = [ start_ARG start_ROW start_CELL - αitalic_K Sitalic_t - 1 Sitalic_t - 1⊤ over~ start_ARG s end_ARGtitalic_K end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ]. Therefore, the matrix-matrix-vector products needed to compute equation 9 can be computed inside a single linear self-attention layer in parallel, for all time steps. The remaining terms in equation 9 are simple scaled additions of s~tK,s~tK+1superscriptsubscript~superscriptsubscript~1 s_t^K, s_t^K+1over~ start_ARG s end_ARGtitalic_K , over~ start_ARG s end_ARGtitalic_K + 1 for which multiple alternative constructions exist. Note that for the construction above to hold, we need to have st−1subscript1s_t-1sitalic_t - 1 available at every layer and push forward s~tKsuperscriptsubscript~ s_t^Kover~ start_ARG s end_ARGtitalic_K such that it can be used to compute (s~tK+1−s~tK)superscriptsubscript~1superscriptsubscript~( s_t^K+1- s_t^K)( over~ start_ARG s end_ARGtitalic_K + 1 - over~ start_ARG s end_ARGtitalic_K ) in the next iteration which again is easy to accomplish within the residual stream. We therefore conclude that deep Transformer models can approximate the solutions of the set of systems of linear equations (St′−1St′−1⊤+1λI)x=stt′=1tsuperscriptsubscriptsubscriptsuperscript′1superscriptsubscriptsuperscript′1top1subscriptsuperscript′1\(S_t -1S_t -1 + 1λI)x=s_t\_t^% =1^t ( Sitalic_t′ - 1 Sitalic_t′ - 1⊤ + divide start_ARG 1 end_ARG start_ARG λ end_ARG I ) x = sitalic_t t′ = 1t efficiently in parallel. This results in a preconditioning of the least-squares problems ∑t′=1t′−1‖st′+1−ΦHt′∗st′‖2t′=1tsuperscriptsubscriptsuperscriptsubscriptsuperscript′1superscript′1superscriptnormsubscriptsuperscript′1Φsubscriptsuperscriptsuperscript′subscriptsuperscript′2superscript′1\ _t =1^t -1\|s_t +1- H^*_% t s_t \|^2\_t =1^t ∑t′ ′ = 1t start_POSTSUPERSCRIPT ′ - 1 end_POSTSUPERSCRIPT ∥ sitalic_t′ ′ + 1 - Φ H∗italic_t′ sitalic_t′ ′ ∥2 t′ = 1t, which can then be solved with a single gradient step, again in parallel and by a single additional linear self-attention layer, built after Proposition 1. Mesa-optimizers solve partially-observed linear tasks We now show that Propositions 1 and 2 can be leveraged to solve next-token prediction problems involving linear latent variable dynamics, as in our experiments with partially-observed linear dynamical systems. We analyze here the deterministic setting, i.e., when no noise is added to the state transitions and observations; for an extension to the stochastic case, see the SI Appendix. We investigate a simple construction where we concatenate the last k observations into a single ‘state’ vector z, and use this state vector in a least-squares problem to estimate the linear transition between zt+1subscript1z_t+1zitalic_t + 1 and ztsubscriptz_tzitalic_t. As zt+1subscript1z_t+1zitalic_t + 1 contains st+1subscript1s_t+1sitalic_t + 1, this state prediction can be used straight-forwardly to predict the next observation. Let us define ztk=[st−k+1⋮st].superscriptsubscriptmatrixsubscript1⋮subscript z_t^k= bmatrixs_t-k+1\\ \\ s_t bmatrix.zitalic_titalic_k = [ start_ARG start_ROW start_CELL sitalic_t - k + 1 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL sitalic_t end_CELL end_ROW end_ARG ] . We first investigate whether the transition between ztksubscriptsuperscriptz^k_tzitalic_kitalic_t and zt+1ksubscriptsuperscript1z^k_t+1zitalic_kitalic_t + 1 is a linear operator. For this, let us define the observation matrix as k=[C∗C∗W∗⋮C∗W∗k−1].subscriptmatrixsuperscriptsuperscriptsuperscript⋮superscriptsuperscriptsuperscript1 _k= bmatrixC^*\\ C^*W^*\\ \\ C^*W^*^k-1 bmatrix.Oitalic_k = [ start_ARG start_ROW start_CELL C∗ end_CELL end_ROW start_ROW start_CELL C∗ W∗ end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL C∗ W∗italic_k - 1 end_CELL end_ROW end_ARG ] . Now we have that ztk=kht−k+1superscriptsubscriptsubscriptsubscriptℎ1z_t^k=O_kh_t-k+1zitalic_titalic_k = Oitalic_k hitalic_t - k + 1 and zt+1k=kW∗ht−k+1superscriptsubscript1subscriptsuperscriptsubscriptℎ1z_t+1^k=O_kW^*h_t-k+1zitalic_t + 1k = Oitalic_k W∗ hitalic_t - k + 1. We want to find a matrix ΦksubscriptΦ _kΦitalic_k such that zt+1k=Φkztksuperscriptsubscript1subscriptΦsuperscriptsubscriptz_t+1^k= _kz_t^kzitalic_t + 1k = Φitalic_k zitalic_titalic_k. As this should hold for all possible system initializations and hence ht−k+1subscriptℎ1h_t-k+1hitalic_t - k + 1, we have that kW∗=ΦkksubscriptsuperscriptsubscriptΦsubscriptO_kW^*= _kO_kOitalic_k W∗ = Φitalic_k Oitalic_k. If kns≥nhsubscriptsubscriptℎkn_s≥ n_hk nitalic_s ≥ nitalic_h, we have an underdetermined or fully-determined (in case of equality) set of linear equations, assuming no rank-deficient matrices. The minimum-norm solution for ΦksubscriptΦ _kΦitalic_k is given by Φk=kW∗k†,subscriptΦsubscriptsuperscriptsuperscriptsubscript† _k=O_kW^*O_k ,Φitalic_k = Oitalic_k W∗ Oitalic_k† , with k†superscriptsubscript†O_k Oitalic_k† the Moore-Penrose pseudoinverse of ksubscriptO_kOitalic_k. If the dimension of the concatenated observations zksuperscriptz^kzitalic_k is smaller than the dimension of the groundtruth state hℎh (kns<nhsubscriptsubscriptℎkn_s<n_hk nitalic_s < nitalic_h), the linear system is overdetermined and in general there does not exist a solution for ΦksubscriptΦ _kΦitalic_k. Hence, in order to do optimal predictions, we need to concatenate enough observations into ztksuperscriptsubscriptz_t^kzitalic_titalic_k such that kns≥nhsubscriptsubscriptℎkn_s≥ n_hk nitalic_s ≥ nitalic_h. As there exists a linear map between zt+1ksuperscriptsubscript1z_t+1^kzitalic_t + 1k and ztksuperscriptsubscriptz_t^kzitalic_titalic_k, and zt+1ksuperscriptsubscript1z_t+1^kzitalic_t + 1k can be used directly to predict st+1subscript1s_t+1sitalic_t + 1, a Transformer can solve the least-squares problem in-context on zt+1ksuperscriptsubscript1z_t+1^kzitalic_t + 1k. One possible implementation is the following construction: (i) copy the last k observations into a concatenated state vector ztksuperscriptsubscriptz_t^kzitalic_titalic_k; (i) format tokens as required by Propositions 1 and 2, now with ztksuperscriptsubscriptz_t^kzitalic_titalic_k instead of htsubscriptℎh_thitalic_t, which can be done by the same self-attention layer as the first step; (i) solve the mesa-optimization problem by directly leveraging the aforementioned propositions. CompressedAlg-d After training a single- or multi-layer linear attention model, we obtain structured matrix products WK⊤WQ,PWvsuperscriptsubscripttopsubscriptsubscriptW_K W_Q,PW_vWitalic_K⊤ Witalic_Q , P Witalic_v per head and layer. When inspecting the trained weight matrix products, we observe strong block-diagonal structure across all layers. We extract the mean values of these block-diagonals and construct sparse weight matrices, consisting only of identity sub-matrices scaled by the resp. obtained mean value, and compute the evaluation of this constructed compressed algorithm on test sequences. Then, during a second training run (for the same initial conditions), we compute the test loss achieved by an a control model with interpolated parameters, obtained by averaging (with equal averaging weight) the compressed per-head weight-matrix-products and the actual trained layer parameters. Probing analyses In Figs. 2, 4 and 6 we show the performance of linear decoders trained to predict certain features (e.g., a given past input token et′subscriptsuperscript′e_t eitalic_t′, in Fig. 2A) from internal model activations at various depths, time steps, and stages of training. For every such probing experiment (i.e., for each layer, context length, or base training step, depending on the analysis at hand) we train a separate linear decoder on a batch of activations to predict the respective probing targets by mean-squared error minimization (linear regression). For the preconditioning probings, we compute the 6-step Chebyshev approximation of (St′−1St′−1⊤+1λI)st′subscriptsuperscript′1superscriptsubscriptsuperscript′1top1subscriptsuperscript′(S_t -1S_t -1 + 1λI)s_t ( Sitalic_t′ - 1 Sitalic_t′ - 1⊤ + divide start_ARG 1 end_ARG start_ARG λ end_ARG I ) sitalic_t′ at each time step t′, and linearly regress the activations after each layer at the respective time step against this preconditioning target. In-context few-shot learning: generative model To generate a few-shot task, we sample a groundtruth W∗superscriptW^*W∗ random orthogonal matrix as done during training, but now use this groundtruth model to generate a labeled training set xi,yii=1Nsuperscriptsubscriptsubscriptsubscript1\x_i,y_i\_i=1^N xitalic_i , yitalic_i i = 1N, with inputs xi∼(0,Ix)similar-tosubscript0subscriptx_i (0,I_x)xitalic_i ∼ N ( 0 , Iitalic_x ) and targets yi=W∗xisubscriptsuperscriptsubscripty_i=W^*x_iyitalic_i = W∗ xitalic_i. We then present this dataset to our autoregressive Transformers as a sequence of tokens, efew-shot=[x1,y1,…,xN,yN]superscriptfew-shotsubscript1subscript1…subscriptsubscripte^few-shot=[x_1,y_1,…,x_N,y_N]efew-shot = [ x1 , y1 , … , xitalic_N , yitalic_N ] of length T=2N2T=2NT = 2 N, cf. Figure 8. As the sequence unfolds, and more training data is presented, we measure in-context learning performance through the mean squared error between the Transformer output fθ(e2i−1;e1:2i−1few-shot)subscriptsubscript21subscriptsuperscriptfew-shot:121f_θ(e_2i-1;e^few-shot_1:2i-1)fitalic_θ ( e2 i - 1 ; efew-shot1 : 2 i - 1 ) and the corresponding target yi=e2isubscriptsubscript2y_i=e_2iyitalic_i = e2 i. We emphasize that both the sequence generative model and loss function differ from the ones used during training; compare the task performance metric Lfew-shot=12∑i=1N‖e2i−fθ(e2i−1;e1:2i−1few-shot)‖2superscriptfew-shot12superscriptsubscript1superscriptnormsubscript2subscriptsubscript21subscriptsuperscriptfew-shot:1212L^few-shot= 12 _i=1^N\|e_2i-f_θ(e_2i-1;e^% few-shot_1:2i-1)\|^2Lfew-shot = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑i = 1N ∥ e2 i - fitalic_θ ( e2 i - 1 ; efew-shot1 : 2 i - 1 ) ∥2 used to evaluate in-context learning performance in this section with the actual loss used to train the Transformer, Eq. 1. As a control, we further report the performance reached by the least-squares solution (LSQ) obtained on the dataset DNmesa=(xi,yi)i=1N∪(yi,xi+1)i=1N−1superscriptsubscriptmesasuperscriptsubscriptsubscriptsubscript1superscriptsubscriptsubscriptsubscript111D_N^mesa=\(x_i,y_i)\_i=1^N∪\(y_i,x_i+1)\_i=1^% N-1Ditalic_Nmesa = ( xitalic_i , yitalic_i ) i = 1N ∪ ( yitalic_i , xitalic_i + 1 ) i = 1N - 1, and observe a similar decrease in loss after a phase of early ascent. This dataset, where half of the associations consist of wrong input-output pairs DNspurious=(yi,xi+1)i=1N−1subscriptsuperscriptspurioussuperscriptsubscriptsubscriptsubscript111D^spurious_N=\(y_i,x_i+1)\_i=1^N-1Dspuriousitalic_N = ( yitalic_i , xitalic_i + 1 ) i = 1N - 1 as illustrated in Figure 8A, corresponds to the training set an autoregressive Transformer imbued with the mesa-optimizers uncovered in the previous section learns from. Acknowledgements João Sacramento and Johannes von Oswald thank Angelika Steger and Jyrki Alakuijala for their support and guidance. The authors also thank Marc Kaufmann, Yassir Akram, Andrey Zhmoginov, Yanick Schimpf, Oliver Sieberling and Luca Versari for fruitful discussions and insights, and to Luke Sernau, Maciej Wolczyk, Simon Schug and Robert T. Lange for valuable comments on the manuscript. João Sacramento and Nicolas Zucchet were supported by an Ambizione grant (PZ00P3_186027) from the Swiss National Science Foundation and ETH Research Grant (ETH-23 21-1). References Bommasani et al. [2022] Rishi Bommasani, Drew A. Hudson, Ehsan Adeli, Russ Altman, Simran Arora, Sydney von Arx, Michael S. Bernstein, Jeannette Bohg, Antoine Bosselut, Emma Brunskill, Erik Brynjolfsson, Shyamal Buch, Dallas Card, Rodrigo Castellon, Niladri Chatterji, Annie Chen, Kathleen Creel, Jared Quincy Davis, Dora Demszky, Chris Donahue, Moussa Doumbouya, Esin Durmus, Stefano Ermon, John Etchemendy, Kawin Ethayarajh, Li Fei-Fei, Chelsea Finn, Trevor Gale, Lauren Gillespie, Karan Goel, Noah Goodman, Shelby Grossman, Neel Guha, Tatsunori Hashimoto, Peter Henderson, John Hewitt, Daniel E. Ho, Jenny Hong, Kyle Hsu, Jing Huang, Thomas Icard, Saahil Jain, Dan Jurafsky, Pratyusha Kalluri, Siddharth Karamcheti, Geoff Keeling, Fereshte Khani, Omar Khattab, Pang Wei Koh, Mark Krass, Ranjay Krishna, Rohith Kuditipudi, Ananya Kumar, Faisal Ladhak, Mina Lee, Tony Lee, Jure Leskovec, Isabelle Levent, Xiang Lisa Li, Xuechen Li, Tengyu Ma, Ali Malik, Christopher D. Manning, Suvir Mirchandani, Eric Mitchell, Zanele Munyikwa, Suraj Nair, Avanika Narayan, Deepak Narayanan, Ben Newman, Allen Nie, Juan Carlos Niebles, Hamed Nilforoshan, Julian Nyarko, Giray Ogut, Laurel Orr, Isabel Papadimitriou, Joon Sung Park, Chris Piech, Eva Portelance, Christopher Potts, Aditi Raghunathan, Rob Reich, Hongyu Ren, Frieda Rong, Yusuf Roohani, Camilo Ruiz, Jack Ryan, Christopher Ré, Dorsa Sadigh, Shiori Sagawa, Keshav Santhanam, Andy Shih, Krishnan Srinivasan, Alex Tamkin, Rohan Taori, Armin W. Thomas, Florian Tramèr, Rose E. Wang, William Wang, Bohan Wu, Jiajun Wu, Yuhuai Wu, Sang Michael Xie, Michihiro Yasunaga, Jiaxuan You, Matei Zaharia, Michael Zhang, Tianyi Zhang, Xikun Zhang, Yuhui Zhang, Lucia Zheng, Kaitlyn Zhou, and Percy Liang. On the opportunities and risks of foundation models. arXiv preprint arXiv:2108.07258, 2022. Brown et al. [2020] Tom B. Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel M. Ziegler, Jeffrey Wu, Clemens Winter, Christopher Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Radford, Ilya Sutskever, and Dario Amodei. Language models are few-shot learners. In Advances in Neural Information Processing Systems, volume 33, 2020. Vaswani et al. [2017] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in Neural Information Processing Systems, volume 30, 2017. Elhage et al. [2021] Nelson Elhage, Neel Nanda, Catherine Olsson, Tom Henighan, Nicholas Joseph, Ben Mann, Amanda Askell, Yuntao Bai, Anna Chen, Tom Conerly, Nova DasSarma, Dawn Drain, Deep Ganguli, Zac Hatfield-Dodds, Danny Hernandez, Andy Jones, Jackson Kernion, Liane Lovitt, Kamal Ndousse, Dario Amodei, Tom Brown, Jack Clark, Jared Kaplan, Sam McCandlish, and Chris Olah. A Mathematical Framework for Transformer Circuits. Transformer Circuits Thread, 2021. Olsson et al. [2022] Catherine Olsson, Nelson Elhage, Neel Nanda, Nicholas Joseph, Nova DasSarma, Tom Henighan, Ben Mann, Amanda Askell, Yuntao Bai, Anna Chen, Tom Conerly, Dawn Drain, Deep Ganguli, Zac Hatfield-Dodds, Danny Hernandez, Scott Johnston, Andy Jones, Jackson Kernion, Liane Lovitt, Kamal Ndousse, Dario Amodei, Tom Brown, Jack Clark, Jared Kaplan, Sam McCandlish, and Chris Olah. In-context learning and induction heads. Transformer Circuits Thread, 2022. Chan et al. [2022] Stephanie C. Y. Chan, Adam Santoro, Andrew K. Lampinen, Jane X. Wang, Aaditya Singh, Pierre H. Richemond, Jay McClelland, and Felix Hill. Data distributional properties drive emergent in-context learning in transformers. In Advances in Neural Information Processing Systems, volume 35, 2022. Xie et al. [2022] Sang Michael Xie, Aditi Raghunathan, Percy Liang, and Tengyu Ma. An explanation of in-context learning as implicit Bayesian inference. In International Conference of Learning Representations, 2022. Singh et al. [2023] Aaditya Singh, Stephanie Chan, Ted Moskovitz, Erin Grant, Andrew Saxe, and Felix Hill. The transient nature of emergent in-context learning in transformers. In Advances in Neural Information Processing Systems, volume 36, 2023. Hochreiter et al. [2001] Sepp Hochreiter, A. Steven Younger, and Peter R. Conwell. Learning to learn using gradient descent. In Artificial Neural Networks — ICANN 2001, 2001. Duan et al. [2016] Yan Duan, John Schulman, Xi Chen, Peter L. Bartlett, Ilya Sutskever, and Pieter Abbeel. RL2: Fast reinforcement learning via slow reinforcement learning. arXiv preprint arXiv:1611.02779, 2016. Wang et al. [2016] Jane X Wang, Zeb Kurth-Nelson, Dhruva Tirumala, Hubert Soyer, Joel Z Leibo, Remi Munos, Charles Blundell, Dharshan Kumaran, and Matt Botvinick. Learning to reinforcement learn. arXiv preprint arXiv:1611.05763, 2016. Rabinowitz [2019] Neil C. Rabinowitz. Meta-learners’ learning dynamics are unlike learners’. arXiv preprint arXiv:1905.01320, 2019. Garg et al. [2022] Shivam Garg, Dimitris Tsipras, Percy S. Liang, and Gregory Valiant. What can transformers learn in-context? A case study of simple function classes. In Advances in Neural Information Processing Systems, volume 35, 2022. Akyürek et al. [2023] Ekin Akyürek, Dale Schuurmans, Jacob Andreas, Tengyu Ma, and Denny Zhou. What learning algorithm is in-context learning? Investigations with linear models. In International Conference of Learning Representations, 2023. Dai et al. [2022] Damai Dai, Yutao Sun, Li Dong, Yaru Hao, Zhifang Sui, and Furu Wei. Why Can GPT Learn In-Context? Language Models Secretly Perform Gradient Descent as Meta-Optimizers, December 2022. URL http://arxiv.org/abs/2212.10559. arXiv:2212.10559 [cs]. Kirsch et al. [2022] Louis Kirsch, James Harrison, Jascha Sohl-Dickstein, and Luke Metz. General-purpose in-context learning by meta-learning transformers. In Sixth Workshop on Meta-Learning at the Conference on Neural Information Processing Systems, 2022. von Oswald et al. [2023] Johannes von Oswald, Eyvind Niklasson, Ettore Randazzo, João Sacramento, Alexander Mordvintsev, Andrey Zhmoginov, and Max Vladymyrov. Transformers learn in-context by gradient descent. In International Conference on Machine Learning, 2023. Zhang et al. [2023] Ruiqi Zhang, Spencer Frei, and Peter L. Bartlett. Trained transformers learn linear models in-context. arXiv preprint arXiv:2306.09927, 2023. Mahankali et al. [2023] Arvind Mahankali, Tatsunori B. Hashimoto, and Tengyu Ma. One step of gradient descent is provably the optimal in-context learner with one layer of linear self-attention. arXiv preprint arXiv:2307.03576, 2023. Ahn et al. [2023] Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, and Suvrit Sra. Transformers learn to implement preconditioned gradient descent for in-context learning. arXiv preprint arXiv:2306.00297, 2023. Li et al. [2023a] Yingcong Li, Muhammed Emrullah Ildiz, Dimitris Papailiopoulos, and Samet Oymak. Transformers as algorithms: Generalization and stability in in-context learning. In International Conference on Machine Learning, 2023a. Raventós et al. [2023] Allan Raventós, Mansheej Paul, Feng Chen, and Surya Ganguli. Pretraining task diversity and the emergence of non-Bayesian in-context learning for regression. arXiv preprint arXiv:2306.15063, 2023. Ding et al. [2023] Nan Ding, Tomer Levinboim, Jialin Wu, Sebastian Goodman, and Radu Soricut. CausalLM is not optimal for in-context learning. arXiv preprint arXiv:2308.06912, 2023. Vladymyrov et al. [2024] Max Vladymyrov, Johannes von Oswald, Mark Sandler, and Rong Ge. Linear transformers are versatile in-context learners. arXiv preprint arXiv:2402.14180, 2024. Fu et al. [2023a] Deqing Fu, Tian-Qi Chen, Robin Jia, and Vatsal Sharan. Transformers learn higher-order optimization methods for in-context learning: A study with linear models. arXiv preprint arXiv:2310.17086, 2023a. Giannou et al. [2024] Angeliki Giannou, Liu Yang, Tianhao Wang, Dimitris Papailiopoulos, and Jason D Lee. How well can transformers emulate in-context Newton’s method? arXiv preprint arXiv:2403.03183, 2024. Hubinger et al. [2019] Evan Hubinger, Chris van Merwijk, Vladimir Mikulik, Joar Skalse, and Scott Garrabrant. Risks from learned optimization in advanced machine learning systems. arXiv preprint 1906.01820, 2019. Ba et al. [2016a] Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E. Hinton. Layer normalization. arXiv preprint 1607.06450, 2016a. Schmidhuber [1992] Jürgen Schmidhuber. Learning to control fast-weight memories: an alternative to dynamic recurrent networks. Neural Computation, 4(1):131–139, 1992. Ba et al. [2016b] Jimmy Ba, Geoffrey E. Hinton, Volodymyr Mnih, Joel Z. Leibo, and Catalin Ionescu. Using fast weights to attend to the recent past. In Advances in Neural Information Processing Systems, volume 29, 2016b. Katharopoulos et al. [2020] Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are RNNs: fast autoregressive transformers with linear attention. In International Conference on Machine Learning, 2020. Wang et al. [2020] Sinong Wang, Belinda Z. Li, Madian Khabsa, Han Fang, and Hao Ma. Linformer: self-attention with linear complexity. arXiv preprint arXiv:2006.04768, 2020. Schlag et al. [2021] Imanol Schlag, Kazuki Irie, and Jürgen Schmidhuber. Linear transformers are secretly fast weight programmers. In International Conference on Machine Learning, 2021. Choromanski et al. [2021] Krzysztof Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Davis, Afroz Mohiuddin, Lukasz Kaiser, David Belanger, Lucy Colwell, and Adrian Weller. Rethinking attention with performers. In International Conference of Learning Representations, 2021. Ahn et al. [2024] Kwangjun Ahn, Xiang Cheng, Minhak Song, Chulhee Yun, Ali Jadbabaie, and Suvrit Sra. Linear attention is (maybe) all you need (to understand transformer optimization). In International Conference of Learning Representations, 2024. Sherman and Morrison [1950] Jack Sherman and Winifred J. Morrison. Adjustment of an inverse matrix corresponding to a change in one element of a given matrix. The Annals of Mathematical Statistics, 21(1):124–127, 1950. Gauss [1821] Carl Friedrich Gauss. Theoria combinationis observationum: erroribus minimis obnoxiae. Societas Regia Scientiarum Gottingensis, 1821. Garnelo and Czarnecki [2023] Marta Garnelo and Wojciech Marian Czarnecki. Exploring the space of key-value-query models with intention. arXiv preprint arXiv:2305.10203, 2023. Alain and Bengio [2017] Guillaume Alain and Yoshua Bengio. Understanding intermediate layers using linear classifier probes. In International Conference of Learning Representations, 2017. Kalman [1960] R. E. Kalman. A new approach to linear filtering and prediction problems. Journal of Basic Engineering, 82(1):35–45, March 1960. Tangirala [2018] Arun K. Tangirala. Principles of system identification: theory and practice. CRC Press, 2018. Marino et al. [2018] Joseph Marino, Milan Cvitkovic, and Yisong Yue. A general method for amortizing variational filtering. In Advances in Neural Information Processing Systems, volume 31, 2018. Willems et al. [2005] Jan C. Willems, Paolo Rapisarda, Ivan Markovsky, and Bart L. M. De Moor. A note on persistency of excitation. Systems & Control Letters, 54(4):325–329, April 2005. De Persis and Tesi [2020] Claudio De Persis and Pietro Tesi. Formulas for data-driven control: stabilization, optimality, and robustness. IEEE Transactions on Automatic Control, 65(3):909–924, March 2020. Fu et al. [2023b] Daniel Y. Fu, Tri Dao, Khaled K. Saab, Armin W. Thomas, Atri Rudra, and Christopher Ré. Hungry hungry hippos: towards language modeling with state space models. In International Conference of Learning Representations, 2023b. Arora et al. [2023] Simran Arora, Sabri Eyuboglu, Aman Timalsina, Isys Johnson, Michael Poli, James Zou, Atri Rudra, and Christopher Ré. Zoology: Measuring and improving recall in efficient language models. arXiv preprint arXiv:2312.04927, 2023. Poli et al. [2023] Michael Poli, Stefano Massaroli, Eric Nguyen, Daniel Y Fu, Tri Dao, Stephen Baccus, Yoshua Bengio, Stefano Ermon, and Christopher Ré. Hyena hierarchy: Towards larger convolutional language models. In International Conference on Machine Learning, pages 28043–28078, 2023. De et al. [2024] Soham De, Samuel L. Smith, Anushan Fernando, Aleksandar Botev, George Cristian-Muraru, Albert Gu, Ruba Haroun, Leonard Berrada, Yutian Chen, Srivatsan Srinivasan, Guillaume Desjardins, Arnaud Doucet, David Budden, Yee Whye Teh, Razvan Pascanu, Nando De Freitas, and Caglar Gulcehre. Griffin: mixing gated linear recurrences with local attention for efficient language models, February 2024. URL http://arxiv.org/abs/2402.19427. arXiv:2402.19427 [cs]. Finn et al. [2017] Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of deep networks. In International Conference on Machine Learning, 2017. Kaplan et al. [2020] Jared Kaplan, Sam McCandlish, Tom Henighan, Tom B Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. Scaling laws for neural language models. arXiv preprint arXiv:2001.08361, 2020. Hinton et al. [2015] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015. Györfi et al. [2002] László Györfi, Michael Kohler, Adam Krzyzak, and Harro Walk. A distribution-free theory of nonparametric regression. Springer, New York, 2002. Krogh and Hertz [1992] A. Krogh and J. A. Hertz. Generalization in a linear perceptron in the presence of noise. Journal of Physics A: Mathematical and General, 25(5):1135, March 1992. Advani et al. [2020] Madhu S. Advani, Andrew M. Saxe, and Haim Sompolinsky. High-dimensional dynamics of generalization error in neural networks. Neural Networks, 132:428–446, December 2020. Bellman [1961] Richard Bellman. Adaptive Control Processes: A Guided Tour. Princeton University Press, 1961. Snell et al. [2017] Jake Snell, Kevin Swersky, and Richard Zemel. Prototypical networks for few-shot learning. Advances in Neural Information Processing Systems, 30, 2017. Lin and Lee [2024] Ziqian Lin and Kangwook Lee. Dual operating modes of in-context learning. arXiv preprint arXiv:2402.18819, 2024. Li and Liang [2021] Xiang Lisa Li and Percy Liang. Prefix-tuning: optimizing continuous prompts for generation. In Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics, 2021. Lester et al. [2021] Brian Lester, Rami Al-Rfou, and Noah Constant. The power of scale for parameter-efficient prompt tuning. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, 2021. Irie et al. [2023] Kazuki Irie, Róbert Csordás, and Jürgen Schmidhuber. Automating continual learning. arXiv preprint arXiv:2312.00276, 2023. Bender et al. [2021] Emily M. Bender, Timnit Gebru, Angelina McMillan-Major, and Shmargaret Shmitchell. On the dangers of stochastic parrots: can language models be too big? In Proceedings of the 2021 ACM Conference on Fairness, Accountability, and Transparency, 2021. Toshniwal et al. [2022] Shubham Toshniwal, Sam Wiseman, Karen Livescu, and Kevin Gimpel. Chess as a testbed for language model state tracking. In Proceedings of the AAAI Conference on Artificial Intelligence, 2022. Li et al. [2023b] Yingcong Li, Kartik Sreenivasan, Angeliki Giannou, Dimitris Papailiopoulos, and Samet Oymak. Dissecting chain-of-thought: a study on compositional in-context learning of MLPs. arXiv preprint arXiv:2305.18869, 2023b. Nanda et al. [2023] Neel Nanda, Andrew Lee, and Martin Wattenberg. Emergent linear representations in world models of self-supervised sequence models. arXiv preprint arXiv:2309.00941, 2023. Amos and Kolter [2017] Brandon Amos and J. Zico Kolter. OptNet: Differentiable optimization as a layer in neural networks. In International Conference on Machine Learning, 2017. Gould et al. [2021] Stephen Gould, Richard Hartley, and Dylan John Campbell. Deep declarative networks. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2021. Zucchet and Sacramento [2022] Nicolas Zucchet and João Sacramento. Beyond backpropagation: bilevel optimization through implicit differentiation and equilibrium propagation. Neural Computation, 34(12), 2022. Ramsauer et al. [2021] Hubert Ramsauer, Bernhard Schäfl, Johannes Lehner, Philipp Seidl, Michael Widrich, Lukas Gruber, Markus Holzleitner, Thomas Adler, David Kreil, Michael K. Kopp, Günter Klambauer, Johannes Brandstetter, and Sepp Hochreiter. Hopfield networks is all you need. In International Conference on Learning Representations, 2021. Martins et al. [2020] André Martins, António Farinhas, Marcos Treviso, Vlad Niculae, Pedro Aguiar, and Mario Figueiredo. Sparse and continuous attention mechanisms. In Advances in Neural Information Processing Systems, volume 33, 2020. Hoover et al. [2023] Benjamin Hoover, Yuchen Liang, Bao Pham, Rameswar Panda, Hendrik Strobelt, Duen Horng Chau, Mohammed Zaki, and Dmitry Krotov. Energy transformer. In Advances in Neural Information Processing Systems, volume 36, 2023. Hebb [1949] Donald O. Hebb. The Organization of Behavior: A Neuropsychological Theory. Wiley, New York, 1949. Hertz et al. [1991] John Hertz, Richard G. Palmer, and Anders S. Krogh. Introduction to the Theory of Neural Computation. Perseus Publishing, 1st edition, 1991. Widrow and Hoff [1960] Bernard Widrow and Marcian E. Hoff. Adaptive switching circuits. In IRE WESCON convention record, volume 4, 1960. Lillicrap et al. [2020] Timothy P. Lillicrap, Adam Santoro, Luke Marris, Colin J. Akerman, and Geoffrey Hinton. Backpropagation and the brain. Nature Reviews Neuroscience, 21(6):335–346, 2020. Lee et al. [2015] Dong-Hyun Lee, Saizheng Zhang, Asja Fischer, and Yoshua Bengio. Difference target propagation. In Joint European Conference on Machine Learning and Knowledge Discovery in Databases, 2015. Whittington and Bogacz [2017] James C. R. Whittington and Rafal Bogacz. An approximation of the error backpropagation algorithm in a predictive coding network with local Hebbian synaptic plasticity. Neural Computation, 29(5):1229–1262, 2017. Meulemans et al. [2022] Alexander Meulemans, Nicolas Zucchet, Seijin Kobayashi, Johannes von Oswald, and João Sacramento. The least-control principle for local learning at equilibrium. In Advances in Neural Information Processing Systems, volume 35, 2022. Hinton et al. [2006] Geoffrey Hinton, Simon Osindero, and Yee Whye Teh. A Fast Learning Algorithm for Deep Belief Nets. Neural Computation, 18:1527–1554, 2006. Nøkland and Eidnes [2019] Arild Nøkland and Lars Hiller Eidnes. Training neural networks with local error signals. In International Conference on Machine Learning, 2019. Belilovsky et al. [2019] Eugene Belilovsky, Michael Eickenberg, and Edouard Oyallon. Greedy layerwise learning can scale to ImageNet. In International Conference on Machine Learning, 2019. Löwe et al. [2019] Sindy Löwe, Peter O’Connor, and Bastiaan Veeling. Putting an end to end-to-end: Gradient-isolated learning of representations. In Advances in Neural Information Processing Systems, volume 32, 2019. Hinton [2022] Geoffrey Hinton. The forward-forward algorithm: Some preliminary investigations. arXiv preprint arXiv:2212.13345, 2022. Hawkins and Blakeslee [2004] Jeff Hawkins and Sandra Blakeslee. On intelligence. Macmillan, 2004. Clark [2013] Andy Clark. Whatever next? Predictive brains, situated agents, and the future of cognitive science. Behavioral and Brain Sciences, 36(3):181–204, 2013. Hinton et al. [1995] Geoffrey E. Hinton, Peter Dayan, Brendan J. Frey, and Radford M. Neal. The "wake-sleep" algorithm for unsupervised neural networks. Science, 268(5214):1158–1161, 1995. Rao and Ballard [1999] Rajesh P. N. Rao and Dana H. Ballard. Predictive coding in the visual cortex: a functional interpretation of some extra-classical receptive-field effects. Nature Neuroscience, 2(1):79–87, 1999. Lee and Mumford [2003] Tai Sing Lee and David Mumford. Hierarchical Bayesian inference in the visual cortex. Journal of the Optical Society of America A, 20(7):1434, July 2003. Friston et al. [2006] Karl Friston, James Kilner, and Lee Harrison. A free energy principle for the brain. Journal of Physiology-Paris, 100(1-3):70–87, 2006. Keller and Mrsic-Flogel [2018] Georg B. Keller and Thomas D. Mrsic-Flogel. Predictive processing: a canonical cortical computation. Neuron, 100(2):424–435, 2018. Radford et al. [2018] Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever. Language models are unsupervised multitask learners. OpenAI blog, 1(8), 2018. Hendrycks and Gimpel [2016] Dan Hendrycks and Kevin Gimpel. Gaussian Error Linear Units (GELUs). arXiv preprint arXiv:1606.08415, 2016. Loshchilov and Hutter [2019] Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. In International Conference of Learning Representations, 2019. Caruana [1997] Rich Caruana. Multitask learning. Machine Learning, 28(1):41–75, July 1997. Rirchardson [1911] Lewis Fry Rirchardson. The approximate arithmetical solution by finite differences of physical problems involving differential equations, with an application to the stresses in a masonry dam. Philosophical Transactions of the Royal Society of London. Series A, Containing Papers of a Mathematical or Physical Character, 210(459-470):307–357, January 1911. Golub and Varga [1961] Gene H. Golub and Richard S. Varga. Chebyshev semi-iterative methods, successive overrelaxation iterative methods, and second order Richardson iterative methods: Part I. Numerische Mathematik, 3(1):147–156, December 1961. Johnstone et al. [1982] Richard M Johnstone, C Richard Johnson Jr, Robert R Bitmead, and Brian DO Anderson. Exponential convergence of recursive least squares with exponential forgetting factor. Systems & Control Letters, 2(2):77–82, 1982. Publisher: Elsevier. Bradbury et al. [2018] James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake VanderPlas, Skye Wanderman-Milne, and Qiao Zhang. JAX: composable transformations of Python+NumPy programs, 2018. Kingma and Ba [2015] Diederik P. Kingma and Jimmy Ba. Adam: a method for stochastic optimization. In International Conference on Learning Representations, 2015. Gao et al. [2020] Leo Gao, Stella Biderman, Sid Black, Laurence Golding, Travis Hoppe, Charles Foster, Jason Phang, Horace He, Anish Thite, Noa Nabeshima, Shawn Presser, and Connor Leahy. The pile: an 800GB dataset of diverse text for language modeling. arXiv preprint arXiv:2101.00027, 2020. Hahnloser et al. [2000] Richard H. R. Hahnloser, Rahul Sarpeshkar, Misha A. Mahowald, Rodney J. Douglas, and H. Sebastian Seung. Digital selection and analogue amplification coexist in a cortex-inspired silicon circuit. Nature, 405(6789):947–951, 2000. Harris et al. [2020] Charles R. Harris, K. Jarrod Millman, Stéfan J. van der Walt, Ralf Gommers, Pauli Virtanen, David Cournapeau, Eric Wieser, Julian Taylor, Sebastian Berg, Nathaniel J. Smith, Robert Kern, Matti Picus, Stephan Hoyer, Marten H. van Kerkwijk, Matthew Brett, Allan Haldane, Jaime Fernández del Río, Mark Wiebe, Pearu Peterson, Pierre Gérard-Marchant, Kevin Sheppard, Tyler Reddy, Warren Weckesser, Hameer Abbasi, Christoph Gohlke, and Travis E. Oliphant. Array programming with NumPy. Nature, 585(7825):357–362, 2020. Hunter [2007] J. D. Hunter. Matplotlib: A 2D graphics environment. Computing in Science & Engineering, 9(3):90–95, 2007. Hennigan et al. [2020] Tom Hennigan, Trevor Cai, Tamara Norman, Lena Martens, and Igor Babuschkin. Haiku: Sonnet for JAX, 2020. Babuschkin et al. [2020] Igor Babuschkin, Kate Baumli, Alison Bell, Surya Bhupatiraju, Jake Bruce, Peter Buchlovsky, David Budden, Trevor Cai, Aidan Clark, Ivo Danihelka, Antoine Dedieu, Claudio Fantacci, Jonathan Godwin, Chris Jones, Ross Hemsley, Tom Hennigan, Matteo Hessel, Shaobo Hou, Steven Kapturowski, Thomas Keck, Iurii Kemaev, Michael King, Markus Kunesch, Lena Martens, Hamza Merzic, Vladimir Mikulik, Tamara Norman, George Papamakarios, John Quan, Roman Ring, Francisco Ruiz, Alvaro Sanchez, Laurent Sartran, Rosalia Schneider, Eren Sezener, Stephen Spencer, Srivatsan Srinivasan, Miloš Stanojević, Wojciech Stokowiec, Luyu Wang, Guangyao Zhou, and Fabio Viola. The DeepMind JAX Ecosystem, 2020. Appendix A Visualization of weights and attention maps of trained multi-layer Transformers Figure 9: Weights of the deep 6-layer linear Transformer trained on constructed tokens et=(0,st,st,st−1)subscript0subscriptsubscriptsubscript1e_t=(0,s_t,s_t,s_t-1)eitalic_t = ( 0 , sitalic_t , sitalic_t , sitalic_t - 1 ). We observe clear structure in the trained Transformer weight products WK⊤WQsuperscriptsubscripttopsubscriptW_K W_QWitalic_K⊤ Witalic_Q as well as PWVsubscriptPW_VP Witalic_V in all 4 heads. Note that this structure seems to be sufficient to approximate (St−1St−1⊤+1/λI)−1stsuperscriptsubscript1superscriptsubscript1top11subscript(S_t-1S_t-1 +1/λ I)^-1s_t( Sitalic_t - 1 Sitalic_t - 1⊤ + 1 / λ I )- 1 sitalic_t, cf. probing experiments and weight construction in the main text. We show here all 4 heads (f.l.t.r.) of the first (top 2 rows), the second (next 2 rows), and the fourth (last 2 rows) linear layer. Figure 10: Weight products of the deep 7-layer softmax Transformers trained on unconstructed tokens et=stsubscriptsubscripte_t=s_teitalic_t = sitalic_t. We observe diagonal structure in the trained Transformer weight products WK⊤WQsuperscriptsubscripttopsubscriptW_K W_QWitalic_K⊤ Witalic_Q as well as PWVsubscriptPW_VP Witalic_V. Note that this structure seems to be sufficient to approximate layer-wise the final prediction st+1subscript1s_t+1sitalic_t + 1 as well as (St−1St−1⊤+1/λI)−1stsuperscriptsubscript1superscriptsubscript1top11subscript(S_t-1S_t-1 +1/λ I)^-1s_t( Sitalic_t - 1 Sitalic_t - 1⊤ + 1 / λ I )- 1 sitalic_t, cf. probing experiments and weight construction in the main text. We show here all 4 heads (f.l.t.r.) of the first (top 2 rows) the second (middle 2 rows) and the fourth (last 2 rows) layers after the first (potential) copying-softmax-layer. Appendix B Additional details on the mesa-layer In this section, we provide a detailed derivation of the forward and backward (reverse-mode differentiation) pass of the mesa-layer. For completeness, we consider a generalized version of the mesa-layer, which includes an additional forget factor Γh,t=(γh,t′)t′=1tsubscriptΓℎsuperscriptsubscriptsubscriptℎsuperscript′1 _h,t=( _h,t )_t =1^tΓitalic_h , t = ( γitalic_h , t′ )t′ = 1t, where γh,t′∈(0,1]subscriptℎsuperscript′01 _h,t ∈(0,1]γitalic_h , t′ ∈ ( 0 , 1 ], inspired by the recursive least-squares with forget factor algorithm [96]. Given again a set of tokens EtsubscriptE_tEitalic_t, the generalized mesa-layer changes the tokens as follows: ΔetmesaΔsuperscriptsubscriptmesa e_t^mesaΔ eitalic_tmesa =∑h=1HPhΦ^h,tmesaqh,t,absentsuperscriptsubscriptℎ1subscriptℎsuperscriptsubscript^Φℎmesasubscriptℎ = _h=1^HP_h _h,t^mesaq_h,t,= ∑h = 1H Pitalic_h over start_ARG Φ end_ARGh , tmesa qitalic_h , t , (10) withΦ^h,tmesa=argminΦwithsuperscriptsubscript^ΦℎmesasubscriptargminΦ _h,t^mesa= *% arg\,min_ with over start_ARG Φ end_ARGh , tmesa = start_OPERATOR arg min end_OPERATORΦ 12∑t′=1t(∏t′=t′+1tγh,t’′)‖Φkh,t′−vh,t′‖2+∏t′=1tγh,t′2λh‖Φ‖F2.12superscriptsubscriptsuperscript′1superscriptsubscriptproductsuperscript′1subscriptℎsuperscript’′normΦsubscriptℎsuperscript′subscriptℎsuperscript′2superscriptsubscriptproductsuperscript′1subscriptℎsuperscript′2subscriptℎsuperscriptsubscriptnormΦF2 \ 12 _t =1^t ( _t % =t +1^t _h,t )|| k_% h,t -v_h,t ||^2+ _t =1^t% _h,t 2 _h|| ||_F^2 \. divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑t′ = 1t ( ∏t′ ′ = t′ + 1t γitalic_h , t ’′ ) | | Φ kitalic_h , t′ - vitalic_h , t′ | |2 + divide start_ARG ∏t′ ′ = 1t γitalic_h , t′ ′ end_ARG start_ARG 2 λitalic_h end_ARG | | Φ | |F2 . (11) For notational simplicity we drop the subscript in hℎh and ignore the sum over the heads in the following derivation. It can be shown that the analytical solution of the optimization problem is Φ^tmesa=(∑t′=1t(∏t′=t′+1tγt’′)vt’kt’⊤)(∑t′=1t(∏t′=t′+1tγt’′)kt’kt’⊤+∏t′=1tγt′λI)−1subscriptsuperscript^Φmesasuperscriptsubscriptsuperscript′1superscriptsubscriptproductsuperscript′1subscriptsuperscript’′subscript’superscriptsubscript’topsuperscriptsuperscriptsubscriptsuperscript′1superscriptsubscriptproductsuperscript′1subscriptsuperscript’′subscript’superscriptsubscript’topsuperscriptsubscriptproductsuperscript′1subscriptsuperscript′1 ^mesa_t= ( _t =1^t ( _t^% =t +1^t _t )v_t% k_t ) ( _t =1^t% ( _t =t +1^t _t ^% )k_t k_t + _t^% =1^t _t λI )^-1over start_ARG Φ end_ARGmesaitalic_t = ( ∑t′ = 1t ( ∏t′ ′ = t′ + 1t γitalic_t ’′ ) vitalic_t ’ kitalic_t ’⊤ ) ( ∑t′ = 1t ( ∏t′ ′ = t′ + 1t γitalic_t ’′ ) kitalic_t ’ kitalic_t ’⊤ + divide start_ARG ∏t′ ′ = 1t γitalic_t′ ′ end_ARG start_ARG λ end_ARG I )- 1 (12) We will now see how ΔetmesaΔsuperscriptsubscriptmesa e_t^mesaΔ eitalic_tmesa can be efficiently computed in a forward pass. B.1 Computing the inverse term within Φ^tmesasubscriptsuperscript^Φmesa ^mesa_tover start_ARG Φ end_ARGmesaitalic_t Computing the full-fledged inverse at every timestep is computationally too expensive. We resort to using the Sherman-Morrison formula to efficiently compute the inverse term for all timestep sequentially in time. We redefine Rt=(∑t′=1t(∏t′=t′+1tγt’′)kt’kt’⊤+∏t′=1tγt′λI)−1.subscriptsuperscriptsuperscriptsubscriptsuperscript′1superscriptsubscriptproductsuperscript′1subscriptsuperscript’′subscript’superscriptsubscript’topsuperscriptsubscriptproductsuperscript′1subscriptsuperscript′1R_t= ( _t =1^t ( _t =t +1% ^t _t )k_t k_t% + _t =1^t _t % λI )^-1\!.Ritalic_t = ( ∑t′ = 1t ( ∏t′ ′ = t′ + 1t γitalic_t ’′ ) kitalic_t ’ kitalic_t ’⊤ + divide start_ARG ∏t′ ′ = 1t γitalic_t′ ′ end_ARG start_ARG λ end_ARG I )- 1 . (13) It satisfies the recursive formula Rt+1=(γtRt−1+kt+1kt+1⊤)−1subscript1superscriptsubscriptsuperscriptsubscript1subscript1superscriptsubscript1top1R_t+1= ( _tR_t^-1+k_t+1k_t+1 )^-1Ritalic_t + 1 = ( γitalic_t Ritalic_t- 1 + kitalic_t + 1 kitalic_t + 1⊤ )- 1 (14) with R0=λIsubscript0R_0=λ IR0 = λ I, and the Sherman-Morrison formula thus gives Rt+1subscript1 R_t+1Ritalic_t + 1 =γt+1−1(Rt−1+γt+1−1kt+1kt+1⊤)−1absentsuperscriptsubscript11superscriptsuperscriptsubscript1superscriptsubscript11subscript1superscriptsubscript1top1 = _t+1^-1 (R_t^-1+ _t+1^-1k_t+1k_t+1% )^-1= γitalic_t + 1- 1 ( Ritalic_t- 1 + γitalic_t + 1- 1 kitalic_t + 1 kitalic_t + 1⊤ )- 1 (15) =γt+1−1(Rt−γt+1−1Rtkt+1kt+1⊤Rt1+γt+1−1kt+1⊤Rtkt+1)absentsuperscriptsubscript11subscriptsuperscriptsubscript11subscriptsubscript1superscriptsubscript1topsubscript1superscriptsubscript11superscriptsubscript1topsubscriptsubscript1 = _t+1^-1 (R_t- _t+1^-1R_tk_t+1% k_t+1 R_t1+ _t+1^-1k_t+1 R_tk_t+1 )= γitalic_t + 1- 1 ( Ritalic_t - divide start_ARG γitalic_t + 1- 1 Ritalic_t kitalic_t + 1 kitalic_t + 1⊤ Ritalic_t end_ARG start_ARG 1 + γitalic_t + 1- 1 kitalic_t + 1⊤ Ritalic_t kitalic_t + 1 end_ARG ) (16) =γt+1−1(Rt−Rtkt+1kt+1⊤Rtγt+1+kt+1⊤Rtkt+1).absentsuperscriptsubscript11subscriptsubscriptsubscript1superscriptsubscript1topsubscriptsubscript1superscriptsubscript1topsubscriptsubscript1 = _t+1^-1 (R_t- R_tk_t+1k_t+1 R_% t _t+1+k_t+1 R_tk_t+1 ).= γitalic_t + 1- 1 ( Ritalic_t - divide start_ARG Ritalic_t kitalic_t + 1 kitalic_t + 1⊤ Ritalic_t end_ARG start_ARG γitalic_t + 1 + kitalic_t + 1⊤ Ritalic_t kitalic_t + 1 end_ARG ) . (17) B.2 Computing ΔetmesaΔsuperscriptsubscriptmesa e_t^mesaΔ eitalic_troman_mesa Given Rh,tsubscriptℎR_h,tRitalic_h , t for all heads, we can rewrite the token update as ΔetmesaΔsuperscriptsubscriptmesa e_t^mesaΔ eitalic_tmesa =∑h=1HPh(∑t′=1t(∏t′=t′+1tγh,t’′)vh,t’kh,t’⊤)Rh,tqh,tabsentsuperscriptsubscriptℎ1subscriptℎsuperscriptsubscriptsuperscript′1superscriptsubscriptproductsuperscript′1subscriptℎsuperscript’′subscriptℎ’superscriptsubscriptℎ’topsubscriptℎsubscriptℎ = _h=1^HP_h ( _t =1^t ( _t^% =t +1^t _h,t )v_h% ,t k_h,t )R_h,tq_h,t= ∑h = 1H Pitalic_h ( ∑t′ = 1t ( ∏t′ ′ = t′ + 1t γitalic_h , t ’′ ) vitalic_h , t ’ kitalic_h , t ’⊤ ) Ritalic_h , t qitalic_h , t (18) =∑h=1HPhVh((t′≤t∏t′=t′+1tγh,t’′)t′=1⊤⊙Kh⊤q~h,t)absentsuperscriptsubscriptℎ1subscriptℎsubscriptℎdirect-productsuperscriptsubscriptsubscript1superscript′subscriptproductsuperscript′1subscriptℎsuperscript’′1topsuperscriptsubscriptℎtopsubscript~ℎ = _h=1^HP_hV_h ( ( 1_t ≤ t% _t =t +1^t _h,t % )_t =1 K_h q_h,t )= ∑h = 1H Pitalic_h Vitalic_h ( ( blackboard_1t′ ≤ t ∏t′ ′ = t′ + 1t γitalic_h , t ’′ )t′ = 1⊤ ⊙ Kitalic_h⊤ over~ start_ARG q end_ARGh , t ) (19) =∑h=1HPhVh(M:,t⊙Kh⊤q~h,t)absentsuperscriptsubscriptℎ1subscriptℎsubscriptℎdirect-productsubscript:superscriptsubscriptℎtopsubscript~ℎ = _h=1^HP_hV_h (M_:,t K_h q% _h,t )= ∑h = 1H Pitalic_h Vitalic_h ( M: , t ⊙ Kitalic_h⊤ over~ start_ARG q end_ARGh , t ) (20) where q~h,t=Rh,tqh,tsubscript~ℎsubscriptℎsubscriptℎ q_h,t=R_h,tq_h,tover~ start_ARG q end_ARGh , t = Ritalic_h , t qitalic_h , t and Mt′,t:=t′≤t∏t′=t′+1tγh,t’′assignsubscriptsuperscript′subscript1superscript′subscriptproductsuperscript′1subscriptℎsuperscript’′M_t ,t:= 1_t ≤ t _t =t^% +1^t _h,t Mitalic_t′ , t := blackboard_1t′ ≤ t ∏t′ ′ = t′ + 1t γitalic_h , t ’′. Note that we apply some form causal masking here: we take the key Kh∈ℝDa×TsubscriptℎsuperscriptℝsubscriptK_h ^D_a× TKitalic_h ∈ blackboard_RDitalic_a × T and value matrices Vh∈RDa×TsubscriptℎsuperscriptRsubscriptV_h ^D_a× TVitalic_h ∈ Ritalic_Ditalic_a × T with all the sequence timesteps and select the entries occurring before time t. The main difference with the usual causal mask (t′≤t)t′,tsubscriptsubscript1superscript′( 1_t ≤ t)_t ,t( blackboard_1t′ ≤ t )t′ , t is the inclusion of the forget factors. It can be efficiently computed leveraging partial products. We conclude by remarking that the same mask can be applied to softmax attention layers, applying it to the key-queries products before the softmax. Appendix C Mesa-layer differentiation C.1 Mesa-layer backward pass computation via Sherman-Morrison We now detail how to compute the backward pass of the mesa-layer. Summarizing the results above, its forward pass is computed recursively following: Rh,t+1subscriptℎ1 R_h,t+1Ritalic_h , t + 1 =γh,t+1−1(Rh,t−Rh,tkh,t+1kh,t+1⊤Rh,tγh,t+1+kh,t+1⊤Rh,tkh,t+1)absentsuperscriptsubscriptℎ11subscriptℎsubscriptℎsubscriptℎ1superscriptsubscriptℎ1topsubscriptℎsubscriptℎ1superscriptsubscriptℎ1topsubscriptℎsubscriptℎ1 = _h,t+1^-1 (R_h,t- R_h,tk_h,t+1k_h,t+1% R_h,t _h,t+1+k_h,t+1 R_h,tk_h,t+1 )= γitalic_h , t + 1- 1 ( Ritalic_h , t - divide start_ARG Ritalic_h , t kitalic_h , t + 1 kitalic_h , t + 1⊤ Ritalic_h , t end_ARG start_ARG γitalic_h , t + 1 + kitalic_h , t + 1⊤ Ritalic_h , t kitalic_h , t + 1 end_ARG ) (21) Δet,mesaΔsubscriptmesa e_t,mesaΔ eitalic_t , mesa =∑h=1HPhVh(M:,t⊙Kh⊤q~h,t)absentsuperscriptsubscriptℎ1subscriptℎsubscriptℎdirect-productsubscript:superscriptsubscriptℎtopsubscript~ℎ = _h=1^HP_hV_h (M_:,t K_h q% _h,t )= ∑h = 1H Pitalic_h Vitalic_h ( M: , t ⊙ Kitalic_h⊤ over~ start_ARG q end_ARGh , t ) (22) with Rh,0=λhIsubscriptℎ0subscriptℎR_h,0= _hIRitalic_h , 0 = λitalic_h I. These computations can be decomposed into 3 steps: 1. First, the matrices Rt,hsubscriptℎR_t,hRitalic_t , h are computed sequentially. 2. Then, for all t and hℎh, the transformed queries q~h,t=Rh,tqh,tsubscript~ℎsubscriptℎsubscriptℎ q_h,t=R_h,tq_h,tover~ start_ARG q end_ARGh , t = Ritalic_h , t qitalic_h , t are computed. 3. Finally, using the transformed queries Q~h=(q~h,t)tsubscript~ℎsubscriptsubscript~ℎ Q_h=( q_h,t)_tover~ start_ARG Q end_ARGh = ( over~ start_ARG q end_ARGh , t )t as the queries, a standard cross-attention operation is computed from (Vh,Kh,Q~h)subscriptℎsubscriptℎsubscript~ℎ(V_h,K_h, Q_h)( Vitalic_h , Kitalic_h , over~ start_ARG Q end_ARGh ) using the causal mask M that includes forgetting rates. While the backward pass of steps 2 and 3 can be computed easily with automatic differentiation tools without much overhead compared to standard attention layers, the same thing cannot be said about 1. We will here discuss how the backward pass of the computation of Q~hsubscript~ℎ Q_hover~ start_ARG Q end_ARGh can be computed in a memory-efficient way. Without loss of generality, we drop the subscript hℎh for notational simplicity. The issue with out-of-the-box automatic differentiation. For all time steps t, q~t=Rtqtsubscript~subscriptsubscript q_t=R_tq_tover~ start_ARG q end_ARGt = Ritalic_t qitalic_t depends on qtsubscriptq_tqitalic_t, but also KtsubscriptK_tKitalic_t, ΓtsubscriptΓ _tΓitalic_t and λ through the variable RtsubscriptR_tRitalic_t. In the backward pass, we are given as input the gradient of the loss function w.r.t. Q~~ Qover~ start_ARG Q end_ARG, namely dℒdq~tdℒdsubscript~ dLd q_tdivide start_ARG d L end_ARG start_ARG d over~ start_ARG q end_ARGt end_ARG for all t. The goal is then to compute the gradient of the loss w.r.t. the input of Q~~ Qover~ start_ARG Q end_ARG, namely dℒdktdℒdsubscript dLdk_tdivide start_ARG d L end_ARG start_ARG d kitalic_t end_ARG,dℒdγtdℒdsubscript dLd _tdivide start_ARG d L end_ARG start_ARG d γitalic_t end_ARG, dℒdqtdℒdsubscript dLdq_tdivide start_ARG d L end_ARG start_ARG d qitalic_t end_ARG and dℒdλdℒd dLdλdivide start_ARG d L end_ARG start_ARG d λ end_ARG, which can be achieved via the chain rule. While using automatic differentiation out of the box would take care of this computation, it would require in particular the storing of all intermediate variables RtsubscriptR_tRitalic_t, which can be prohibitively expensive. Memory efficient custom backward pass. Instead, we will show that storing the matrices K,Γ,QΓK, ,QK , Γ , Q as well as RTsubscriptR_TRitalic_T where T is the last time step of the training sequence, is sufficient to exactly compute the backward pass. Indeed, given the aforementioned inputs, all RtsubscriptR_tRitalic_t can be recomputed in linear complexity w.r.t. T, which means we can reconstruct recursively the inputs of q~tsubscript~ q_tover~ start_ARG q end_ARGt at all time steps. By noticing that Rt−1=γt(Rt−1−ktkt⊤)−1subscript1subscriptsuperscriptsuperscriptsubscript1subscriptsuperscriptsubscripttop1R_t-1= _t(R_t^-1-k_tk_t )^-1Ritalic_t - 1 = γitalic_t ( Ritalic_t- 1 - kitalic_t kitalic_t⊤ )- 1, we can apply the Sherman-Morrison formula backwards to obtain Rt−1subscript1R_t-1Ritalic_t - 1 as Rt−1subscript1 R_t-1Ritalic_t - 1 =γt(Rt−Rt(−kt)kt⊤Rt1+(−kt)⊤Rtkt)absentsubscriptsubscriptsubscriptsubscriptsuperscriptsubscripttopsubscript1superscriptsubscripttopsubscriptsubscript = _t (R_t- R_t(-k_t)k_t R_t1+(-% k_t) R_tk_t )= γitalic_t ( Ritalic_t - divide start_ARG Ritalic_t ( - kitalic_t ) kitalic_t⊤ Ritalic_t end_ARG start_ARG 1 + ( - kitalic_t )⊤ Ritalic_t kitalic_t end_ARG ) (23) =γt(Rt−Rtktkt⊤Rtkt⊤Rtkt−1)absentsubscriptsubscriptsubscriptsubscriptsuperscriptsubscripttopsubscriptsuperscriptsubscripttopsubscriptsubscript1 = _t (R_t- R_tk_tk_t R_tk_t^% R_tk_t-1 )= γitalic_t ( Ritalic_t - divide start_ARG Ritalic_t kitalic_t kitalic_t⊤ Ritalic_t end_ARG start_ARG kitalic_t⊤ Ritalic_t kitalic_t - 1 end_ARG ) (24) We will now show how accumulating the right error signal and leveraging the vector-jacobian product trick together with automatic differentiation tools is sufficient for computing the full backward pass recursively. Firstly, given the error signal and reconstructed RtsubscriptR_tRitalic_t allows the computation of dℒdqtdℒdsubscript dLdq_tdivide start_ARG d L end_ARG start_ARG d qitalic_t end_ARG via dℒdqt=dℒdq~tdq~tdqt=dℒdq~tStdℒdsubscriptdℒdsubscript~dsubscript~dsubscriptdℒdsubscript~subscript dLdq_t= dL% d q_t d q_tdq_t= % dLd q_tS_tdivide start_ARG d L end_ARG start_ARG d qitalic_t end_ARG = divide start_ARG d L end_ARG start_ARG d over~ start_ARG q end_ARGt end_ARG divide start_ARG d over~ start_ARG q end_ARGt end_ARG start_ARG d qitalic_t end_ARG = divide start_ARG d L end_ARG start_ARG d over~ start_ARG q end_ARGt end_ARG Sitalic_t (25) Secondly, we rewrite q~tsubscript~ q_tover~ start_ARG q end_ARGt as a function of kt,γt,Rt−1subscriptsubscriptsubscript1k_t, _t,R_t-1kitalic_t , γitalic_t , Ritalic_t - 1 and qtsubscriptq_tqitalic_t, i.e. q~t=ℛforward(Rt−1,kt,γt)qtsubscript~superscriptℛforwardsubscript1subscriptsubscriptsubscript q_t=R^forward(R_t-1,k_t, _t)q_tover~ start_ARG q end_ARGt = Rforward ( Ritalic_t - 1 , kitalic_t , γitalic_t ) qitalic_t (26) Since ℒLL depends on ktsubscriptk_tkitalic_t only via both q~tsubscript~ q_tover~ start_ARG q end_ARGt and RtsubscriptR_tRitalic_t, we can then rewrite dℒdktdℒdsubscript dLdk_tdivide start_ARG d L end_ARG start_ARG d kitalic_t end_ARG =dℒdq~tdq~tdkt+dℒdRtdRtdktabsentdℒdsubscript~dsubscript~dsubscriptdℒdsubscriptdsubscriptdsubscript = dLd q_t % d q_tdk_t+ dL % dR_t dR_tdk_t= divide start_ARG d L end_ARG start_ARG d over~ start_ARG q end_ARGt end_ARG divide start_ARG d over~ start_ARG q end_ARGt end_ARG start_ARG d kitalic_t end_ARG + divide start_ARG d L end_ARG start_ARG d Ritalic_t end_ARG divide start_ARG d Ritalic_t end_ARG start_ARG d kitalic_t end_ARG (27) =dℒdq~t∂q~t∂kt+dℒdRt∂Rt∂ktabsentdℒdsubscript~subscript~subscriptdℒdsubscriptsubscriptsubscript = dLd q_t % ∂ q_t∂ k_t+ dLd% R_t ∂R_t∂ k_t= divide start_ARG d L end_ARG start_ARG d over~ start_ARG q end_ARGt end_ARG divide start_ARG ∂ over~ start_ARG q end_ARGt end_ARG start_ARG ∂ kitalic_t end_ARG + divide start_ARG d L end_ARG start_ARG d Ritalic_t end_ARG divide start_ARG ∂ Ritalic_t end_ARG start_ARG ∂ kitalic_t end_ARG (28) where, provided Rt−1,kt,γtsubscript1subscriptsubscriptR_t-1,k_t, _tRitalic_t - 1 , kitalic_t , γitalic_t and qtsubscriptq_tqitalic_t, ∂q~t∂ktsubscript~subscript ∂ q_t∂ k_tdivide start_ARG ∂ over~ start_ARG q end_ARGt end_ARG start_ARG ∂ kitalic_t end_ARG can be computed easily using e.g. automatic differentiation tools. Similarly, we have, dℒdγtdℒdsubscript dLd _tdivide start_ARG d L end_ARG start_ARG d γitalic_t end_ARG =dℒdq~t∂q~t∂γt+dℒdRt∂Rt∂γtabsentdℒdsubscript~subscript~subscriptdℒdsubscriptsubscriptsubscript = dLd q_t % ∂ q_t∂ _t+ dL% dR_t ∂R_t∂ _t= divide start_ARG d L end_ARG start_ARG d over~ start_ARG q end_ARGt end_ARG divide start_ARG ∂ over~ start_ARG q end_ARGt end_ARG start_ARG ∂ γitalic_t end_ARG + divide start_ARG d L end_ARG start_ARG d Ritalic_t end_ARG divide start_ARG ∂ Ritalic_t end_ARG start_ARG ∂ γitalic_t end_ARG (29) Notice that dℒdRtdℒdsubscript dLdR_tdivide start_ARG d L end_ARG start_ARG d Ritalic_t end_ARG can be computed recursively following the chain rule dℒdRt−1dℒdsubscript1 dLdR_t-1divide start_ARG d L end_ARG start_ARG d Ritalic_t - 1 end_ARG =dℒdRt∂Rt∂Rt−1+dℒdq~t∂q~t∂Rt−1absentdℒdsubscriptsubscriptsubscript1dℒdsubscript~subscript~subscript1 = dLdR_t ∂ R_% t∂ R_t-1+ dLd q_t% ∂ q_t∂ R_t-1= divide start_ARG d L end_ARG start_ARG d Ritalic_t end_ARG divide start_ARG ∂ Ritalic_t end_ARG start_ARG ∂ Ritalic_t - 1 end_ARG + divide start_ARG d L end_ARG start_ARG d over~ start_ARG q end_ARGt end_ARG divide start_ARG ∂ over~ start_ARG q end_ARGt end_ARG start_ARG ∂ Ritalic_t - 1 end_ARG (30) where again, provided Rt−1,kt,γtsubscript1subscriptsubscriptR_t-1,k_t, _tRitalic_t - 1 , kitalic_t , γitalic_t and qtsubscriptq_tqitalic_t, both terms can be computed efficiently with standard automatic differentiation tools coupled with the well known vector-Jacobian product trick given the quantities dℒdRtdℒdsubscript dLdR_tdivide start_ARG d L end_ARG start_ARG d Ritalic_t end_ARG and dℒdq~tdℒdsubscript~ dLd q_tdivide start_ARG d L end_ARG start_ARG d over~ start_ARG q end_ARGt end_ARG. Thirdly, we can show that dℒdλ=Tr[dℒdR0]dℒdTrdelimited-[]dℒdsubscript0 dLdλ=Tr [ % dLdR_0 ]divide start_ARG d L end_ARG start_ARG d λ end_ARG = Tr [ divide start_ARG d L end_ARG start_ARG d R0 end_ARG ] (31) Combining everything, we can now implement the backward computation recursively via the following equations: Rt−1subscript1 R_t-1Ritalic_t - 1 =γt(Rt−Rtktkt⊤Rtkt⊤Rtkt−1)absentsubscriptsubscriptsubscriptsubscriptsuperscriptsubscripttopsubscriptsuperscriptsubscripttopsubscriptsubscript1 = _t (R_t- R_tk_tk_t R_tk_t^% R_tk_t-1 )= γitalic_t ( Ritalic_t - divide start_ARG Ritalic_t kitalic_t kitalic_t⊤ Ritalic_t end_ARG start_ARG kitalic_t⊤ Ritalic_t kitalic_t - 1 end_ARG ) (32) dℒdRt−1dℒdsubscript1 dLdR_t-1divide start_ARG d L end_ARG start_ARG d Ritalic_t - 1 end_ARG =dℒdRt∂Rt∂Rt−1+∂ℒ∂q~t∂q~t∂Rt−1absentdℒdsubscriptsubscriptsubscript1ℒsubscript~subscript~subscript1 = dLdR_t ∂ R_t% ∂ R_t-1+ ∂ q_t % ∂ q_t∂ R_t-1= divide start_ARG d L end_ARG start_ARG d Ritalic_t end_ARG divide start_ARG ∂ Ritalic_t end_ARG start_ARG ∂ Ritalic_t - 1 end_ARG + divide start_ARG ∂ L end_ARG start_ARG ∂ over~ start_ARG q end_ARGt end_ARG divide start_ARG ∂ over~ start_ARG q end_ARGt end_ARG start_ARG ∂ Ritalic_t - 1 end_ARG (33) dℒdktdℒdsubscript dLdk_tdivide start_ARG d L end_ARG start_ARG d kitalic_t end_ARG =dℒdq~t∂q~t∂kt+dℒdRt∂Rt∂ktabsentdℒdsubscript~subscript~subscriptdℒdsubscriptsubscriptsubscript = dLd q_t % ∂ q_t∂ k_t+ dLd% R_t ∂R_t∂ k_t= divide start_ARG d L end_ARG start_ARG d over~ start_ARG q end_ARGt end_ARG divide start_ARG ∂ over~ start_ARG q end_ARGt end_ARG start_ARG ∂ kitalic_t end_ARG + divide start_ARG d L end_ARG start_ARG d Ritalic_t end_ARG divide start_ARG ∂ Ritalic_t end_ARG start_ARG ∂ kitalic_t end_ARG (34) dℒdγtdℒdsubscript dLd _tdivide start_ARG d L end_ARG start_ARG d γitalic_t end_ARG =dℒdq~t∂q~t∂γt+dℒdRt∂Rt∂γtabsentdℒdsubscript~subscript~subscriptdℒdsubscriptsubscriptsubscript = dLd q_t % ∂ q_t∂ _t+ dL% dR_t ∂R_t∂ _t= divide start_ARG d L end_ARG start_ARG d over~ start_ARG q end_ARGt end_ARG divide start_ARG ∂ over~ start_ARG q end_ARGt end_ARG start_ARG ∂ γitalic_t end_ARG + divide start_ARG d L end_ARG start_ARG d Ritalic_t end_ARG divide start_ARG ∂ Ritalic_t end_ARG start_ARG ∂ γitalic_t end_ARG (35) dℒdqtdℒdsubscript dLdq_tdivide start_ARG d L end_ARG start_ARG d qitalic_t end_ARG =dℒdq~tRtabsentdℒdsubscript~subscript = dLd q_tR_t= divide start_ARG d L end_ARG start_ARG d over~ start_ARG q end_ARGt end_ARG Ritalic_t (36) dℒdλdℒd dLdλdivide start_ARG d L end_ARG start_ARG d λ end_ARG =Tr[dℒdR0]absentTrdelimited-[]dℒdsubscript0 =Tr [ dLdR_0 ]= Tr [ divide start_ARG d L end_ARG start_ARG d R0 end_ARG ] (37) RTsubscriptR_TRitalic_T is assumed to be given and dℒdRT=0dℒdsubscript0 dLdR_T=0divide start_ARG d L end_ARG start_ARG d Ritalic_T end_ARG = 0. The above equations only require the storage of dℒdRt,dℒdRt−1,Rt,Rt−1dℒdsubscriptdℒdsubscript1subscriptsubscript1 dLdR_t, dL% dR_t-1,R_t,R_t-1divide start_ARG d L end_ARG start_ARG d Ritalic_t end_ARG , divide start_ARG d L end_ARG start_ARG d Ritalic_t - 1 end_ARG , Ritalic_t , Ritalic_t - 1 at all time, and computes the backward pass in a similar time and memory complexity as for the forward pass. The derivation is identical without forgetting factors, by setting all γ to 1111. Comment on runtime. We highlight that, although this implementation of the mesa-layer reduces the memory footprint of the forward and backward pass substantially, the layer still runs forward (and backward) in time. This prevents the computation of all mesa-layer outputs in parallelization during training, a crucial advantage of softmax as well as linear attention. On the other hand, during test time, the mesa-layer benefits from the same advantages of linear self-attention or RNNs and predicts the next token without the necessity to store and attend to the past. In the next sections, we present two potential avenue to improve the training time by a solution based in linear solvers or by a solution approximating the necessary inversions by a Neumann series running in parallel. C.2 Alternative derivation through the implicit function theorem We here present an alternative way of deriving the gradients presented above that leverages the implicit function theorem. The key here is to remark that Φ^tmesasuperscriptsubscript^Φmesa _t^mesaover start_ARG Φ end_ARGtroman_mesa satisfies that the gradient of the least-square regression loss L is 0. For simplicity, we restrict ourselves to the case in which the output dimension of Φ^tmesasuperscriptsubscript^Φmesa _t^mesaover start_ARG Φ end_ARGtroman_mesa is one, that is Φ^tmesa=ϕ^t⊤superscriptsubscript^Φmesasuperscriptsubscript^italic-ϕtop _t^mesa= φ_t over start_ARG Φ end_ARGtroman_mesa = over start_ARG ϕ end_ARGt⊤ for ϕ^tsubscript^italic-ϕ φ_tover start_ARG ϕ end_ARGt some column vector, and remark that we have to repeat the same operation over all rows of Φ^tmesasuperscriptsubscript^Φmesa _t^mesaover start_ARG Φ end_ARGtroman_mesa to obtain the full gradient, as all output coordinates are independent in the least-square regression problem. Therefore, we w defined through the implicit function dLdϕ(ϕ^t)=∑t′=1tMt′,t(ϕ^t⊤kt′−vt′)kt′⊤+M1,tλϕ^t⊤=0.dditalic-ϕsubscript^italic-ϕsuperscriptsubscriptsuperscript′1subscriptsuperscript′subscript^italic-ϕtopsubscriptsuperscript′subscriptsuperscript′subscriptsuperscript′topsubscript1superscriptsubscript^italic-ϕtop0 dLdφ( φ_t)= _t =1^tM_t% ,t( φ_t k_t -v_t )k_t % + M_1,tλ φ_t =0.divide start_ARG d L end_ARG start_ARG d ϕ end_ARG ( over start_ARG ϕ end_ARGt ) = ∑t′ = 1t Mitalic_t′ , t ( over start_ARG ϕ end_ARGt⊤ kitalic_t′ - vitalic_t′ ) kitalic_t′⊤ + divide start_ARG M1 , t end_ARG start_ARG λ end_ARG over start_ARG ϕ end_ARGt⊤ = 0 . (38) We can then use the implicit function theorem and compute the derivative of w with respect to any quantity ⋅·⋅ through dϕ^td⋅ d φ_td \ ·divide start_ARG d over start_ARG ϕ end_ARGt end_ARG start_ARG d ⋅ end_ARG =−(d2Ltdϕ2(ϕt))−1d2Lt(ϕ^t)d⋅dϕabsentsuperscriptsuperscriptd2subscriptdsuperscriptitalic-ϕ2subscriptitalic-ϕ1superscriptd2subscriptsubscript^italic-ϕ⋅dditalic-ϕ =- ( d^2L_tdφ^2( _t)% )^\!\!-1 d^2L_t( φ_t)d·% dφ= - ( divide start_ARG d2 Litalic_t end_ARG start_ARG d ϕ2 end_ARG ( ϕitalic_t ) )- 1 divide start_ARG d2 Litalic_t ( over start_ARG ϕ end_ARGt ) end_ARG start_ARG d ⋅ d ϕ end_ARG (39) =−Rtd2Lt(ϕ^t)d⋅dϕ.absentsubscriptsuperscriptd2subscriptsubscript^italic-ϕ⋅dditalic-ϕ =-R_t d^2L_t( φ_t)d·% dφ.= - Ritalic_t divide start_ARG d2 Litalic_t ( over start_ARG ϕ end_ARGt ) end_ARG start_ARG d ⋅ d ϕ end_ARG . (40) For example, this yields dϕ^tdvt′=Mt′,tRtkt′.dsubscript^italic-ϕdsubscriptsuperscript′subscriptsuperscript′subscriptsubscriptsuperscript′ d φ_tdv_t =M_t ,tR_t% k_t .divide start_ARG d over start_ARG ϕ end_ARGt end_ARG start_ARG d vitalic_t′ end_ARG = Mitalic_t′ , t Ritalic_t kitalic_t′ . (41) Finally, we can recover the desired gradient by combining the previous equation with the chain rule. C.3 Parallel backward pass through Neumann series approximation Although the previous custom backward gradient computation allows for dramatic memory savings during training, the underlying recursive least squares computation still suffers from linear scaling in time, similar to recurrent neural networks, as we cannot parallelize computation across time dimension. Here, we discuss an alternative forward pass that can be used when one can afford storing all intermediate matrices Rh,tsubscriptℎR_h,tRitalic_h , t in time. This forward pass leverages a K-step truncated Neumann series to approximate the inverses in parallel, and is compatible with automatic differentiation tools out of the box. Interestingly, we can do this by simply repeating (with the same weights) a slightly altered linear self-attention layer K times. Our goal is now to efficiently compute the terms q~t:=Rtqt=(KtKt⊤+1λI)−1qtassignsubscript~subscriptsubscriptsuperscriptsubscriptsuperscriptsubscripttop11subscript q_t:=R_tq_t=(K_tK_t + 1λI)^-1q_tover~ start_ARG q end_ARGt := Ritalic_t qitalic_t = ( Kitalic_t Kitalic_t⊤ + divide start_ARG 1 end_ARG start_ARG λ end_ARG I )- 1 qitalic_t for all time steps in parallel. Indeed, once give these vectors, one can leverage Equation 20 and efficient dot-product attention (DPA) layers implementations111See https://flax.readthedocs.io/en/latest/_modules/flax/linen/attention.html for an implementation of DPA in JAX [97].. Note that we here ignore the forgetting factors, but their partial products can easily be integrated in one of the KtsubscriptK_tKitalic_t in KtKt⊤subscriptsuperscriptsubscripttopK_tK_t Kitalic_t Kitalic_t⊤ to recover the version with forget rates described above. Given an invertible matrix X with operator norm less than 1111, the truncated Neumann series approximates its inverse by X−1≈X~(K)−1:=∑k=0K(I−X)k.superscript1subscriptsuperscript~1assignsuperscriptsubscript0superscriptX^-1≈ X^-1_(K):= _k=0^K(I-X)^k.X- 1 ≈ over~ start_ARG X end_ARG- 1( K ) := ∑k = 0K ( I - X )k . (42) When multiplying a vector from the right, we see that x~(K):=X~(K)−1xassignsuperscript~subscriptsuperscript~1 x^(K):= X^-1_(K)xover~ start_ARG x end_ARG( K ) := over~ start_ARG X end_ARG- 1( K ) x =∑k=0K(I−X)kxabsentsuperscriptsubscript0superscript = _k=0^K(I-X)^kx= ∑k = 0K ( I - X )k x (43) =∑k=1K(I−X)kx+xabsentsuperscriptsubscript1superscript = _k=1^K(I-X)^kx+x= ∑k = 1K ( I - X )k x + x (44) =(I−X)∑k=0K−1(I−X)kx+xabsentsuperscriptsubscript01superscript =(I-X) _k=0^K-1(I-X)^kx+x= ( I - X ) ∑k = 0K - 1 ( I - X )k x + x (45) =(I−X)x~(K−1)+xabsentsuperscript~1 =(I-X) x^(K-1)+x= ( I - X ) over~ start_ARG x end_ARG( K - 1 ) + x (46) An advantage of the truncated Neumann series compared to other approximate inverse techniques such as Newton-Iteration is that we can compute more series elements without passing intermediate matrices across algorithmic steps – which in turn makes it memory efficient and straightforward to use in the light of automatic differentiation. We only need to keep the original matrix we wish to invert in memory at all times and store the intermediate vectors x~(k)superscript~ x^(k)over~ start_ARG x end_ARG( k ) for the backward pass. We now look at the quantities we wish to compute, that is q~t=(KtKt⊤+1λI)−1qtsubscript~superscriptsubscriptsuperscriptsubscripttop11subscript q_t=(K_tK_t + 1λI)^-1q_tover~ start_ARG q end_ARGt = ( Kitalic_t Kitalic_t⊤ + divide start_ARG 1 end_ARG start_ARG λ end_ARG I )- 1 qitalic_t, and approximate it by q~t(K)superscriptsubscript~ q_t^(K)over~ start_ARG q end_ARGt( K ), obtained by multiplying qtsubscriptq_tqitalic_t to the K-step truncated Neumann series approximating the inverse term (KtKt⊤+1λI)−1superscriptsubscriptsuperscriptsubscripttop11(K_tK_t + 1λI)^-1( Kitalic_t Kitalic_t⊤ + divide start_ARG 1 end_ARG start_ARG λ end_ARG I )- 1. Note that a normalization by the operator norm of the matrix inside the inverse is necessary for the approximation to hold. Then, q~t(K)superscriptsubscript~ q_t^(K)over~ start_ARG q end_ARGt( K ) can be computed recursively as q~t(k+1)superscriptsubscript~1 q_t^(k+1)over~ start_ARG q end_ARGt( k + 1 ) =(I−(KtKt⊤+1λI))q~t(k)+qtabsentsubscriptsuperscriptsubscripttop1superscriptsubscript~subscript = (I- (K_tK_t + 1λI ) )% q_t^(k)+q_t= ( I - ( Kitalic_t Kitalic_t⊤ + divide start_ARG 1 end_ARG start_ARG λ end_ARG I ) ) over~ start_ARG q end_ARGt( k ) + qitalic_t (47) =qt+(1−1λ)q~t(k)−KtKt⊤q~t(k)absentsubscript11subscriptsuperscript~subscriptsuperscriptsubscripttopsuperscriptsubscript~ =q_t+ (1- 1λ ) q^(k)_t-K_tK% _t q_t^(k)= qitalic_t + ( 1 - divide start_ARG 1 end_ARG start_ARG λ end_ARG ) over~ start_ARG q end_ARG( k )t - Kitalic_t Kitalic_t⊤ over~ start_ARG q end_ARGt( k ) (48) and thus by denoting Q~t(k):=(q~t′(k))t′=1tassignsuperscriptsubscript~superscriptsubscriptsuperscriptsubscript~superscript′1 Q_t^(k):=( q_t ^(k))_t =1^tover~ start_ARG Q end_ARGt( k ) := ( over~ start_ARG q end_ARGt′( k ) )t′ = 1t, we have Q~k+1(k+1)superscriptsubscript~11 Q_k+1^(k+1)over~ start_ARG Q end_ARGk + 1( k + 1 ) =Qt+(1−1λ)Q~t(k)−KtKt⊤Q~t(k)absentsubscript11superscriptsubscript~subscriptsuperscriptsubscripttopsuperscriptsubscript~ =Q_t+ (1- 1λ ) Q_t^(k)-K_tK% _t Q_t^(k)= Qitalic_t + ( 1 - divide start_ARG 1 end_ARG start_ARG λ end_ARG ) over~ start_ARG Q end_ARGt( k ) - Kitalic_t Kitalic_t⊤ over~ start_ARG Q end_ARGt( k ) (49) which is the sum of simple terms with a DPA computed between Kt,Kt,Q~t(k)subscriptsubscriptsuperscriptsubscript~K_t,K_t, Q_t^(k)Kitalic_t , Kitalic_t , over~ start_ARG Q end_ARGt( k ). After obtaining Q~t(K)superscriptsubscript~ Q_t^(K)over~ start_ARG Q end_ARGt( K ) to approximate Q~tsubscript~ Q_tover~ start_ARG Q end_ARGt, we compute the approximate least-squares solution as described above. Note that other implementations could save us from effectively recomputing (KtKt⊤)subscriptsuperscriptsubscripttop(K_tK_t )( Kitalic_t Kitalic_t⊤ ) at every iteration of Equation 49 by simply pre-computing these terms before running the Neumann approximation. We nevertheless observe the former version to be faster when timing for forward and backward computation and speculate the reason being the highly optimized implementation of DPA as the backbone of the self-attention layer. Note that a simple byproduct of the derivations here is the insight that chaining linear self-attention layers can actually easily implement truncated Neumann series computation – especially if the goal is an inverse multiplied by a known vector. See materials and methods section of the main text for an in-depth analysis. Appendix D Probabilistic latent-state inference in Transformers In this section, we generalize our results on latent-state inference in partially-observed deterministic linear systems towards noisy linear systems. Our aim is to show that the optimal maximum-likelihood estimator (MLE) of the next observation st+1subscript1s_t+1sitalic_t + 1 is a linear map of the concatenated previous observations, possibly encoded into a lower-dimensional subspace by a linear encoder. First, we show that in the Gaussian noise setting, the MLE of st+1subscript1s_t+1sitalic_t + 1 is a linear map of the MLE of the latent state ht+1subscriptℎ1h_t+1hitalic_t + 1. Second, we show that the MLE of the latent state ht+1subscriptℎ1h_t+1hitalic_t + 1 is a linear map of a concatenation of the previous k observations. Finally, we generalize our setting to allow for a linear encoding of all the previous observations into a fixed low-dimensional subspace, instead of explicitly concatenating k observations. Taken together, these results show that performing least-squares linear regression on tokens that encode or concatenate previous observations is an optimal strategy for predicting the next observation according to the maximum-likelihood estimator. D.1 The MLE of st+1subscript1s_t+1sitalic_t + 1 is a linear map of the MLE of ht+1subscriptℎ1h_t+1hitalic_t + 1 As we consider linear dynamics with additive Gaussian noise, the distributions p(st+1∣ztk)conditionalsubscript1superscriptsubscriptp(s_t+1 z_t^k)p ( sitalic_t + 1 ∣ zitalic_titalic_k ) and p(st+1,ht+1∣ztk)subscript1conditionalsubscriptℎ1superscriptsubscriptp(s_t+1,h_t+1 z_t^k)p ( sitalic_t + 1 , hitalic_t + 1 ∣ zitalic_titalic_k ) are multivariate Gaussians. Let us now consider the MLE estimators of the marginal p(st+1∣ztk)conditionalsubscript1superscriptsubscriptp(s_t+1 z_t^k)p ( sitalic_t + 1 ∣ zitalic_titalic_k ) and joint distribution p(st+1,ht+1∣ztk)subscript1conditionalsubscriptℎ1superscriptsubscriptp(s_t+1,h_t+1 z_t^k)p ( sitalic_t + 1 , hitalic_t + 1 ∣ zitalic_titalic_k ). s^t+1marginalsubscriptsuperscript^marginal1 s^marginal_t+1over start_ARG s end_ARGmarginalt + 1 =argmaxst+1p(st+1∣ztk)absentsubscriptargmaxsubscript1conditionalsubscript1superscriptsubscript = *arg\,max_s_t+1p(s_t+1 z_t^k)= start_OPERATOR arg max end_OPERATORs start_POSTSUBSCRIPT t + 1 end_POSTSUBSCRIPT p ( sitalic_t + 1 ∣ zitalic_titalic_k ) s^t+1joint,h^t+1jointsubscriptsuperscript^joint1subscriptsuperscript^ℎjoint1 s^joint_t+1, h^joint_t+1over start_ARG s end_ARGjointt + 1 , over start_ARG h end_ARGjointt + 1 =argmaxst+1,ht+1p(st+1∣ztk)p(ht+1∣st+1,ztk)absentsubscriptargmaxsubscript1subscriptℎ1conditionalsubscript1superscriptsubscriptconditionalsubscriptℎ1subscript1superscriptsubscript = *arg\,max_s_t+1,h_t+1p(s_t+1 z_t^k% )p(h_t+1 s_t+1,z_t^k)= start_OPERATOR arg max end_OPERATORs start_POSTSUBSCRIPT t + 1 , hitalic_t + 1 end_POSTSUBSCRIPT p ( sitalic_t + 1 ∣ zitalic_titalic_k ) p ( hitalic_t + 1 ∣ sitalic_t + 1 , zitalic_titalic_k ) p(ht+1∣st+1,ztk)conditionalsubscriptℎ1subscript1superscriptsubscriptp(h_t+1 s_t+1,z_t^k)p ( hitalic_t + 1 ∣ sitalic_t + 1 , zitalic_titalic_k ) is Gaussian, as conditional distributions of jointly distributed Gaussian variables are also Gaussian. Furthermore, the covariance of a Gaussian conditional distribution only depends on the covariance of the joint distribution, not on the specific value of the conditioned variable st+1subscript1s_t+1sitalic_t + 1. Hence, the maximum (not the argmaxargmax *arg\,maxarg max) of p(ht+1∣st+1,ztk)conditionalsubscriptℎ1subscript1superscriptsubscriptp(h_t+1 s_t+1,z_t^k)p ( hitalic_t + 1 ∣ sitalic_t + 1 , zitalic_titalic_k ) is independent from st+1subscript1s_t+1sitalic_t + 1, and we hence have that the MLE s^t+1marginalsubscriptsuperscript^marginal1 s^marginal_t+1over start_ARG s end_ARGmarginalt + 1 is equal to s^t+1jointsubscriptsuperscript^joint1 s^joint_t+1over start_ARG s end_ARGjointt + 1. Rewriting the joint distribution as p(ht+1∣ztk)p(st+1∣ht+1,ztk)conditionalsubscriptℎ1superscriptsubscriptconditionalsubscript1subscriptℎ1superscriptsubscriptp(h_t+1 z_t^k)p(s_t+1 h_t+1,z_t^k)p ( hitalic_t + 1 ∣ zitalic_titalic_k ) p ( sitalic_t + 1 ∣ hitalic_t + 1 , zitalic_titalic_k ), and repeating the same arguments, we have that s^t+1marginal=C∗h^t+1joint=C∗h^t+1marginalsubscriptsuperscript^marginal1superscriptsubscriptsuperscript^ℎjoint1superscriptsubscriptsuperscript^ℎmarginal1 s^marginal_t+1=C^* h^joint_t+1% =C^* h^marginal_t+1over start_ARG s end_ARGmarginalt + 1 = C∗ over start_ARG h end_ARGjointt + 1 = C∗ over start_ARG h end_ARGmarginalt + 1 with h^t+1marginalsubscriptsuperscript^ℎmarginal1 h^marginal_t+1over start_ARG h end_ARGmarginalt + 1 the MLE of p(ht+1∣ztk)conditionalsubscriptℎ1superscriptsubscriptp(h_t+1 z_t^k)p ( hitalic_t + 1 ∣ zitalic_titalic_k ). Hence, the MLE of st+1subscript1s_t+1sitalic_t + 1 is a linear map of the MLE of the latent state ht+1subscriptℎ1h_t+1hitalic_t + 1. D.2 The MLE of ht+1subscriptℎ1h_t+1hitalic_t + 1 is a linear map of ztksuperscriptsubscriptz_t^kzitalic_titalic_k Now we turn our focus on showing that h^t+1MLE=argmaxht+1p(ht+1∣ztk)superscriptsubscript^ℎ1MLEsubscriptargmaxsubscriptℎ1conditionalsubscriptℎ1superscriptsubscript h_t+1^MLE= *arg\,max_h_t+1p(h_t+1 z_% t^k)over start_ARG h end_ARGt + 1MLE = start_OPERATOR arg max end_OPERATORh start_POSTSUBSCRIPT t + 1 end_POSTSUBSCRIPT p ( hitalic_t + 1 ∣ zitalic_titalic_k ) as a linear map of ztksuperscriptsubscriptz_t^kzitalic_titalic_k. First, by similar arguments as before, we have that h^t+1MLE=Ah^tMLEsuperscriptsubscript^ℎ1MLEsuperscriptsubscript^ℎMLE h_t+1^MLE=A h_t^MLEover start_ARG h end_ARGt + 1MLE = A over start_ARG h end_ARGtMLE, with h^tMLE=argmaxp(ht∣ztk)superscriptsubscript^ℎMLEargmaxconditionalsubscriptℎsuperscriptsubscript h_t^MLE= *arg\,maxp(h_t z_t^k)over start_ARG h end_ARGtMLE = start_OPERATOR arg max end_OPERATOR p ( hitalic_t ∣ zitalic_titalic_k ). In the following, we show that h^tMLEsuperscriptsubscript^ℎMLE h_t^MLEover start_ARG h end_ARGtMLE is a linear map of ztksuperscriptsubscriptz_t^kzitalic_titalic_k, thereby completing our goal of this section. Running the noisy dynamics backwards gives us ht−1=W∗−1(ht−ϵh,t−1)subscriptℎ1superscriptabsent1subscriptℎsubscriptitalic-ϵℎ1h_t-1=W^*-1(h_t- _h,t-1)hitalic_t - 1 = W∗ - 1 ( hitalic_t - ϵitalic_h , t - 1 ). Repeating this k times gives us ztksuperscriptsubscript z_t^kzitalic_titalic_k =vtk+[C∗W∗−(k−1)⋮C∗W∗−1C∗]ht−[C∗W∗−(k−1)⋮C∗W∗−10]ϵh,t−1−[C∗W∗−(k−2)⋮C∗W∗−100]ϵh,t−2−…absentsuperscriptsubscriptmatrixsuperscriptsuperscriptsuperscript1⋮superscriptsuperscriptabsent1superscriptsubscriptℎmatrixsuperscriptsuperscriptabsent1⋮superscriptsuperscriptabsent10subscriptitalic-ϵℎ1matrixsuperscriptsuperscriptabsent2⋮superscriptsuperscriptabsent100subscriptitalic-ϵℎ2… =v_t^k+ bmatrixC^*W^*^-(k-1)\\ \\ C^*W^*-1\\ C^* bmatrixh_t- bmatrixC^*W^*-(k-1)\\ \\ C^*W^*-1\\ 0 bmatrix _h,t-1- bmatrixC^*W^*-(k-2)\\ \\ C^*W^*-1\\ 0\\ 0 bmatrix _h,t-2-…= vitalic_titalic_k + [ start_ARG start_ROW start_CELL C∗ W∗- ( k - 1 ) end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL C∗ W∗ - 1 end_CELL end_ROW start_ROW start_CELL C∗ end_CELL end_ROW end_ARG ] hitalic_t - [ start_ARG start_ROW start_CELL C∗ W∗ - ( k - 1 ) end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL C∗ W∗ - 1 end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ] ϵitalic_h , t - 1 - [ start_ARG start_ROW start_CELL C∗ W∗ - ( k - 2 ) end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL C∗ W∗ - 1 end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ] ϵitalic_h , t - 2 - … (50) =vtk+ℱkht−ℱk1ϵh,t−1−ℱk2ϵh,t−2−…absentsuperscriptsubscriptsubscriptℱsubscriptℎsuperscriptsubscriptℱ1subscriptitalic-ϵℎ1superscriptsubscriptℱ2subscriptitalic-ϵℎ2… =v_t^k+F_kh_t-F_k^1 _h,t-% 1-F_k^2 _h,t-2-…= vitalic_titalic_k + Fitalic_k hitalic_t - Fitalic_k1 ϵitalic_h , t - 1 - Fitalic_k2 ϵitalic_h , t - 2 - … (51) =vtk+ℱkht−∑l=1k−1ℱklϵh,t−labsentsuperscriptsubscriptsubscriptℱsubscriptℎsuperscriptsubscript11superscriptsubscriptℱsubscriptitalic-ϵℎ =v_t^k+F_kh_t- _l=1^k-1F_k^% l _h,t-l= vitalic_titalic_k + Fitalic_k hitalic_t - ∑l = 1k - 1 Fitalic_kitalic_l ϵitalic_h , t - l (52) with vtksuperscriptsubscriptv_t^kvitalic_titalic_k the concatenated observation noise variables ϵs,tsubscriptitalic-ϵ _s,tϵitalic_s , t of the last k timesteps, and ℱklsuperscriptsubscriptℱF_k^lFitalic_kitalic_l shifted versions of the filter matrix ℱksubscriptℱF_kFitalic_k by inserting l zero blockmatrices from below: ℱk=[C∗W∗−(k−1)⋮C∗W∗−1C∗]subscriptℱmatrixsuperscriptsuperscriptabsent1⋮superscriptsuperscriptabsent1superscript _k= bmatrixC^*W^*-(k-1)\\ \\ C^*W^*-1\\ C^* bmatrixFitalic_k = [ start_ARG start_ROW start_CELL C∗ W∗ - ( k - 1 ) end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL C∗ W∗ - 1 end_CELL end_ROW start_ROW start_CELL C∗ end_CELL end_ROW end_ARG ] Now we want to extract the maximum-likelihood estimate of htsubscriptℎh_thitalic_t. We have that p(ht,vtk,ϵh,t−(k−1):t−1∣st−(k−1):t)=p(ht∣st−(k−1):t)p(vtk,ϵh,t−(k−1):t−1∣ht,st−(k−1):t)subscriptℎsuperscriptsubscriptconditionalsubscriptitalic-ϵ:ℎ11subscript:1conditionalsubscriptℎsubscript:1superscriptsubscriptconditionalsubscriptitalic-ϵ:ℎ11subscriptℎsubscript:1 p(h_t,v_t^k, _h,t-(k-1):t-1 s_t-(k-1):t)=p(% h_t s_t-(k-1):t)p(v_t^k, _h,t-(k-1):t-1 h_t,s_t-(k% -1):t)p ( hitalic_t , vitalic_titalic_k , ϵitalic_h , t - ( k - 1 ) : t - 1 ∣ sitalic_t - ( k - 1 ) : t ) = p ( hitalic_t ∣ sitalic_t - ( k - 1 ) : t ) p ( vitalic_titalic_k , ϵitalic_h , t - ( k - 1 ) : t - 1 ∣ hitalic_t , sitalic_t - ( k - 1 ) : t ) (53) Importantly, all variables are Gaussian, as we have linear dynamics and Gaussian noise. Due to the property of Gaussian conditional distribution conditioned before, the maximum of p(vtk,ϵh,t−(k−1):t−1∣ht,st−(k−1):t)superscriptsubscriptconditionalsubscriptitalic-ϵ:ℎ11subscriptℎsubscript:1p(v_t^k, _h,t-(k-1):t-1 h_t,s_t-(k-1):t)p ( vitalic_titalic_k , ϵitalic_h , t - ( k - 1 ) : t - 1 ∣ hitalic_t , sitalic_t - ( k - 1 ) : t ) only depends on the covariance matrix of the distribution and hence does not depend on the value of htsubscriptℎh_thitalic_t. Consequently, we have that the value of htsubscriptℎh_thitalic_t that maximizes p(ht,vtk,ϵh,t−(k−1):t−1∣st−(k−1):t)subscriptℎsuperscriptsubscriptconditionalsubscriptitalic-ϵ:ℎ11subscript:1p(h_t,v_t^k, _h,t-(k-1):t-1 s_t-(k-1):t)p ( hitalic_t , vitalic_titalic_k , ϵitalic_h , t - ( k - 1 ) : t - 1 ∣ sitalic_t - ( k - 1 ) : t ) is the same one that maximizes p(ht∣st−(k−1):t)conditionalsubscriptℎsubscript:1p(h_t s_t-(k-1):t)p ( hitalic_t ∣ sitalic_t - ( k - 1 ) : t ). This is convenient, as it is much more tractable to maximize the joint distribution w.r.t. htsubscriptℎh_thitalic_t and the noise variables, compared to maximizing the marginal distribution w.r.t. htsubscriptℎh_thitalic_t, for which we need to compute integrals. As the noise variables are Gaussian (with covariances which we assume to be equal to σIσ Iσ I for simplicity), maximizing the joint log-probability is equivalent to the following optimization problem: argminht,ϵh,t−1:t−k+1,vtk12σ2‖vtk‖2+12σ2∑l=1k−1‖ϵh,t−1‖2s.t.ztk=vtk+ℱkht−∑l=1k−1ℱklϵh,t−l.subscriptargminsubscriptℎsubscriptitalic-ϵ:ℎ11superscriptsubscript12superscript2superscriptnormsuperscriptsubscript212superscript2superscriptsubscript11superscriptnormsubscriptitalic-ϵℎ12s.t.superscriptsubscriptsuperscriptsubscriptsubscriptℱsubscriptℎsuperscriptsubscript11superscriptsubscriptℱsubscriptitalic-ϵℎ *arg\,min_h_t, _h,t-1:t-k+1,v_t^k% 12σ^2\|v_t^k\|^2+ 12σ^2 _l=1^k-1\|% _h,t-1\|^2 \ \ s.t% . \ \ z_t^k=v_t^k+F_k% h_t- _l=1^k-1F_k^l _h,t-l.start_OPERATOR arg min end_OPERATORh start_POSTSUBSCRIPT t , ϵitalic_h , t - 1 : t - k + 1 , vitalic_titalic_k end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG 2 σ2 end_ARG ∥ vitalic_titalic_k ∥2 + divide start_ARG 1 end_ARG start_ARG 2 σ2 end_ARG ∑l = 1k - 1 ∥ ϵitalic_h , t - 1 ∥2 s.t. zitalic_titalic_k = vitalic_titalic_k + Fitalic_k hitalic_t - ∑l = 1k - 1 Fitalic_kitalic_l ϵitalic_h , t - l . (54) We solve it with the Lagrange multiplier method: ℒ=12σ2‖vtk‖2+12σ2∑l=1k−1‖ϵh,t−1‖2+λ⊤(−ztk+vtk+ℱkht−∑l=1k−1ℱklϵh,t−l)ℒ12superscript2superscriptnormsuperscriptsubscript212superscript2superscriptsubscript11superscriptnormsubscriptitalic-ϵℎ12superscripttopsuperscriptsubscriptsuperscriptsubscriptsubscriptℱsubscriptℎsuperscriptsubscript11superscriptsubscriptℱsubscriptitalic-ϵℎ L= 12σ^2\|v_t^k\|^2+ 12σ^% 2 _l=1^k-1\| _h,t-1\|^2+λ (-z_t^k+v_% t^k+F_kh_t- _l=1^k-1F_k^l _h,t-% l )L = divide start_ARG 1 end_ARG start_ARG 2 σ2 end_ARG ∥ vitalic_titalic_k ∥2 + divide start_ARG 1 end_ARG start_ARG 2 σ2 end_ARG ∑l = 1k - 1 ∥ ϵitalic_h , t - 1 ∥2 + λ⊤ ( - zitalic_titalic_k + vitalic_titalic_k + Fitalic_k hitalic_t - ∑l = 1k - 1 Fitalic_kitalic_l ϵitalic_h , t - l ) (55) Taking the gradients of this Lagrangian and equating them to zero gives us the following linear system with knh+2knssubscriptℎ2subscriptkn_h+2kn_sk nitalic_h + 2 k nitalic_s equations and knh+2knssubscriptℎ2subscriptkn_h+2kn_sk nitalic_h + 2 k nitalic_s variables: ∇htℒsubscript∇subscriptℎℒ _h_t L∇h start_POSTSUBSCRIPT t end_POSTSUBSCRIPT L =ℱk⊤λ=0absentsuperscriptsubscriptℱtop0 =F_k λ=0= Fitalic_k⊤ λ = 0 (56) ∇ϵh,t−lℒsubscript∇subscriptitalic-ϵℎℒ _ _h,t-l L∇ϵ start_POSTSUBSCRIPT h , t - l end_POSTSUBSCRIPT L =ϵh,t−l−ℱkl⊤λ=0absentsubscriptitalic-ϵℎsuperscriptsubscriptℱlimit-fromtop0 = _h,t-l-F_k^l λ=0= ϵitalic_h , t - l - Fitalic_kitalic_l ⊤ λ = 0 (57) ∇vtkℒsubscript∇superscriptsubscriptℒ _v_t^k L∇v start_POSTSUBSCRIPT titalic_k end_POSTSUBSCRIPT L =vtk+λ=0absentsuperscriptsubscript0 =v_t^k+λ=0= vitalic_titalic_k + λ = 0 (58) ∇λℒsubscript∇ℒ _λ L∇λ L =ztk+vtk+ℱkht−∑l=1k−1ℱklϵh,t−l=0absentsuperscriptsubscriptsuperscriptsubscriptsubscriptℱsubscriptℎsuperscriptsubscript11superscriptsubscriptℱsubscriptitalic-ϵℎ0 =\-z_t^k+v_t^k+F_kh_t- _l=1^k-1% F_k^l _h,t-l=0= zitalic_titalic_k + vitalic_titalic_k + Fitalic_k hitalic_t - ∑l = 1k - 1 Fitalic_kitalic_l ϵitalic_h , t - l = 0 (59) We can structure this set of equations in a big matrix equation S[htϵh,t−1⋮ϵh,t−(k−1)vtkλ]=[00⋮0ztk]matrixsubscriptℎsubscriptitalic-ϵℎ1⋮subscriptitalic-ϵℎ1superscriptsubscriptmatrix00⋮0superscriptsubscript S bmatrixh_t\\ _h,t-1\\ \\ _h,t-(k-1)\\ v_t^k\\ λ bmatrix= bmatrix0\\ 0\\ \\ 0\\ z_t^k bmatrixS [ start_ARG start_ROW start_CELL hitalic_t end_CELL end_ROW start_ROW start_CELL ϵitalic_h , t - 1 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL ϵitalic_h , t - ( k - 1 ) end_CELL end_ROW start_ROW start_CELL vitalic_titalic_k end_CELL end_ROW start_ROW start_CELL λ end_CELL end_ROW end_ARG ] = [ start_ARG start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL zitalic_titalic_k end_CELL end_ROW end_ARG ] (60) Where S contains the terms of the equations that multiply with the variables, and the right-hand-side of the above equation contains all other terms (only ztksuperscriptsubscriptz_t^kzitalic_titalic_k in our case). We can solve this system by inverting S (assuming it is invertible). Now we can extract our maximum likelihood estimate of htsubscriptℎh_thitalic_t as h^t=[I0…0]S−1[00⋮0ztk]=[S−1]0,k+2ztksubscript^ℎmatrix0…0superscript1matrix00⋮0superscriptsubscriptsubscriptdelimited-[]superscript102superscriptsubscript h_t= bmatrixI&0&…&0 bmatrixS^-1 % bmatrix0\\ 0\\ \\ 0\\ z_t^k bmatrix= [S^-1 ]_0,k+2z_t^kover start_ARG h end_ARGt = [ start_ARG start_ROW start_CELL I end_CELL start_CELL 0 end_CELL start_CELL … end_CELL start_CELL 0 end_CELL end_ROW end_ARG ] S- 1 [ start_ARG start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL zitalic_titalic_k end_CELL end_ROW end_ARG ] = [ S- 1 ]0 , k + 2 zitalic_titalic_k (61) with [S−1]0,k+2subscriptdelimited-[]superscript102 [S^-1 ]_0,k+2[ S- 1 ]0 , k + 2 the upper right block of S−1superscript1S^-1S- 1. So after this slightly more complicated derivation, we again end up with a simple linear map from ztksuperscriptsubscriptz_t^kzitalic_titalic_k to decode the maximum likelihood hidden state. Let us rename it for ease of notation: U=[S−1]0,k+2subscriptdelimited-[]superscript102U= [S^-1 ]_0,k+2U = [ S- 1 ]0 , k + 2: h^t=Uztksubscript^ℎsuperscriptsubscript h_t=Uz_t^kover start_ARG h end_ARGt = U zitalic_titalic_k (62) Using this state estimation, we can predict the next observation as s^t+1=C∗W∗h^tsubscript^1superscriptsuperscriptsubscript^ℎ s_t+1=C^*W^* h_tover start_ARG s end_ARGt + 1 = C∗ W∗ over start_ARG h end_ARGt. This leads us to the following optimal candidate for the linear map zt+1k=Φztsuperscriptsubscript1Φsubscriptz_t+1^k= z_tzitalic_t + 1k = Φ zitalic_t: Φ=[0I0…0⋮⋱⋮00…0IC∗W∗U]Φmatrix00…0⋮⋱⋮00…0missing-subexpressionmissing-subexpressionsuperscriptsuperscriptmissing-subexpressionmissing-subexpression = bmatrix0&I&0&…&0\\ & & & & \\ 0&0&…&0&I\\ &&C^*W^*U&& bmatrixΦ = [ start_ARG start_ROW start_CELL 0 end_CELL start_CELL I end_CELL start_CELL 0 end_CELL start_CELL … end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL … end_CELL start_CELL 0 end_CELL start_CELL I end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL C∗ W∗ U end_CELL start_CELL end_CELL start_CELL end_CELL end_ROW end_ARG ] (63) As there exists an optimal map between ztksuperscriptsubscriptz_t^kzitalic_titalic_k and zt+1ksuperscriptsubscript1z_t+1^kzitalic_t + 1k that is linear, this map can be found by performing least-squares on an autoregressive dataset with ztksuperscriptsubscriptz_t^kzitalic_titalic_k. D.3 Capacity constraints on the representation Previously, we derived results for a fixed k. Now, we consider the case with a capacity bottleneck on the representations of the transformer. Let us assume that the transformer can allocate a d-dimensional subspace to store some representation of the past observations s0:t−1subscript:01s_0:t-1s0 : t - 1. Instead of concatenating k previous observations into this subspace with the constraint that k≤d/nssubscriptk≤ d/n_sk ≤ d / nitalic_s with nssubscriptn_snitalic_s the observation dimension, we can consider a more general case where we have an encoding ut=Es0:t=EztTsubscriptsubscript:0superscriptsubscriptu_t=Es_0:t=Ez_t^Tuitalic_t = E s0 : t = E zitalic_titalic_T. Here, E∈ℝd×TnssuperscriptℝsubscriptE ^d× Tn_sE ∈ blackboard_Rd × T nitalic_s, with T the sequence length. For t<Tt<Tt < T, we prepend zeros to s0:tsubscript:0s_0:ts0 : t to make the dimensions fit. When E consists of identity matrices on the diagonals corresponding to the k last observations, we recover the previous case. However, it might be more optimal to copy partial information from more than k observations, resulting in a different encoding matrix E. We are interested in three main points. First, we need to formalize a bottleneck objective that the encoding matrix E should optimize. Second, we need to show that the MLE for stsubscripts_tsitalic_t is still a linear map of the encoded observations utsubscriptu_tuitalic_t. Finally, we need some algorithm or strategy to compute the optimal encoding matrix E, such that we can compare it to the learned weights of the transformer. Bottleneck objective. We want the encoding to capture as much useful information about past observations as possible, to predict the future observation. Hence, we want the MLE y^t+1subscript^1 y_t+1over start_ARG y end_ARGt + 1 conditioned on utsubscriptu_tuitalic_t to be as close as possible to the MLE conditioned on the full past y0:tsubscript:0y_0:ty0 : t. We can formalize this in the following bilevel optimization problem minE||argmaxst+1p(st+1∣s0:t)−argmaxst+1p(st+1∣ut)||2 _E|| *arg\,max_s_t+1p(s_t+1 s_0:t% )- *arg\,max_s_t+1p(s_t+1 u_t)||^2minitalic_E | | start_OPERATOR arg max end_OPERATORs start_POSTSUBSCRIPT t + 1 end_POSTSUBSCRIPT p ( sitalic_t + 1 ∣ s0 : t ) - start_OPERATOR arg max end_OPERATORs start_POSTSUBSCRIPT t + 1 end_POSTSUBSCRIPT p ( sitalic_t + 1 ∣ uitalic_t ) | |2 (64) As both p(st+1∣s0:t)conditionalsubscript1subscript:0p(s_t+1 s_0:t)p ( sitalic_t + 1 ∣ s0 : t ) and p(st+1∣ut)conditionalsubscript1subscriptp(s_t+1 u_t)p ( sitalic_t + 1 ∣ uitalic_t ) are Gaussian, we have that the MLE of p(st+1∣s0:t)conditionalsubscript1subscript:0p(s_t+1 s_0:t)p ( sitalic_t + 1 ∣ s0 : t ) and p(st+1,ht+1∣s0:t)subscript1conditionalsubscriptℎ1subscript:0p(s_t+1,h_t+1 s_0:t)p ( sitalic_t + 1 , hitalic_t + 1 ∣ s0 : t ) are the same (see previous section), and hence we can rewrite the bilevel optimization problem into an equivalent form: minE||C[argmaxht+1p(ht+1∣s0:t)−argmaxht+1p(ht+1∣ut)]||2 _E||C [ *arg\,max_h_t+1p(h_t+1 s% _0:t)- *arg\,max_h_t+1p(h_t+1 u_t) ]||^2minitalic_E | | C [ start_OPERATOR arg max end_OPERATORh start_POSTSUBSCRIPT t + 1 end_POSTSUBSCRIPT p ( hitalic_t + 1 ∣ s0 : t ) - start_OPERATOR arg max end_OPERATORh start_POSTSUBSCRIPT t + 1 end_POSTSUBSCRIPT p ( hitalic_t + 1 ∣ uitalic_t ) ] | |2 (65) MLE of ht+1subscriptℎ1h_t+1hitalic_t + 1 is a linear map of utsubscriptu_tuitalic_t. For a fixed encoding E, it is easy to see that the MLE h^t+1subscript^ℎ1 h_t+1over start_ARG h end_ARGt + 1, and hence the MLE s^t+1=Ch^t+1subscript^1subscript^ℎ1 s_t+1=C h_t+1over start_ARG s end_ARGt + 1 = C over start_ARG h end_ARGt + 1 as well, are a linear map of utsubscriptu_tuitalic_t. We have that ut=EztTsubscriptsuperscriptsubscriptu_t=Ez_t^Tuitalic_t = E zitalic_titalic_T. Hence, we can repeat the calculations of the previous section, now with a new linear constraint ut=E[vtT+ℱTht−∑l=1T−1ℱTlϵh,t−l]subscriptdelimited-[]superscriptsubscriptsubscriptℱsubscriptℎsuperscriptsubscript11superscriptsubscriptℱsubscriptitalic-ϵℎu_t=E [v_t^T+F_Th_t- _l=1^T-1F_T^l% _h,t-l ]uitalic_t = E [ vitalic_titalic_T + Fitalic_T hitalic_t - ∑l = 1T - 1 Fitalic_Titalic_l ϵitalic_h , t - l ] for the MLE objective equation 54. The main result that the MLE h^t+1subscript^ℎ1 h_t+1over start_ARG h end_ARGt + 1 is a linear map of utsubscriptu_tuitalic_t holds in this case as well, as all equations for the first-order optimality conditions remain linear. How to compute the optimal encoding? Now that we derived argmaxht+1p(ht+1∣ut)subscriptargmaxsubscriptℎ1conditionalsubscriptℎ1subscript *arg\,max_h_t+1p(h_t+1 u_t)start_OPERATOR arg max end_OPERATORh start_POSTSUBSCRIPT t + 1 end_POSTSUBSCRIPT p ( hitalic_t + 1 ∣ uitalic_t ) as a function of E, we can use this to optimize the encoding objective equation 65 w.r.t. E, by computing its gradients. Concretely, we need to iterate the following two steps: 1. Compute the MLE s^t=C∗h^tsubscript^superscriptsubscript^ℎ s_t=C^* h_tover start_ARG s end_ARGt = C∗ over start_ARG h end_ARGt conditioned on utsubscriptu_tuitalic_t, by solving the linear system resulting from the MLE objective equation 54 with the new constraint ut=E[vtT+ℱTht−∑l=1T−1ℱTlϵh,t−l]subscriptdelimited-[]superscriptsubscriptsubscriptℱsubscriptℎsuperscriptsubscript11superscriptsubscriptℱsubscriptitalic-ϵℎu_t=E [v_t^T+F_Th_t- _l=1^T-1F_T^l% _h,t-l ]uitalic_t = E [ vitalic_titalic_T + Fitalic_T hitalic_t - ∑l = 1T - 1 Fitalic_Titalic_l ϵitalic_h , t - l ]. Use a differentiable linear solver (e.g. torch.linalg.solve), such that we can backpropagate through it in step 2. 2. Compute the encoding loss equation 65 and compute the gradients w.r.t. E on a training dataset consisting of multiple teacher systems. Appendix E Additional experiments with different sequence generator distributions For our main text experiments, the groundtruth transition matrix W∗superscriptW^*W∗ was set to a random orthogonal matrix. Here we briefly analyze Transformers trained on systems with different transition matrix statistics. For all settings in this section, we assume full observability, that is st=htsubscriptsubscriptℎs_t=h_tsitalic_t = hitalic_t for all time steps t. E.1 Contracting linear dynamics We show here the preliminary result when diverging from purely orthogonal teachers W to construct the sequence presented to the Transformer and restrict the eigenvalues of W∼(0,I)similar-to0W (0,I)W ∼ N ( 0 , I ) in a band of [0.3,0.9]0.30.9[0.3,0.9][ 0.3 , 0.9 ]. We notice that with these W approximately 2%percent22\%2 % of the sequences lead to very large values. To ease trainability, we therefore clip all the values of those sequences to values between [−2,2]22[-2,2][ - 2 , 2 ]. When training a single layer of linear self-attention, see Figure 11, we again observe that the trained layer matches the performance of a single step of gradient descent. We furthermore find clean weight structure, comparable to the weights trained on sequences which are generated by an orthogonal teacher, see Figure 13. For multi-layer linear transformers we find both gradually increasing probing of preconditioned inputs as necessary for our hypothesis, Proposition 2222, as well as gradual performance improvement for deeper Transformers. Figure 11: Evidence for mesa-optimization in Transformers trained on contracting linear dynamics. (A) At convergence, models trained on contracting sequences exhibit the same in-context learning performance (measured as the loss as a function of sequence length) as 1 step of gradient descent (dashed line), as in our findings for models trained on data generated by an orthogonal teacher. (B) In six-layer linear self-attention models trained on constructed tokens, we find that linear probing of preconditioned inputs (St−1St−1⊤+1/λI)−1stsuperscriptsubscript1superscriptsubscript1top11subscript(S_t-1S_t-1 +1/λ I)^-1s_t( Sitalic_t - 1 Sitalic_t - 1⊤ + 1 / λ I )- 1 sitalic_t improves with depth and context length, consistent with the mesa-optimizer of Proposition 2222 and our findings for the orthogonal-teacher setting. (C) For deeper models, performance in this setting increases. We find that the mesa-layer outperforms any other model and that a six-layer linear self-attention model can be explained by Proposition 2222. (D) We again find highly structured weights that, in the shown two-head-one-layer case, can implement an update step of gradient descent. E.2 Fixed-teacher linear dynamics Here we analyse the setting where every sequence shares the same single fixed orthogonal transition matrix W∗∈ℝnh×nhsuperscriptsuperscriptℝsubscriptℎsubscriptℎW^* ^n_h× n_hW∗ ∈ blackboard_Rnitalic_h × nitalic_h, and only the initial state h1∼(0,1)similar-tosubscriptℎ101h_1 (0,1)h1 ∼ N ( 0 , 1 ) is sequence-specific. Thus, in this setting there is no need to infer W∗superscriptW^*W∗ in-context. We report the results for the experiments in Figure 12. We observe that for this case even a one-layer linear self-attention Transformer drastically outperforms an update step of gradient descent. Furthermore, we find no evidence for the mesa-optimizers of Propositions 1 and 2, neither in the weights, which appear less structured and less interpretable, nor in linear probings of preconditioned tokens, where we barely observe a gradual improvement over layers as well as an overall worse probing performance. Lastly, all trained transformers, including a single mesa-layer seem to outperform optimization-algorithms in this settings, indicating that the models learn the fixed teacher and thereby predict with very low error already very early in the sequence, as we also find in next-token prediction analyses. Figure 12: No evidence for mesa-optimization in Transformers trained on fixed-teacher linear dynamics, as predicted by our theory. (A) At convergence, one layer linear self-attention transformers trained on fixed-teacher linear sequences significantly outperform the performance achieved by a single update step of gradient descent (dashed line). (B) In six-layer linear self-attention models trained on constructed tokens, we only find very weak linear probing of preconditioned inputs (St−1St−1⊤+1/λI)−1stsuperscriptsubscript1superscriptsubscript1top11subscript(S_t-1S_t-1 +1/λ I)^-1s_t( Sitalic_t - 1 Sitalic_t - 1⊤ + 1 / λ I )- 1 sitalic_t and only barely see gradual improvement over depth. (C) Various deep linear self-attention and mesa- transformers drastically outperform optimization algorithms when evaluated on test sequences for the same fixed teacher. (D) We find less structured and less interpretable weights in a trained one-layer transformer. Appendix F Experimental details F.1 Training Transformers on fully observable linear dynamical systems We provide here details about the training details of the Transformer models when training on the fully observable linear dynamics setting. As already stated in the main text, we train all Transformer models by minimizing the following classical autoregressive prediction error objective regression loss: ℒ(θ)=e∼p(e)[12∑t=1T−1‖et+1−ft(e1:t,θ)‖2].ℒsubscriptsimilar-todelimited-[]12superscriptsubscript11superscriptnormsubscript1subscriptsubscript:12L(θ)=E_e p(e)\! [ 12 _t=1^T-1% \|e_t+1-f_t(e_1:t,θ)\|^2 ].L ( θ ) = blackboard_Ee ∼ p ( e ) [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑t = 1T - 1 ∥ eitalic_t + 1 - fitalic_t ( e1 : t , θ ) ∥2 ] . (66) In all of our experiments, we employ causal masking during self-attention, implemented in the same way as in the majority of auto-regressive language modeling experiments. Specifically, during the self-attention operation we zero out the elements corresponding to the upper triangular matrix of the attention map, except for the diagonal itself. We do this both for the linear attention layer and for the mesa-layer. In practice, for softmax self-attention the incoming logits to the softmax are set to −1e301superscript30-1e^30- 1 e30. We ran into stability issues especially when training models with linear layers. To mitigate those, we simply clipped the activations of the forward pass to values between [−4,4]44[-4,4][ - 4 , 4 ] for linear self-attention Transformer-layers, which stabilized training significantly. Hyperparameters and other experimental details can be found in table 1. Table 1: Hyperparameters for all settings and model variants when training on simple fully observable linear dynamics. Hyperparameter Value/Description Context size We used length 50, except for the ICL experiments, where we used length 224 and the softmax-linearization experiments where we vary the context size according to the ratio context size = 4⋅nh⋅4subscriptℎ4· n_h4 ⋅ nitalic_h. Optimizer Adam [98] with ϵ=1e−8,β1=0.9,β2=0.999formulae-sequenceitalic-ϵ1superscript8formulae-sequencesubscript10.9subscript20.999ε=1e^-8, _1=0.9, _2=0.999ϵ = 1 e- 8 , β1 = 0.9 , β2 = 0.999 Weight decay 0.1 for constructed tokens, 0.05 otherwise Batchsize 256, except for ICL and Linearization due to memory constraints, here 128 and 64, resp. Gradient clipping 1.0 across models Activation clipping Clip [−4,4]44[-4,4][ - 4 , 4 ] for all linear models trained on constructed tokens, no clipping otherwise. Positional encodings We concatenate positional encodings of dimension 40 to queries and keys before computing the self-attention in the first layer for models trained on unconstructed tokens, otherwise no positional encodings. Dropout We do not use Dropout for any model. Architecture 1-L., Constr. We use a 1-layer, 2-head, key-size 20, dim-40-tokens, no input- or output-embedding architecture for single-layer models trained on constructed tokens. Architecture k-L. (k>1)1(k>1)( k > 1 ), Constr. We use a k-layer, 4-head, key-size 20, dim-40-token, no input- or output- embedding architecture for the multi-layer models (softmax and linear) trained on constructed tokens for the probing analysis and used key-size 40 for the interpolation. Architecture Full-softmax, No Constr. We use a 7-layer, 4-head, key-size 20, dim-10-tokens, dim-40- embedding- architecture with input- and output-embedding layers for full-fledged softmax-only-models. Architecture Hybrid-mesa, No Constr. We use 2-layer, 4-head, key-size 20, dim-10-tokens, dim-40- embedding-architecture with inputs- and output embedding layers. First a softmax-self-attention layer, then a single Mesa-layer. Architecture Full-mesa, No Constr. We use 2-layer, 4-head, key-size 20, dim-10-tokens, dim-40- embedding-architecture with inputs- and output embedding layers. Both layers are mesa-layers. Weight initialization W∼(0,σ2)similar-to0superscript2W (0,σ^2)W ∼ N ( 0 , σ2 ) with σ2=0.0002superscript20.0002σ^2=0.0002σ2 = 0.0002 for models trained on constructed tokens and σ=0.050.05σ=0.05σ = 0.05 for all other models. We always fixed the bias parameters to zero. Learning rate (& scheduler) For models trained on non-constructed tokens, we used linear warm-up starting from 00 to 7e−47superscript47e^-47 e- 4 in 1000100010001000 steps, Cosine annealing to 1e−5151e-51 e - 5 for the next 10000100001000010000 (single-layer interpolation experiments), 30000300003000030000 (other experiments) steps. We note here that we train the models only for at most 10000100001000010000 steps, except for the ICL- setting where we do Cosine annealing for 60000600006000060000 steps and train for 40000400004000040000 steps. For models trained on constructed tokens, we used a fixed learning rate of 1e−41superscript41e^-41 e- 4. Mesa regularization λ We initialize the learnable regularization parameter λ for every mesa-head to 1. F.1.1 Single-layer linear self-attention Transformer We analyze single-layer, two-head, key-size-20 linear self-attention Transformers, trained on constructed tokens, by comparing their performance with other models and providing an interpolation in parameter space between trained Transformers and the provided construction for Proposition 1111, which is described by only a few hyper-parameters. We read out the predictions from the first DssubscriptD_sDitalic_s entries of the outputs (which initially contain a zero-vector). For the performance analysis, these models are compared to a Proposition 1111, thus a single gradient descent update step on the auto-regressive loss. The optimal learning rate for this gradient descent step is line-searched. Figure 13: Mesa-optimization in a trained linear self-attention layer. We inspect the parameters of a two-headed, linear self-attention layer trained to predict the future state of a linear dynamical system. The dominant pattern obtained after learning corresponds to our mesa-gradient descent construction. The faint additional structure can be further reverse-engineered, and results from a modified mesa-objective function, Lt(Φ)=∑t′=1t−112‖st′+1−Φst′‖2subscriptΦsuperscriptsubscriptsuperscript′1112superscriptnormsubscriptsuperscript′1Φsubscriptsuperscript′2L_t( )= _t =1^t-1 12\|s_t +1- s_t^% \|^2Litalic_t ( Φ ) = ∑t′ = 1t - 1 divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ sitalic_t′ + 1 - Φ sitalic_t′ ∥2, discovered by base-optimization of Equation 66. Please compare to the similar structure of the weight matrix products of our construction. Please note that these matrices are actually of shape 40×40404040× 4040 × 40. Here we only show the 30×30303030× 3030 × 30 dimensional sub-matrix containing nonzero entries. Interpolation details: We first train a Transformer, then extract scalar parameters of the mesa-optimization algorithm, from the Ds×DssubscriptsubscriptD_s× D_sDitalic_s × Ditalic_s-shaped sub-matrices by taking the mean of the sub-diagonals of the matrix products Wk⊤WqsuperscriptsubscripttopsubscriptW_k W_qWitalic_k⊤ Witalic_q, PWvsubscriptPW_vP Witalic_v (cf. 13). We proceed by using these to both build a construction of sparse weight matrices, each consisting only of identity-sub-matrices (scaled by the resp. parameters), and, for the single-layer case, also directly compute a loss for the hard-coded implementation of Proposition 1111 with the respective hyper-parameters. Then, during a second training-run of a Transformer for the same initial conditions, we simultaneously compute the test loss for an interpolation, where we average equally not between the single weight matrices, but between the correct weight-matrix-products per head to obtain a new, interpolated model. The reason for this procedure is the non-uniqueness of weight matrices to obtain the found matrix products. We repeat this procedure for 5 different seeds, train a newly initialized Transformer each time and plot the obtained mean and standard deviation values for the test loss during training. F.1.2 Multi-layer linear self-attention Transformer For the multi-layer experiments, we use different settings: For the experiments with constructed tokens, we use a k-layer (k>11k>1k > 1), no input- or output-embedding layer architecture, we found that forward-pass activation clipping in linear self-attention based Transformers after each layer greatly stabilized training and hence clip activations in a band of [−4,4]44[-4,4][ - 4 , 4 ]. Interpolation details: The interpolation of multi-layer transformers when training on the token construction, we follow the procedure described in the previous subsection, per layer, but extend it to 4-head key-size 40 self-attention layers: We read off the parameters as the mean of the diagonals of the respective ns×nssubscriptsubscriptn_s× n_snitalic_s × nitalic_s sub-matrices of the resulting matrix weight products Wk⊤WqsuperscriptsubscripttopsubscriptW_k W_qWitalic_k⊤ Witalic_q, PWvsubscriptPW_vP Witalic_v per head of a trained Transformer. Then we construct sparse weight matrices consisting of identity-sub-matrices (scaled by the resp. parameters). We name this algorithm Compressed-Alg-6666. We proceed as for the single-layer experiment and re-train the Transformer from the initial conditions, but during training also report the test loss of a model that is obtained by equally averaging the weight products of our construction for Compressed-Alg-6666 and the Transformer. We average the products and not the single weight matrices for the same reasons stated in the previous subsection F.1.1 and report the loss obtained in runs for 5555 different seeds. F.1.3 Full-fledged Transformers For the experiments with full-fledged Transformers, we use either a 7-layer full-softmax architecture or 1+1 softmax-mesa and mesa-mesa hybrid-models. In all full-fledged models, we have input- and output-embedding layers, and the first layer always incorporates the logic for the positional encodings, while the other Transformer layers are either 6666 softmax self-attention layers, or 1111 mesa layer (1+1-layer architecture). The positional encodings are concatenated to the outputs of the key- and query projections before the computation of the attention. Analysing copying behaviour in full-fledged Transformers: We examine Transformers trained on linear sequence models to understand if they learn a token-binding process in early layers to construct aggregate internal token representations, which are necessary for the proposed mesa-optimization algorithms in subsequent layers. We analyse the causally masked attention maps of trained models (cf. 14, 15) and find clear data-independent attention on both the current and the previous token at each time-step. Furthermore, we propose a token-probing and a gradient sensitivity experiment (cf. 16, 17) to understand if the transformed tokens after the first Transformer layer contain both the current as well as the previous token in the sequence, as necessary for our hypothesis. For the token probing, we report the performance of linear decoders trained to predict previous tokens from output. There, we linearly regress a batch of sequences at a single time-step against a range of previous time-steps and report the obtained MSE loss. We find that, as predicted by our hypothesis, Transformers learn a process that data-independently binds previous and current tokens at each time steps to construct the proposed representations internally. We support this evidence by further analyses where we compute the sensitivity norm ‖∇st′ft(1)(s1:t,θ)‖normsubscript∇subscriptsuperscript′subscript1subscript:1\| _s_t f_t^(1)(s_1:t,θ)\|∥ ∇s start_POSTSUBSCRIPT t′ end_POSTSUBSCRIPT fitalic_t( 1 ) ( s1 : t , θ ) ∥ of the output of the first layer for all time steps t′≤tsuperscript′t ≤ t′ ≤ t. Furthermore we analyse full-mesa (first and second layer mesa) models and report the findings for the above experiments. Here, we find weaker and less clear - but still existing binding of previous tokens at each time-step. Figure 14: Softmax attention maps of the first softmax self-attention layer when training a softmax-only Transformer on unconstructed inputs. We visualize all four heads of the first softmax-attention layer and observe strong copying behavior, as predicted by the provided theory, in the heads i.e. full attention on the current and the previous token. We average the attention maps over a batch of 2048. Figure 15: Softmax attention maps of the first softmax self-attention layer when training a hybrid-mesa Transformer on unconstructed inputs. We visualize all four heads of the first softmax-attention layer and observe strong copying behavior, as predicted by the provided theory, in the heads i.e. full attention on the current and the previous token. We average the attention maps over a batch of 2048. Figure 16: Gradient sensitivity analysis of activations after the first layer in various Transformer models over the course of training. The first softmax layer groups together neighboring tokens. This can be seen in the high sensitivity to the current and previous tokens of the outputs of the first layer of a softmax-only Transformer. For full-mesa models we find less clear binding of all previous tokens, which is also reflected in the token probing analyses, cf. 17. Figure 17: Token probing for various full-fledged Transformer models trained on fully observable linear sequences models. We find further evidence for a learned token binding process in the first layer, indicated by a very low decoding-loss for both the current and the previous token at a chosen time-step (50505050) over batches of test-sequences. Analysing optimization algorithms in full-fledged Transformers: We proceed by analysing later layers in a variety of experiments. First, we compare the performance across fresh test sequences of the full-fledged model architectures and a hard-coded implementation of our proposed mesa-optimization that consists of six steps of preconditioning an internal optimization problem which is then solved in the last layer by an update step of gradient descent. Previously, we learn the parameters for the Chebyshev-iteration method for inverting matrices (as necessary for the proposed optimization procedure) by optimizing directly for solving fully observable linear sequence models generated by the same teacher as used in this setting. Furthermore, we find strong evidence for mesa-optimization in various activation-probing experiments. We linearly regress activations separately per time-step against targets and preconditioned inputs as predicted by our Proposition-2222, (St−1St−1⊤+1/λI)−1stsuperscriptsubscript1superscriptsubscript1top11subscript(S_t-1S_t-1 +1/λ I)^-1s_t( Sitalic_t - 1 Sitalic_t - 1⊤ + 1 / λ I )- 1 sitalic_t and find gradually increasing performance over layers in both experiments. F.1.4 Testing autoregressively trained Transformers on few-shot in-context-learning We provide here details about the post-training in-context learning experiment. For this experiment, we exclusively analyse full-fledged Transformers. After training, we "prompt" the model with few-shot regression datasets i.e. simply switch from sequences [x1,x2,…,xt−1,xt]subscript1subscript2…subscript1subscript[x_1,x_2,…,x_t-1,x_t][ x1 , x2 , … , xitalic_t - 1 , xitalic_t ] where xt+1=Wxtsubscript1subscriptx_t+1=Wx_txitalic_t + 1 = W xitalic_t and x0∼(0,I)similar-tosubscript00x_0 (0,I)x0 ∼ N ( 0 , I ) to [x1,y1,…,xN,yN]subscript1subscript1…subscriptsubscript[x_1,y_1,…,x_N,y_N][ x1 , y1 , … , xitalic_N , yitalic_N ] where yi=Wxisubscriptsubscripty_i=Wx_iyitalic_i = W xitalic_i and all xi∼(0,I)similar-tosubscript0x_i (0,I)xitalic_i ∼ N ( 0 , I ). Note that there is no relation between yi,xi+1subscriptsubscript1y_i,x_i+1yitalic_i , xitalic_i + 1 as in the autoregressive case. In both cases we sample W, if not stated otherwise from the same distribution i.e. as random orthogonal matrices. This results in a sequence length of t=2N2t=2Nt = 2 N and t=3N3t=3Nt = 3 N when incorporating EOS tokens. Throughout the sequence we measure ℒi=[12‖yi−f2i−1(xi;(yj,xj)j=1i−1)‖2].subscriptℒdelimited-[]12superscriptnormsubscriptsubscript21subscriptsuperscriptsubscriptsubscriptsubscript112L_i=E\! [ 12\|y_i-f_2i-1(x_i;\(y_j,x% _j)\_j=1^i-1)\|^2 ].Litalic_i = blackboard_E [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ yitalic_i - f2 i - 1 ( xitalic_i ; ( yitalic_j , xitalic_j ) j = 1i - 1 ) ∥2 ] . (67) for i≥22i≥ 2i ≥ 2 depicted e.g. in Figure 18. For the EOS-token fine-tuning experiments, we initialize a single vector EOS∼(0,I)similar-toEOS0 EOS (0,I)EOS ∼ N ( 0 , I ) and optimize this single vector on the same loss ℒ(EOS)=[12∑i=1N‖yi−f3i−2(xi,EOS;(yj,xj)j=1i−1)‖2]ℒEOSdelimited-[]12superscriptsubscript1superscriptnormsubscriptsubscript32subscriptEOSsuperscriptsubscriptsubscriptsubscript112L( EOS)=E\! [ 12 _i=1^N\|y_i-f% _3i-2(x_i, EOS;\(y_j,x_j)\_j=1^i-1)\|^2 ]L ( EOS ) = blackboard_E [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑i = 1N ∥ yitalic_i - f3 i - 2 ( xitalic_i , EOS ; ( yitalic_j , xitalic_j ) j = 1i - 1 ) ∥2 ] (68) via batch gradient descent for 5000 steps with batchsize 256 on randomly sampled training data. Note that we interleave every datapair with an EOS token i.e. [x1,y1,EOS,x2,…,yN−1,EOS,xN,yN]subscript1subscript1EOSsubscript2…subscript1EOSsubscriptsubscript[x_1,y_1, EOS,x_2,…,y_N-1, EOS,x_N,y_N][ x1 , y1 , EOS , x2 , … , yitalic_N - 1 , EOS , xitalic_N , yitalic_N ] and we therefore increase the sequence length from 2N22N2 N to 3N33N3 N. For the prefix-prompt P, we fine-tune a single sequence of 20202020 tokens which we append at the beginning of every in-context learning sequence. We initialize here again all vectors before training of the soft-prompt Pi∼(0,I)similar-tosubscriptP0 P_i (0,I)Pi ∼ N ( 0 , I ) and optimize again the same loss with or without the additional (pre-trained, see above) EOS token, ℒ(P)=[12∑i=21N−20‖yi−20−f3i−2+20(xi−20,P,EOS;(yj,xj)j=1i−21)‖2],ℒPdelimited-[]12superscriptsubscript2120superscriptnormsubscript20subscript3220subscript20PEOSsuperscriptsubscriptsubscriptsubscript1212L( P)=E\! [ 12 _i=21^N-20\|y_i-% 20-f_3i-2+20(x_i-20, P, EOS;\(y_j,x_j)\_j=1^i-21% )\|^2 ],L ( P ) = blackboard_E [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑i = 21N - 20 ∥ yitalic_i - 20 - f3 i - 2 + 20 ( xitalic_i - 20 , P , EOS ; ( yitalic_j , xitalic_j ) j = 1i - 21 ) ∥2 ] , (69) via batch gradient descent for 5000 steps with batchsize 256 on randomly sampled training data resulting in sequences [P1,…,P20,x1,y1,EOS,x2,…,yN−1,EOS,xN,yN]subscript1…subscript20subscript1subscript1EOSsubscript2…subscript1EOSsubscriptsubscript[P_1,…,P_20,x_1,y_1, EOS,x_2,…,y_N-1, EOS,% x_N,y_N][ P1 , … , P20 , x1 , y1 , EOS , x2 , … , yitalic_N - 1 , EOS , xitalic_N , yitalic_N ]. We extend this analysis by a continual-in-context learning experiment where we demonstrate the in-context learning capabilities of autoregressively trained Transformers on two tasks shown in sequence in context. Figure 18: Autoregressive Transformers display in-context few-shot learning capabilities. After training a hybrid-mesa Transformer on autoregressive sequence prediction problems, we measure its ability to solve linear regression tasks in-context, without further parameter fine-tuning. The task training set is presented to the model in sequence, with each token corresponding either to an input or to its corresponding label. A final test input is provided and the loss is measured after completing the sequence using the autoregressive Transformer. (A) The mesa-optimizers installed by autoregressive pretraining can be leveraged off-the-shelf to solve in-context supervised regression tasks, but yield sub-optimal regression performance (lightest red lines). In-context learning performance can be improved following the standard strategies of prompt (TF+EOS, light red lines) and prefix fine-tuning (TF+EOS+P, dark red lines). For comparison, we provide the loss achieved by an autoregressive linear model learned by least-squares (LSQ, yellow lines) (B) Same analysis, now presenting two tasks in a row. The autoregressive models develop some in-context continual learning capabilities. F.2 Linearizing softmax-Transformers We provide here details and additional results about the linearization experiments. For the linearization analysis presented in the main text, we proceed as follows: First, we fix the ratio of context-size to (observed) data-dimension to 4:1:414:14 : 1. Then, for each of the listed settings (ns∈[4,6,10,20,40,60]subscript4610204060n_s∈ [4,6,10,20,40,60 ]nitalic_s ∈ [ 4 , 6 , 10 , 20 , 40 , 60 ] and T according to the fixed ratio) we first train a classical full-fledged softmax-attention Transformer model on data generated by a linear-sequence generating teacher. We note here that for larger dimensions, the training becomes significantly more difficult in this setting. Then, for each layer in the model, we distill a separate linear self-attention layer by training it to ‘behave’ like its softmax-counterpart. To this end, we record the outputs of the softmax-layer for a new input-sequence. Note that the inputs to the linear layer that we are training are not the original input-sequences, but rather the (transformed) sequences that are the activations before the softmax-layer in the multi-layer softmax-Transformer. Hence, the distillation process is described by optimizing this objective: ℒ(θlinear)=[12∑t=1T−1‖SA(l)(s1:t,θsoftmax,l)−LSA(ft(l−1)(s1:t,θTF),θlinear)‖].ℒsubscriptdelimited-[]12superscriptsubscript11normsuperscriptSAsubscript:1subscriptsoftmaxLSAsuperscriptsubscript1subscript:1subscriptTFsubscriptlinearL( _linear)=E [ 12 _t=1^T-1||% SA^(l)(s_1:t, _softmax,l)-LSA(f_t^(l-1)(s_% 1:t, _TF), _linear)|| ].L ( θitalic_l i n e a r ) = blackboard_E [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑t = 1T - 1 | | SA( l ) ( s1 : t , θsoftmax , l ) - LSA ( fitalic_t( l - 1 ) ( s1 : t , θTF ) , θlinear ) | | ] . (70) Here, SA(l)superscriptSASA^(l)SA( l ) denotes the softmax attention operation at the l-th layer of the full-softmax transformer, θsoftmax,lsubscriptsoftmax _softmax,lθsoftmax , l the (learned) parameters for this operation, LSA the linear self-attention layer and ft(l−1)(s1:t,θTF)superscriptsubscript1subscript:1subscriptTFf_t^(l-1)(s_1:t, _TF)fitalic_t( l - 1 ) ( s1 : t , θTF ) the activation after the (l−1)1(l-1)( l - 1 )-th layer in the trained full-softmax Transformer, which will be the input to the linear layer we aim to distill. After this distiallation process is completed, we construct a model where we swap out the softmax operation at the respective layer and replace it by the distilled layer in the full-softmax model. Then we compare the performance of this new ‘linearized’ Transformer with the original full-softmax model on a batch of test sequences and report the measured test loss. Furthermore, we find that the distilled weights that were trained on the in- and outputs of a specific softmax-layer appear to be very similar to softmax-attention layers in structure, cf. 19. Figure 19: The weights of a distilled linear layer are surprisingly similar to those of the original full-softmax model. Here, we present the resulting weights for a linearization of a full-softmax, 2-layer-1-head model trained on constructed data with T=80,ns=20formulae-sequence80subscript20T=80,n_s=20T = 80 , nitalic_s = 20. Furthermore, we analyse and compare the performance of autoregressive models learned by regularized least squares and an generic interpolation algorithm, softmax-kernel-regression, in various settings as described above. We line-search the parameters necessary for regularization. Here, we extend these results and also analyse these settings for varying noise settings in the generating model. We report mean and standard deviation for three different seeds, each using generated data of batch-size 32323232, in 20. Figure 20: For different noise levels σh2superscriptsubscriptℎ2 _h^2σitalic_h2 in the sequence generation process, we analyse the performance of autoregressive models learned by regularized least squares and softmax kernel regression for increasing dimensions to underline the effect of the ‘curse of dimensionality’ in our setting. F.3 Training Transformers on partially observable linear dynamical systems For the experiments with partially observable linear dynamical systems, we directly analyze full-fledged Transformers trained on the observations. In detail, we use either a 7-layer full-softmax architecture or 1+1 softmax-mesa hybrid-models. In all models, we have input- and output-embedding layers, and the first layer always incorporates the logic for the positional encodings, while the other Transformer layers are either 6666 softmax self-attention layers, or 1111 mesa layer. The positional encodings are concatenated to the outputs of the key- and query projections before the computation of the attention. Generally, our models are trained on ns=5subscript5n_s=5nitalic_s = 5 - dimensional observations from a process with nh=15subscriptℎ15n_h=15nitalic_h = 15 - dimensional hidden states. Further training details can be found in Table 2. Table 2: Hyperparameters for all settings and model variants when training on partially observable linear dynamics. Hyperparameter Value/Description Context size We use context size T=5050T=50T = 50 Optimizer Adam [98] with ϵ=1e−8,β1=0.9,β2=0.999formulae-sequenceitalic-ϵ1superscript8formulae-sequencesubscript10.9subscript20.999ε=1e^-8, _1=0.9, _2=0.999ϵ = 1 e- 8 , β1 = 0.9 , β2 = 0.999 Weight decay 0.05 across models Batchsize Batchsize 256 Gradient clipping 1.0 across models Activation clipping No activation clipping Positional encodings We concatenate positional encodings of dimension = embedding-dimension to queries and keys before computing the self-attention in the first layer for all models Dropout We do not use Dropout for any model. Architecture Full-softmax, No Constr. We use a 7-layer, 4-head, key-size min(20,embedding-dim.)20embedding-dim. (20,embedding-dim.)min ( 20 , embedding-dim. ), dim-5-input-tokens, architecture with varying embedding dimensions in [5,10,15,20,30,50,80]5101520305080[5,10,15,20,30,50,80][ 5 , 10 , 15 , 20 , 30 , 50 , 80 ] with input- and output-embedding layers for full-fledged softmax-only-models. Architecture Hybrid-mesa, No Constr. We use 2-layer, 4-head, key-size min(20,embedding-dim.)20embedding-dim. (20,embedding-dim.)min ( 20 , embedding-dim. ), dim-5-input-tokens, architecture with embedding dimensions in [5,10,15,20,30,50,80]5101520305080[5,10,15,20,30,50,80][ 5 , 10 , 15 , 20 , 30 , 50 , 80 ] with inputs- and output embedding layers. First a softmax-self-attention layer, then a single Mesa-layer. Weight initialization W∼(0,σ2)similar-to0superscript2W (0,σ^2)W ∼ N ( 0 , σ2 ) with σ2=0.05superscript20.05σ^2=0.05σ2 = 0.05 for all models. We always fixed the bias parameters to zero. Learning rate (& scheduler) We used linear warm-up starting from 00 to 4e−44superscript44e^-44 e- 4 in 1000100010001000 steps, Cosine annealing to 1e−5151e-51 e - 5 for the next 30000300003000030000 steps. Mesa regularization λ We initialize the learnable regularization parameter λ for every mesa-head to 1. Analysing copying behaviour in full-fledged Transformers: We examine Transformers trained on partially observable linear sequence models to understand if they learn a token-binding process in early layers to construct aggregate internal token representations, which are necessary for the proposed mesa-optimization algorithms in subsequent layers. As in the fully-observable setting, we use both a token-probing and a gradient sensitivity experiment to understand if the transformed tokens after the first Transformer layer contain the previous tokens in the sequence, as necessary for our hypothesis for partially observable models. For the token probing (cf. 21), we report the performance of linear decoders trained to predict previous tokens from output. There, we linearly regress a batch of sequences at a single time-step against a range of previous time-steps and report the obtained MSE loss for models of varying embedding dimension. We find that as the embedding dimension grows, the probing of previous tokens becomes more clear and stable. Hence we infer that, as expected by our hypothesis, Transformers learn a process that data-independently binds previous and current tokens at each time steps to construct the proposed representations internally. We support this evidence by further analyses where we compute the sensitivity norm ‖∇st′ft(1)(s1:t,θ)‖normsubscript∇subscriptsuperscript′subscript1subscript:1\| _s_t f_t^(1)(s_1:t,θ)\|∥ ∇s start_POSTSUBSCRIPT t′ end_POSTSUBSCRIPT fitalic_t( 1 ) ( s1 : t , θ ) ∥ of the output of the first layer for all time steps t′≤tsuperscript′t ≤ t′ ≤ t (cf. 22). Figure 21: Token probing for Transformers trained on partially observable data. If we vary the embedding-dimension of the Transformers, we find that larger Transformers use the provided space to copy over relevant tokens. Figure 22: Gradient sensitivity analysis of activations after the first layer in full-softmax (A) and hybrid-mesa (B) Transformer models over the course of training. The first softmax layer groups together the current and multiple previous tokens as predicted by our hypothesis. This can be seen in the high sensitivity to the current and previous tokens of the outputs of the first layer of the Transformer models. Analysing optimization algorithms in full-fledged Transformers trained on partially observable linear dynamical systems: We proceed by analysing later layers using the same method as in the fully observable setting (cf. 24). We compare the performance across fresh test sequences of the full-fledged model architectures and a hard-coded implementation of our proposed mesa-optimization that consists of six steps of preconditioning an internal optimization problem which is then solved in the last layer by an update step of gradient descent. Previously, we learn the parameters for the Chebyshev-iteration method for inverting matrices (as necessary for the proposed optimization procedure) by optimizing directly for solving fully observable linear sequence models generated by the same teacher as used in this setting. Furthermore, we find strong evidence for mesa-optimization in various activation-probing experiments. We linearly regress activations separately per time-step against targets and preconditioned inputs as predicted by our Proposition 2222 for partially observable linear sequence models, (Zt−1Zt−1⊤+1/λI)−1ztsuperscriptsubscript1superscriptsubscript1top11subscript(Z_t-1Z_t-1 +1/λ I)^-1z_t( Zitalic_t - 1 Zitalic_t - 1⊤ + 1 / λ I )- 1 zitalic_t (here ztsubscriptz_tzitalic_t refers to an aggregation of previous k=55k=5k = 5 tokens in one constructed token) and find gradually increasing performance over layers in both experiments. F.4 Training Transformers on fully observable nonlinear dynamical systems For the experiments with fully observable nonlinear dynamical systems, we also directly analyze full-fledged Transformers trained on non-constructed observation-tokens. In detail, we use either a 7-layer full-softmax architecture or 1+1 softmax-mesa hybrid-models. In all models, we have input- and output-embedding layers, and the first layer always incorporates the logic for the positional encodings, while the other Transformer layers are either 6666 softmax self-attention layers, or 1111 mesa layer (1+1-layer architecture). The positional encodings are concatenated to the outputs of the key- and query projections before the computation of the attention. Furthermore, we use MLPs with hidden dimension 300300300300 (factor 5×5×5 × if compared with embedding-dimension for the models, which we set to 50505050-dimensional) and a specialized version of normalization, sum normalization, as introduced by [33], where we divide the query and key projections by their respective sums of components. Further training details can be found in 3. Table 3: Hyperparameters for all settings and model variants when training on fully observable nonlinear dynamics. Hyperparameter Value/Description Context size We use context size T=5050T=50T = 50 Optimizer Adam [98] with ϵ=1e−8,β1=0.9,β2=0.999formulae-sequenceitalic-ϵ1superscript8formulae-sequencesubscript10.9subscript20.999ε=1e^-8, _1=0.9, _2=0.999ϵ = 1 e- 8 , β1 = 0.9 , β2 = 0.999 Weight decay 0.05 across models Batchsize Batchsize 256 Gradient clipping 1.0 across models Activation clipping No activation clipping Positional encodings We concatenate positional encodings of dimension 60 to queries and keys before computing the self-attention in the first layer for all models Dropout We do not use Dropout for any model. Architecture Full-softmax, No Constr. We use a 7-layer, 4-head, key-size 20202020, dim-10-input-tokens architecture with varying embedding dimensions 60606060 with input- and output-embedding- layers for full-fledged softmax-only-models. The models comprise of MLPs with hidden dimension 300300300300 and layer-normalization of query- and key-projections at each layer. Architecture Hybrid-mesa, No Constr. We use 2-layer, 4-head, key-size 20202020, dim-10-input-tokens architecture with varying embedding dimensions 60606060 with input- and output-embedding- layers. First a softmax-self-attention layer, then a single Mesa-layer. The models comprise of MLPs with hidden dimension 300300300300 and layer-normalization of query- and key-projections at each layer. Weight initialization W∼(0,σ2)similar-to0superscript2W (0,σ^2)W ∼ N ( 0 , σ2 ) with σ2=0.05superscript20.05σ^2=0.05σ2 = 0.05 for all models. We always fixed the bias parameters to zero. Learning rate (& scheduler) We used linear warm-up starting from 00 to 4e−44superscript44e^-44 e- 4 for hybrid-mesa and 1e−31superscript31e^-31 e- 3 for full-softmax models in 1000100010001000 steps, Cosine annealing to 1e−5151e-51 e - 5 for the next 50000500005000050000 steps. We only train for 40000400004000040000 steps. Mesa regularization λ We initialize the learnable regularization parameter λ for every mesa-head to 1. Analysing copying behaviour in full-fledged Transformers: As in the fully- and partially observable linear setting, we use both a token-probing and a gradient sensitivity experiment to test if trained Transformers learn a token binding mechanism in early layers. For the token probing, we report the performance of linear decoders trained to predict previous tokens from output. There, we linearly regress a the transformed token after the first layer for a batch of sequences at a single time-step against nonlinear transformed tokens from a range of previous time-steps and report the obtained MSE loss. Therefore, we employ the teacher used during training, MLP∗superscriptMLPMLP^*MLP∗ Here, we also show further analyses where we compute the sensitivity norm ‖∇st′ft(1)(s1:t,θ)‖normsubscript∇subscriptsuperscript′subscript1subscript:1\| _s_t f_t^(1)(s_1:t,θ)\|∥ ∇s start_POSTSUBSCRIPT t′ end_POSTSUBSCRIPT fitalic_t( 1 ) ( s1 : t , θ ) ∥ of the output of the first layer for all time steps t′≤tsuperscript′t ≤ t′ ≤ t. We report the results in Figure 23. Figure 23: Gradient sensitivity analysis of activations after the first layer in full-softmax and hybrid-mesa Transformer models trained on fully-observable nonlinear dynamical systems over the course of training. The first softmax layer groups together the current and multiple previous tokens as predicted by our hypothesis. This can be seen in the high sensitivity to the current and previous tokens of the outputs of the first layer of the Transformer models. Analysing optimization algorithms in full-fledged Transformers trained on fully observable nonlinear dynamical systems: We proceed by analysing later layers using the same method as in the linear settings (cf. 25). We compare the performance across fresh test sequences of the full-fledged model architectures and a hard-coded implementation of our proposed mesa-optimization that consists of six steps of preconditioning an internal optimization problem which is then solved in the last layer by an update step of gradient descent. Previously, we learn the parameters for the Chebyshev-iteration method for inverting matrices (as necessary for the proposed optimization procedure) by optimizing directly for solving fully observable nonlinear sequence models generated by the same teacher as used during training. Furthermore, we find strong evidence for mesa-optimization in various activation-probing experiments. We linearly regress activations separately per time-step against targets and preconditioned inputs as predicted by our Proposition 2222 for partially observable linear sequence models, (Ft−1Ft−1⊤+1/λI)−1ftsuperscriptsubscript1superscriptsubscript1top11subscript(F_t-1F_t-1 +1/λ I)^-1f_t( Fitalic_t - 1 Fitalic_t - 1⊤ + 1 / λ I )- 1 fitalic_t (here ftsubscriptf_tfitalic_t refers nonlinear transformed tokens MLP(st)∗^*(s_t)start_FLOATSUPERSCRIPT ∗ end_FLOATSUPERSCRIPT ( sitalic_t ) using the nonlinear teacher) and find gradually increasing performance over layers in both experiments. Figure 24: Evidence for mesa-optimization in standard (softmax) Transformers trained on partially observable linear dynamical systems. (A) Linear probes decode next-token target st+1subscript1s_t+1sitalic_t + 1 from internal Transformer activations, with decoding performance improving with depth (intensity color-coded) and context length, consistent with gradual optimization of an internal next-token prediction model. (B) Likewise for preconditioned input (Zt−1Zt−1⊤+1/λI)−1ztsuperscriptsubscript1superscriptsubscript1top11subscript(Z_t-1Z_t-1 +1/λ I)^-1z_t( Zitalic_t - 1 Zitalic_t - 1⊤ + 1 / λ I )- 1 zitalic_t probing, where ztsubscriptz_tzitalic_t are constructed tokens, comprising of the past 5 observations, consistent with our findings in token probings for Transformers trained on partially observable dynamics and the mesa-optimizer of Proposition 2222. (C) Next-token prediction error of a 7-layer Transformer (blue line) decreases with context length in a very similar way as 7 steps of Proposition 2222 on constructed tokens as predicted by our hypothesis for partially observable linear dynamical systems (dashed yellow line), with hyperparameters of the latter set for best performance, not to match Transformer behavior. Figure 25: Evidence for mesa-optimization in standard (softmax) Transformers trained on fully observable nonlinear dynamical systems. (A) Linear probes decode next-token target st+1subscript1s_t+1sitalic_t + 1 from internal Transformer activations, with decoding performance improving with depth (intensity color-coded) and context length, consistent with gradual optimization of an internal next-token prediction model. (B) Likewise for preconditioned input (Ft−1Ft−1⊤+1/λI)−1ftsuperscriptsubscript1superscriptsubscript1top11subscript(F_t-1F_t-1 +1/λ I)^-1f_t( Fitalic_t - 1 Fitalic_t - 1⊤ + 1 / λ I )- 1 fitalic_t probing, where ftsubscriptf_tfitalic_t are the nonlinearily transformed observations ft=MLP∗(st)subscriptsuperscriptMLPsubscriptf_t=MLP^*(s_t)fitalic_t = MLP∗ ( sitalic_t ) using the teacher-MLP, consistent with the mesa-optimizer of Proposition 2222. (C) Next-token prediction error of a 7-layer Transformer (blue line) decreases with context length in almost exactly the same way as 7 steps of Proposition 2222 (dashed yellow line), with hyperparameters of the latter set for best performance, not to match Transformer behavior. Appendix G Language modeling Figure 26: Single-layer Transformers with key-shifts, the Pile. We observe improved (A) perplexity and (B) in-context learning scores when comparing one linear to one mesa layer with different DPFP sizes ν∈0,1,2,30123ν∈\0,1,2,3\ν ∈ 0 , 1 , 2 , 3 , corresponding inversely to color fade. Mesa layers consistently outperform linear layers, catching up with softmax. We present here first preliminary results on the performance of models which replace (some) softmax self-attention layer with the mesa-layer. Our hypothesis is that the mesa-layer will improve the in-context learning and working memory capabilities of a Transformer, in particular of the linear kind. We further hypothesize that this in turn translates to language modeling improvements, based on the high correlation between in-context learning and actual autoregressive loss reported by Kaplan et al. [50]. We therefore quantify performance along two axes: the next-token prediction loss, the actual objective of base-optimization; and the ability to learn in-context, measured as the difference in loss calculated over two timepoints within a sequence, as defined by Kaplan et al. [50] and Olsson et al. [5]. We train Transformers with various architectural configurations on the Pile [99], a large compilation of various English text datasets including parts of Wikipedia, arXiv, and code. We always model the first layer using softmax self-attention in all experiments. This decision is based on insights from our previous experiments, where base-optimization consistently attributed a mesa-objective creation role to this layer. We then compare pure softmax-only Transformers to two types of hybrid models, where the subsequent layers are either linear or mesa. We vary the depth of our models, from 2-layer attention-only to deeper 4-attention-layer models endowed with tokenwise MLPs which are present by default in standard Transformers. By transforming the data nonlinearly, MLP layers allow solving nonlinear regression problems by mesa-gradient descent. Following this reasoning, we further adopt in our hybrid-linear and hybrid-mesa Transformers the deterministic parameter-free projection (DPFP, size denoted by ν) due to Schlag et al. [33], a non-learned and simple to compute nonlinear transformation of keys and queries. We found that this significantly improved the performance of non-softmax attention layers. Finally, to represent discrete input symbols as real-valued vectors, we learn a vocabulary of real-valued vectors using the standard GPT-2 tokenizer. We note that all models have an (almost) identical number of parameters. In line with our synthetic experiments, we observe stable learning across all model types of copying layers, indicated by the constant attention to tokens in direct or close proximity, as shown in Figure 28. We therefore reproduce the findings of Olsson et al. [5], extending them to models that include other forms of attention. This phenomenon is predicted by the mesa-optimization theory presented here, where copy layers serve the purpose of constructing internal mesa-objective functions. We note that, in contrast to our previous synthetic linear prediction tasks, the Pile is no longer Markovian of order 1. This is reflected in the more complicated attention maps, indicating more involved copying behavior. Additionally, we run an ablation where we compare to a single-layer control model whose first softmax layer is removed and replaced by a hardcoded one-step key-shift operator. Interestingly, such an operator can be found in previous work [5, 45]. Again, we verify the findings of [5] and observe strong in-context learning scores, within a single layer, with the mesa-layer performing on-par with softmax, see Figure 26. As in [33], DPFP features substantially improve performance; we fix ν=33ν=3ν = 3 for the linear as well as the mesa layer for all other language modeling experiments. Figure 27: Language modeling experiments on the Pile. We observe improved perplexity and in-context learning scores across all our language modeling experiments when switching from standard linear self-attention to the mesa-layer. As hypothesized, we confirm that in all models various copying heads can be found in the first softmax layer, see Figure 28 for visualizations of the attention heads. (A&B) 2-layer Transformers without MLPs and first layers softmax self-attention and second layer either softmax, mesa or linear. (C&D) 4-layer Transformers with MLPs and first layers softmax self-attention and rest of the layers either all softmax, mesa or linear. We find that the hybrid-mesa Transformers dominate their hybrid-linear counterparts in terms of performance, across all configurations, essentially matching (for 2-layer models) or coming closer (for 4-layer models with MLPs) to pure-softmax Transformers, cf. Figure 27. We leave for future work studying the mesa-layer equipped with forgetting factors, see Appendix C.1, which could further improve upon our results here. This is reflected both in terms of perplexity and in-context learning scores. Strictly speaking, these results are not sufficient to make claims on whether mesa-optimization is occurring within standard Transformers. However, the high performance achieved by the hybrid-mesa models, which operate on mesa-optimization principles by design, suggests that mesa-optimization might be happening within conventional Transformers. More reverse-engineering work is needed to add weight to this conjecture. We provide now additional details about the language modeling experiments. We use standard values found in the literature and the same hyperparameters, which we did not tune, across all experiments. We, if not stated otherwise, use the standard GPT-2 transformer architecture with LayerNorm [28], MLPs between self-attention layer and skip-connection after every layer which we train on a standard (autoregressively) masked cross-entropy loss. We do not use an input embedding layer but an output projection before computing the logits. To train enable stable training of the linear as well as the mesa-layer, we apply the proposed key and query normalization of schlag and simply devide them by their L2 norm. Intriguingly, this stabilizes training drastically also for the mesa-layer after which we did not observe any more instabilities. Note that this is very similar to using additional LayerNorm [28] on the keys and queries. Except from this normalization, all models are constructed and trained identically. See 4 for an overview of all design decisions and hyperparameters. Also, we refer to the appendix of [33] on how to compute the DPFP kernels to non-linearly alter the key and query features,we use ν=33ν=3ν = 3 if not stated otherwise. Table 4: Hyperparameters for language modelling experiments across all Transformer variants i.e. pure softmax, linear-hybrid and mesa-hybrid with/out MLPs. Hyperparameter Value Dataset The pile [99] Tokenizer GPT-2 tokenizer - we append a special "EOS" token between every sequence Context size 1024 Vocabulary size 50257 Vocabulary dim 756 Optimizer Adam [98] with ϵ=1e−8,β1=0.9,β2=0.95formulae-sequenceitalic-ϵ1superscript8formulae-sequencesubscript10.9subscript20.95ε=1e^-8, _1=0.9, _2=0.95ϵ = 1 e- 8 , β1 = 0.9 , β2 = 0.95 Weight decay 0.1 Batchsize 256 Gradient clipping Global norm of 1. Positional encodings We add standard positional encodings. Dropout We use embedding dropout of 0.1 right after adding positional encodings. Architecture details 12 heads, key size 64, token size 756, no input- but output-embedding Weight init W∼(0,σ2)similar-to0superscript2W (0,σ^2)W ∼ N ( 0 , σ2 ) with σ=0.020.02σ=0.02σ = 0.02 and bias parameter to zero. We scale all weight matrices before a skip connection with 12N12 12 Ndivide start_ARG 1 end_ARG start_ARG 2 square-root start_ARG N end_ARG end_ARG with N the number of layers. Learning rate scheduler Linear warm-up starting from 1e−61superscript61e^-61 e- 6 to 3e−43superscript43e^-43 e- 4 in the first 8000 training steps, cosine annealing to 2e−4242e-42 e - 4 for the next 300 billion tokens MLP size Widening factor 4 i.e. hidden dimension 4∗75647564*7564 ∗ 756 with ReLU non-linearities [100] Mesa regularization λ We initialize the learnable regularization parameter λ for every mesa-head to 1. Figure 28: Softmax attention maps of the 2-layer softmax-only Transformer trained on the Pile. We average the attention maps of the first softmax-attention layer over a batch of size 256 and observe stable off diagonals with different offsets and widths indicating clean copying behavior based on positional encodings in multiple heads. Appendix H Software The results reported in this paper were produced with open-source software. We used the Python programming language together with the Google JAX [97] framework, and the NumPy [101], Matplotlib [102], Flax [103] and Optax [104] packages.