← Back to papers

Paper deep dive

Compact Proofs of Model Performance via Mechanistic Interpretability

Jason Gross, Rajashree Agrawal, Thomas Kwa, Euan Ong, Chun Hei Yip, Alex Gibson, Soufiane Noubir, Lawrence Chan

Year: 2024Venue: NeurIPS 2024Area: Formal/TheoreticalType: EmpiricalEmbeddings: 183

Models: Small one-layer attention-only transformers (151 random seeds)

Abstract

Abstract:We propose using mechanistic interpretability -- techniques for reverse engineering model weights into human-interpretable algorithms -- to derive and compactly prove formal guarantees on model performance. We prototype this approach by formally proving accuracy lower bounds for a small transformer trained on Max-of-K, validating proof transferability across 151 random seeds and four values of K. We create 102 different computer-assisted proof strategies and assess their length and tightness of bound on each of our models. Using quantitative metrics, we find that shorter proofs seem to require and provide more mechanistic understanding. Moreover, we find that more faithful mechanistic understanding leads to tighter performance bounds. We confirm these connections by qualitatively examining a subset of our proofs. Finally, we identify compounding structureless errors as a key challenge for using mechanistic interpretability to generate compact proofs on model performance.

Tags

ai-safety (imported, 100%)empirical (suggested, 88%)formaltheoretical (suggested, 92%)interpretability (suggested, 80%)

Links

Your browser cannot display the PDF inline. Open PDF directly →

Intelligence

Status: succeeded | Model: google/gemini-3.1-flash-lite-preview | Prompt: intel-v1 | Confidence: 92%

Last extracted: 3/12/2026, 7:20:42 PM

Summary

The paper proposes using mechanistic interpretability to derive compact, formally verified performance guarantees for neural networks. By reverse-engineering a one-layer attention-only transformer trained on the 'Max-of-K' task, the authors demonstrate that more faithful mechanistic understanding leads to tighter performance bounds, while shorter proofs require more mechanistic insight. They identify compounding structureless errors as a primary challenge in generating these compact proofs.

Entities (4)

Max-of-K · task · 99%Mechanistic Interpretability · methodology · 98%Attention-only Transformer · model-architecture · 95%Compounding structureless errors · challenge · 92%

Relation Signals (3)

Mechanistic Interpretability enables Compact Proofs

confidence 95% · We propose using mechanistic interpretability... to derive and compactly prove formal guarantees on model performance.

Mechanistic Understanding leadsto Tighter Performance Bounds

confidence 93% · Moreover, we find that more faithful mechanistic understanding leads to tighter performance bounds.

Compounding structureless errors hinders Compact Proofs

confidence 90% · Finally, we identify compounding structureless errors as a key challenge for using mechanistic interpretability to generate compact proofs on model performance.

Cypher Suggestions (2)

Find all proof strategies and their associated performance metrics · confidence 85% · unvalidated

MATCH (s:Strategy)-[:HAS_METRIC]->(m:Metric) RETURN s.name, m.type, m.value

Map the relationship between mechanistic understanding and proof compactness · confidence 80% · unvalidated

MATCH (m:Methodology)-[:IMPROVES]->(p:Proof) WHERE p.compactness = 'high' RETURN m, p

Full Text

182,241 characters extracted from source content.

Expand or collapse full text

Compact Proofs of Model Performance via Mechanistic Interpretability Jason Gross ∗ Rajashree AgrawalThomas Kwa † Euan Ong † Chun Hei Yip † Alex Gibson ‡ Soufiane Noubir ‡ Lawrence Chan Abstract We propose using mechanistic interpretability – techniques for reverse engineering model weights into human-interpretable algorithms – to derive and compactly prove formal guarantees on model performance. We prototype this approach by formally proving accuracy lower bounds for a small transformer trained on Max-of- K, validating proof transferability across 151 random seeds and four values ofK. We create 102 different computer-assisted proof strategies and assess their length and tightness of bound on each of our models. Using quantitative metrics, we find that shorter proofs seem to require and provide more mechanistic understanding. Moreover, we find that more faithful mechanistic understanding leads to tighter performance bounds. We confirm these connections by qualitatively examining a subset of our proofs. Finally, we identify compounding structureless errors as a key challenge for using mechanistic interpretability to generate compact proofs on model performance. 1 Introduction One approach to ensuring the safety and reliability of powerful AI systems is via formally verified proofs of model performance [48,11]. If we hope to deploy formal verification on increasingly large models [24,27] with powerful emergent capabilities [56] across more diverse and broader domains [5,46], we will needcompactproofs of generalization bounds onspecificmodels that certify globalrobustness. However, existing approaches tend to use proof strategies that suffer from bad asymptotic complexity, while verifying either generalization properties of training procedures or local robustness properties of specific models. One key challenge to verification is that neural network architectures are highly expressive [51, 58], and models with similar training procedure and performance may still have learned significantly different weights [38,9]. This expressivity makes it difficult to adequatelycompressexplanations of global model behavior in ways thatcorrespondclosely enough to the model’s actual mechanisms to be useful for efficient verification without being toolossy, especially when using only knowledge of the architecture or training procedure. We propose verifying model performance using understanding derived frommechanistic interpretability(Section 2) – that is, reverse engineering the specific implementation of the algorithm from the learned weights of particular models. Knowledge of the specific implementation allows us to construct less lossy simplifications of the model, and more efficiently reason about model performance over possible inputs. In this work, we provide a case study of translating mechanistic interpretations into compact proofs. We train an attention-only transformer on a Max-of-Ktask with 151 random seeds (Section 3), and ∗ Corresponding author. Please direct correspondence tojgross@mit.edu. † These authors contributed equally to this work. ‡ These authors contributed equally to this work. Preprint. Under review. arXiv:2406.11779v14 [cs.LG] 24 Dec 2024 UnembedUnembed EmbedEmbed OVQK OV Circuit QK Circuit Direct Path t 0 t 1 t 2 t 3 t 0 t 1 t 2 t 3 t 0 t 1 t 2 t 3 t 0 t 1 t 2 t 3 True ModelBrute Force ProofCubic ProofSubcubic Proofs (ℓ 0 ,ℓ 1 ,...,ℓ 63 )(ℓ 0 ,ℓ 1 ,...,ℓ 63 )(ℓ 0 ,ℓ 1 ,...,ℓ 63 )(ℓ 0 ,ℓ 1 ,...,ℓ 63 ) Input Logits ? . . . . . . . . . . . . FLOPs Required: Accuracy Lower Bound: Unexplained Dimension: Asymptotic Complexity: 1.41×10 14 99.92 % 1.07×10 9 O(d vocab n ctx ) 3.51×10 7 95.31 % 1.28×10 4 O(d vocab 3 ·n ctx ) 4.68×10 6 28.41 % 4.42×10 3 O(d vocab ·d model 2 ·n ctx ) QK Circuit decomposes into large “size” and small “noise” components Size component w/ singular value7.4×10 3 Other components have singular value<1.5×10 1 Figure 1:We construct proofs using different degrees of mechanistic interpretation. (Left) The models we consider in this paper are one-layer attention-only transformers, and so contain three “paths”: the OV circuit, the QK circuit, and the direct path. (Right) For the brute-force proof (Section 4.3.1), we treat the model as a black box and thus need to check all possible combinations of inputs. For the cubic proof (Section 4.3.1), we decompose the model into its three corresponding paths, but still check the correctness of each path via brute force. Finally, in some subcubic proofs (Section 4.3), we use all parts of the mechanistic interpretation presented in Section 3. (Bottom) For each of the three categories of proof, we report the number of FLOPs used in computing the certificate (lower=better, Appendix A.6), lower bound on model accuracy (higher=better), effective dimension of the unexplained parts of the model (lower=better, Appendix A.5), and asymptotic complexity of the proof strategy as we scale the inputs and model (lower=better). Significantly more compact proofs have vacuous accuracy bounds by default. Using more mechanistic understanding allows us to recover some, but not all, of the accuracy bounds on these more compact proofs, as our understanding is not fully faithful to the model internals. then reverse engineer the models using standard mechanistic interpretability techniques. We use our understanding to define a set of 102 different computer-assisted proof strategies with varying tightness of bound and with different asymptotic complexity and number of required floating-point operations (Section 4). 4 We validate our technique against an additional 604 models for varying values ofK(Appendix A.2.1). We define a quantitative metric to assess the mechanistic understanding used in a proof strategy by the dimensionality of the function space that the proof strategy must consider, which we deem theunexplained dimensionalityof the proof strategy (Sections 5.1, and A.5). Using this metric, we find a negative relationship between proof length and degree of understanding. We qualitatively examine proof strategies to confirm and explain this relationship, finding that more compact proofs both require and provide more mechanistic understanding. We also find suggestive evidence that the trade-off between proof length and tightness of bound is modulated by the faithfulness of the mechanistic understanding used to derive the proof (Section 5.2). 5 However, we also identify compounding structureless error terms as a key challenge for generating compact proofs on model behavior (Sections 5.3, and G.2.5). The implementation of algorithms inside of neural networks may contain components that defy mechanistic understanding and appear to us as “noise”. When we don’t know how noise composes across model components, establishing a bound requires pessimizing over the ways the composition could occur. Worst-case noise can quickly grow over components even when the empirical noise is small, leading to vacuous performance bounds (Appendix G.2.5). 2 Mechanistic interpretability for proofs Generalization bounds on global performanceIn the style of prior mechanistic interpretability evaluation work [6], we target theorem templates that establish bounds on the expected global performance of the model. LetM:X→Ybe a model (here assumed to be a neural network),Dbe a probability distribution over input-label pairs(l,t)∈L×X, notated asD| X when marginalized 4 Our 102 proof strategies are can be broken up as1 + 1 + 10×5×2: two standalone strategies, and a class of strategies parameterized on three axes of cardinality 10, 5, and 2 (Appendix H). 5 Code for reproducing our results can be found athttps://github.com/JasonGross/guarantees- based-mechanistic-interpretability/. A cache of generated data can be found athttps://github. com/JasonGross/guarantees-based-mechanistic-interpretability/tree/with-data. 2 over labels, andf:L×Y→Rbe a scoring function for evaluating the performance of the model. Then, we seek to establish lower boundsbon the expected ̄sas the form: ̄s:=E (l,t)∼D [f(l,M(t))]≥b.(1) Asfcan be any metric, this is a fully general template for theorems that can capture any aspect of model performance for which we have a formal specification. However, in this work we restrictfto be the accuracy andD| X to be uniform, so our theorems lower bound the accuracy of the model. Our proof methodology generalizes straightforwardly to other input distributions (Appendix A.8), and only a little work is required to generalize from accuracy to log-loss (Appendix A.11). Proof templateThe proofs of model performance in this work have two components: a computational componentC:model weights→Rand a non-computational componentQarguing that for any modelM ′ ,C(M ′ )≤E (l,t)∼D f(l,M ′ (t)), thus implying thatCgenerates a valid lower bound for the performance ofM. The whole proof isQpaired with a trace of runningCthat certifies its output onM. 6 Here,b=C(M). As even the size of the model parameters is much larger than any reasonableQ, we approximate the length of a proof pairC,Qby the length of a trace ofC(M). Proof compactness vs. tightness of boundDifferent proof strategies make different tradeoffs between compactness and tightness of bound. For example, consider two extreme proof strategies: We can “prove” a vacuous bound using a null proof. On the other hand, in the brute-force proof, we simply run the model on the entirety ofDto achieveb= ̄s, albeit with a very long proof. We quantify the length ofC(M)using two metrics: theasymptotic time complexityofCas we scale the size of the model and the inputt, as well as the empirical averagenumber of floating point operationsrequired to evaluateC(M ′ )over a given set of modelsM i . We measuretightness of boundofC(M)using the ratio of the bound to the true accuracy:b/ ̄s. Proof as pessimal ablationA standard way of assessing the faithfulness of mechanistic interpretabil- ity is by ablating the parts of the model that your interpretation does not explain [54,6,23]. In this framework, proofs can be thought of as performing apessimal ablationover the unexplained parts of the model – we set the remaining components of the model (the “noise” or error terms) to values overXthat minimize the performance of the model. However, the number of ablations required for a complete argument might be quite high. Thus, we constructrelaxations(Appendix A.4) over input sequences, such that performing pessimal ablations on a smaller number of relaxed input sequences is sufficient to lower bound the performance onD. 3 Experimental setting We study our approach to generating compact proofs in a simple toy setting: Max-of-K. Model ArchitectureWe study one-layer, one-head, attention-only transformers with no biases but with learned positional embeddings, with vocabulary sized vocab , model and head dimension d=d model =d head , and context lengthn ctx :=k. The model parameters consist of then ctx ×d model positional embeddingP; thed vocab ×d model token embedE; thed model ×d model query, key, value, and output matrices of the attention headQ,K,V, andO; as well as thed model ×d vocab unembed matrixU. We assume (as is standard in language modeling) thatd model < d vocab . For ann ctx ×d vocab one-hot encodingx= [x 0 ,x 1 ,...,x n ctx −1 ]of an input sequencet= [t 0 ,t 1 ,...,t n ctx −1 ], we compute the logits of the model as follows: h (0) =xE+PInitial residual stream(n ctx ×d model ) α=h (0) QK T h (0) T / √ dAttention matrix(n ctx ×n ctx ) h (1) =σ ∗ (α)·h (0) V O+h (0) Final residual stream(n ctx ×d model ) M(t) =ℓ=h (1) n ctx −1 UFinal seq. position logits(d vocab ) whereσ ∗ is the masked softmax function used in causal attention. Because we only look at outputs of the model above the final sequence positioni=n ctx −1, we also denote this position as the 6 Other components of the proof to account for the difference between floating point numbers and reals are described in Appendix A.7. Note that all proofs explicitly given in this paper are ofQonly; we do not include any traces of runningC. 3 “query position” and the value of the token in this position ast query , one-hot encoded asx query . The model’s prediction is the token corresponding to the max-valued logitℓ max . TaskSpecifically, we study the setting withn ctx =k= 4because it is the largest sequence length for which we can feasibly evaluate the brute-force proof. We set hidden dimensiond model = 32and a vocabulary of sized vocab = 64comprising integers between 0 and 63 inclusive. For an input sequence t= [t 0 ,t 1 ,t 2 ,t 3 ], we denote thetruemaximum of the sequence byt max . Outputting the correct behavior is equivalent to outputting logitsℓsuch that∆ℓ t ∗ :=ℓ t ∗ −ℓ max <0 for allt ∗ ̸=t max . We trained 151 models on this task. Models achieved an average accuracy of0.9992±0.0015over the entire data distribution. Path decompositionFollowing prior work [13], we expand the logits of the model and split the paths through the model into three components – the QK circuit, the OV circuit, and the direct path: M(t) =σ ∗ (x query E+P query )QK T (xE+P) T | z QK circuit / √ d ·(xE+P)V OU | z OV circuit + (x query E+P query )U | z direct path (2) Intuitively, the QK circuit determineswhichtokens the model attends to from a particular query token and sequence position, while the OV circuitprocessesthe tokens and sequence positions the model attends to. The direct path is simply the skip connection around the attention head. We further divide the QK and OV circuits into token (position-independent) and position-dependent components. LetP avg = P i P i /n ctx be the average position embeds across positions (of size d model ), and let ̄ Pdenote either1 n ctx ⊗P avg or1 d vocab ⊗P avg depending on context, the result of broadcastingP avg back into the shape ofPorE(that is,n ctx ×d model ord vocab ×d model ). Similarly, letP q =1 d vocab ⊗P query be the result of broadcastingP query . Then for one-hot encodedx, we can rewrite the QK and OV circuits, as well as the direct path, as follows: QK circuit=x query E q QK T ̄ E T | z EQKE x T +E q QK T ˆ P T | z EQKP OV circuit=x ̄ EV OU | z EVOU + ˆ PV OU |z PVOU Direct Path=x query E q U | z EU where ˆ P=P− ̄ Pand ̄ E=E+ ̄ PandE q =E+P q (sinceh (0) =x ̄ E+ ˆ P). 3.1 Mechanistic interpretation of learned models Input Logits (ℓ 0 ,ℓ 1 ,...,ℓ 63 ) Unembed Embed t 0 t 1 t 2 t 3 Direct path “does nothing” OV Circuit performs low-rank copying QK Circuit attends to larger tokens more Figure 2:The models in our setting implement Max-of-Kby attending ex- ponentially more to larger tokens and copying the attended-to tokens (Sec- tion 3.1). Using standard empirical mechanistic interpretability tech- niques, we interpret one of our learned models (our “mainline” model) by independently examining the QK and OV circuits and thedirect path. 7 We find that the model outputs the largest logit on the true max tokent max by attending more to larger tokens via the QK circuit and copying the tokens it attends to via the OV circuit. We then quantitatively confirm that these interpretations hold for all 151 models by reporting the mean plus minus standard deviation for various summary statistics. Plots for this section are available in Appendix B.2. QK circuitBy qualitatively examining the position- independent QK componentEQKE, we find the amount of pre-softmax attention paid to a key token is approximately in- dependent of the value of the query tokent query , and increases monotonically based on the size of the key token. We confirm this hypothesis by performing a singular-value decomposition (SVD) of the EQKEmatrices (Appendix G.2.3), and find that it contains a single large rank-one component with singular value around7800±380, around620±130times larger than the second largest component with singular value13±3. The left (query-side) singular vector is approximately constant in all 7 All of our trained models behave similarly; see Appendix B.3. 4 dimensions, with value0.1243±0.0003≈ 1 ⁄ 8 = 1 / √ d vocab . The right (key-side) singular vector of this component is monotonically increasing as we increase the size of the key token, with (1/ √ d-scaled) pre-softmax attention increasing by an average of1.236±0.056when the key token increases by 1. 8 In comparison, each1/ √ d-scaled entry of the position-dependent QK componentEQKPhas negligi- ble size (average0.31±0.18), suggesting thatEQKPis unimportant to the functioning of the model. We confirm this by zero ablatingEQKP, which changes the models’ accuracies from0.9992±0.0015 to0.9993±0.0011. Combined with our interpretation ofEQKE, this implies that the attention pattern of the model depends only on the token values and not the ordering of the sequence. OV circuitThen, by qualitatively examining the position-independent OV componentEVOU, we see that it has large positive entries along the diagonal. In fact, the entry along the diagonal is the largest in the row for all rows corresponding tot >6.6±1.2. Since each entry in the sequence is uniformly sampled andd vocab = 64, this means thatEVOUis a good approximation for the identity matrix for all but≈(7/64) 4 ≈1.2×10 −2 %of the sequences. As with the position-dependent QK component, the position-dependent OV componentPVOUalso has negligible size and is unimportant to model performance. Taken together with the above results onEVOU, this suggests that the attention head copies the tokens it attends to. Direct pathAs with the two position-dependent components, the entries inEUhave small absolute magnitude2.54±0.20, 9 and contribute negligibly to model performance. 4 Proofs of model performance In this section we describe intuitions for three categories of proof that are developed around different mechanistic interpretations and methods for using the interpretations. The strategies result in proofs of different complexities with varying bound tightness (Table 1). We provide detailed theorem statements, proofs, algorithms, and explanations of proof search in Appendices C, D, E, F, and G. Our theorem statements forQwill all be of the form ∀M ′ ,C specific strategy (M ′ )≤E t∼D| X f(t max ,M ′ (t)). We leave implicit the traces of runningC specific strategy on our specific models to give the overall theorem. We report the computational complexity or estimated FLOPs of runningC specific strategy as approximations for our proof lengths. 4.1 The brute-force baseline We start by considering the brute-force proof (Appendix D), which treats the model as a black box and evaluates it on all possible sequences. 10 However, this proof strategy has bad asymptotic complexity and is untenable for larger models and larger input distributions. So in subsequent sections, we use knowledge of the model drawn from the interpretation in Section 3.1 to derive more compact proofs. 4.2 A cubic proof Next, we use the fact that the model is composed of the direct path and the QK and OV circuits (Section 3) to decrease the number of sequences that we need to consider, and the fact that only the position-independent componentsEQKEandEVOUcontribute meaningfully to performance (Section 3.1) to pessimize over sequence ordering. First, let a pure sequenceξbe a sequence with at most three distinct tokens: the max tokent max , the final tokent query ≤t max , and optionally a third tokent ′ < t max , and letΞ pure be the set of all pure sequences inX. 11 For a given input sequencet, define the adjacent pure sequencesAdj(t)as the set of sequences that share the same max and query token, and only take on values int: Adj(t) = n ξ∈Ξ pure max i ξ i =t max , ξ query =t query ,∀i < n ctx , ξ i ∈t o 8 This implies that the ratio of attention paid to tokentandt−1is approximatelyexp(1.236)≈3.442. 9 For comparison, the average off-diagonal element ofEVOUis21.68±0.83below the corresponding diagonal element. 10 Appendix A.10 discusses how to compute the “brute-force” accuracy of a model on an infinite distribution. 11 In Section 4.3, we will consider a smaller set of “pure sequences”. 5 Table 1:We report the proof complexity, accuracy bound, and estimated flops required (Equation 2), as well as unexplained dimensionality (Section 5). We round the FLOP and unexplained dimension counts to the closest power of 2, and report the mean/standard deviation of the bound averaged across all 151 models. As we include more aspects of the mechanistic interpretation (reflected by a lower number of unexplained dimensions), we get more compact proofs (in terms of both asymptotic complexity and FLOPs), albeit with worse bounds. For space reasons, we usek:=n ctx ,d:=d model , andv:=d vocab . Description of Proof Complexity CostBoundEst. FLOPs Unexplained Dimensions Brute forceO(v k+1 kd)0.9992±0.00152 47 2 30 CubicO(v 3 k 2 )0.9531±0.00872 25 2 14 Sub-cubicO(v 2 ·k 2 +v 2 ·d)0.702±0.0332 21 2 13 w/o mean+diff0.349±0.0802 21 2 13 Low-rank QKO(v 2 k 2 +vd 2 |z QK +v 2 d |z EU&OV )0.675±0.0352 22 2 12 SVD only0.284±0.0722 22 2 12 Low-rank EUO(v 2 k 2 +vd |z EU +v 2 d |z QK&OV )0.633±0.0622 21 2 13 SVD only(3.38±0.06)×10 −6 2 21 2 13 Low-rank QK&EUO(v 2 k 2 +vd 2 |z QK +vd |z EU +v 2 d |z OV )0.610±0.0602 21 2 13 SVD only(3.38±0.06)×10 −6 2 22 2 13 Quadratic QKO(v 2 k 2 +vd |z QK +v 2 d |z EU&OV )0.316±0.0372 21 2 12 Quadratic QK&EUO(v 2 k 2 +vd |z QK&EU +v 2 d |z OV )0.283±0.0362 21 2 13 Using the convexity of softmax and the fact that the model contains three paths, we can show that one-layer attention-only transformers satisfies a variant of the following convexity property: for a givent, ifM(ξ)is correct for allξ∈Adj(t), thenM(t)is correct. That is, for these transformers, we can bound the accuracy on all sequences by evaluatingMon only theO(d vocab 3 (n ctx −1)!) pure sequences. This allows us to bound the accuracy of our actualMon alld vocab n ctx sequences, while evaluating it onO(d vocab 3 (n ctx −1)!)sequences. We can reduce the number of sequences that we need to evaluate by pessimizing over the order of a sequence. For a given tuple of(t max ,t query ,t ′ ), there are(n ctx −1)!pure sequences, corresponding to the permutations of the tuple. Pessimizing over the order of sequences reduces the number of sequences to consider for each(t max ,t query ,t ′ )tuple to the number oft ′ in the pure sequence, and the total number of sequences toO(d vocab 3 n ctx ) . By precomputing the five component matrices EU,EQKE,EQKP,EVOU,PVOUand cleverly caching intermediate outputs, we can reduce the additional work of each sequence to theO(n ctx )required to compute the softmax overn ctx elements, resulting in asymptotic complexityO(d vocab 3 n ctx 2 ) (Theorem 12, additional details in Appendix E). 4.3 Sub-cubic proofs We now consider proofs that are more compact thanO(d vocab 3 ). These require avoiding iteration over any set of sizeO(d vocab 3 ) (e.g. the set of pure sequences) and performing operations that take O(d vocab )time on each ofO(d vocab 2 )combinations. Unfortunately, some methods of avoiding these operations can lead to vacuous bounds (i.e. accuracy lower bounds near0%). In order to recover non-vacuous bounds, we introduce two tricks: the “mean+diff trick” to better approximate the sum of two components with unequal variance, and the “max row diff trick” to improve upon the low-rank approximations forEUandEQKE. We consider applying variants of these tricks at different locations in the naïve subcubic proof, leading to 100 distinct subcubic proof strategies. See Appendix G.2 for a formal description of these strategies. 4.3.1 Removing cubic-time computations Reducing the number of cases by pessimizing over sufficiently small tokensPreviously, we consideredΘ(d vocab 3 n ctx )pure sequencesξ, withξparameterized by(t max ,t query ,t ′ ,c). Recall 6 from our mechanistic interpretation in Section 3.1 that the pre-softmax attention paid fromt query to a key tokent ′ is broadly invariant int query and increases roughly linearly with the size oft ′ . This allows us to pessimize over the OV circuit over all “sufficiently small” tokens. More formally, suppose we are given some gapg∈N. For each pure sequenceξwith max tokent max , query tokent query , such thatt query ≤t max −g, andccopies of the third token typet ′ ≤t max −g, we pessimally ablate the OV circuit over the setΞ pure (t max ,t query ,c;g)of pure sequencesξ ′ with the same max and query tokens andccopies of the third token typet ′ . If the model gets all sequences inΞ pure (t max ,t query ,c;g)correct, then we can conclude that it getsξcorrect, otherwise, we treat the model as having gottenξwrong. This means that it suffices to only consider theO(d vocab 2 n ctx ) pessimal pure sequences of each of theO(d vocab 2 n ctx )sets of the formΞ pure (t max ,t query ,c;g). Decoupling and pessimizing computations that requireO(d vocab 3 )computationsMany parts of our cubic certificate require iterating throughO(d vocab 2 ) cases parameterized by(t max ,t query ) or(t max ,t ′ ). For example, as part of the pessimization procedure over pure sequences, for each of thed vocab possible values oft max , we need to consider the relative effects on thed vocab -sized logits of attending to each of theO(d vocab )other tokenst ′ < t max , and for eacht max andt query , we need to check that the contribution of the direct path on logitsx query EUis not sufficiently large as to overwhelm the contribution fromx max EVOU. We independently pessimize over each of these components over one of thed vocab -sized axes: for example, instead of computingx max EVOU + x query EUfor eacht max ,t query pair, we first pessimally ablate the direct path along the query token (which takesO(d vocab 2 ) time as it does not depend on thet max , and then consider the sum x max EVOU + max x ′ x ′ EU. Since this sum no longer depends ont query , we only need to perform it O(d vocab )times, for a total cost ofO(d vocab 2 ). Low rank approximations toEQKEandEURecall from Section 3.1 thatEQKEis approximately rank 1, where the sole direction of variation is the size of the key token. By computing only the low rank approximation toEQKE, we can more cheaply compute the most significant component of the behavior in the QK circuit. To bound the remaining error, we can use the fact that after pulling off the first principal component from each of the four matrices we multiply, very little structure remains. We can find the rank 1/2 approximations by performing SVD onEQKE. We can efficiently compute the SVD inO(d vocab d model 2 )time by using the fact thatEQKEcan be written as the product of ad vocab ×d model matrix and ad model ×d vocab matrix. This allows us to avoid performing the O(d vocab 2 d model )-cost matrix multiplications to explicitly computeEQKE. Similarly, we can more efficiently check that the direct pathEUcontributes negligibly to the model outputs, by using SVD to decomposeEUinto a sum of rank 1 products (which we can evaluate exactly) and a high-rank error term that we can cheaply bound. 4.3.2 Additional subcubic proof strategies Tighter bounds for sums of variables with unequal variance via the “mean+diff trick”Suppose we want to lower bound the minimum of the sum of two functions over three variablesh(x,y,z) = f(x,y) +g(y,z), while only iterating over two variables at a time. The naïve way is to minimize f(x,y)andg(x,y)independently: min x,y,z h(x,y,z)≥min x,y f(x,y) + min y,z g(y,z) Here, the error comes from setting theys infandgto different values. But in cases whereg(y,z) varies significantly withyand only slightly withz, rewritinggas a sum of a component that is independent ofz(only varying alongy), and a component that depends onz, yields a better lower bound: min x,y,z h(x,y,z)≥min x,y (f(x,y) +E ′ z g(y,z ′ )) + min y,z (g(y,z)−E ′ z g(y,z ′ )) This estimate will have error at mostε, while the naïve estimator can have arbitrarily large error. We refer to this rewrite as the “mean+diff trick”. 12 From the mechanistic interpretation in Section 3.1, 12 In fact, this is the motivation behind the standard rewrites of QK and OV into position-independent and position-dependent components (Section 3). 7 2 29 2 39 0.0 0.2 0.4 0.6 0.8 1.0 FLOPs to Verify Proof (approximate) Normalized Accuracy Bound brute force (acc: 0.9992 ± 0.0015) cubic (rel acc: 0.9539 ± 0.0080) subcubic (rel acc: 0.700 ± 0.036) attention-d vocab d model 2 (rel acc: 0.675 ± 0.035) direct-quadratic (rel acc: 0.633 ± 0.062) attention-d vocab d model 2 , direct-quadratic (rel acc: 0.610 ± 0.060) attention-quadratic (rel acc: 0.316 ± 0.037) attention-quadratic, direct-quadratic (rel acc: 0.283 ± 0.037) brute-force linear baseline Figure 3:For each of the proofs in Section 4, we plot the number of FLOPs used to compute the certificate, as well as the normalized accuracy lower-bound (b/ ̄s). The brute-force proof (Section 4.1) computes the exact performance uses orders of magnitude more compute than other approaches. The cubic proof (Section 4.3) uses a small amount of mechanistic understanding and less compute, while still retaining good accuracy lower bounds. Finally, subcubic proofs (Section 4.3) require the entirety of the mechanistic interpretation of the model to attain non-vacuous bounds; this understanding allows us to further reduces compute costs, but we still achieve worse bounds. See Appendix H.2.1 for a detailed description of the various proof strategies. we know that some of the components barely vary among one or more axes. So we can apply the mean+diff trick to get tighter lower bounds. Avoiding matrix multiplications using the “max row-diff trick”Using properties of linear algebra, we derive a cheap approximation to the max row-diff for the product of matricesABin terms of the product of the max row-diff ofBand the absolute value ofA, which we deem the “max row-diff” trick. We apply this trick to get a better cheap bound on the error terms of low-rank approximations, without having to multiply out the full matrices. See Appendix G.2.2 for more details. See Appendix F for more variants and combinations of these strategies. 5 Results We run each of151transformers on the various proof strategies of different asymptotic complexity, and analyze these proofs to empirically examine the relationship between proof length, bound tightness, and degree of understanding. For each proof on each transformer, we approximate the length of the proof by estimating the number of FLOPs used, and plot this against the ratio of certified bound the true accuracyb/ ̄s(Equation 2) in Figure 3. There exists a clear trade-off between bound tightness and compactness of the proof – more compact proofs yield looser bounds, and tighter bounds are associated with more expensive proofs. 5.1 Compact proofs both require and provide mechanistic understanding Quantifying mechanistic understanding using unexplained dimensionalityWe first quantify the amount of mechanistic understanding used in a proof by measuring itsunexplained dimensionality – the number of free parameters required to fully describe model behavior, assuming the structural assumptions of the proof are correct. More detailed mechanistic interpretations will leave fewer free parameters that need to be filled in via empirical observation. (Details in Appendix A.5.) In Figure 5, we plot the two axes and find a suggestive correlation – that is, proofs based on less mechanistic understanding are longer. More mechanistic understanding allows for more compact proofsIn addition to the constructions in Section 4, the parts of proofs we were unable to compact seem to correspond to components that we do not mechanistically understand. For example, we could not cheaply bound the behavior of EVOUwithout multiplying out the matrices, and this seems in part because we do have a mechanistic understanding of howEVOUimplements low-rank copying. Compact proofs seem to provide understandingBy examining compact proofs, we can extract understanding about the model. For example, the fact that replacing each row ofEUwith its average across rows has little effect on the bound implies thatEUdoes not vary much based ont query . 5.2 Proof length vs. bound tightness trade-off is modulated by faithfulness of interpretation Compact proofs are less faithful to model internalsTo derive more compact proofs, we use our mechanistic understanding to simplify the model computation in ways that diverge from the original model internals. For example, in some subcubic proofs (Section 4.3), we approximateEQKEwith a 8 4006008001000 0.0 0.2 0.4 0.6 0.8 1.0 EPQKE Singular Ratio:σ 1 /σ 2 Normalized Accuracy Bound mean+max-diff max-diff mean+max-diff-subproduct max-diff-subproduct max-diff-exact svd mean-recursive+max-diff-subproduct-recursive mean+max-diff-subproduct-recursive max-diff-subproduct-recursive Figure 4:We plot the normalized accuracy bound versus the ratio of first and second singular values ofEQKE, for various types of subcubic proofs that depend on a rank-1 approximationEQKE. For each class of proof, the closerEQKEis to rank-1, the tighter the accuracy bound. This suggests that more faithful interpretations lead to tighter bounds even holding proof length fixed. Note that the “svd” proof strategy has the clearest upward trend (b/ ̄s= 0.000 20(σ 1 /σ 2 ) + 0.44,R 2 = 0.41). See Appendix H.2.2 for a detailed description of the various proof strategies. 2 23 2 27 2 31 2 35 2 39 2 43 2 47 2 14 2 17 2 20 2 23 2 26 2 29 FLOPs to Verify Proof (approximate) Unexplained Dimension (Estimated) brute force cubic subcubic attention-d vocab d model 2 direct-quadratic attention-d vocab d model 2 , direct-quadratic attention-quadratic attention-quadratic, direct-quadratic Figure 5:We plot, for each proof, the approximate number of flops required to evaluate the proof, versus the unexplained dimensionality (Section 5.1). More mechanistic understanding leaves fewer dimensions unexplained. We observe that more compact proofs seem to leave fewer unexplained dimensions, which is indicative of the relationship of mechanistic understanding and compact proofs. See Appendix H.2.1 for a detailed description of the various proof strategies. rank-1 approximation corresponding to the “size direction”. However, while other components are small, they’re nonzero; this approximation harms model internals. Less faithful interpretations lead to worse bounds on performanceTo confirm that faithfulness of understanding affects the tightness of bound independent of proof length, we plot the normalized accuracy bound of subcubic proofs that perform a rank-1 approximation toEQKE, versus the ratio of the first two singular components. A larger ratio between the components implies that the rank-1 approximation is more faithful. In Figure 4, we see a positive correlation between the two axes: when the interpretation is more faithful, the bounds are tighter, even at a fixed proof length. 5.3 Compounding structureless noise is a big challenge for compacting global-behavior proofs Pessimal error terms compound in the absence of known structureThe rank-1 approximation of EQKEhas small error. However, when making rank-1 approximations of each of the constituent matricesE,Q,K, pessimizing over the worst way to composing the individual small error terms leads to a bound on the error term ofEQKEthat is orders of magnitude larger than the actual error term. Because we don’t understand how the matrices compose in a way that doesn’t cause errors to compound (without just multiplying out the matrices), this approximation leads to a trivial bound on performance (Appendix G.2.5). We speculate that in many cases, there is no short human-interpretable description for why random noise or approximation errors do not compound across layers of neural networks (e.g., see the error correction results onrandomly initializedneural networks from Hänni et al. [21]), and thus that compounding structureless errors may be an issue in practice. 6 Related Work Generalization BoundsPrior work in the PAC-Bayes framework [58,36,12] proves generalization bounds over learning procedures, which are similar to the global performance bounds we consider in 9 this work. These proofs tend to provide statistical guarantees [25,26] about the outputs of a known stochastic training procedure, while we seek to bound the performance of particular trained models. Formally verifying neural networksMost prior work formally verifies neural networks either via model checking [28,7] or by relaxing the problem setting and taking an automated theorem proving approach [17, 50, 18, 35, 43] to verifylocalrobustness properties. These proof strategies tend to be derived by examining only the network architecture. We take an approach more akin to interactive theorem proving [22] and verifyglobalperformance properties by reverse-engineering the neural network weights. Mechanistic InterpretabilityFinally, mechanistic interpretability is the subfield of the broader field of understanding model internals [45], which is too large to faithfully summarize. Our work takes most direct inspiration from efforts to deeply understand how either toy models [38,9,53,2] or small pretrained text transformers [54,20] implement algorithmic tasks, generally by performing ablations and SVD. In contrast, we formally prove that a transformer implements an algorithm. Nichani et al.[39]proves that, in a significantly simplified 2-layer, 1-head attention-only transformer model and for the task of in-context bigram statistics, gradient descent will create induction heads [40]. Our results concern transformers with fixed weights. In concurrent work, Michaud et al.[34]use techniques inspired by mechanistic interpretability to perform automated program synthesis on 2-dimensional RNNs, while our work works with significantly larger transformer models. 7 Conclusion and Future Work SummaryIn this work, we used a Max-of-Ksetting to prototype the use of mechanistic interpretabil- ity to derive compact proofs of model behavior. Using varying amounts of understanding, we derived more efficient proof computations lower bounding model accuracy. We found preliminary evidence that mechanistic understanding can compactify proofs. Moreover, we observed that the tightness of the lower bound offered by various proof strategies can be used to grade the faithfulness our mechanistic interpretation. Finally, we identified compounding structureless errors as a key obstacle to deriving compact proofs of model behavior. Limitations and future workWe study one-layer attention-only transformers on a toy algorithmic task. Future work should explore the viability of deriving proofs via interpretability using larger models featuring MLPs or layernorm on more complex domains. In addition, we were unable to significantly compact the part of the proof involving the OV circuit, which future work can explore. The proofs we explored in this work also did not lead to qualitatively novel insights; future work may be able to derive such insights with improved techniques. Finally, future work can address the problem of compounding structureless errors, perhaps by relaxing from worst-case pessimal ablations to typical-case heuristic guarantees [8]. Acknowledgments and Disclosure of Funding We are immensely grateful to Paul Christiano for providing the initial support for this project and for his invaluable research advice, encouragement, and feedback throughout its duration. Additionally, we are thankful for clarifying discussions and feedback from Jacob Hilton, Matthew Coudron, Adrià Garriga-Alonso, Aryan Bhatt, Leo Gao, Jenny Nitishinskaya, Somsubhro Bagchi, Gabriel Wu, Erik Jenner, Ryan Greenblatt, Ronak Mehta, Louis Jaburi, and many others. Louis Jaburi in particular contributed the text of the final proof of Theorem 11 in Appendix E. We are indebted to various organizations for their support: • Alignment Research Center for funding this project and making it possible at all •Mentorship for Alignment Research Students (MARS) program of the Cambridge AI Safety Hub (CAISH) for setting up the collaboration between a subset of authors, and providing funding for compute and in-person research sprints •Constellation and FAR Labs for hosting a subset of the authors and providing an excel- lent research environment, including as part of the Visiting Fellows Program and Astra Fellowship 10 Author Contributions Jason Grossled the project, including managing the team and conceptualizing the proofs approach. He ran the Max-of-4 experiments, devised the proof strategies, and wrote up the formal proofs. He worked on various case studies and developed general methodology for computing complexity and length bounds for proofs. He also developed the particular convex relaxations presented in the paper. Rajashree Agrawalwas invaluable in steering the direction of the project, including contributing to the preliminary experiment on Max-of-2 and developing the pessimal ablation approach. She worked on framing the results, and contributed text to the paper. Thomas KwaandEuan Ongextended the preliminary experiments to larger values ofkand contributed substantially to the cubic proof.Chun Hei Yip,Alex Gibson, andSoufiane Noubir worked on case studies other than the Max-of-Ktask and informed discussion on proof complexity. Lawrence Chanspearheaded the writing of the paper, including turning informal claims into formal theorem statements, creating figures, and writing the core text. He also developed the unexplained dimensionality metric for clarifying the takeaway of the paper. 11 References [1] Behzad Akbarpour, Amr Abdel-Hamid, Sofiène Tahar, and John Harrison. Verifying a syn- thesized implementation of IEEE-754 floating-point exponential function using HOL.The Computer Journal, 53:465–488, May 2010. doi: 10.1093/comjnl/bxp023. [2]Ekin Akyürek, Dale Schuurmans, Jacob Andreas, Tengyu Ma, and Denny Zhou. What learning algorithm is in-context learning? investigations with linear models, 2022. [3] Andrew Appel and Ariel Kellison. VCFloat2: Floating-point error analysis in Coq. InProceed- ings of the 13th ACM SIGPLAN International Conference on Certified Programs and Proofs, CPP 2024, pages 14–29, New York, NY, USA, 2024. Association for Computing Machinery. ISBN 9798400704888. doi: 10.1145/3636501.3636953. [4] Sylvie Boldo and Guillaume Melquiond. Flocq: A unified library for proving floating-point algorithms in Coq. In2011 IEEE 20th Symposium on Computer Arithmetic, pages 243–252, July 2011. doi: 10.1109/ARITH.2011.40. [5]Tom B. Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel M. Ziegler, Jeffrey Wu, Clemens Winter, Christopher Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Radford, Ilya Sutskever, and Dario Amodei. Language models are few-shot learners, 2020. URLhttps://arxiv.org/abs/2005.14165. [6]Lawrence Chan, Adrià Garriga-Alonso, Nicholas Goldwosky-Dill, Ryan Greenblatt, Jenny Nitishinskaya, Ansh Radhakrishnan, Buck Shlegeris, and Nate Thomas.Causal scrub- bing, a method for rigorously testing interpretability hypotheses.AI Alignment Forum, 2022. URLhttps://w.alignmentforum.org/posts/JvZhhzycHu2Yd57RN/causal- scrubbing-a-method-for-rigorously-testing. [7]Chih-Hong Cheng, Georg Nührenberg, and Harald Ruess. Maximum resilience of artificial neural networks. InAutomated Technology for Verification and Analysis: 15th International Symposium, ATVA 2017, Pune, India, October 3–6, 2017, Proceedings 15, pages 251–268. Springer, 2017. [8]Paul Christiano, Eric Neyman, and Mark Xu. Formalizing the presumption of independence. arXiv preprint arXiv:2211.06738, 2022. doi: 10.48550/arxiv.2211.06738. [9]Bilal Chughtai, Lawrence Chan, and Neel Nanda. A toy model of universality: Reverse engineering how networks learn group operations, 2023. [10]Edmund M. Clarke, William Klieber, Miloš Nová ˇ cek, and Paolo Zuliani.Model Checking and the State Explosion Problem, pages 1–30. Springer Berlin Heidelberg, Berlin, Heidelberg, 2012. ISBN 978-3-642-35746-6. doi: 10.1007/978-3-642-35746-6_1. URLhttps://doi.org/10. 1007/978-3-642-35746-6_1. [11] David Dalrymple, Joar Skalse, Yoshua Bengio, Stuart Russell, Max Tegmark, Sanjit Seshia, Steve Omohundro, Christian Szegedy, Ben Goldhaber, Nora Ammann, et al. Towards guar- anteed safe AI: A framework for ensuring robust and reliable AI systems.arXiv preprint arXiv:2405.06624, 2024. [12]Gintare Karolina Dziugaite and Daniel M. Roy. Computing nonvacuous generalization bounds for deep (stochastic) neural networks with many more parameters than training data. Proceedings of the Thirty-Third Conference on Uncertainty in Artificial Intelligence, UAI 2016, August 11–15, 2017, Sydney, NSW, Australia, 2017. [13] Nelson Elhage, Neel Nanda, Catherine Olsson, Tom Henighan, Nicholas Joseph, Ben Mann, Amanda Askell, Yuntao Bai, Anna Chen, Tom Conerly, Nova DasSarma, Dawn Drain, Deep Ganguli, Zac Hatfield-Dodds, Danny Hernandez, Andy Jones, Jackson Kernion, Liane Lovitt, Kamal Ndousse, Dario Amodei, Tom Brown, Jack Clark, Jared Kaplan, Sam McCandlish, and Chris Olah. A mathematical framework for transformer circuits.Transformer Circuits Thread, 2021. URLhttps://transformer-circuits.pub/2021/framework/index.html. 12 [14]Martín H. Escardó. Synthetic topology of data types and classical spaces.Electronic Notes in Theoretical Computer Science, 87:21–156, November 2004. [15]Martín H. Escardó. Infinite sets that admit fast exhaustive search. InProceedings of the 22nd Annual IEEE Symposium on Logic in Computer Science (LICS 2007), Wrocław, Poland, July 2007. [16]Martín H. Escardó. Seemingly impossible functional programs, 2007. URLhttps://math. andrej.com/2007/09/28/seemingly-impossible-functional-programs/. Accessed: 2024-05-15. [17]T. Gehr, M. Mirman, D. Drachsler-Cohen, P. Tsankov, S. Chaudhuri, and M. Vechev. AI2: Safety and robustness certification of neural networks with abstract interpretation. In2018 IEEE Symposium on Security and Privacy (SP), pages 3–18, Los Alamitos, CA, USA, May 2018. IEEE Computer Society. doi: 10.1109/SP.2018.00058. [18] Sven Gowal, Krishnamurthy Dvijotham, Robert Stanforth, Rudy Bunel, Chongli Qin, Jonathan Uesato, Relja Arandjelovic, Timothy Mann, and Pushmeet Kohli. On the effectiveness of interval bound propagation for training verifiably robust models.arXiv preprint arXiv:1810.12715, 2018. [19]Jason S. Gross.Performance Engineering of Proof-Based Software Systems at Scale. PhD thesis, Massachusetts Institute of Technology, February 2021. URLhttps://dspace.mit. edu/handle/1721.1/130763. [20]Michael Hanna, Ollie Liu, and Alexandre Variengien. How does GPT-2 compute greater-than. Interpreting mathematical abilities in a pre-trained language model, 2:11, 2023. [21]Kaarel Hänni, Jake Mendel, Dmitry Vaintrob, and Lawrence Chan. Mathematical models of computation in superposition. InICML 2024 Workshop on Mechanistic Interpretability, 2024. URLhttps://openreview.net/forum?id=OcVJP8kClR. [22] John Harrison, Josef Urban, and Freek Wiedijk. History of interactive theorem proving. In Handbook of the History of Logic, volume 9, pages 135–214. Elsevier, 2014. [23]Stefan Heimersheim and Neel Nanda. How to use and interpret activation patching.arXiv preprint arXiv:2404.15255, 2024. [24] Jordan Hoffmann, Sebastian Borgeaud, Arthur Mensch, Elena Buchatskaya, Trevor Cai, Eliza Rutherford, Diego de Las Casas, Lisa Anne Hendricks, Johannes Welbl, Aidan Clark, et al. Training compute-optimal large language models.arXiv preprint arXiv:2203.15556, 2022. [25] Xiaowei Huang, Daniel Kroening, Wenjie Ruan, James Sharp, Youcheng Sun, Emese Thamo, Min Wu, and Xinping Yi. A survey of safety and trustworthiness of deep neural networks: Verification, testing, adversarial attack and defence, and interpretability.arXiv preprint arXiv:1812.08342, 2018. [26]Xiaowei Huang, Wenjie Ruan, Wei Huang, Gaojie Jin, Yi Dong, Changshun Wu, Saddek Bensalem, Ronghui Mu, Yi Qi, Xingyu Zhao, Kaiwen Cai, Yanghao Zhang, Sihao Wu, Peipei Xu, Dengyu Wu, Andre Freitas, and Mustafa A. Mustafa. A survey of safety and trustworthiness of large language models through the lens of verification and validation, 2023. [27] Jared Kaplan, Sam McCandlish, Tom Henighan, Tom B. Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. Scaling laws for neural language models.arXiv preprint arXiv:2001.08361, 2020. [28]Guy Katz, Clark Barrett, David L. Dill, Kyle Julian, and Mykel J. Kochenderfer. Reluplex: An efficient SMT solver for verifying deep neural networks. InComputer Aided Verification: 29th International Conference, CAV 2017, Heidelberg, Germany, July 24-28, 2017, Proceedings, Part I 30, pages 97–117. Springer, 2017. [29]A. E. Kellison, A. W. Appel, M. Tekriwal, and D. Bindel. LAProof: A library of formal proofs of accuracy and correctness for linear algebra programs. In2023 IEEE 30th Symposium on Computer Arithmetic (ARITH), pages 36–43, Los Alamitos, CA, USA, September 2023. IEEE Computer Society. doi: 10.1109/ARITH58626.2023.00021. 13 [30]Gerwin Klein, Kevin Elphinstone, Gernot Heiser, June Andronick, David Cock, Philip Derrin, Dhammika Elkaduwe, Kai Engelhardt, Rafal Kolanski, Michael Norrish, Thomas Sewell, Harvey Tuch, and Simon Winwood. seL4: Formal verification of an OS kernel. InProceedings of the ACM SIGOPS 22nd Symposium on Operating Systems Principles, SOSP ’09, pages 207– 220, New York, NY, USA, 2009. Association for Computing Machinery. ISBN 9781605587523. doi: 10.1145/1629575.1629596. URLhttps://doi.org/10.1145/1629575.1629596. [31]Xavier Leroy. A formally verified compiler back-end.Journal of Automated Reasoning, 43: 363–446, 2009. [32] Chuan Li. OpenAI’s GPT-3 language model: A technical overview, June 2020. URLhttps: //lambdalabs.com/blog/demystifying-gpt-3. Lambda Labs Blog, accessed October 30, 2024. [33] Wes McKinney. Data Structures for Statistical Computing in Python. In Stéfan van der Walt and Jarrod Millman, editors,Proceedings of the 9th Python in Science Conference, pages 56–61, 2010. doi: 10.25080/Majora-92bf1922-00a. [34]Eric J. Michaud, Isaac Liao, Vedang Lad, Ziming Liu, Anish Mudide, Chloe Loughridge, Zifan Carl Guo, Tara Rezaei Kheirkhah, Mateja Vukeli ́ c, and Max Tegmark. Opening the AI black box: program synthesis via mechanistic interpretability, 2024. [35] Matthew Mirman, Timon Gehr, and Martin Vechev. Differentiable abstract interpretation for provably robust neural networks. InInternational Conference on Machine Learning, pages 3578–3586. PMLR, 2018. [36]Vaishnavh Nagarajan and J. Zico Kolter. Uniform convergence may be unable to explain generalization in deep learning.Advances in Neural Information Processing Systems, 32, 2019. [37]Neel Nanda and Joseph Bloom.TransformerLens.https://github.com/ TransformerLensOrg/TransformerLens, 2022. [38]Neel Nanda, Lawrence Chan, Tom Lieberum, Jess Smith, and Jacob Steinhardt. Progress measures for grokking via mechanistic interpretability.arXiv preprint, 2023. doi: 10.48550/ arXiv.2301.05217. [39]Eshaan Nichani, Alex Damian, and Jason D. Lee. How transformers learn causal structure with gradient descent.arXiv preprint, 2024. doi: 10.48550/arXiv.2402.14735. [40]Catherine Olsson, Nelson Elhage, Neel Nanda, Nicholas Joseph, Nova DasSarma, Tom Henighan, Ben Mann, Amanda Askell, Yuntao Bai, Anna Chen, Tom Conerly, Dawn Drain, Deep Ganguli, Zac Hatfield-Dodds, Danny Hernandez, Scott Johnston, Andy Jones, Jackson Kernion, Liane Lovitt, Kamal Ndousse, Dario Amodei, Tom Brown, Jack Clark, Jared Kaplan, Sam McCandlish, and Chris Olah. In-context learning and induction heads.Transformer Circuits Thread, 2022. URLhttps://transformer-circuits.pub/2022/in-context- learning-and-induction-heads/index.html. [41] Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, et al. PyTorch: An imperative style, high-performancedeep learning library.Advances in neural information processing systems, 32, 2019. [42] Plotly Technologies Inc. Collaborative data science, 2015. URLhttps://plot.ly. [43]Aditi Raghunathan, Jacob Steinhardt, and Percy S. Liang. Semidefinite relaxations for certifying robustness to adversarial examples.Advances in neural information processing systems, 31, 2018. [44] Tahina Ramananandro, Paul Mountcastle, Benoıˆt Meister, and Richard Lethin. A unified Coq framework for verifying C programs with floating-point computations. InProceedings of the 5th ACM SIGPLAN Conference on Certified Programs and Proofs, CPP 2016, pages 15–26, New York, NY, USA, 2016. Association for Computing Machinery. ISBN 9781450341271. doi: 10.1145/2854065.2854066. 14 [45]Tilman Räuker, Anson Ho, Stephen Casper, and Dylan Hadfield-Menell. Toward transparent AI: A survey on interpreting the inner structures of deep neural networks. InFirst IEEE Conference on Secure and Trustworthy Machine Learning, 2022. doi: 10.48550/arxiv.2207.13243. [46]Scott Reed, Konrad Zolna, Emilio Parisotto, Sergio Gomez Colmenarejo, Alexander Novikov, Gabriel Barth-Maron, Mai Gimenez, Yury Sulsky, Jackie Kay, Jost Tobias Springenberg, Tom Eccles, Jake Bruce, Ali Razavi, Ashley Edwards, Nicolas Heess, Yutian Chen, Raia Hadsell, Oriol Vinyals, Mahyar Bordbar, and Nando de Freitas. A generalist agent, 2022. URL https://arxiv.org/abs/2205.06175. [47]Alex Rogozhnikov. Einops: Clear and reliable tensor manipulations with Einstein-like notation. InInternational Conference on Learning Representations, 2022. URLhttps://openreview. net/forum?id=oapKSVM2bcj. [48]Sanjit A. Seshia, Dorsa Sadigh, and S. Shankar Sastry. Toward verified artificial intelligence making AI more trustworthy with a formal methods-based approach to AI system verification and validation. [49]Alex K. Simpson. Lazy functional algorithms for exact real functionals. In Luboš Brim, Jozef Gruska, and Ji ˇ rí Zlatuška, editors,Mathematical Foundations of Computer Science 1998, pages 456–464, Berlin, Heidelberg, 1998. Springer Berlin Heidelberg. ISBN 978-3-540-68532-6. [50]Gagandeep Singh, Timon Gehr, Markus Püschel, and Martin Vechev. An abstract domain for certifying neural networks.Proc. ACM Program. Lang., 3(POPL), January 2019. doi: 10.1145/3290354. [51] Christian Szegedy, Wojciech Zaremba, Ilya Sutskever, Joan Bruna, Dumitru Erhan, Ian Goodfel- low, and Rob Fergus. Intriguing properties of neural networks.arXiv preprint arXiv:1312.6199, 2013. [52] Pauli Virtanen, Ralf Gommers, Travis E. Oliphant, Matt Haberland, Tyler Reddy, David Cournapeau, Evgeni Burovski, Pearu Peterson, Warren Weckesser, Jonathan Bright, Stéfan J. van der Walt, Matthew Brett, Joshua Wilson, K. Jarrod Millman, Nikolay Mayorov, Andrew R. J. Nelson, Eric Jones, Robert Kern, Eric Larson, C. J. Carey, ̇ Ilhan Polat, Yu Feng, Eric W. Moore, Jake VanderPlas, Denis Laxalde, Josef Perktold, Robert Cimrman, Ian Henriksen, E. A. Quintero, Charles R. Harris, Anne M. Archibald, Antônio H. Ribeiro, Fabian Pedregosa, Paul van Mulbregt, and SciPy 1.0 Contributors. SciPy 1.0: Fundamental algorithms for scientific computing in Python.Nature Methods, 17:261–272, 2020. doi: 10.1038/s41592-019-0686-2. [53] Johannes Von Oswald, Eyvind Niklasson, Ettore Randazzo, João Sacramento, Alexander Mordvintsev, Andrey Zhmoginov, and Max Vladymyrov. Transformers learn in-context by gradient descent. InInternational Conference on Machine Learning, pages 35151–35174. PMLR, 2023. [54]Kevin Wang, Alexandre Variengien, Arthur Conmy, Buck Shlegeris, and Jacob Steinhardt. Interpretability in the wild: a circuit for indirect object identification in GPT-2 small.arXiv preprint, 2022. doi: 10.48550/arXiv.2211.00593. [55]Michael L. Waskom. seaborn: statistical data visualization.Journal of Open Source Software, 6 (60):3021, 2021. doi: 10.21105/joss.03021. [56] Jason Wei, Yi Tay, Rishi Bommasani, Colin Raffel, Barret Zoph, Sebastian Borgeaud, Dani Yogatama, Maarten Bosma, Denny Zhou, Donald Metzler, et al. Emergent abilities of large language models.arXiv preprint arXiv:2206.07682, 2022. [57]Eric Wong and Zico Kolter. Provable defenses against adversarial examples via the convex outer adversarial polytope. InInternational conference on machine learning, pages 5286–5295. PMLR, 2018. [58] Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, and Oriol Vinyals. Understanding deep learning (still) requires rethinking generalization.Communications of the ACM, 64(3): 107–115, 2021. 15 A Subtleties of our approach In this section, we address some subtleties and frequently asked questions about our approach. A.1 Why study this simple task? Formal reasoning is computationally expensive; very few large software projects have ever been verified [31,30], none of them comparable to large transformer models [10,19]. Separately, there is a high fixed cost to taking on any verification project, regardless of computational efficiency of the verification itself. Thus, we picked the simplest setting to study the question of interest: Is it even possible to formally reason more efficiently than by brute force about model behavior? A.2 Scalability In this section, we address concerns about the scalability of our approach. A.2.1 Larger input spaces We demonstrate that our proof strategies can be reused on larger input spaces while scaling better than the brute force approach does. We applied our proof strategies to models trained for Max-of-5, Max-of-10, and Max-of-20. While running the brute force proof on Max-of-20 would require approximately2 148 FLOPs, which is about 2 70 ×the cost of training GPT-3 [32], our cubic proof achieves bounds of(94.1±1.1) %(Max-of-5), (91.4±2.1) %(Max-of-10), and(88.4±4.0) %(Max-of-20) in under two minutes. See Tables 2, 3, 4, and 5 for more detailed numbers, and Figures 6, and 7 for visualizations. These results demonstrate that proof strategies can be reused on larger input spaces while scaling better than the brute force approach does. A.2.2 Different tasks In this paper, we worked on highly optimizing our relaxation to make our bounds as tight as possible when incorporating as little understanding as possible. This is not necessary for deriving proofs. Our general formalization of mechanistic interpretability is replicable: (1) theorem statements are exact expressions for the difference between the actual behavior of the model and the purported behavior, and (2) proofs are computations that bound the expression. Furthermore, our convexity theorems and proofs are applicable much more generally generally to element retrieval tasks. A.2.3 More complicated architectures We worked on a simple model studied inA Mathematical Framework for Transformer Circuits[13]. In follow-up work, we will extend this approach to proving bounds on 1L transformers with ReLU MLP trained on modular addition. A.2.4 Larger models It is an open question whether or not the mechanistic interpretability approach to proofs can scale to larger models. However, a large part of this question lies in the feasibility of deriving a high degree of faithful mechanistic understanding from large models — that is, whether mechanistic interpretability itself will scale. This is widely recognized in the field, and scaling interpretability approaches while getting both a high degree of mechanistic understanding and assurances that said understanding is faithful to the model is an active area of research. Broadly, we see the compact proofs approach as a metric on the quality of mechanistic understanding — we are not purporting to have a general solution to the problem of scaling interpretability, but instead claim that the challenges in proofs are in fact challenges in understanding networks. 16 Table 2:Version of Table 1 from Section 4.1 withn ctx = 5,d vocab = 64. We report the proof complexity, accuracy bound, and estimated flops required (Equation 2), as well as unexplained dimensionality (Section 5). Unlike Table 1, which computes the brute force bound exactly, we instead use importance sampling to estimate the bound; estimated FLOPs are reported for what the full brute force proof would take. We round the FLOP and unexplained dimension counts to the closest power of 2, and report the mean/standard deviation of the bound averaged across all 151 models. For space reasons, we usek:=n ctx ,d:=d model , andv:=d vocab . Description of Proof Complexity CostBoundEst. FLOPs Unexplained Dimensions Brute forceO(v k+1 kd)0.9990±0.00182 54 2 36 CubicO(v 3 k 2 )0.941±0.0112 26 2 14 Sub-cubicO(v 2 ·k 2 +v 2 ·d)0.705±0.0312 22 2 13 w/o mean+diff0.405±0.0732 22 2 13 Low-rank QKO(v 2 k 2 +vd 2 |z QK +v 2 d |z EU&OV )0.682±0.0332 22 2 12 SVD only0.335±0.0662 22 2 12 Low-rank EUO(v 2 k 2 +vd |z EU +v 2 d |z QK&OV )0.649±0.0552 21 2 13 SVD only(4.8±0.1)×10 −8 2 21 2 13 Low-rank QK&EUO(v 2 k 2 +vd 2 |z QK +vd |z EU +v 2 d |z OV )0.628±0.0532 22 2 13 SVD only(4.8±0.1)×10 −8 2 22 2 13 Quadratic QKO(v 2 k 2 +vd |z QK +v 2 d |z EU&OV )0.354±0.0342 21 2 12 Quadratic QK&EUO(v 2 k 2 +vd |z QK&EU +v 2 d |z OV )0.335±0.0332 21 2 13 Table 3:Version of Table 1 from Section 4.1 withn ctx = 10,d vocab = 64. We report the proof complexity, accuracy bound, and estimated flops required (Equation 2), as well as unexplained dimensionality (Section 5). Unlike Table 1, which computes the brute force bound exactly, we instead use importance sampling to estimate the bound; estimated FLOPs are reported for what the full brute force proof would take. We round the FLOP and unexplained dimension counts to the closest power of 2, and report the mean/standard deviation of the bound averaged across all 151 models. For space reasons, we usek:=n ctx ,d:=d model , andv:=d vocab . Description of Proof Complexity CostBoundEst. FLOPs Unexplained Dimensions Brute forceO(v k+1 kd)0.9988±0.00132 86 2 66 CubicO(v 3 k 2 )0.914±0.0212 28 2 14 Sub-cubicO(v 2 ·k 2 +v 2 ·d)0.674±0.0282 23 2 13 w/o mean+diff0.539±0.0612 23 2 13 Low-rank QKO(v 2 k 2 +vd 2 |z QK +v 2 d |z EU&OV )0.657±0.0282 23 2 12 SVD only0.469±0.0592 23 2 12 Low-rank EUO(v 2 k 2 +vd |z EU +v 2 d |z QK&OV )0.639±0.0322 23 2 13 SVD only(0±100)×10 −12 2 22 2 13 Low-rank QK&EUO(v 2 k 2 +vd 2 |z QK +vd |z EU +v 2 d |z OV )0.625±0.0312 23 2 13 SVD only(2.9±0.1)×10 −17 2 23 2 13 Quadratic QKO(v 2 k 2 +vd |z QK +v 2 d |z EU&OV )0.392±0.0302 22 2 12 Quadratic QK&EUO(v 2 k 2 +vd |z QK&EU +v 2 d |z OV )0.390±0.0282 22 2 13 17 Table 4:Version of Table 1 from Section 4.1 withn ctx = 10andd vocab = 128. We report the proof complexity, accuracy bound, and estimated flops required (Equation 2), as well as unexplained dimensionality (Section 5). Unlike Table 1, which computes the brute force bound exactly, we instead use importance sampling to estimate the bound; estimated FLOPs are reported for what the full brute force proof would take. We round the FLOP and unexplained dimension counts to the closest power of 2, and report the mean/standard deviation of the bound averaged across all 151 models. For space reasons, we usek:=n ctx ,d:=d model , andv:=d vocab . Description of Proof Complexity CostBoundEst. FLOPs Unexplained Dimensions Brute forceO(v k+1 kd)0.9972±0.00312 96 2 77 CubicO(v 3 k 2 )0.882±0.0122 31 2 16 Sub-cubicO(v 2 ·k 2 +v 2 ·d)0.622±0.0312 24 2 15 w/o mean+diff0.390±0.0702 24 2 15 Low-rank QKO(v 2 k 2 +vd 2 |z QK +v 2 d |z EU&OV )0.594±0.0352 24 2 14 SVD only0.320±0.0532 25 2 14 Low-rank EUO(v 2 k 2 +vd |z EU +v 2 d |z QK&OV )0.607±0.0312 24 2 15 SVD only(5.4±0.2)×10 −20 2 24 2 15 Low-rank QK&EUO(v 2 k 2 +vd 2 |z QK +vd |z EU +v 2 d |z OV )0.595±0.0302 24 2 14 SVD only(5.4±0.2)×10 −20 2 25 2 14 Quadratic QKO(v 2 k 2 +vd |z QK +v 2 d |z EU&OV )0.350±0.0292 24 2 14 Quadratic QK&EUO(v 2 k 2 +vd |z QK&EU +v 2 d |z OV )0.384±0.0252 24 2 14 Table 5:Version of Table 1 from Section 4.1 withn ctx = 20,d vocab = 64. We report the proof complexity, accuracy bound, and estimated flops required (Equation 2), as well as unexplained dimensionality (Section 5). Unlike Table 1, which computes the brute force bound exactly, we instead use importance sampling to estimate the bound; estimated FLOPs are reported for what the full brute force proof would take. We round the FLOP and unexplained dimension counts to the closest power of 2, and report the mean/standard deviation of the bound averaged across all 151 models. For space reasons, we usek:=n ctx ,d:=d model , andv:=d vocab . Description of Proof Complexity CostBoundEst. FLOPs Unexplained Dimensions Brute forceO(v k+1 kd)0.995±0.0152 148 2 126 CubicO(v 3 k 2 )0.884±0.0402 29 2 14 Sub-cubicO(v 2 ·k 2 +v 2 ·d)0.561±0.0432 24 2 13 w/o mean+diff0.486±0.0602 24 2 13 Low-rank QKO(v 2 k 2 +vd 2 |z QK +v 2 d |z EU&OV )0.547±0.0432 24 2 12 SVD only0.431±0.0602 24 2 12 Low-rank EUO(v 2 k 2 +vd |z EU +v 2 d |z QK&OV )0.538±0.0432 24 2 13 SVD only(1.0±6.0)×10 −4 2 24 2 13 Low-rank QK&EUO(v 2 k 2 +vd 2 |z QK +vd |z EU +v 2 d |z OV )0.526±0.0412 24 2 13 SVD only(1.0±5.0)×10 −4 2 24 2 13 Quadratic QKO(v 2 k 2 +vd |z QK +v 2 d |z EU&OV )0.322±0.0352 24 2 12 Quadratic QK&EUO(v 2 k 2 +vd |z QK&EU +v 2 d |z OV )0.321±0.0352 24 2 13 18 2 31 2 43 2 55 0.0 0.2 0.4 0.6 0.8 1.0 FLOPs to Verify Proof (approximate) Normalized Accuracy Bound brute force (estimated) (acc: 0.9990 ± 0.0018) cubic (rel acc: 0.942 ± 0.010) subcubic (rel acc: 0.702 ± 0.033) direct-quadratic (rel acc: 0.649 ± 0.055) attention-d vocab d model 2 (rel acc: 0.682 ± 0.033) attention-d vocab d model 2 , direct-quadratic (rel acc: 0.629 ± 0.053) attention-quadratic (rel acc: 0.354 ± 0.034) attention-quadratic, direct-quadratic (rel acc: 0.336 ± 0.033) (a)n ctx = 5,d vocab = 64 2 42 2 65 2 88 0.0 0.2 0.4 0.6 0.8 1.0 FLOPs to Verify Proof (approximate) Normalized Accuracy Bound brute force (estimated) (acc: 0.9988 ± 0.0013) cubic (rel acc: 0.915 ± 0.021) subcubic (rel acc: 0.673 ± 0.030) direct-quadratic (rel acc: 0.640 ± 0.032) attention-d vocab d model 2 (rel acc: 0.658 ± 0.028) attention-d vocab d model 2 , direct-quadratic (rel acc: 0.626 ± 0.031) attention-quadratic (rel acc: 0.393 ± 0.030) (b)n ctx = 10,d vocab = 64 2 47 2 74 0.0 0.2 0.4 0.6 0.8 1.0 FLOPs to Verify Proof (approximate) Normalized Accuracy Bound brute force (estimated) (acc: 0.9972 ± 0.0031) cubic (rel acc: 0.885 ± 0.012) direct-quadratic (rel acc: 0.609 ± 0.030) subcubic (rel acc: 0.613 ± 0.037) attention-d vocab d model 2 , direct-quadratic (rel acc: 0.596 ± 0.030) attention-d vocab d model 2 (rel acc: 0.595 ± 0.035) attention-quadratic, direct-quadratic (rel acc: 0.385 ± 0.025) (c)n ctx = 10andd vocab = 128 2 63 2 109 0.0 0.2 0.4 0.6 0.8 1.0 FLOPs to Verify Proof (approximate) Normalized Accuracy Bound brute force (estimated) (acc: 0.995 ± 0.015) cubic (rel acc: 0.888 ± 0.036) subcubic (rel acc: 0.563 ± 0.041) attention-d vocab d model 2 (rel acc: 0.549 ± 0.041) direct-quadratic (rel acc: 0.540 ± 0.040) attention-d vocab d model 2 , direct-quadratic (rel acc: 0.528 ± 0.039) attention-quadratic, direct-quadratic (rel acc: 0.322 ± 0.034) attention-quadratic (rel acc: 0.323 ± 0.035) (d)n ctx = 20,d vocab = 64 Figure 6:Version of Figure 3 from page 8 with varyingn ctx andd vocab . The brute-force proof (Section 4.1) computes the exact performance uses orders of magnitude more compute than other approaches; unlike in Figure 3, here we use importance sampling to estimate the bound. 19 40060080010001200 0.0 0.2 0.4 0.6 0.8 1.0 EPQKE Singular Ratio:σ 1 /σ 2 Normalized Accuracy Bound max-diff-exact max-diff-subproduct mean+max-diff-subproduct max-diff mean+max-diff svd mean+max-diff-subproduct-recursive max-diff-subproduct-recursive mean-recursive+max-diff-subproduct-recursive (a)n ctx = 5,d vocab = 64. The “svd” proof strategy best-fit line has equationb/ ̄s= 0.000 15(σ 1 /σ 2 ) + 0.48, R 2 = 0.37. 40060080010001200140016001800 0.0 0.2 0.4 0.6 0.8 1.0 EPQKE Singular Ratio:σ 1 /σ 2 Normalized Accuracy Bound max-diff-exact max-diff-subproduct mean+max-diff-subproduct max-diff mean+max-diff svd max-diff-subproduct-recursive mean+max-diff-subproduct-recursive mean-recursive+max-diff-subproduct-recursive (b)n ctx = 10,d vocab = 64. The “svd” proof strategy best-fit line has equationb/ ̄s= 0.000 074(σ 1 /σ 2 )+0.51, R 2 = 0.23. 800100012001400160018002000 0.0 0.2 0.4 0.6 0.8 1.0 EPQKE Singular Ratio:σ 1 /σ 2 Normalized Accuracy Bound max-diff-exact max-diff-subproduct mean+max-diff-subproduct max-diff mean+max-diff svd max-diff-subproduct-recursive mean-recursive+max-diff-subproduct-recursive mean+max-diff-subproduct-recursive (c)n ctx = 10andd vocab = 128. The “svd” proof strategy best-fit line has equationb/ ̄s= 0.000 085(σ 1 /σ 2 ) + 0.40,R 2 = 0.42. 400600800100012001400 0.0 0.2 0.4 0.6 0.8 1.0 EPQKE Singular Ratio:σ 1 /σ 2 Normalized Accuracy Bound max-diff-exact max-diff-subproduct mean+max-diff-subproduct mean+max-diff max-diff svd mean-recursive+max-diff-subproduct-recursive mean+max-diff-subproduct-recursive max-diff-subproduct-recursive (d)n ctx = 20,d vocab = 64. The “svd” proof strategy best-fit line has equationb/ ̄s= 0.000 098(σ 1 /σ 2 )+0.41, R 2 = 0.22. Figure 7:Version of Figure 4 from page 9 with varyingn ctx andd vocab . Note that the “svd” proof strategy has a clear upward trend, especially on early points. 20 A.3 Why is more mechanistic understanding correlated with worse bounds? Figure 3 exhibits Simpson’s Paradox: although more faithful mechanistic understanding is correlated with better bounds within each class of proof (and moreover the most extensive mechanistic under- standing results in the greatest improvement in bound tightness over baseline), when we aggregate across all proof strategies, we find that more mechanistic understanding is correlated with worse bounds. This relationship is summarized in Figure 8. From the compression perspective, more mechanistic understanding is about having more compres- sion. Unless the model is losselessly compressible, we should expect that more compression will inherently be more lossy, no matter how good our compression scheme is. Correspondingly, using more understanding to get more compression will often result in a weaker bound, no matter how good our understanding is. Conversely, we can think of the quality of proofs (the combination of tightness of bound, and length of proof) as a metric for how good our mechanistic understanding is. From this perspective, the fact that mechanistic-interpretability-derived bounds are bad suggests gaps in our mechanistic understanding. As the field matures and we develop tools that enable more faithful and complete understanding of model behavior, we expect that the quality of bounds we derive from mechanistic understanding will improve. A.4 Convex relaxation In this work, we construct convex relaxations to perform the pessimal ablations for our proofs. In what sense are we using “convexity”? The intuition is that we are attempting to optimize a function fover its domainXby incrementally making local changes to the input sequence, such as replacing one token by another, or by changing the order of tokens. The reason that convex optimization problems are easy to solve is that all local extrema are global extrema. This is not the case for our optimization problem, so we find a relaxation offand its domain such that all local extrema are in fact global extrema. Furthermore, most convex optimizers perform optimization at runtime by repeatedly stepping towards extrema. In this work, we “optimize by hand”, performing the optimization in the proof of our general theorems. The computation of the bound then only needs to instantiate the precomputed possible extrema with the actual values of the model’s parameters to determine the the extrema actually are. We now give a formal description of what we mean by “convex relaxation”. For a set of inputsX i , we define a set of “relaxed inputs”X relaxed i with an injectionT i :X i ,→ P(X relaxed i )mapping input to the model to the set of corresponding relaxed inputs. On the relaxed input, we define a functionh i :X relaxed i →Rsuch that for allt∈X i and all labelslfor which(l,t) is supported by (has non-zero probability in)D, we can findt relaxed ∈T i (t)withf(l,M(t))≥ h i (t relaxed ) . We proceed by finding a small subset of “boundary” examplesB i ⊂X relaxed i , proving that ifh i (t relaxed )≥b i for allt relaxed ∈B i thenh i (t relaxed )≥b i for allt relaxed ∈X relaxed i . Then, the computational componentCof the proof validates that thath i (t relaxed )≥b i for someb i for allt relaxed ∈X relaxed i . This allows us to conclude thatf(l,M(t))≥b i for allt∈X i . A.5 Computing unexplained dimensionality We claim in Figure 5 that we can use unexplained dimensionality as a metric for understanding. Here we describe how we compute the unexplained dimensionality of a proof strategy. As in Figure 1, for any given proof, we can separate our treatment of transformer components into “black-box” (e.g., matrix multiplication) and “white-box” components (e.g., specifying that the QK circuit is approximately rank one; pessimizing over non-max tokens). Considering the performance score as a large white-box component which may reference black-boxes internally, we define the unexplained dimensionality of a single black-box computation as the log-cardinality of it function space (so, e.g,2·64for a function64→R 2 , whose cardinality is(R 2 ) 64 , where64denotes the finite 21 FLOPs to Verify Proof Performance Lower Bound Brute Force Proof True Performance Trivial Proof Impractical baseline of guaranteeing performance via inference (a)The baseline of using inference to generate proofs. FLOPs to Verify Proof Performance Lower Bound Understanding partially recovers bound tightness Decreased lengthof proof leads tolooser bounds True Performance (b)Shorter proofs by default have a worse performance boundb. Faithful under- standing allows us to recover significant — but not complete — bound tightness with minimal proof-length overhead. Figure 8:The theoretical relationship between proof length and bound tightness. set on 64 elements). The unexplained dimensionality of the entire proof is the sum of the unexplained dimensions of all black-box components. Intuitively speaking, unexplained dimensionality tries to capture the degrees of freedom that we have to check via brute enumeration over black-box computations. Proofs with less unexplained dimensionality contain more mechanistic understanding, and vice versa. A.6 Computing approximate FLOPs In Figure 3 and Table 1 on page 6 and on page 8, we display approximate floating point operations. We instrument our code to execute on phantom tensors that track their shape and accumulate an approximate count of floating point operations. We compute matrix additions and multiplications in the obvious way. We take the instruction count of SVD to be the cost of verifying that the output of SVD is a valid decomposition: that we have a pair of orthonormal bases which when multiplied out give the original basis. A.7 IEEE 754 vs.R In Section 2 we definedCandQand glossed over whether we were reasoning over reals or floats. Here we clarify this point that we’ve so far been sweeping under the rug. LetFdenote the set of the relevant flavor of IEEE 754 Floating Point numbers (generally 32-bit for our concrete models, but everything would hold just as well for 64-bit). LetF ∗ denoteFrestricted to finite numbers (that is, without NaNs and without±∞). We parameterizeC,M, andDover the real field 13 they operate on, so that, e.g.,C F : model weights→F. Then we haveQestablishing that for any modelM ′ ,C R (M ′ R )≤ E (l,t)∼D R f R (l,M ′ R (t)), and we have a trace demonstrating thatC F (M F ) =b. 13 Technically the floating point numbers are not a field. We gloss over this point here, since they define all of the field operations, even if those operations do not satisfy the field axioms. 22 Leti:F ∗ →Rbe any injection that maps each floating point number to some real number that it is “closest to”. Supposing thatb∈F ∗ and thusb∈R, we need two additional components of the proof. We need to findε,ε ′ ∈R + prove that |C R (M R )−i(C F (M F ))|< εand E (l,t)∼D R f R (l,M R (t)) −i E (l,t)∼D F f F (l,M F (t)) < ε ′ Then we can chain these proofs to prove that i E (l,t)∼D F f F (l,M F (t)) ≥b−ε−ε ′ Suchε-ball robustness proofs should be well within the scope of existing approaches to formal methods on neural nets, see, e.g., [44,3,4,29,1,57]. We leave actually dealing with the gap between floating point numbers and real numbers to future work. A.8 Non-uniform distributions In Equation 1 in Section 2 we defined the expected model performance as the expectation of the distributionD: ̄s:=E (l,t)∼D [f(l,M(t))]≥b. We then immediately specialized to the case where the marginalizationD| X ofDover labels is uniform. As we’l see in Theorem 1 in Appendix D and Algorithm 3 in Appendix E, the bound computation is modularized between a function that bounds the performancef(l,M(t))over a restricted collection of inputs, and a much simpler function that combines the bounds on individual cases into a bound on the expectation over the entire distribution. The per-input bound computation isCORRECTNESSin Algorithm 1 andRELAXED-CORRECTNESS-PESSIMIZING-OVER-POSITIONin Algorithm 3; the expectation computation isBRUTE-FORCEin Algorithm 1 andCUBICin Algorithm 3. Since the expectation computation is modularized, it is straightforward to extend our approach to non-uniform distributions simply by adjusting the weighting of each region of inputs. However, if the distribution is too far off from the uniform training distribution, the bound we get may not be very good, as we may not be allocating adequate computation to the high-probability regions of the input space. A.9 Adversarial robustness via flexibility inD There is flexibility inherent in Equation 1. Normally, by out-of-distribution (OOD) or adversarial inputs, we suppose that there is a distributionD in that’s used for training and (in-distribution) validation, and another distributionD ′ that is the deployment distribution or generated by an adversary. If we had knowledge ofD ′ , we could compute the expected performance from inputs sampled from D ′ . Even if we don’t have exact knowledge ofD ′ , we can still define a very broad distributionDthat covers possibleD ′ s. In this work,Dis the distribution of all64 4 possible valid input sequences. In addition, as our proofs partitionDinto subdistributions, and bound the performance on each subdistribution, we can bound the model’s performance on any possible distribution over valid input sequences. A.10 Infinite distributions In the brute force proof in Section 4.1, we run the model on the entirety ofD. This operation is straightforward whenXis finite. Perhaps surprisingly, we can do this even ifXis infinite as long as the PDFL×X→RofDis computable and the natural computational topology of Xis compact [16,15,14], because integration of computable functions on computable reals is computable [49]. A.11 Using alternate loss functions Building on the point from Appendix A.8, it is also relatively straightforward to extend our approach from bounding expected accuracy to bounding log-loss. We will see in Figure 12 that the accuracy and log-loss share a subterm∆ℓ i . Since we compute this subterm in all of our algorithms, we can 23 easily extend our approach to log-loss by combining∆ℓ i directly rather than merely checking that the value is negative as we currently do inRELAXED-CORRECTNESS-PESSIMIZING-OVER-POSITION in Algorithm 3. Although this is sufficient for the brute-force and cubic proofs, for the subcubic proof using Algorithm 6 in Appendix F, we would additionally have to compute a log-loss bound for the sequences where the largest non-max token is “too close” to the max token, which we currently neglect by considering the model to get them wrong in the worst case. A.12 Proving upper bounds In this work, we focus on proving lower bounds on model performance. Most of our theorems, for example in Appendices E, and F, prove two-sided bounds. Most of the other theorems can be straightforwardly adapted to proving upper bounds by swapping uses ofminandmax. Therefore, we expect that proving upper bounds on model performance should be straightforward. A.13 What proof system? Length of proof depends on what proof system we use. We permit any proof system where proof- checking time is linear in the length of the proof. This excludes dependently typed proof systems such as Martin-Löf type theory, but such proof systems can easily be accommodated by considering a proof-checking-trace rather than the proof object itself. Alternatively, a more conventional proof system like ZF, ZFC, or the proof system underlying Isabelle/HOL should suffice. B Experimental details B.1 Training details To train each model, we generate 384,000 random sequences of 4 integers picked uniformly at random, corresponding to less than 2.5% of the input distribution. We use AdamW withbatch_size= 128, lr= 0.001,betas= (0.9,0.999), weight_decay left at the default0.01. We train for 1 epoch (3000 steps). Over our 151 seeds, models trained with this procedure achieve(99.92±0.15) %train accuracy and a loss of(4±8)×10 −3 . 14 When qualitatively examining a single model (for example in Section 3.1 or Appendix H.1), we use the model with config seed 123, model seed 15 613947648. As our models as sufficiently small, we did not have to use any GPUs to accelerate training our inference. Each training run takes less than a single CPU-hour to complete. In total, the experiments in this paper took less than 1000 CPU-hours. We use the following software packages in our work: Paszke et al.[41], Plotly Technologies Inc. [42], Nanda and Bloom [37], Rogozhnikov [47], Virtanen et al. [52], McKinney [33], Waskom [55] B.2 Additional details supporting our mechanistic interpretation of the model We provide heatmaps of the matrices corresponding to the five components described/defined in Section 3, for the mainline model. 14 Numbers reported as mean across training runs±std dev across training runs of mean accuracy and loss. 15 The model seed is deterministically pseudorandomly derived from the seed 123. 24 0102030405060 0 20 40 60 key token query token −200 −100 0 100 200 (a)EQKE =E q QK T ̄ E T 0123 0 20 40 60 key position query token −3 −2 −1 0 1 (b)EQKP =E q QK T ˆ P T Figure 9:TheQK circuitcan be decomposed into the position-independent and position-dependent components EQKEandEQKP. It computes the pre-softmax attention score for the model. The positional contribution to the attention score, as shown in Figure (b), is minimal. In Figure (a), the gradient from left to right along the key axis indicates that the single attention head pays more attention to larger tokens. The uniformity along the query axis suggests that this behavior is largely independent of the query token. Further, the light and dark bands imply that some queries are better than others at focusing more on larger tokens. 0102030405060 0 20 40 60 output logit token input token −30 −20 −10 0 10 20 30 (a)EVOU = ̄ EV OU 0102030405060 0 1 2 3 output logit token input position −0.6 −0.4 −0.2 0.0 0.2 0.4 0.6 0.8 (b)PVOU = ˆ PV OU 0102030405060 0 20 40 60 output logit token input token −2 −1 0 1 2 (c)Direct Path=E q U Figure 10:TheOV circuitis a sum ofEVOUandPVOU. In Figure (a) we see thatEVOU“copies” — with the exception of input tokens≤5(6.6±1.2across all models) — by virtue of the fact that above 5, the diagonal is larger than all the other elements in the same row. We see that the range on Figure (b) is much smaller than Figure (a), indicating that positional contribution to the copying is minimal. In Figure (c) we see that direct path values matter a bit more thanPVOU, being only≈20×smaller than the typicalEVOUdifference. They don’t matter that much, though, being so small. Additionally, the vertical banding indicates that the primary effect of this is a largely-query-independent bias towards larger numbers, reflecting the fact that the input distribution is biased towards larger numbers being the maximum. The weak diagonal pattern indicates a slight bias towards upweighting the query token itself as a (possible) maximum token. B.3 Distribution of model mechanisms We provide some analysis of the distribution of the mechanisms of the models trained on the same configuration. At a glance, there is not that much variation across models. The statistics of interest are: (1)σ 1 /σ 2 , the ratio of the first two singular values ofEQKE, a measure of the extent to which the attention score computation is low-rank; (2) ̄s, the average score (accuracy) of the model across the entire input distribution; (3)b cubic / ̄s, the percent-score-recovered accuracy bound achieved by the cubic proof from Section 4.2; (4)b subcubic / ̄s, the percent-score-recovered accuracy bound achieved by the (per-model best) 16 subcubic proof from Section 4.3. For each statistic of interest, Table 6 presents an eleven-number summary of the statistic. Plots, seeds, and statistic values are shown for models whose values are closest to each of the corresponding summary statistics. 17 Additionally, each group contains a boxplot of the summary: • the minimum, maximum; the first and third quartiles; the median and mean; percentiles 2.15 %,97.85 %,8.87 %, and91.13 %; these are displayed as: •top and bottom of the vertical whisker lines; top and bottom of the box; horizontal line inside the box, and the square; horizontal whisker lines and whisker crosshatches. 16 “Per-model best” here means that for each model seed, we select the variant of the subcubic proof with the highest bound. 17 If a single model is the closest to two statistics, for example when the mean and median are very similar, the model is shown only once. 25 Table 6:Plots of various models. The statistics of interest are: (1)σ 1 /σ 2 , the ratio of the first two singular values ofEQKE, a measure of the extent to which the attention score computation is low-rank; (2) ̄s, the average score (accuracy) of the model across the entire input distribution; (3)b cubic / ̄s, the percent-score-recovered accuracy bound achieved by the cubic proof from Section 4.2; and (4)b subcubic / ̄s, the percent-score-recovered accuracy bound achieved by the (per-model best) subcubic proof from Section 4.3. Theyaxes are: forEQKEand EQKP, the query token; forEVOU, the input token; forE q U, the input query token; forPVOU, the input position. Thexaxes are: forEQKE, the key token; forEQKP, the key position; forEVOU,PVOU, and E q U, the output logit token. All token axes range from 0 at the top (or left) tod vocab −1 = 63at the bottom (or right). All position axes range from 0 at the top (or left) ton ctx −1 = 3at the bottom (or right). seed statistic of interestEQKEEQKPEVOUPVOU E q U distributionvalueE q QK T ̄ E T E q QK T ˆ P T ̄ EV OU ˆ PV OU −400−2000200400 −40−2002040 123(selected seed) 24262 400 600 800 1,000 σ 1 /σ 2 1061.1 13654626.5 4810350.0 6204333.4 12457 0.992 0.994 0.996 0.998 1 ̄ s 0.9999 194510.9992 156620.9932 161970.9915 32103 0.92 0.93 0.94 0.95 0.96 0.97 b cubic / ̄ s 0.971 60820.954 61550.926 23060.922 20415 0.6 0.65 0.7 0.75 0.8 b subcubic / ̄ s 0.80 297250.70 209760.63 61550.61 26 C Mathematical definitions We provide a detailed breakdown of the mathematical notation used in the appendix. Let nbe the finite set onnelements; we write N <n :=0,1,...,n−1when we care about the ele- ments ofn σ(v)be the softmax functione v / P i e v i σ ∗ (v)bethecasually-maskedsoftmaxfunction, σ ∗ (v) i :=e v i / P j≤i e v j d vocab be the size of the vocabulary dbe the dimension of the attention head, in our case, equal to the hidden dimension of the model d model (assumption:d < d vocab ) n ctx be the context length, the number of tokens in the input sequence, equal toKin Max-of-K Pbe then ctx ×d model positional embedding Ebe thed vocab ×d model token embed Q,K, V,O be thed model ×d model query, key, value, and output matrices of the attention head Ube thed model ×d vocab unembed matrix tbe the input token sequence[t 0 ,t 1 ,...,t n ctx −1 ] x be then ctx ×d vocab one-hot-encoded input token sequence[x 0 ,x 1 ,...,x n ctx −1 ] x query :=x −1 :=x n ctx −1 be the query token t query :=t −1 :=t n ctx −1 be the one-hot encoded query token t max be the true maximum token in the input sequence, max i t i M of typeX→Ybe the model 18 ; sometimes we write ℓfor the logits of the modelM(t) Dbe a probability distribution over input-label pairs (t,l)∈X×L D| X beDmarginalized to a distribution overt∈X fof typeL×Y→Rbe a scoring function for evaluating the performance of the model P avg be the average position embeds across positions (of sized model ), 1 n ctx P i P i ̄ Pbe either1 n ctx ⊗P avg or1 d vocab ⊗P avg depending on context – that is the result of broadcasting P avg back into the shape ofPorE(that is, n ctx ×d model ord vocab ×d model ) P q be1 d vocab ⊗P query , the broadcasting ofP query ˆ P, ̄ E,E q beP− ̄ P,E+ ̄ P, andE+P q respectively For any vector-valued functionvof lengthd vocab , parameter- ized over the input sequencet, let ∆v i :=v i −v t max be the difference between thei th element ofv and the element ofvcorresponding to the true maximum of the input sequencet max . For our particular model, we will have d vocab := 64d=d model := 32n ctx := 4 X:= (N <d vocab ) n ctx ∼ = 64 4 L:=N <d vocab ∼ = 64 D| X :=U(0,1,...,d vocab −1) n ctx , the uniform distribution Figure 11:Preliminary model definitions E (l,t)∼D [f(l,M(t))] |z ̄s(average model score) ≥b |z lower bound (theorem statement)(3) We define the two typical performance functions corresponding to accuracy and log-loss. Note the shared subterm∆ℓ i . f accuracy (t max ,ℓ) :=1[argmax i ℓ i =t max ] =1[0>max i̸=t max ℓ i −ℓ t max |z ∆ℓ i ](4) f log-loss (t max ,ℓ) := (σ(ℓ)) t max = log( P i exp(ℓ i −ℓ t max | z ∆ℓ i ))(5) We present the model definition in four different regroupings to define via underbrace labels various useful quantities: M(t) =ℓ(t) =σ ∗ (x query E+P query )QK T (xE+P) T | z QK circuit / √ d ·(xE+P)V OU | z OV circuit + (x query E+P query )U | z direct path (6) =σ ∗ x query E q QK T ̄ E T | z EQKE x T +E q QK T ˆ P T |z EQKP / √ d | z α ∗ (t) · x ̄ EV OU |z EVOU + ˆ PV OU |z PVOU | z EPVOU(t) +x query E q U |z EU (7) =α ∗ (t)·x ̄ EV OU |z ℓ EVOU (t) +α ∗ (t)· ˆ PV OU |z ℓ PVOU (t) +x query E q U |z ℓ EU (x query ) (8) = P n ctx −1 i=0 (α ∗ (t)) i x i ̄ EV OU | z ℓ EVOU,i (t) + (α ∗ (t)) i ˆ P i V OU |z ℓ PVOU,i (t) | z ℓ i (t) +x query E q U |z ℓ EU (x query ) (9) Figure 12:Definitions of the model behavior 27 D Brute-force proof Theorem 1.ForBRUTE-FORCE(d vocab ,n ctx ,M)as defined in Algorithm 1, E t∼U(0,1,...,d vocab −1) n ctx argmax i (M(t)) i = max i t i ≥BRUTE-FORCE(d vocab ,n ctx ,M) Proof. In fact the two sides of the inequality are equal by definition. Hence the inequality follows by reflexivity of≥. Algorithm 1Counting Correct Sequences By Brute Force 1:functionCORRECTNESS(M, input-sequence) 2:returnMODEL-BEHAVIOR(M, input-sequence)==MAX(input-sequence) 3:end function 4:functionBRUTE-FORCE(d vocab ,n ctx ,M) 5:return 1 d vocab n ctx SUM(CORRECTNESS(M, tokens)fortokens∈(RANGE(d vocab )) n ctx ) 6:end function E Details of cubic proof In this section, we prove formally the result used in Section 4.2, A cubic proof. At its heart, the convexity of softmax 19 is an extension to a simple idea: a weighted average of scalar values is extremized by putting 100% of the weight on an extremal value. Using this simple version of the theorem, however, gives a useless bound of 0% accuracy: if we pay no attention to the maximum of the sequence, of course we’re going to get the wrong answer. Since in fact the space of possible weightings we may see in practice is much smaller (finite, in fact, with at mostd vocab n ctx values), we may look for a more general version of this idea that gives us tighter bounds that still cover the space of possible weightings. The weights arenotlinearly independently choosable (softmax is non-linear), so extremal values do not necessarily result from putting maximal attention on the worst token. It may be, when trying to find the worst case, that some positions are so dis-preferred that it makes more sense to choose a token that is “less bad” for those positions, if it draws enough attention away from the correct token. See Lemma 3 for details. We thus spend this section characterizing a relaxation of the constraints on weights: 1. that contains all actually possible weightings, 2.that is extremized at weights that still correspond to some notion of “put the most weight on the extremal tokens”, and 3. for which computing the extremal weightings is computationally efficient. Before diving in, let’s recall the proof that a weighted average of scalar values is extremized by putting 100% of the weight on extremal values: Theorem 2(Warmup: Extremizing weighted averages).Fix a set of valuesv i ∈R. The weighted average is bounded by the extremal values: for anyw i such that P i w i = 1and0≤w i ≤1, min i v i ≤ X i w i v i ≤max i v i 18 Note that while in the main body,M(t)referred to the pre-softmax output logits, in the appendix we abuse notation and occasionally use it to refer to maximum token indicated by the logits where appropriate. 19 See Appendix A.4 for the reason that we call this “convexity”. Note that our use of “convexity” is purely descriptive in this section; all theorems are written out explicitly. 28 Proof.The proof is simple. We have X i w i v i −min i v i = X i w i (v i −min j v j )≥0 and max i v i − X i w i v i = X i w i (max j v j −v i )≥0 so the result follows. E.1 Proof strategy The model computes the true maximumt max when its outputs logitsℓare such that∆ℓ t ∗ := ℓ t ∗ −ℓ t max <0 for allt ∗ ̸=t max . 20 As a result, it suffices to lower-bound the proportion of sequences where (an upper bound on)∆ℓ t ∗ is negative for allt ∗ ̸=t max . In particular, we will upper-bound the contribution from incorrect tokenstin positionsito the difference∆ℓ i between incorrect (t ∗ ) and correct (t max ) output tokens∆ℓ i t ∗ =ℓ i t ∗ −ℓ i t max . We do this by arguing that the logit difference∆ℓ t ∗ satisfies a certain notion of convexity over the space of a relaxation of sequences (Theorem 6), and constructing a set ofΘ(d vocab 3 n ctx ) “extremal” relaxed sequences where the position and token embedding components of attention are pessimized independently. We start by first rewriting the contribution of each token through the attention head to the logit difference into the contributions involvingPVOUandEVOU: ∆ℓ t t ∗ (t) = ∆ℓ PVOU,i t ∗ (t) + ∆ℓ EVOU,i t ∗ (t) We then upper bound∆ℓ PVOU,i t ∗ (t)by noting that because the softmax attention is a weighted average ofPVOU, ∆ℓ PVOU,i t ∗ (t) =ℓ PVOU,i (t) t ∗ −ℓ PVOU,i (t) max j t j =α ∗ i (t)PVOU i,t ∗ −α ∗ i (t)PVOU i,max j t j =α ∗ i (t) PVOU i,t ∗ −PVOU i,max j t j ≤α ∗ i (t) max i PVOU i,t ∗ −PVOU i,max j t j Since P i α ∗ i (t) = 1, we have n ctx −1 X i=0 ∆ℓ PVOU,i t ∗ (t)≤max i PVOU i,t ∗ −PVOU i,max j t j We then construct a setΞ pure of “pure sequences” consisting of only three types of tokens in one of two orders, and show that for each input sequencetand readoff logitt ∗ , we bound the logit difference from the token embeddings∆ℓ EVOU,i t ∗ (t)using a small subsetXofΞ pure : n ctx −1 X i=0 ∆ℓ EVOU,i t ∗ (t)≤max ξ∈X n ctx −1 X i=0 ∆ℓ EVOU,i t ∗ (ξ) We construct a setX relaxed of relaxed sequences, where each relaxed sequencet relaxed consists of a sequence and a position(t,i), where∆ℓ t ∗ (t,i)is evaluated by separately considering the positional 20 We use the logit difference∆ℓ t ∗ because: (a) it is shared in the computation off 0-1 andf log-loss ; (b) it is a linear function of the various paths through the model, which can therefore be analyzed separately; (c) it leaves open both the options of pessimizing over output logit before or after combining contributions of various paths through the model. 29 contribution through attention (that is, the attention weightedPVOU) and the token contribution (that is, the attention-weightedEVOU) and direct contribution (the logit difference through the skip connectionEU). Note thatiindicates the position that we pay 100% of the attention to for thePVOU contribution. We argue that∆ℓ t ∗ (t,i)satisfies a certain notion of convexity over mixtures of sequences, such that we can evaluate it only on a set ofΘ(d vocab 3 n ctx )“extremal” sequences in a way that takes O(d vocab 3 n ctx )total time to bound∆ℓ t ∗ (t,i) foreverypossible input sequence. We then use the extremal sequences that the model gets correct to lower bound the proportion ofallsequences that the model will get correct. Specifically, we argue that Algorithm 3 provides a valid lower bound on the proportion of sequences the model gets correct. E.2 Proof outline We now proceed to the main results of this section. Math fact:For each tokent ∗ , the logit difference∆ℓ t ∗ for any sequencetcan be decomposed into the direct contribution from the embedsℓ EU , the attention-weighted position contribution (PVOU), and the attention-weighted token contribution (EVOU). Therefore, it suffices to upper bound each of the three components independently, since summing these upper bounds gives a valid upper bound on the logit difference. We can compute the direct contributionℓ EU exactly by first computingEU =E q Uand then, for each max, subtracting the logit of the max token from each row of the matrix. No theorems needed. For each max token, we can bound the position contribution by its maximum over positions (Theorem 6). In order to upper bound the token contribution, we argue that any mixed sequence will be upper bounded by the maximum of the corresponding pure sequences (Theorem 7). We then argue that for pure sequences, it suffices to consider orderings where same tokens appear contiguously (Theorem 4). E.3 Formal proof For this subsection, all theorems are parameterized over the following quantities. Definition 1(Common theorem parameters).Fix a token value function (à la a row difference in EVOU)v:N <d vocab →Rand a token attention function (à laEQKEfor a fixed query token) a:N <d vocab →R. Fix a position value function (à la a row difference inPVOU)w:N <n ctx →R and a position attention function (à laEQKPfor a fixed query token)b:N <n ctx →R. In practice, we’l take, for fixed query tokent query , fixed output token of interestt ∗ , and fixed maximum tokent max , v t = EVOU t,t ∗ −EVOU t,t max a t = EQKE t query ,t / √ d w i = PVOU i,t ∗ −PVOU i,t max b i = EQKP t query ,i / √ d Definition 2(of a sequence via sorted tokens and a position permutation).We candefine a sequence of tokens via sorted tokens and a position permutationby specifying a non-decreasing sequence of tokenst 0 ≤·≤t n ctx −1 ∈N <d vocab paired with a permutationσ:N <n ctx →N <n ctx . Definition 3(sequence score).Given a non-decreasing sequence of tokenst 0 ≤ · ≤t n ctx −1 ∈ N <d vocab and a permutationσ:N <n ctx →N <n ctx define thesequence scores t 0 ,...,t n ctx −1 ,σ as: s t 0 ,...,t n ctx −1 ,σ := X 0≤i<n ctx v t i e a t i +b σ(i) , X 0≤i<n ctx e a t i +b σ(i) We will drop the token subscript, writing onlys σ , when the token values are unambiguous by context. The sequence score here will be computing∆ℓ EVOU t ∗ for some fixedt ∗ andt max . The way we’ve set up our definitions, high scores predictt ∗ (and are thus bad), negative scores predictt max (and are thus good), and more negative the scores, the stronger the prediction oft max . Definition 4(swap permutation).Given a permutationσ:N <n ctx →N <n ctx of then ctx positions and two indices0≤i,j < n ctx , define theswap permutationσ i↔j to be the permutation that isσ 30 except swappingiandj: σ i↔j (k) =    σ(i)ifk=j σ(j)ifk=i σ(k)otherwise Define∆ σ,i↔j to be thedifference in sequence scores when you swapiandj: ∆ σ,i↔j :=s σ i↔j −s σ Lemma 3(Characterization of swapping tokens).Fix a non-decreasing sequence of tokenst 0 ≤·≤ t n ctx −1 ∈N. Fixσ:N→Nbe a permutation of then ctx positions. Fix indices0≤i,j < n ctx . Then there are two cases forsign (∆ σ,i↔j ): 1. Ifa t i =a t j thensign (∆ σ,i↔j ) =−sign b σ(i) −b σ(j) sign v t i −v t j . 2. Otherwise,sign (∆ σ,i↔j ) = sign a t i −a t j sign b σ(i) −b σ(j) sign s σ − v t i e a t i −v t j e a t j e a t i −e a t j . Intuitively, Lemma 3 says that, if the token contribution to attention is equal between tokenst i and t j , then the impact of swapping their positionsσ(i)andσ(j)is entirely determined by how much attention is paid to the positions ofiandjand the relative difference in their value. (Notably, by swapping these tokens, we don’t affect the attention paid on other tokens, and so the effect of the change does not depend on the values of the other tokens.) Alternatively, if the attentions are not equal, then swapping the positions changes the allocation of attention to other tokens in the sequence, and so it may the case that this change in allocation in attention dominates the attention-weighted values of these two tokens. Proof.First note that the theorem is trivial fori=j. For the rest of the proof, we takei̸=j. The proof proceeds just by algebraic manipulation with no deep insight. We first list the facts we use, the proceed to computingsign (∆ σ,i↔j ). We abbreviateσ i↔j asσ ′ for brevity. sign e b σ(i) −e b σ(j) = sign b σ(i) −b σ(j) sign (∆ σ,i↔j ) = sign (s σ ′ −s σ ) = sign P 0≤p<n ctx v t p e a t p +b σ ′ (p) P 0≤p<n ctx e a t p +b σ ′ (p) −s σ ! Now multiply through by the denominator, which is positive = sign   X 0≤p<n ctx v t p e a t p +b σ ′ (p) −s σ X 0≤p<n ctx e a t p +b σ ′ (p)   = sign   X 0≤p<n ctx v t p e a t p +b σ(p) −v t i e a t i e b σ(i) −e b σ ′ (i) −v t j e a t j e b σ(j) −e b σ ′ (j) −s σ X 0≤p<n ctx e a t p +b σ(p) +s σ e a t i e b σ(i) −e b σ ′ (i) +s σ e a t j e b σ(j) −e b σ ′ (j)   = sign   X 0≤p<n ctx v t p e a t p +b σ(p) −v t i e a t i e b σ(i) −e b σ(j) −v t j e a t j e b σ(j) −e b σ(i) − X 0≤p<n ctx v t p e a t p +b σ(p) +s σ e a t i e b σ(i) −e b σ(j) +s σ e a t j e b σ(j) −e b σ(i)   31 = sign v t j e a t j −v t i e a t i e b σ(i) −e b σ(j) +s σ (e a t i −e a t j ) e b σ(i) −e b σ(j) = sign e b σ(i) −e b σ(j) sign v t j e a t j −v t i e a t i +s σ (e a t i −e a t j ) = sign b σ(i) −b σ(j) sign s σ (e a t i −e a t j )− v t i e a t i −v t j e a t j Divide through by non-zero values when possible = sign b σ(i) −b σ(j) ·    sign v t i −v t j ifa t i =a t j sign (e a t i −e a t j ) sign s σ − v t i e a t i −v t j e a t j e a t i −e a t j otherwise =    −sign b σ(i) −b σ(j) sign v t i −v t j ifa t i =a t j sign a t i −a t j sign b σ(i) −b σ(j) sign s σ − v t i e a t i −v t j e a t j e a t i −e a t j otherwise Definition 5(σfixesF).Fix a set of fixed indicesF⊆N <n ctx and an assignment of token values to each of the fixed positionst F :F→N <d vocab . (Fis the set of positions for which we are not pessimizing over the value of the token in that position.) Fix a non-decreasing sequence of tokens t 0 ≤·≤t n ctx −1 ∈N. Given a permutationσ:N <n ctx →N n ctx , say thatσfixesF(relative tot 0 ,...,t n ctx −1 )ift i = t F (σ(i))wheneverσ(i)∈F. Note that in this section, for the cubic proofs, we will in fact generally takeF=n ctx −1, so that we are fixing the final query token, though in Theorems 7, 8, and 9Fwill also contain all positions with the maximum tokent max . In Appendix F, we will takeF=∅orFto be the set of positions of the maximum token. However, none of these theorems are specific toFbeing subsingleton, and we prove them in generality. Definition 6(position-sorting permutation).Fix a set of fixed indicesF⊆N <n ctx and an assignment of token values to each of the fixed positionst F :F→N <d vocab . Define theposition-sorting permutation fixing indices inF σ s :N <n ctx →N <n ctx to be the permutation that sorts the indices not inFaccording tob: for0≤i,j < n ctx withi,j̸∈F,b i ≤b j wheneverσ s (i)< σ s (j); andσ s (i) =ifori∈F. Definition 7(contiguous on equal tokens).Fix a set of fixed indicesF⊆N <n ctx and an assignment of token values to each of the fixed positionst F :F→N <d vocab . Fix a non-decreasing sequence of tokenst 0 ≤·≤t n ctx −1 ∈N. Say that the sequence represented by a permutationσ:N <n ctx →N <n ctx iscontiguous on equal tokensif, for all0≤i,j,k < n ctx witht i =t j ̸=t k andi,j,k̸∈σ −1 (F), it is never the case that σ s (σ(i))< σ s (σ(k))< σ s (σ(j)). Theorem 4(Pessimization over sequence ordering is possible and results in contiguous sequences). Fix a set of fixed indicesF⊆N <n ctx and an assignment of token values to each of the fixed positions t F :F→N <d vocab . Fix a non-decreasing sequence of tokenst 0 ≤·≤t n ctx −1 ∈N. Letσ min ,σ max :N→Nbe permutations of then ctx positions, fixing positions inF, satisfying the following property: For allσ:N→Na permutation fixingF, we have s σ min ≤s σ ≤s σ max (10) (Such permutations are guaranteed to exist because the permutation group onn ctx elements is finite.) Thenσ max andσ min may be taken to be contiguous on equal tokens. That is, there existσ max and σ min satisfying the property of Equation 10 which additionally satisfy the definition of Definition 7. The basic idea is that we will assume that one ofσ max andσ min cannot be contiguous on equal tokens and derive a contradiction. We will pick the extremal permutation that is closest to being contiguous, take a contiguity violation, and then show that either we can correct the contiguity violation without changing the score—thus violating the presumption that the permutation isclosest 32 to being contiguous—or we will find one swap of indices that decreases the score and another swap of indices that increases the score, thus violating the presumption of extremality. In slightly more detail, but still informally, we will consider the sign of the difference between scores of our purported extremal permutation and a permutation that has swapped some indices. The theorem follows from showing that there exists a triple of indicesi,j,ksuch that the sign of the score difference from swappingiandjis different from the sign of the score difference from swappingj andk. First, a definition and some helpful facts about it. Definition 8(contiguous on equally-attended positions).Fix a set of fixed indicesF⊆N <n ctx and an assignment of token values to each of the fixed positionst F :F→N <d vocab . Fix a non-decreasing sequence of tokenst 0 ≤·≤t n ctx −1 ∈N. Say that a permutationσiscontiguous on equally-attended positionsif, for all0≤i < n ctx with i̸∈σ −1 (F), the sorting order accordingσ s on the contiguous block of positions with contribution to the attention score equal to that ofσ(i), σ(j) b σ(j) =b σ(i) andσ(j)̸∈F , is the same as the sorting order according to the fraction of tokens equal tot j withb-values greater thanb σ(i) , with ties broken by the value oft j . Equationally, this second sorting order is defined by the score k t k =t j andb σ(k) > b σ(i) andσ(k)̸∈F + t j d vocab |k|t k =t j andσ(k)̸∈F|. Most importantly, any permutation that is contiguous on equally-attended positions has the property that for any indices0≤i,j,k < n ctx withi,j,k̸∈σ −1 (F)andt i =t j ̸=t k andσ s (σ(i))< σ s (σ(k))< σ s (σ(j)), we will have thestrictinequalityb σ(i) < b σ(k) < b σ(j) . Additionally, we may always sort equally-attended positions to make any permutation contiguous on equally-attended positions. We will define an additional notion of contiguity-violations which we avoid up-front by arbitrarily swapping involved indices without changing the scores σ . Definition 9(needlessly non-contiguous).Fix a set of fixed indicesF⊆N <n ctx and an assignment of token values to each of the fixed positionst F :F→N <d vocab . Fix a non-decreasing sequence of tokenst 0 ≤·≤t n ctx −1 ∈N. Say that a permutationσisneedlessly non-contiguous ati,j,k(fori,j,k̸∈σ −1 (F)) if∆ σ,i↔k = 0 or∆ σ,j↔k = 0, for0≤i,j,k < n ctx withi,j,k̸∈σ −1 (F)witht i =t j ̸=t k andσ s (σ(i))< σ s (σ(k))< σ s (σ(j)). Say that a permutationσisneedlessly non-contiguousif it is needlessly non-contiguous at any i,j,k̸∈σ −1 (F). Lemma 5.Fix a set of fixed indicesF⊆N <n ctx and an assignment of token values to each of the fixed positionst F :F→N <d vocab . Fix a non-decreasing sequence of tokenst 0 ≤·≤t n ctx −1 ∈N. Any needlessly non-contiguous sequenceσwhich fixesFcan be made into a sequenceσ ′ which still fixesFand is both simultaneously contiguous on equally-attended positions and not needlessly non-contiguous, and for whichs σ =s σ ′ . Proof.First, sort regions of equally-attended positions to makeσcontiguous on equally-attended positions. If the resulting permutation is not needlessly non-contiguous, then we are done. Otherwise, we have∆ σ,i↔k = 0or∆ σ,j↔k = 0for somei,j,k, for0≤i,j,k < n ctx with i,j,n̸∈σ −1 (F)andt i =t j ̸=t k andσ s (σ(i))< σ s (σ(k))< σ s (σ(j)). Since the sequence is contiguous on equally-attended positions, we have the strict inequalityb σ(i) < b σ(k) < b σ(j) . By Lemma 3, we have two cases. Noting thatt i =t j , we can write them as 1.v t k =v t i anda t i =a t k 2.a t i ̸=a t k ands σ = v t i e a t i −v t k e a t k e a t i −e a t k 33 In the first case, we may fully freely interchange tokens equal tot i with tokens equal tot k without changing the score; in this case we may use the token value as a sorting tie-breaker and swap tokens until there are no more needlessly non-contiguous triples falling into case (1). In the second case, since swapping tokens does not changes σ , the property will continue to hold for these tokens after the swap. We may then swap tokens, again using token value as a tie-breaker, until there are no more needlessly non-contiguous triples falling into case (2). We can now finally make our argument for Theorem 4 more precise. Proof of Theorem 4.Chooseσ max andσ min to be contiguous on equally-attended positions and not needlessly non-contiguous, and suppose that we haveσ∈ σ max ,σ min such that for some 0≤i,j,k < n ctx withi,j,k̸∈σ −1 (F)andt i =t j ̸=t k , we haveb σ(i) < b σ(k) < b σ(j) . We will derive a contradiction with the presumption thatσis extremal by showing that we can swapiandk to change the score in one direction and that we can swapjandkto change the score in the other direction. Takeσ ′ 0 to beσbut swappingiandk, and takeσ ′ 1 to beσbut swappingjandk. Now we will consider the cases for the sign of the score difference∆ 0 :=s σ ′ 0 −s σ and∆ 1 :=s σ ′ 1 −s σ . By the presumption of not being needlessly non-contiguous,∆ z ̸= 0forz∈0,1. If we can show that the sign of∆ 0 is distinct from the sign of∆ 1 , then we will have a contradiction with extremality because we will have eithers σ ′ 0 < s σ < s σ ′ 1 ors σ ′ 1 < s σ < s σ ′ 0 . That is, we would be able to swap i↔kandj↔kto get a lower and higher score, makingσnot extremal. Noting thatt i =t j , sign (∆ 0 ) = sign b σ(i) −b σ(k) ( sign (v t k −v t i )ifa t i =a t k sign (a t i −a t k ) sign s σ − v t i e a t i −v t k e a t k e a t i −e a t k otherwise sign (∆ 1 ) = sign b σ(j) −b σ(k) ( sign (v t k −v t i )ifa t i =a t k sign (a t i −a t k ) sign s σ − v t i e a t i −v t k e a t k e a t i −e a t k otherwise Noting that the product is non-zero by presumption, that right multiplicand is equal for∆ 0 and∆ 1 , andsign b σ(i) −b σ(k) =−1andsign b σ(j) −b σ(k) = 1, we have our desired contradiction. Note that the proof of Theorem 4 does not go through if we include the position value functionwin the score, because we may trade off the position value function against the token value function. We now show that we canindependentlypessimize over positional attention. Definition 10(full sequence score).Given a non-decreasing sequence of tokenst 0 ≤·≤t n ctx −1 ∈ N <d vocab and a permutationσ:N <n ctx →N <n ctx define thefull sequence scores ′ t 0 ,...,t n ctx −1 ,σ as: s ′ t 0 ,...,t n ctx −1 ,σ := X 0≤i<n ctx (v t i +w σ(i) )e a t i +b σ(i) , X 0≤i<n ctx e a t i +b σ(i) We will drop the token subscript, writing onlys ′ σ , when the token values are unambiguous by context. The sequence score here will be computing∆ℓ EPVOU t ∗ := ∆ℓ EVOU t ∗ + ∆ℓ PVOU t ∗ for some fixed t ∗ andt max . As with Definition 3, with the way we’ve set up our definitions, high scores predictt ∗ (and are thus bad), negative scores predictt max (and are thus good), and more negative the scores, the stronger the prediction oft max . Definition 11(relaxed sequence score).Given a non-decreasing sequence of tokenst 0 ≤ · ≤ t n ctx −1 ∈N <d vocab and a permutationσ:N <n ctx →N <n ctx define therelaxed sequence scores r t 0 ,...,t n ctx −1 ,σ,min andr t 0 ,...,t n ctx −1 ,σ,max as: r t 0 ,...,t n ctx −1 ,σ,min :=s t 0 ,...,t n ctx −1 ,σ + min 0≤i<n ctx w i r t 0 ,...,t n ctx −1 ,σ,max :=s t 0 ,...,t n ctx −1 ,σ + max 0≤i<n ctx w i We will drop the token subscript, writing onlyr σ,min orr σ,max , when the token values are unambigu- ous by context. 34 Theorem 6(Independent pessimization over positional contributions is possible).Fix non-decreasing sequences of tokenst 0 ≤ · ≤t n ctx −1 ∈Nandt ′ 0 ≤ · ≤t ′ n ctx −1 ∈N and permutations σ,σ ′ :N <n ctx →N <n ctx . Letr σ,min andr σ,max denoter t 0 ,...,t n ctx −1 ,σ,min andr t 0 ,...,t n ctx −1 ,σ,max ; lets σ denotes t 0 ,...,t n ctx −1 ,σ ; and lets σ ′ ands ′ σ ′ denotes t ′ 0 ,...,t ′ n ctx −1 ,σ ′ ands ′ t ′ 0 ,...,t ′ n ctx −1 ,σ ′ . Then we have min 0≤i<n ctx w i =r σ,min −s σ ≤s ′ σ ′ −s σ ′ ≤r σ,max −s σ = max 0≤i<n ctx w i . That is, the difference between the relaxed sequence score and the sequence score of any given sequence always bounds the difference between the full sequence score and the sequence score for any (related or unrelated) sequence. Proof.This proof follows straightforwardly from the softmax weighting being an affine weighting. We have r σ,min −s σ = min 0≤i<n ctx w σ(i) = min 0≤i<n ctx w i r σ,max −s σ = max 0≤i<n ctx w σ(i) = max 0≤i<n ctx w i s ′ σ ′ −s σ ′ = X 0≤i<n ctx w σ ′ (i) e a t ′ i +b σ ′ (i) , X 0≤i<n ctx e a t ′ i +b σ ′ (i) = X 0≤i<n ctx w σ ′ (i) e a t ′ i +b σ ′ (i) P 0≤j<n ctx e a t ′ j +b σ ′ (j) Sincee x is non-negative for all realx, we have min 0≤j<n ctx w σ ′ (j) P 0≤i<n ctx e a t ′ i +b σ ′ (i) P 0≤j<n ctx e a t ′ j +b σ ′ (j) ≤s ′ σ ′ −s σ ′ ≤max 0≤j<n ctx w σ ′ (j) P 0≤i<n ctx e a t ′ i +b σ ′ (i) P 0≤j<n ctx e a t ′ j +b σ ′ (j) . Thus we get as desired min 0≤i<n ctx w i ≤s ′ σ ′ −s σ ′ ≤max 0≤i<n ctx w i . Note that we could prove a more fine-grained theorem, that pessimizes over attention paid to positions only for sequences compatible with the chosen fixed tokensFandt F , but since the positional contribution is so small we do not bother. Theorem 7(For a fixed ordering, softmax is convex over token counts and only pure sequences need be considered).Fix a set of fixed indicesF⊆N <n ctx and an assignment of token values to each of the fixed positionst F :F→N <d vocab . Fix a setS⊆N <d vocab of valid other tokens in the sequence. (In our uses of this theorem,Swill be the largest subset ofN <t max for which we can guarantee that the model behaves correctly on all sequences compatible withFandt max and with tokens otherwise drawn fromS.) Define a comparison on non-negative integers less thand vocab : c:= X i∈F v t F (i) e a t F (i) +b i d:= X i∈F e a t F (i) +b i f:= X 0≤i<n ctx i̸∈F e b i cmp(x,y) := sign d(e a x v x −e a y v y )−c(e a x −e a y ) +fe a x +a y v x e a x +a y −v y e a x +a y Lett cmp min andt cmp max be the minimum and maximum elements ofSaccording tocmp. 21 For a given choice of a non-decreasing sequence of tokenst 0 ≤·≤t n ctx −1 ∈Ncompatible with FandSand a given choice of permutationσ:N→Nof then ctx positions fixingF(t i =t F (σ(i)) 21 We will prove thatcmpis transitive in the process of proving this theorem. 35 forσ(i)∈F; andt i ∈Sforσ(i)̸∈F): lets σ,min (ands σ,max ) denotes t 0 ,...,t n ctx −1 ,σ when t i =t cmp min for allσ(i)̸∈F(ort cmp max , respectively). Then for all such choices of sequence-permutation pairs, s σ,min ≤s t 0 ,...,t n ctx −1 ,σ ≤s σ,max . This theorem follows by chaining two lemmas: that scores are extremized by considering pure sequences, and that the extremal pure sequences match the comparison function defined in the theorem statement. Lemma 8(Sequences scores are extremized on purer sequences).Fix all the same quantities as in Theorem 7. For any indices0≤i < j < n ctx , token valuesx,y∈S, the score for a sequence witht i =x̸= y=t j is bounded on both sides by sequences witht i =t j =xandt i =t j =y. Proof.Lets α,β be the sequence score witht i =αandt j =β, and define the score differences ∆ x :=s x,x −s x,y and∆ y :=s y,y −s x,y . It suffices to show thatsign(∆ x ∆ y )≤0. To show this, we must only compute the sign of∆ α forα∈x,yand show that whenever both∆ x and∆ y are non-zero, they have opposite signs. We proceed by computation after defining some convenience variables for brevity: C:= X 0≤k<n ctx k̸=i,j v t k e a t k +b σ(k) D:= X 0≤k<n ctx k̸=i,j e a t k +b σ(k) ̃α:= xifα=y yifα=x i α := iifα=x jifα=y i ̃α := iif ̃α=x jif ̃α=y sign (∆ α ) = sign v α e a α +b σ(i) +v α e a α +b σ(j) +C e a α +b σ(i) +e a α +b σ(j) +D − v x e a x +b σ(i) +v y e a y +b σ(j) +C e a x +b σ(i) +e a y +b σ(j) +D = sign v α e a α +b σ(i α ) +v α e a α +b σ(i ̃α ) +C e a α +b σ(i α ) +e a α +b σ(i ̃α ) +D − v α e a α +b σ(i α ) +v ̃α e a ̃α +b σ(i ̃α ) +C e a α +b σ(i α ) +e a ̃α +b σ(i ̃α ) +D Multiply through by positive denominators and simplify = sign C e a ̃α +b σ ( i ̃α ) −e b σ ( i ̃α ) +a α +D v α e b σ ( i ̃α ) +a α −v ̃α e a ̃α +b σ ( i ̃α ) +v α e b σ ( i ̃α ) +e b σ(i α ) e a ̃α +b σ ( i ̃α ) +a α −v ̃α e b σ ( i ̃α ) +e b σ(i α ) e a ̃α +b σ ( i ̃α ) +a α Pulling oute b σ(i ̃α ) = sign e a ̃α +a α e b σ ( i ̃α ) +e b σ(i α ) (v α −v ̃α ) +C(e a ̃α −e a α ) +D(e a α v α −e a ̃α v ̃α ) Note that swappingαand ̃αnegates the sign. Hence, we havesign(∆ x ) =−sign(∆ y )and hence s x,x ≤s x,y ≤s y,y ors y,y ≤s x,y ≤s x,x as desired. Lemma 9(Pure sequences are sorted according to cmp in Theorem 7).Fix all the same quantities as in Theorem 7. Fix tokensx,y∈S. Letn:=n ctx −|F|be the number of non-fixed tokens. Fix sequences with ncopies ofxandyrespectively: fixt x,0 ≤ · ≤t x,n ctx −1 ∈Nandt y,0 ≤ · ≤t y,n ctx −1 ∈N compatible withFandSand given choices of permutationsσ x ,σ y :N→Nof then ctx positions fixingF:t x,i =t F (σ x (i))forσ x (i)∈F;t y,i =t F (σ y (i))forσ y (i)∈F;t x,i =xforσ x (i)̸∈F; andt y,i =yforσ y (i)̸∈F. Then sign((s σ x ,t x,0 ,...,t x,n ctx −1 )−(s σ y ,t y,0 ,...,t y,n ctx −1 )) = cmp(x,y) 36 EQKE(t −1 ,t i ) :=t −1 E q QK T ̄ E T t i T / √ d EQKP(t −1 ,i) :=t −1 E q QK T ˆ P T i / √ d EVOU(t i ) :=t i ̄ EV OU PVOU(i) := ˆ P i V OU ℓ EU (t −1 ) :=t −1 E q U ∆ℓ EU t ∗ (t −1 ,max i t i ) :=ℓ EU (t −1 ) t ∗ −ℓ EU (t −1 ) max i t i Figure 13:Recapitulation of some relevant definitions from Figure 12, parameterized by the arguments they actually depend on. Proof.The proof goes by straightforward computation. sign((s σ x ,t x,0 ,...,t x,n ctx −1 )−(s σ y ,t y,0 ,...,t y,n ctx −1 )) = sign v x e a x f+c e a x f+d − v y e a y f+c e a y f+d Multiply through by non-negative denominators = sign ((v x e a x f+c) (e a y f+d)−(v y e a y f+c) (e a x f+d)) = sign −cfe a x +cfe a y +dfv x e a x −dfv y e a y +f 2 v x e a x +a y −f 2 v y e a x +a y Usef >0 = sign −ce a x +ce a y +dv x e a x −dv y e a y +fv x e a x +a y −fv y e a x +a y = sign c(e a y −e a x ) +d(v x e a x −v y e a y ) +f v x e a x +a y −v y e a x +a y = cmp(x,y) Corollary 10.Define the relation≤ cmp byx≤ cmp yif and only ifcmp(x,y)∈ −1,0. The relation≤ cmp is always transitive. Proof.Note that by Lemma 9,cmpis comparing two sequence scores. Since≤is transitive over the reals, the relation≤ cmp is also transitive. Finally, we combine the previous lemmas to complete our proof of Theorem 7: Proof of Theorem 7. Extremal sequences with scoress σ,min ands σ,max are guaranteed to exist because there are only finitely many elements ofSand therefore only finitely many sequences. By Lemma 8, the extremal sequences must be pure (havet i =t j wheneverσ(i),σ(j)̸∈F). By Lemma 9, the extremal sequences must have tokens that are extremal according tocmp. We now have all the tools necessary to prove the following theorem. We refer to Algorithm 3 and Algorithm 4 or the proof of Theorem 11 for a definition of theCUBICalgorithm. Theorem 11. E t∼U(0,1,...,d vocab −1) n ctx argmax i (M(t)) i = max i t i ≥CUBIC(d vocab ,n ctx ,M) Before we give the proof of this theorem, we introduce some helpful notation. Definition 12.Fix an element(r m ,r q ,c)∈0,...,d vocab 2 ×0,...3such thatr m ≥r q . We define X (r m ,r q ,c) to be the set of tokenstsuch that 1. The max tokent max is equal tor m , 37 Algorithm 2Counting Correct Sequences in Cubic Time: Preliminaries 1:functionCORRECTNESS(M, input-sequence) 2:returnMODEL-BEHAVIOR(M, input-sequence)==MAX(input-sequence) 3:end function 4:functionMODEL-BEHAVIOR(M, input-sequence) Require:input-sequence is a tensor of shape(n ctx ,)with values inN <d vocab 5:t max ←MAX(input-sequence)▷ t max ←max-token 6:t←input-sequence 7:skip-score t ∗ ←∆ℓ EU t ∗ (t n ctx −1 ,t max ) 8:attn-weights-unscaled i ←EQKE(t n ctx −1 ,t i ) + EQKP(t n ctx −1 ,i) 9:attn-weights←SOFTMAX(attn-weights-unscaled/ √ d) 10:v t ←EVOU(t) 11:w i ←PVOU(i) 12:∆v t,t ∗ ←v t,t ∗ −v t,t max 13:∆w i,t ∗ ←w i,t ∗ −w i,t max 14:returnmax t ∗ ̸=t max (skip-score t ∗ + P n ctx −1 i=0 (∆v i,t ∗ + ∆w i,t ∗ )·attn-weights i ) 15:end function 16:functionCORRECTNESS-PESSIMIZING-OVER-POSITION-SLOW(M, input-sequence) 17:t←input-sequence 18:returnALL(CORRECTNESS(M, perm + [t −1 ])for allperm∈PERMUTATIONS(t 0:−1 )) 19:end function 2. The query tokent query is equal tor q , 3.The cardinality of tokens that are not at the query position and not equal tot max is equal to c. For clarity, we list all the possible cases. We always taket query ≤t max and letS 3 act on sequences by permuting the first three factors (i.e. keeping the query position fixed). 1. Ifc= 0, thenX (t max ,t query ,0) =[t max ,t max ,t max ,t query ], 2. Ifc= 1, thenX (t max ,t query ,1) =S 3 .[t 1 ,t max ,t max ,t query ]|t 1 < t max , 3. Ifc= 2, thenX (t max ,t query ,2) =S 3 .[t 1 ,t 2 ,t max ,t query ]|t i < t max , 4. Ifc= 3, thenX (t max ,t max ,3) =S 3 .[t 1 ,t 2 ,t 3 ,t max ]|t i < t max . Definition 13.Lett∈Xbe a sequence. We saytis pure, if it has at most three distinct tokens: the max tokent max , the query tokent query , and optionally a third tokent ∗ < t max . We denote byX pure the subset of pure tokens. For any subsetY⊂X, we setY pure :=Y∩X pure . We now come to the proof of Theorem 11. We will show how to use the previous theorems to get explicit bounds and explain howCUBIC(d vocab ,n ctx ,M)computes these bounds. Proof of Theorem 11.First of all, we note that the algorithmCUBIC=CUBIC(d vocab ,n ctx ,M) yields a lower bound for the accuracy on the setX (t max ,t query ,c) . We can therefore compute the bound onX= ` (t max ,t query ,c) X (t max ,t query ,c) by computing it for each such choice(t max ,t query ,c)and summing over them E t∼U(0,1,...,d vocab −1) n ctx argmax i (M(t)) i = max i t i ≥ X (t max ,t query ,c) CUBIC(X (t max ,t query ,c) ). So from now on we will fix one such subsetX (t max ,t query ,c) . We begin by defining a map f:X (t max ,t query ,c) →0,...,d vocab c which sends a sequence to the subsequence of elements which are not at the query position and not equal tot max . Then Theorem 7 can be restated as follows 22 : 22 In fact, the theorem yields a stronger result, but we will only need the following formulation. 38 Algorithm 3Counting Correct Sequences in Cubic Time, Part I. Lines are annotated with comments indicating the parameters for a cache to avoid duplicate computations. 1:functionMODEL-BEHAVIOR-RELAXED(M, query-tok, max-tok, non-max-tok, n-copies- nonmax) 2:t query ←query-tok,t max ←max-tok,t ′ ←non-max-tok,c←n-copies-nonmax Require:0≤t query ≤t max < d vocab ,0≤t ′ ≤t max < d vocab ,0≤c < n ctx Require: ifn-copies-nonmax= 0thennon-max-tok=max-tok Require: ifquery-tok̸=max-tokthenn-copies-nonmax< n ctx −1 Ensure: return≥MODEL-BEHAVIOR(M,t) for alltwith specifiedt query ,ccopies oft ′ in non- query positions, and the remainder of the tokens equal tot max 3:skip-score t ∗ ←∆ℓ EU t ∗ (t query ,t max )▷Cache byt max ,t query ,t ∗ 4:w i ←PVOU(i)for0≤i < n ctx ▷Cache byi 5:∆w max,t ∗ ←max 0≤i<n ctx (w i,t ∗ −w i,t max )▷Cache byt max ,t ∗ 6:v t ←EVOU(t) ,∆v t,t ∗ ←v t,t ∗ −v t,t max fort∈t query ,t max ,t ′ ▷ Cache byt max ,t, t ∗ 7:a t ←EQKE(t query ,t)/ √ dfort∈t query ,t max ,t ′ ▷Cache byt query ,t 8:b n ctx −1 ←EQKP(t query ,n ctx −1)/ √ d▷Cache byt query 9:b 0,:−1 ←SORT(EQKP(t query ,:−1))/ √ d▷Cache byt query ,i 10:b 1,:−1 ←REVERSE(b 0,:−1 ) 11:attn-weights-unscaled :,n ctx −1 ←a t query +b n ctx −1 ▷Cache byt query 12:attn-weights-unscaled 0,i ←a t max +b 0,i for0≤i < n ctx −c−1▷Cache byt max ,c,i, t query 13:attn-weights-unscaled 1,i ←a t max +b 1,i for0≤i < n ctx −c−1▷Cache byt max ,c,i, t query 14:attn-weights-unscaled 0,i ←a t ′ +b 0,i forn ctx −c−1≤i < n ctx −1▷Cache byt ′ ,c,i, t query 15:attn-weights-unscaled 1,i ←a t ′ +b 1,i forn ctx −c−1≤i < n ctx −1▷Cache byt ′ ,c,i, t query 16:attn-weights 0 ←SOFTMAX(attn-weights-unscaled 0 )▷Cache byt max ,t ′ ,c,i,t query 17:attn-weights 1 ←SOFTMAX(attn-weights-unscaled 1 )▷Cache byt max ,t ′ ,c,i,t query 18:ifc= 0then▷In this case, attn-weights 0,i =attn-weights 1,i , so we drop the first subscript 19: returnmax t ∗ ̸=t max (skip-score t ∗ + ∆w max,t ∗ + ∆v t −1 ,t ∗ attn-weights −1 + ∆v t max ,t ∗ P n ctx −2 i=0 attn-weights i ) 20:else 21:∆v i,t ∗ ←∆v t max ,t ∗ for0≤i < n ctx −c−1 22:∆v i,t ∗ ←∆v t ′ ,t ∗ forn ctx −c−1≤i < n ctx −1 23:∆v n ctx −1,t ∗ ←∆v t query ,n ctx −1 24: returnmax t ∗ ̸=t max skip-score t ∗ +max ( P n ctx −1 i=0 max t ∗ ̸=t max (∆w max,t ∗ + ∆v i,t ∗ )·attn-weights 0,i P n ctx −1 i=0 max t ∗ ̸=t max (∆w max,t ∗ + ∆v i,t ∗ )·attn-weights 1,i 25:end if 26:end function 27:functionRELAXED-CORRECTNESS-PESSIMIZING-OVER-POSITION(M,t query ,t max ,t ′ ,c) 28:▷runs the model on a relaxed variant of input sequences compatible with the arguments Ensure: returnis FalseifCORRECTNESS-PESSIMIZING-OVER-POSITION-SLOW(M,t) is False foranytwith specifiedt query ,ccopies oft ′ in non-query positions, and the remainder of the tokens equal tot max 29:returnMODEL-BEHAVIOR-RELAXED(M,t query ,t max ,t ′ ,c)<0 30:end function 39 Algorithm 4Counting Correct Sequences in Cubic Time, Part I 1:functionCUBIC(d vocab ,n ctx ,M) 2:count←0▷# of correct sequences 3:fort max ∈RANGE(d vocab )do▷ t max ←max-token 4:for0≤t query ≤t max do▷ t query ←query-token 5:c max ←n ctx −1ift query =t max elsen ctx −2▷maximum copies of nonmax 6:for0≤c≤c max do▷number of valid choices for the non-max token 7:RCPOP(⃗χ)←RELAXED-CORRECTNESS-PESSIMIZING-OVER-POSITION(M, ⃗χ) 8:ifc= 0then 9:t-count←1ifRCPOP(t query ,t max ,t max ,0)else0 10:else 11:t-count← P t max −1 t ′ =0 1ifRCPOP(t query ,t max ,t ′ ,c)else0 12:end if 13:count←count+ n ctx −1 c ·(t-count) c ▷taking0 0 = 0conventionally 14:end for 15:end for 16:end for 17:returncount· 1 d vocab n ctx 18:end function LetS⊂0,...,d vocab . Then full accuracyf −1 (S c ) pure :=X pure (t max ,t query ,c) ∩f −1 (S c ), implies full accuracy onf −1 (S c ). Now instead of computing the output of the model for every elementf −1 (S c ) pure , we use Theorem 4 (combined with Theorem 6) to run a relaxed version of this. In particular, we may assume that the pure sequence is contiguous on equal tokens. Here contiguous on equal tokens means that for the positional part of the attention (i.e. the EQKP part), we have eitherb t max <b i ,b j orb t max >b i ,b j , where i,j∈0,...,n ctx −1are indices of tokens not equal tot max . For the algorithmCUBIC(d vocab ,n ctx ,M)we fix at ∗ ∈0,...,t max −1(unlessc= 0, in which case there is no such choice). We then run the relaxed accuracy computationRCPOP(t query ,t max ,t ′ ,c) as described in Theorem 6. IfRCPOP(t query ,t max ,t ′ ,c)<0, we addt ′ toS. If we do, we addt ∗ to S. Therefore by construction ofSwe know that we get full accuracy onf −1 (S c ) pure and therefore we get full accuracy onf −1 (S c ). Now we count the cardinality off −1 (S c )and add it to the count of correct sequences. Theorem 12.The running time of Algorithm 3, after using caching to avoid duplicate computations, isO(d vocab 3 n ctx 2 ). Proof.The nested loops inCUBICexecute the innermost bodyO(d vocab 2 n ctx )times, and the summation on Line 13 costsO(n ctx )per iteration. What remains is to show that the call toRELAXED-CORRECTNESS-PESSIMIZING-OVER-POSITION(M,t query ,t max ,t ′ ,c)costsO(n ctx ) whenc̸= 0and at mostO(d vocab n ctx )whenc= 0andt ′ =t max . The matrix multiplications inEQKE,EQKP,EVOU,PVOU, andℓ EU can be cached upfront, costingO(max(d vocab ,d model ,n ctx ) 2 d model )≤O(d vocab 3 )since we assumed vocab > d model and d vocab > n ctx . The sorting on Line 9 can also be cached upfront (pert query ), costingO(d vocab n ctx logn ctx ). Note that each variable assignment inRELAXED-CORRECTNESS-PESSIMIZING-OVER-POSITIONcan be cached into a table parameterized over at most three variables which range overd vocab and over at most two variables that range overn ctx . What remains is thereturnstatements. Whenc= 0, we have on Line 19:returnmax t ∗ ̸=t max (skip-score t ∗ + ∆w max,t ∗ + ∆v t −1 ,t ∗ attn-weights −1 + ∆v t max ,t ∗ P n ctx −2 i=0 attn-weights i ). This isO(d vocab n ctx )as desired. 40 Whenc̸= 0, we have on Line 24: returnmax t ∗ ̸=t max skip-score t ∗ + max ( P n ctx −1 i=0 max t ∗ ̸=t max (∆w max,t ∗ + ∆v i,t ∗ )·attn-weights 0,i P n ctx −1 i=0 max t ∗ ̸=t max (∆w max,t ∗ + ∆v i,t ∗ )·attn-weights 1,i We can cachemax t ∗ ̸=t max skip-score t ∗ pert max andt query , costingO(d vocab 3 n ctx ). We can cache max t ∗ ̸=t max (∆w max,t ∗ + ∆v i,t ∗ ) pert max andt ′ costingO(d vocab 3 ), since each∆v i,t ∗ will be ∆v t,t ∗ for somet∈t query ,t max ,t ′ . Finally, we can compute the summation in costO(n ctx )per loop iteration, as required. F Quadratic counting for a sub-cubic proof In this section we fill in the details lacking from Section 4.3. In Appendix E we proved an intricate version of convexity of softmax where, modulo pessimizing in unrealistic ways over the attention paid to positions for the computation done on positional encodings, all extremal relaxed sequences correspond to actual sequences. When we only get a budget ofO(d vocab 2 n ctx )extremal relaxed cases to consider, though, we must pessimize more, which gives us a simpler version of the convexity theorem and proof. Notably, when we restrict our sequences to have only two tokens (the max tokent max and the non-max tokent ′ ), most of the theorems from Appendix E.3 get significantly simpler. Additionally, we must pessimize separately over the token value (v) and token attention (b) computa- tions in order to allow efficient computation (Theorem 15). F.1 Proof of baseline sub-cubic result For this subsection, all theorems are parameterized over the following quantities. Definition 14(Common theorem parameters).Fix a total number of tokensn ctx . Fix a token value function (à la a row-difference inEVOU)v:N <d vocab →Rand a token attention function (à la EQKEfor a fixed query token)a:N <d vocab →R. Fix a position value function (à la a row-difference inPVOU)w:N <n ctx →Rand a position attention function (à laEQKPfor a fixed query token) b:N <n ctx →R. In practice, as in Appendix E.3, we’l take, for fixed query tokent query , v t = EVOU t,t ∗ −EVOU t,t max a t = EQKE t query ,t / √ d w i = PVOU i,t ∗ −PVOU i,t max b i = EQKP t query ,i / √ d Note that unlike in Appendix E.3, we pessimize independently over the query token and the non-max token, so the “fixed” query token may not in fact appear in any key-side position in the relaxed sequence we consider. Definition 15(of a sequence via mapping from positions).We candefine a sequence of tokens via mapping from positionsby specifying a subset of valid tokensS⊆N <d vocab paired with a function T:N <n ctx →Sspecifying which token is in each position. Definition 16(sequence score).Given a subset of valid tokensS⊆N <d vocab and a function T:N <n ctx →Sspecifying which token is in each position, define thesequence score s T := X 0≤i<n ctx v T(i) e a T(i) +b i , X 0≤i<n ctx e a T(i) +b i Definition 17(swapped mapping).Given a subset of valid tokensS⊆N <d vocab and a function T:N <n ctx →Sspecifying which token is in each position and two indices0≤i,j < n ctx , define theswapped mappingT i↔j be the function that isTexcept swappingiandj: T i↔j (k) =    T(i)ifk=j T(j)ifk=i T(k)otherwise 41 Lemma 13(Characterization of swapping tokens in a two-token sequence).Fix two tokenst 0 < t 1 ∈Nand a functionT:N <n ctx →t 0 ,t 1 specifying which token is in each position. Define∆ T,i↔j to be the difference in sequence scores when you swapiandj: ∆ T,i↔j :=s T i↔j −s T Then sign (∆ T,i↔j ) =−sign (b i −b j ) sign v T(i) −v T(j) Proof.Lemma 3 gives us the result directly whena T(i) =a T(j) . Otherwise, we get sign (∆ T,i↔j ) = sign a T(i) −a T(j) sign (b i −b j ) sign s T − v T(i) e a T(i) −v T(j) e a T(j) e a T(i) −e a T(j) Hence all that remains is to show that sign s T (e a T(i) −e a T(j) )−v T(i) e a T(i) +v T(j) e a T(j) =−sign v T(i) −v T(j) Define ̄v:= 1 2 (v T(i) +v T(j) )and define∆v:= 1 2 (v T(i) −v T(j) )so thatv T(i) = ̄v+ ∆vand v T(j) = ̄v−∆v. Assume WLOG thatT(i) = 0andT(j) = 1so thatv T(p) = ̄v+ (−1) T(p) ∆vfor allp. Then we have sign s T (e a T(i) −e a T(j) )−v T(i) e a T(i) +v T(j) e a T(j) = sign (s T (e a T(i) −e a T(j) )− ̄v(e a T(i) −e a T(j) )−∆ve a T(i) −∆ve a T(j) ) = sign     X 0≤p<n ctx v T(p) e a T(p) +b p X 0≤p<n ctx e a T(p) +b p (e a T(i) −e a T(j) )− ̄v(e a T(i) −e a T(j) )−∆v(e a T(i) +e a T(j) )     = sign      X 0≤p<n ctx ̄v+ (−1) T(p) ∆v e a T(p) +b p X 0≤p<n ctx e a T(p) +b p (e a T(i) −e a T(j) )− ( ( ( ( ( ( ( ( ̄v(e a T(i) −e a T(j) )−∆v(e a T(i) +e a T(j) )      = sign(∆v) sign        e a T(i) X 0≤p<n ctx T(p)=T(i) e b p −e a T(j) X 0≤p<n ctx T(p)=T(j) e b p X 0≤p<n ctx e a T(p) +b p (e a T(i) −e a T(j) )−e a T(i) −e a T(j)        = sign(v T(i) −v T(j) ) sign        e a T(i) X 0≤p<n ctx T(p)=T(i) e b p −e a T(j) X 0≤p<n ctx T(p)=T(j) e b p X 0≤p<n ctx e a T(p) +b p (e a T(i) −e a T(j) )−e a T(i) −e a T(j)        Define P i := X 0≤p<n ctx T(p)=T(i) e b p P j := X 0≤p<n ctx T(p)=T(j) e b p so that we get sign s T (e a T(i) −e a T(j) )−v T(i) e a T(i) +v T(j) e a T(j) 42 = sign(v T(i) −v T(j) ) sign e a T(i) P i −e a T(j) P j e a T(i) P i +e a T(j) P j (e a T(i) −e a T(j) )−e a T(i) −e a T(j) Multiply through by the positive denominator and expand out so that we get = sign(v T(i) −v T(j) ) sign −2e a T(i) +a T(j) P i −2e a T(i) +a T(j) P j =−sign(v T(i) −v T(j) ) sign e a T(i) +a T(j) P i +e a T(i) +a T(j) P j =−sign(v T(i) −v T(j) ) Theorem 14(Pessimization over sequence ordering for two-token sequences is simple).Letσ s : N→Ndenote a permutation of then ctx positions that sorts them according tob: for0≤i,j < n ctx , b i ≤b j wheneverσ s (i)< σ s (j). Fix two tokenst 0 < t 1 ∈N. Letn t 0 be the number ofp∈[0,n ctx )withT(p) =t 0 and letn 1 be the number ofp∈[0,n ctx )with T(p) =t t 1 . Note thatn t 0 +n t 1 =n ctx . Definet min := argmin t∈t 0 ,t 1 v t and definet max := argmax t∈t 0 ,t 1 v t . DefineT min ,T max :N <n ctx →t 0 ,t 1 to be the assignment of tokens to positions that pays the least (respectively, most) attention tot max : T min (i) := t max if0≤σ s (i)< n t max t min ifn t max ≤σ s (i)< n ctx T max (i) := t min if0≤σ s (i)< n t min t max ifn t min ≤σ s (i)< n ctx Then we have that s T min ≤s T ≤s T max Proof.The extremality ofs T min ands T max follows straightforwardly from Theorem 4. All that remains iss T min ≤s T max . This follows from noting by Lemma 13 that swapping two tokens inT min increasesthe sequence score, while the reverse is true ofs T max , thus showing that it must bes T min that is the minimum and s T max that is the maximum and not vice versa. Definition 18(full sequence score).Given a subset of valid tokensS⊆N <d vocab and a function T:N <n ctx →Sspecifying which token is in each position define thefull sequence scores ′ T : s ′ T := X 0≤i<n ctx (v T(i) +w i )e a T(i) +b i , X 0≤i<n ctx e a T(i) +b i Theorem 15(Independent pessimization over positional contributions and token attention and token value is possible).Fix two tokenst 0 < t 1 ∈N. LetT min ,T max :N <n ctx →t 0 ,t 1 andt max ,t min be as in Theorem 14. Fix a setSof valid tokens witht 0 ,t 1 ∈S. Define relaxed versionsT ′ max ,T ′ min :N <n ctx →SofT max andT min : T ′ max (i) := ( T max (i)ifT max (i) =t max argmin j∈S j̸=t max a j otherwise T ′ min (i) := ( T min (i)ifT max (i) =t max argmax j∈S j̸=t max a j otherwise That is,T ′ max replacest min with whatever token inSdraws the least attention away fromt max , while T ′ min replacest min with whichever token inSdraws the most attention away fromt max . 43 Definerelaxed extremal sequence scoresr T max ,r T min : r T min := min 0≤i<n ctx w i +   X 0≤i<n ctx v T min (i) e a T ′ min (i) +b i , X 0≤i<n ctx e a T ′ min (i) +b i   r T max := max 0≤i<n ctx w i +   X 0≤i<n ctx v T min (i) e a T ′ max (i) +b i , X 0≤i<n ctx e a T ′ max (i) +b i   Thenr T min ≤s ′ T min ands ′ T max ≤r T max . Proof.(sketch) Essentially the same as the proof of Theorem 6. Note that in practice, we takeSto be the set of all tokens less thant max −gfor some minimum gapg. This allows us to share computation across the various maximum tokens to reduce overall computational complexity. Algorithm 5Counting Correct Sequences in Subcubic Time, Preliminaries 1:functionINPUT-SEQUENCE-COMPATIBLE-WITH(input-sequence,d vocab ,n ctx ,t max ,t query ,c, g) 2:t←input-sequence 3:returnFalseift̸∈(N <d vocab ) n ctx ▷the sequence is not made of valid tokens 4:returnFalseift −1 ̸=t query ▷wrong query token 5:returnFalseifmax i t i ̸=t max ▷wrong max token 6:returnFalseif|i∈N <n ctx |t i ̸=t max |̸=c▷wrong count of non-max toks 7:returnALL(t i =t max ort max −t i ≥gfor0≤i < n ctx )▷check gap on non-max toks 8:end function 9:functionCORRECTNESS-PESSIMIZING-OVER-GAP-SLOW(M,d vocab ,n ctx ,t max ,t query ,c,g) 10: returnALL(CORRECTNESS(M,t)for allts.t.INPUT-SEQUENCE-COMPATIBLE-WITH(t, d vocab ,n ctx ,t max ,t query ,c,g)) 11:end function 12:functionSUBCUBIC(d vocab ,n ctx ,M,G) 13:count←0▷# of correct sequences 14:G t max ,t query ,c ←MIN(t max ,MAX(1,G t max ,t query ,c ))▷ClipGto valid range 15:G ∗ t max ,c ←min t≤t max G t max ,t,c ▷Cache running minima 16:fort max ∈RANGE(d vocab )do▷ t max ←max-token 17:for0≤t query ≤t max do▷ t query ←query-token 18:c min ←0ift query =t max else1▷minimum copies of nonmax 19:c max ← 0ift max = 0 n ctx −1otherwise ▷maximum copies of nonmax 20:forc min ≤c≤c max do▷valid choices for the number of non-max tokens 21:g←G t max ,t query ,c 22:g ∗ ←G ∗ t max ,c 23:q-gap←t max −t query 24:RCPOG(⃗χ)←RELAXED-CORRECTNESS-PESSIMIZING-OVER-GAP(M, ⃗χ) 25: if(q-gap= 0orq-gap≥g)andRCPOG(d vocab ,n ctx ,t max ,t query ,c,g,g ∗ ) then 26:c ′ ←cift query =t max elsec−1▷# of non-max non-query tokens 27:count←count+ n ctx −1 c ′ (t max −g) c ′ ▷taking0 0 = 1conventionally 28:end if 29:end for 30:end for 31:end for 32:returncount· 1 d vocab n ctx 33:end function 44 Algorithm 6Counting Correct Sequences in Subcubic Time 1:functionMODEL-BEHAVIOR-RELAXED-OVER-GAP(M,t max ,t query ,c,g,g ∗ ) Ensure:CORRECTNESS-PESSIMIZING-OVER-GAP-SLOWis False=⇒result is False Require:0≤g ∗ ≤g≤t max Require: ifc= 0thent query =t max 2:skip-score←max t ∗ ℓ EU (t query ) t ∗ −min t ∗ ℓ EU (t query ) t ∗ ▷Cache byt query 3:v t ←EVOU(t) 4:w i ←PVOU(i) 5:∆w max,t ∗ ←max i w i,t ∗ −w i,t max ▷Cache byt max ,t ∗ 6:∆w max,max ←max t ∗ ∆w max,t ∗ ▷Cache byt max 7:∆v t ←max t ∗ v t,t ∗ −min t ∗ v t,t ∗ ▷Cache byt 8:∆v max ←max 0≤t≤t max −g ∗ ∆v t ▷Cache byt max −g ∗ 9:∆v t max t ∗ ←v t max ,t ∗ −v t max ,t max ▷Cache byt max 10:∆v t max max ←max t ∗ ̸=t max ∆v t max t ∗ ▷Cache byt max 11:ifc= 0then 12:ℓ t ∗ ←ℓ EU (t max ) t ∗ +v t max ,t ∗ + ∆w max,t ∗ 13:returnmax t ∗ ̸=t max (ℓ t ∗ −ℓ t max ) 14:end if 15:b :,n ctx −1 ←EQKP(t query ,n ctx −1)/ √ d▷Cache byt query 16:b 0,:−1 ←SORT(EQKP(t query ,:−1))/ √ d▷Cache byt query ,i 17:b 1,:−1 ←REVERSE(b 0,:−1 ) 18:a t ←EQKE(t query ,t)/ √ d▷Cache byt query ,t 19:a min,t ←min 0≤t ′ ≤t a t ′ ▷Cache byt query ,t, compute in amortizedO(d vocab 2 ) 20:a max,t ←max 0≤t ′ ≤t a t ′ ▷Cache byt query ,t, compute in amortizedO(d vocab 2 ) 21:∆a max ←a t max −a min,t max −g ▷Cache byt query ,t max ,c 22:∆a min ←a t max −a max,t max −g ▷Cache byt query ,t max ,c 23:idx-set←0,...,n ctx −c−1ift max ̸=t query else0,...,n ctx −c−2,n ctx −1 24:attn-weights-unscaled 0,i ←b 0,i + (∆a min ifi∈idx-setelse0) 25:attn-weights-unscaled 1,i ←b 1,i + (∆a max ifi∈idx-setelse0)▷Cache byt query ,t max ,i, c 26:attn-weights 0 ←SOFTMAX(attn-weights-unscaled 0 )▷Cache byt query ,t max ,i,c 27:attn-weights 1 ←SOFTMAX(attn-weights-unscaled 1 )▷Cache byt query ,t max ,i,c 28:attn-max 0 ← P i∈idx-set attn-weights 0,i 29:attn-max 1 ← P i∈idx-set attn-weights 1,i 30:attn-max←attn-max 0 if∆v t max max <∆v max elseattn-max 1 31:▷Recall that∆v t max max is negative when the model outputs the correct answer 32:returnskip-score+ ∆w max,max +attn-max·∆v t max max + (1−attn-max)∆v max 33:end function 34: functionRELAXED-CORRECTNESS-PESSIMIZING-OVER-GAP(M,d vocab ,n ctx ,t max ,t query ,c, g,g ∗ ) 35:▷runs the model on a relaxed variant of input sequences compatible with the arguments Ensure:CORRECTNESS-PESSIMIZING-OVER-GAP-SLOWis False=⇒result is False Ensure: returnis FalseifCORRECTNESS-PESSIMIZING-OVER-GAP-SLOW(M,t) is Falseforany twith specifiedt max ,t query , andctokens not equal tot max 36:returnMODEL-BEHAVIOR-RELAXED-OVER-GAP(M,t max ,t query ,c,g,g ∗ )<0 37:end function 45 Theorem 16.For allG, E t∼U(0,1,...,d vocab −1) n ctx argmax i (M(t)) i = max i t i ≥SUBCUBIC(d vocab ,n ctx ,M,G) Proof.(sketch) Apply preceding lemmas and theorems to Algorithm 6 Theorem 17.The running time of Algorithm 6, after using caching to avoid duplicate computations, isO(d vocab 2 d model +d vocab 2 n ctx 2 ). Proof.(sketch) Sum the complexities indicated along the right side of Algorithm 3.The d vocab 2 d model term comes from the precomputingEVOU,EU, andEQKP. Thed vocab 2 n ctx 2 term comes from the softmax overn ctx tokens forO(d vocab 2 n ctx )pessimized pure sequences. Con- firming that none of the complexities on the right side exceedsO(d vocab 2 d model +d vocab 2 n ctx 2 ) completes the proof. G Subcubic proof strategies In this section, we present a number of proof strategies that we use to reduce the computational cost of the proof, ultimately driving down the cost ofEUandEQKEverification toO(d vocab d model ), while unfortunately leaving the cost ofEVOUverification atO(d vocab 2 d model ). The three main tricks we cover are the mean+diff trick (Appendix G.1), the max row-diff trick (Appendix G.2.2), and the rank one / rank two SVD decomposition ofEQKE(Appendix G.2.3). While the mean+diff trick is useful for getting slightly better bounds, the SVD decomposition of EQKEis the place where we get to insert the most understanding (without which we’d have no hope of non-vacuous bounds belowO(d vocab 2 d model )), and the max row-diff trick is the workhorse that allows us to drive down the error term computations from cubic to quadratic without getting completely vacuous bounds. G.1 The mean+diff trick Suppose we have quantitiesf x,y andg y,z and we want to pessimize (WLOG, suppose minimize) the quantityf x,y +g y,z overx,y, andzin time less thanO(n x n y n z ), say we allowO(n x n y +n y n z + n x n z ). Also suppose the variation offover theyaxis is much larger than the variation of f over the x-axis. We can of course say min x,y f x,y + min y,z g y,z ≤f x,y +g y,z But we can do better! Note that f x,y =E x f x,y + (f x,y −E x f x,y ) Suppose thatf x,y varies much less overxthan it does overy, and much less thang y,z varies over either ofyandz. This will make the following bound a good approximation, though the bound is sound even without this assumption. We can write f x,y +g y,z ≥min x,y,z [f x,y +g y,z ] = min x,y,z [E x f x,y +g y,z +f x,y −E x f x,y ] ≥min x,y,z [E x f x,y +g y,z ] + min x,y,z [f x,y −E x f x,y ] = min y,z [E x f x,y +g y,z ] + min x,y [f x,y −E x f x,y ] By averaging the variation over certain axes, we have 46 Theorem 18(Mean+Diff). min x,y,z f x,y +g y,z ≥min y,z [E x f x,y +g y,z ] + min x,y [f x,y −E x f x,y ] max x,y,z f x,y +g y,z ≤max y,z [E x f x,y +g y,z ] + max x,y [f x,y −E x f x,y ] and the RHSs can be computed in timeO(n x n y +n y n z +n x n z )forn x ,n y , andn z the number of possible values ofx,y, andz, respectively. Example for how this helps with small variation: Take any functionk(y)and then take f x,y :=k(y) +ε 1 (x,y) g y,z :=−k(y) +ε 2 (y,z) Then we have min x,y,z [f x,y +g y,z ] = min x,y,z [ε 1 (x,y) +ε 2 (y,z)] min x,y f x,y + min y,z g y,z = min y k(y) + min y −k(y) + min x,y ε 1 (x,y) + min y,z ε 2 (y,z) = min y k(y)−max y k(y) + min x,y ε 1 (x,y) + min y,z ε 2 (y,z) min x,y [f x,y −E x f x,y ] + min y,z [g y,z +E x f x,y ] = min x,y ε 1 (x,y) + min y,z [ε 2 (y,z) +E x ε 1 (x,y)] Ifε 1 andε 2 are small compared tomin y k(y)−max y k(y), then usingE x f x,y gives a much better bound. Note, though, that this could be a worse bound if the assumption of small variation does not hold. Note also that this trick is not restricted to adding and subtractingE x f x,y . Iffis a matrix indexed byxandy, we might also try taking SVD and using the first principal component instead. A basic application of the triangle inequality gives the following, more general, result: Theorem 19(Summarize+Diff).For anyh y which can be computed in timeO(n h ), min x,y,z f x,y +g y,z ≥min y,z [h y +g y,z ] + min x,y [f x,y −h y ] max x,y,z f x,y +g y,z ≤max y,z [h y +g y,z ] + max x,y [f x,y −h y ] and the RHSs can be computed in timeO(n x n y +n y n z +n h )forn x ,n y , andn z the number of possible values ofx,y, andz, respectively. We see that if the variation offin thex-axis is indeed much smaller than the variation in they-axis, then letting f x,y =h y +ε x,y we get min x,y,z f x,y +g y,z −min y,z [h y +g y,z ]−min x,y [f x,y −h y ] ≤ min x,y,z [f x,y +g y,z ]−min y,z [h y +g y,z ] + min x,y [ε x,y ] ≤2 max x,y |ε x,y | so indeed this bound isn’t too much worse and we are able to compute it in quadratic rather than cubic time. G.2 Details of SVD of QK proof As discussed in Section 4.3.1, to further reduce the computation cost of proof, we need to avoid computing the residual stream, EVOU, and EPQKE matrices fully. Using mechanistic insight or otherwise, we observe that these matrices (apart from EVOU) can be well-approximated by rank one matrices. This will remove the dominant computation cost ofO(d vocab 2 ·d model ). 47 G.2.1 Comments on relationship between mechanistic insight and proof size Up to this point, we haven’t really said much in our proofs about what the model is doing. All the mechanistic insight has been of the form “the model varies more along this axis than this other axis” or “the input data is distributed such that handling these inputs is more important than handling these other inputs” or, at best, “the model computes the answer by attending to the maximum token of the sequence; everything else is noise”. Here, finally, our proof-size constraints are tight enough that we will see something that we could plausibly call “how the model pays attention to the maximum token more than anything else”, i.e., (if we squint a bit) “the model pays more attention to larger tokens in general. G.2.2 The max row-diff trick As stated above, we are breaking matrices into their rank one approximation and some error term. To bound the error, i.e. to bound expressions of the form Q i (A i +E i )− Q i A i , whereE i denote the matrix errors, we can use the following trick: Lemma 20(Max Row-Diff (vector-matrix version)).For a row vectoraand a matrixB, max i,j ((aB) i −(aB) j )≤ X k |a k |max i,j (B k,i −B k,j ) Moreover, for a collection ofnrow vectorsA r , if the shape ofBism×p, the right hand side can be computed for allrin timeO(nm+mp). Proof. max i,j (aB) i −(aB) j = max i,j X k a k (B k,i −B k,j ) ≤ X k max i,j a k (B k,i −B k,j ) = X k a k max i,j (B k,i −B k,j )ifa k ≥0 min i,j (B k,i −B k,j )ifa k <0 = X k a k max i,j (B k,i −B k,j )ifa k ≥0 −max i,j (B k,i −B k,j )ifa k <0 = X k |a k |max i,j (B k,i −B k,j ) The asymptotic complexity of computing the result follows from caching the computation of max i,j (B k,i −B k,j )for eachkindependently ofr, as the computation does not depend onA r . Theorem 21(Max Row-Diff).For matricesAandB, max r,i,j ((AB) r,i −(AB) r,j )≤max r X k |A r,k |max i,j (B k,i −B k,j ) Proof.By taking the max of Lemma 20 over rowsrofA. Lemma 20 can also be applied recursively for a product of more than two matrices. Lemma 22(Max Row-Diff (vector-matrix recursive version)).For a row vectoraand a sequence of nmatricesB p of shapesr p ×c p , max i,j   a Y p B p ! i − a Y p B p ! j   ≤ X k 0 |a k 0 |· X k n (B n−1 ) k n−1 ,k n max i,j ((B n ) k n ,i −(B n ) k n ,j ) Moreover, for a collection ofqrow vectorsA α , the right hand side can be computed for allαin time O(qr 0 + P p r p c p ). 48 Proof.We proceed by induction onn. Forn= 1, the statement is identical to Lemma 20. Suppose the theorem holds for all positiven=s; we show the theorem holds forn=s+ 1. We reassociate the matrix multiplication as max i,j   a s+1 Y p=1 B p ! i − a s+1 Y p=1 B p ! j   = max i,j   (aB 1 )   s+1 Y p=2 B p ! i − s+1 Y p=2 B p ! j     Using the induction hypothesis gives ≤ X k 1 X k 0 a k 0 (B 1 ) k 0 ,k 1 X k 2 |(B 2 ) k 1 ,k 2 |· X k s+1 (B s ) k s ,k s+1 max i,j (B s+1 ) k s+1 ,i −(B s+1 ) k s+1 ,j The triangle inequality gives ≤ X k 1 X k 0 |a k 0 (B 1 ) k 0 ,k 1 | X k 2 |(B 2 ) k 1 ,k 2 |· X k s+1 (B s ) k s ,k s+1 max i,j (B s+1 ) k s+1 ,i −(B s+1 ) k s+1 ,j and algebra gives = X k 0 |a k 0 | X k 1 |(B 1 ) k 0 ,k 1 | X k 2 |(B 2 ) k 1 ,k 2 |· X k s+1 (B s ) k s ,k s+1 max i,j (B s+1 ) k s+1 ,i −(B s+1 ) k s+1 ,j The asymptotic complexity of computing the right hand side also follows straightforwardly by induction. Theorem 23(Max Row-Diff (recursive)).For a sequence ofn+ 1matricesA 0 , . . . ,A n , max r,i,j   Y p A p ! r,i − Y p A p ! r,j   ≤max r X k 0 |(A 0 ) r,k 0 |· X k n (A n−1 ) k n−1 ,k n max i,j ((A n ) k n ,i −(A n ) k n ,j ) Proof.By taking the max of Lemma 22 over rowsrofA 0 . Note that Theorem 21 is compatible with the mean+diff trick of Appendix G.1. Theorem 24(Combined Mean+Diff and Max Row-Diff).For matricesAandB, and any column- wise summary vectorH k ofA(for example we may takeH k :=E r A r,k ) max r,i,j ((AB) r,i −(AB) r,j )≤ max i,j X k H k (B k,i −B k,j ) ! +max r X k |A r,k −H k |max i,j (B k,i −B k,j ) Proof. max r,i,j ((AB) r,i −(AB) r,j ) = max r,i,j X k A r,k (B k,i −B k,j ) = max r,i,j X k (H k + (A r,k −H k )) (B k,i −B k,j ) = max i,j X k H k (B k,i −B k,j ) + max r X k (A r,k −H k ) (B k,i −B k,j ) ! 49 ≤ max i,j X k H k (B k,i −B k,j ) ! + max r X k max i,j (A r,k −H k ) (B k,i −B k,j ) ≤ max i,j X k H k (B k,i −B k,j ) ! + max r X k |A r,k −H k |max i,j (B k,i −B k,j ) Theorem 25(Combined Mean+Diff and Vector-Matrix Recursive Max Row-Diff).For a row vector a, a vector of summarieshcorresponding toa(for example, ifais a row of a matrix,hmight be the average of the rows), a sequence ofnmatricesB p of shapesr p ×c p , and a corresponding sequence of column-wise summary vectorsh p ofB p (for example we may take(h p ) k :=E r (B p ) r,k ), max i,j   a Y p B p ! i − a Y p B p ! j   ≤max i,j X k 0 h k 0 · X k n (B n−1 ) k n−1 ,k n ((B n ) k n ,i −(B n ) k n ,j ) ! + X k 0 |a k 0 −h k 0 |· max i,j X k n (h n−1 ) k n ((B n ) k n ,i −(B n ) k n ,j ) ! + X k n (B n−1 ) k n−1 ,k n −(h n−1 ) k n max i,j ((B n ) k n ,i −(B n ) k n ,j ) ! Moreover, for a collection ofqrow vectorsA α , the right hand side can be computed for allαin time O(qr 0 + P p r p c p ). Proof sketch. Apply the triangle inequality recursively, fusing the proofs of Lemmas 22, and 24. Theorem 26(Combined Mean+Diff and Recursive Max Row-Diff).For a sequence ofn+ 1matrices A 0 , . . . ,A n , and corresponding column-wise summary vectorsh 0 , . . . ,h n−1 ofA 0 , . . . ,A n−1 , max r,i,j   Y p A p ! r,i − Y p A p ! r,j   ≤max r,i,j X k 0 (h 0 ) k 0 · X k n (A n−1 ) k n−1 ,k n ((A n ) k n ,i −(A n ) k n ,j ) ! + X k 0 |(A 0 ) k 0 −(h 0 ) k 0 |· max i,j X k n (h n−1 ) k n ((A n ) k n ,i −(A n ) k n ,j ) ! + X k n (A n−1 ) k n−1 ,k n −(h n−1 ) k n max i,j ((A n ) k n ,i −(A n ) k n ,j ) ! Moreover, if the matricesA p have shapesr p ×c p , the right hand side can be computed in time O( P p r p c p ). Proof.By taking the max of Theorem 25 over rowsrofA 0 . G.2.3 Exploring rank one approximation via SVD Let us first look at EQKE :=E q QK T ̄ E T . From Figure 9a, we see that there is not much variation along long query token direction. We can confirm this by performing a singular value decomposition (SVD) onEQKE, as seen in Figure 14. 50 0102030405060 0 10 20 30 40 50 60 Singular Index Query Token Query-Side SVD 0102030405060 0 1000 2000 3000 4000 5000 6000 7000 Singular Index Value Singular Values 0102030405060 0 10 20 30 40 50 60 Singular Index Key Token Key-Side SVD −20 −15 −10 −5 0 5 10 15 20 Figure 14:SVD ofEQKEfor seed 123, with principal component vectors scaled by the square root of the corresponding singular value. This scaling allows us to see visually that there is not much going on beyond the first singular component. Numerically: the first singular value is just over7440, while the second singular value is just under15. The first singular value is just over7440(7800±380across all seeds), while the second singular value is just under15(13.1±2.8across all seeds). The ratio across all seeds is620±130. There’s really not much going on here beyond the first singular component. 23 Call the first singular component of EQKE the “query direction”d q and the “size direction”d k on the query-side and key-side, respectively. There are two ways that we can decomposeEQKEinto a low-rank component that we can compute exactly, and a full-rank error term that we approximate bounds for. G.2.4 The simple SVD decomposition of QK In timeO(d vocab d model 2 ) we can perform SVD on each of the four component matricesE q ,Q,K, ̄ Eand perform low-rank SVD on the matrix productE q QK T ̄ E T . We can then bound the difference between two elements in the same row ofEQKEby computing exactly the difference between the two elements in the same row of the rank one approximation of EQKE, and adding to that a bound on the difference between the two elements in the same row of the error term. That is, we can decomposeEinto a part parallel tod q and a part orthogonal tod q , sayE q =E q +E ⊥ q , and similarly ̄ E=E k +E ⊥ k . Note thatE q andE k are both rank one, and hence can be multiplied with other matrices of shaped model ×ain timeO(d model a)rather than timeO(d vocab d model a). Hence we can defineEQKE_err 1 (subscript one for “rank one”) and decomposeEQKEas EQKE =E q QK T (E k ) T + EQKE_err 1 . Define for any vectorv ∆ i,j v:=v i −v j so that we get ∆ i,j (E q QK T (E k ) T ) t query + min i̸=j ∆ i,j (EQKE_err 1 ) t query ≤∆ i,j EQKE t query ≤ ∆ i,j (E q QK T (E k ) T ) t query + max i̸=j ∆ i,j (EQKE_err 1 ) t query Then we may use any method we please to pessimize∆ i,j (EQKE_err 1 ) t query quickly. For example, since for any matrixMwe haveσ 1 (M) = sup x ∥Mx∥/∥x∥, considering vectors with one1, one 23 We might be tempted to keep analyzing the SVD, and notice that the query direction is mostly uniform, while the key direction is monotonic (nearly linear, even). But the proof complexity doesn’t demand this level of analysis, yet, and so we can’t expect that any automated compact proof discovery system will give it to us. 51 0102030405060 0 20 40 60 key token query token −1.00 −0.75 −0.50 −0.25 0.00 0.25 0.50 0.75 1.00 Figure 15:The error termEQKE_errfor seed 123. −1, and zero elsewhere, the maximum difference between elements in a row upper bounded by √ 2σ 1 (M): ∆ i,j EQKE_err 1 t query ≤ √ 2σ 1 (EQKE_err 1 )(11) G.2.5 The complicated SVD decomposition of QK While the “most mechanistic” interpretation would proceed with the analysis in terms ofE q and E k , perhaps decomposing them further, we can get more bang for our buck by extracting out all the low-rank structure availableE,Q, andK, so as to make our error bounds as tight as possible. To this end, we perform SVD onE ⊥ q ,E ⊥ k ,Q, andKand peel off the first singular components so as to get the decomposition E q =E q +E q,2 +E ⊥ q,2 ̄ E=E k +E k,2 +E ⊥ k,2 Q=Q 0 +Q ⊥ K=K 0 +K ⊥ ThenEQKE, a product of these four matrices, can be expressed as a sum of2 2 3 2 −1 = 35 rank one products and one high-rank error term. We can compute the sum of the rank one products in timeO(d vocab 2 )and expressEQKEas, say,EQKE 2 +E ⊥ q,2 Q ⊥ (E ⊥ k,2 K ⊥ ) T . Call the second termEQKE_err(Figure 15). We must now bound for eachqandmthe quantity max i≤m−G EQKE_err[q,i]−EQKE_err[q,m]. How big is this? Even if we relax tomax i,j EQKE_err[q,i]−EQKE_err[q,j], the maximum such value across all rows is under1.85(1.99±0.68across all seeds). And the rows don’t have any particular structure to them; the maximum absolute element of the entire matrix is just barely over1(1.12±0.40across all seeds), so doubling that doesn’t give too bad an estimate. But we somehow need to compute this value without multiplying out the four matrices. One option is to try to use singular value decomposition again. Sinceσ 1 (M) = sup x ∥Mx∥/∥x∥, considering vectors with one1, one−1, and zero elsewhere, the maximum difference between elements in a row upper bounded by √ 2σ 1 (M). The largest singular value ofEQKE_err(Figure 16) is just under7.6(8.4±2.0across all seeds), giving a row-diff bound of about10.7(11.8±2.8across all seeds), which is large but not unusably so. If we perform SVD before multiplying out the matrices (Figure 17), however, their first singular values are about4,1.4,1.4, and4, giving a product of about30, which when multiplied by √ 2 is about43. (Across all seeds, these numbers are3.79±0.12,1.525±0.067,1.513±0.073, and 3.78±0.12, giving a product of about33.1±2.9, which when multiplied by √ 2is about46.8±4.2.) This works becauseσ 1 (AB)≤σ 1 (A)σ 1 (B), but note that we can do factored SVD without needing to use this technique. This bound is still usable, but pretty big. Can we use Frobenius?Note that using anything close to this method to drop belowd vocab d model 2 might seem infeasible (it’l eventually turn out not to be). For example, the best bound we know on the largest singular value that can be verified even in the worst-case in strictly less time than 52 0102030405060 0 10 20 30 40 50 60 U 0102030405060 0 1 2 3 4 5 6 7 Singular Values for EQKE_err 0102030405060 0 10 20 30 40 50 60 V Figure 16:SVD ofEQKE_errfor seed 123. 051015202530 0 10 20 30 40 50 60 U 051015202530 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 Singular Values forE ⊥ q,2 051015202530 0 5 10 15 20 25 30 V 051015202530 0 5 10 15 20 25 30 U 051015202530 0.0 0.2 0.4 0.6 0.8 1.0 1.2 1.4 Singular Values forQ ⊥ 051015202530 0 5 10 15 20 25 30 V 051015202530 0 5 10 15 20 25 30 U 051015202530 0.0 0.2 0.4 0.6 0.8 1.0 1.2 1.4 Singular Values forK ⊥ 051015202530 0 5 10 15 20 25 30 V 051015202530 0 5 10 15 20 25 30 U 051015202530 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 Singular Values forE ⊥ k,2 051015202530 0 10 20 30 40 50 60 V Figure 17:SVD of the four component matrices ofEQKE_errfor seed 123. Matrices look like noise. it takes to compute the full SVD is the Frobenius norm, which is defined astr(M T ), can be computed ind model d vocab time, and is equal to the square root of the sum of the squares of the singular values. While the Frobenius norm ofEQKE_erris only about12(giving a bound of about 17on the row-diff), the Frobenius norms of the four multiplicand matrices are a bit over10,4,4, and10, giving a product of1932and a bound of2732(!). (Across all seeds, the Frobenius norm of EQKE_erris about13.1±1.9(giving a bound of about18.6±2.7on the row-diff), the Frobenius norms of the four multiplicand matrices are a bit over9.92±0.19,4.43±0.01,4.361±0.095, and 9.85±0.19, giving a product of1888±99and a bound of2670±140.) This is unusably large. However, we can get a much better bound on the max row-diff ofEQKE_errwithout having to multiply out all four matrices. We can use an approach vaguely similar to the mean+diff trick, as follows. If we want to compute the max row-diff of a product of matricesAB, we can compute by Theorem 21 max r,i,j ((AB) r,i −(AB) r,j )≤max r X k |A r,k |max i,j (B k,i −B k,j )(12) or by combining this approximation with Theorem 18 via Theorem 24 we may compute max r,i,j ((AB) r,i −(AB) r,j ) ≤ max i,j X k E r A r,k (B k,i −B k,j ) ! + max r X k |A r,k −E r A r,k |max i,j (B k,i −B k,j ) taking whichever bound is better. The first gives us a bound of7.94on the maximum row-diff, which is better than we can get by doing SVD on the product of the matrices! We can get an even better bound by peeling off the first two singular values of all four matrices before multiplying them; this gives us a bound of5.67. Combining it with the avg+diff trick wouldn’t give us much (8.05and5.66respectively), as we’ve effectively already done this by peeling off the leading singular contributions; the mean ofEQKE_errover dimension zero has norm0.025(0.030±0.012across all seeds). Although this error bound is no longer the leading asymptotic bottleneck, we can peek ahead to what we get if we want to be linear in parameter count. In this case, we can apply the recursive version of Equation 12 via Theorem 23, giving a bound of97.06on the maximum row-diff. 53 The mechanistic understanding we get here is roughly “for any given basis vector of the residual stream, the difference between the overlap of any two input tokens with this direction is small once we factor out the first two singular components”, and this is sufficient to drive a low error term overall if we factor out the leading singular components in other places. We don’t mechanistically understand how to combine theE q QK T (without multiplying them out) in a way that allows getting a good bound, though, which corresponds to our inability to drop belowd vocab d model 2 here. If we use this trick on QK only, and use the mean+diff trick on final attention handling (without which we lose about19 %), we can achieve a bound of0.7840(0.661±0.035across all seeds). If we use this trick on the skip connection (EU) only, we can achieve a bound of0.6768(0.632±0.061 across all seed). Using this trick on both EU and QK drops us down only to0.6354(0.601±0.060across all seeds). If we use this trick on EU and use the recursive version of this trick on QK, we get a bound of0.2927 (0.281±0.036across all seeds). Unfortunately, it’s not clear how this trick would apply toEVOU. A fancier convex hull checking algorithm seems required, and an analysis thereof is in progress. G.3 The algorithm We now put all of these tricks together into the subcubic algorithm Algorithm 7, which is the full version of Algorithm 6. The format we give here is parameterized over the summarization strategy (from Theorem 19 in Appendix G.1), the decomposition ofEQKE, and the handling ofEQKE_err andEU. 54 Algorithm 7Counting Correct Sequences in Subcubic Time 1:functionMODEL-BEHAVIOR-RELAXED-OVER-GAP(M,t max ,t query ,c,g,g ∗ ) Ensure:CORRECTNESS-PESSIMIZING-OVER-GAP-SLOWis False=⇒result is False Require:0≤g ∗ ≤g≤t max Require: ifc= 0thent query =t max 2:skip-score t ∗ ←SUMMARIZE EU,t query (ℓ EU (t query ) t ∗ )▷Cache byt ∗ 3:skip-score←max t ∗ ℓ EU (t query ) t ∗ −min t ∗ ℓ EU (t query ) t ∗ ▷Cache byt query 4:v t ←EVOU(t) 5:w i ←PVOU(i) 6:∆w max,t ∗ ←max i w i,t ∗ −w i,t max ▷Cache byt max ,t ∗ 7:∆w max,max ←max t ∗ ∆w max,t ∗ ▷Cache byt max 8:∆v t ←max t ∗ v t,t ∗ −min t ∗ v t,t ∗ ▷Cache byt 9:∆v max ←max 0≤t≤t max −g ∗ ∆v t ▷Cache byt max −g ∗ 10:∆v t max t ∗ ←v t max ,t ∗ −v t max ,t max ▷Cache byt max 11:∆v t max max ←max t ∗ ̸=t max ∆v t max t ∗ ▷Cache byt max 12:ifc= 0then 13:ℓ t ∗ ←ℓ EU (t max ) t ∗ +v t max ,t ∗ + ∆w max,t ∗ 14:returnmax t ∗ ̸=t max (ℓ t ∗ −ℓ t max ) 15:end if 16:b :,n ctx −1 ←EQKP(t query ,n ctx −1)▷Cache byt query 17:b 0,:−1 ←SORT(EQKP(t query ,:−1))▷Cache byt query ,i 18:b 1,:−1 ←REVERSE(b 0,:−1 ) 19:EQKE (1) ,EQKE_err←DECOMPOSE(EQKE) Require:EQKE (1) (t query ,t)−EQKE (1) (t query ,t max )−EQKE_err t query ≤EQKE(t query ,t)− EQKE(t query ,t max )≤EQKE (1) (t query ,t)−EQKE (1) (t query ,t max ) + EQKE_err t query 20:a t ←EQKE (1) (t query ,t)▷Cache byt query ,t 21:a min,t ←min 0≤t ′ ≤t a t ′ ▷Cache byt query ,t, compute in amortizedO(d vocab 2 ) 22:a max,t ←max 0≤t ′ ≤t a t ′ ▷Cache byt query ,t, compute in amortizedO(d vocab 2 ) 23:∆a max ←a t max −a min,t max −g + EQKE_err t query ▷Cache byt query ,t max ,c 24:∆a min ←a t max −a max,t max −g −EQKE_err t query ▷Cache byt query ,t max ,c 25:idx-set←0,...,n ctx −c−1ift max ̸=t query else0,...,n ctx −c−2,n ctx −1 26:attn-weights-unscaled 0,i ←b 0,i + (∆a min ifi∈idx-setelse0) 27:attn-weights-unscaled 1,i ←b 1,i + (∆a max ifi∈idx-setelse0)▷Cache byt query ,t max ,i, c 28:attn-weights 0 ←SOFTMAX(attn-weights-unscaled 0 / √ d)▷Cache byt query ,t max ,i,c 29:attn-weights 1 ←SOFTMAX(attn-weights-unscaled 1 / √ d)▷Cache byt query ,t max ,i,c 30:attn-max 0 ← P i∈idx-set attn-weights 0,i 31:attn-max 1 ← P i∈idx-set attn-weights 1,i 32:attn-max←attn-max 0 if∆v t max max ≥∆v max elseattn-max 1 33: attn-max←SUMMARIZE attn,t query (attn-max)▷Cache byt max ,c 34:attn-max ′ ←attn-max−attn-max 35:summary t ∗ ←∆w max,t ∗ +skip-score t ∗ +attn-max∆v t max t ∗ + (1− attn-max)∆v max ▷ Cache byt max ,t ∗ 36:returnskip-score+attn-max ′ ·∆v t max max + (−attn-max ′ )∆v max + max t ∗ ̸=t max summary t ∗ 37:end function 55 H Comparison of proof strategies In this section, we compare the various proof strategies that we have developed in Appendix G. We do some traditional mechanistic interpretability analysis to justify that the choices that we made could be expected to lead to reasonably good bounds in Appendix H.1. We then compare the complexities and performance of various proof strategies in Appendix H.2 to line up with the legends of Figures 3, and 4. We close with a figure relating the various categories of proof strategies. H.1 Justification of pessimization choices In Sections 4.3, F, and G we make a number of choices about which axes of variation are more or less important to track at various points in the bound computation. Here we do some more traditional mechanistic interpretability analysis to justify that the choices that we made could be expected to lead to reasonably good bounds. H.1.1 Justifying the gap We take advantage of the fact that attention is mostly monotonically increasing in input integers and that for most sequences, the attentional contribution of the particular query token matters much more than the particular non-max token in the sequence. We justify this as follows. We can look at the typical diff, when attending to the max token, between the largest non-max logit and the max logit. As shown in Figure 18a, the largest difference between an off-diagonal entry of EVOUand the diagonal of that row is typically at most−7. 24 The typical worst contribution to the wrong logit from a non-max token (this is typical over non-max tokens, worst over choice of output token-logit index) is around43, as shown in Figure 18b. The difference in attention between tokens is approximately linear in the gap between the tokens, as seen in Figure 19. The slope of the line, that is, the difference in pre-softmax attention scores divided by the gap between the key token and the max token, is approximately1.2. Exponentiating, the post-softmax attention paid to the max is typically about3×larger than to the token one below the max; here the logit difference between the max and non-max token is significant, typically being around13(43/3) for the worst output logit. But by the time the gap is 3, this difference has dropped to about1.1, and by the time the gap is 4 it is around0.3. 24 “Typically” here means about96 %of the time. −10−50510 0 1 2 3 4 5 6 7 ·10 6 logit - diag count × # sequences with given max (a)The attention computation weighted by the number of sequences with the particular max. The computation ismax js.t.j̸=i EVOU i,j −EVOU i,i .μ±σ:−9.9± 2.1; range:(−0±120)×10 −1 303540455055606570 0.0 2.5 5.0 7.5 10.0 12.5 15.0 17.5 20.0 logit diff count (b)Histogram of the maximum difference between two logits contributed by a single row ofEVOU. The computation is, for eachi,max h EVOU i,j − min j EVOU i,j .μ±σ:43.4±9.0; range:52±20 Figure 18:Plots of the difference in logit for the attention computation,EVOU := ̄ EV OUfor seed 123. 56 −60−40−200204060 −75 −50 −25 0 25 50 75 token gap attention difference (a)(EQKE i −EQKE j )/ √ dvs.i−j −20246 0.0 0.2 0.4 0.6 0.8 1.0 ·10 9 attention difference / gap count × # sequences with given max (b)(EQKE i −EQKE j )/( √ d(i−j)), weighted by sequence count.μ±σ= 1.22±0.13 Figure 19:Plots of attention difference vs. token gap, forEQKE :=E q QK T ̄ E T for seed 123. The difference in attention between tokens is approximately linear in the gap between the tokens. 123456789 0 1 2 3 4 5 6 ·10 6 gap count × # sequences Figure 20:Histogram of the minimum gap between the max token and the largest non-max token, for the seed 123. So for sequences where the largest non-max and the max are close together, the particular structure of the non-maxEVOUmatters a lot; but when the max is separated from the largest non-max by a modest gap, the structure of the non-maxEVOUdoes not matter so much. The upshot is that to handle most sequences, we need only ask an oracle for the minimum gapg >0 between the max tokent max and largest non-max tokenst ′ ̸=t max , such that the model outputs the correct answer for all sequences where the non-max, non-query tokens have value at mostt max −g. While computing this gap may be expensive (and indeed the naïve computation of the oracle takes longer than the brute-force proof—though it should be very easy to optimize), we don’t have to pay the cost of computing the gap in the size of the proof, only the cost of storing the gap table (O(d vocab 2 n ctx )) and of verifying the gap. Empirically, gaps are typically 1–5, as seen in Figure 20. If we rely on the gaps, this results in leaving behind about6.9 %of sequences. Picking up more sequencesIn this paragraph / bulleted list, we sketch out how we might go about picking up more sequences to get a tighter bound. This is not coded up, and is left as future work. We propose computing the following quantities: •First, we could build in time (O(d vocab 2 )) a table indexed on pairs(t,t max )of the maximum token and a non-maximum token: the table would store pessimal logit contributions from tto maximum output tokens≤thet max parameter. The table could be further split to pessimize separately for tokens within and outside of the gap window. •Compute a table of pre-softmax attention differences between tokenstandt+ 1in time (O(d vocab 2 )). • Next sort the queries by overlap with the query direction. 57 −0.6−0.4−0.20.00.20.40.60.8 0 25 50 75 100 125 150 175 200 matrix element value count E ⊥ q,2 N(−0.00,0.22) −0.4−0.20.00.20.4 0 20 40 60 80 100 120 matrix element value count Q ⊥ N(−0.01,0.14) −0.4−0.20.00.20.4 0 20 40 60 80 100 120 matrix element value count K ⊥ N(−0.00,0.14) −0.6−0.4−0.20.00.20.40.60.8 0 25 50 75 100 125 150 175 200 matrix element value count E ⊥ k,2 N(0.00,0.22) Figure 21:The distribution of entries of the four residual matrices (after removing two principal components fromE q and ̄ Eand one principal component fromQandK). Distributions look pretty close to normal. Plots are for the seed 123. • Compute for each number of queries handled (where we assume we handle all queries with greater overlap than the current one) and for each maximum input tokent max , how many of the query tokenst query fall strictly below the maxt max (and whether or not the model succeeds whent max =t query ). This will tell us how many query tokens we can count for a given maximum token. •Compute a table indexed on pairs of # of queries handled and input tokenstwhich stores the smallest difference in more attention paid tot+ 1than tot(O(d vocab 2 )). •Compute a table indexed on pairst max ,tstoring an upper bound on amount more attention paid to non-maximum tokens than tot max by Oracle-permitted query tokens (the Oracle is indexed only ont max ) (O(d vocab 2 )). • For each # queries permitted: compute for eacht max ,t,c, if the non-maximum tokent contributes little enough to incorrect logits that even with the worst skip connection the model still gets the correct answer. H.1.2 Stopping after 1–2 principal components of QK Did we miss out on any structure in the error term ofEQKE? The distribution of entries of the four matrices looks pretty close to normal as seen in Figure 21. 58 2 25 2 31 2 37 2 43 0.0 0.2 0.4 0.6 0.8 1.0 FLOPs to Verify Proof (approximate) Normalized Accuracy Bound brute force (acc: 0.9992 ± 0.0015) cubic (rel acc: 0.9539 ± 0.0080) subcubic (rel acc: 0.821 ± 0.013) attention-d vocab d model 2 (rel acc: 0.795 ± 0.014) direct-quadratic (rel acc: 0.653 ± 0.060) attention-d vocab d model 2 , direct-quadratic (rel acc: 0.628 ± 0.060) attention-quadratic (rel acc: 0.390 ± 0.032) attention-quadratic, direct-quadratic (rel acc: 0.286 ± 0.036) 2 25 2 31 2 37 2 43 0.0 0.2 0.4 0.6 0.8 1.0 FLOPs to Verify Proof (approximate) Normalized Accuracy Bound brute force (acc: 0.9992 ± 0.0015) cubic (rel acc: 0.9539 ± 0.0080) subcubic (rel acc: 0.821 ± 0.013) attention-d vocab d model 2 (rel acc: 0.795 ± 0.014) direct-quadratic (rel acc: 0.653 ± 0.060) attention-d vocab d model 2 , direct-quadratic (rel acc: 0.628 ± 0.060) attention-quadratic (rel acc: 0.390 ± 0.032) attention-quadratic, direct-quadratic (rel acc: 0.286 ± 0.036) Figure 22:Recreations of Figure 3 for ease of viewing of the legend. Top is a strict recreation; bottom includes points not on the Pareto frontier. If we replace the entries ofE ⊥ q,2 ,E ⊥ k,2 ,Q ⊥ , andK ⊥ with randomly sampled values, we get (sample size 100) that the maximum row-diff of the product of the matrices is approximately1.31±0.13 (sampling without replacement from the empirical distribution) or1.31±0.14(sampling from the normal distribution). So in fact our max row-diff is unusually high (by about4σ). 25 H.2 How various combinations of tricks perform Recall Figures 3 and 4 on page 8 and on page 9 from Section 5, recapitulated here without captions for convenience as Figures 22, and 23. We describe what each subcubic proof strategy in the legend means. Note that all subcubic proof strategies (that is, all proof strategies except for “brute force” and “cubic”) use the quadratic counting algorithm of Appendix F. H.2.1 Proof strategies grouped by complexity In Figures 3, and 22, proof strategies are grouped by computational complexity. The 102 proof strategies break down into1 + 1 + 2×5×10×2strategies. Thebrute forceandcubicproofs (1 + 1) were fully covered in Appendices D, and E. There are 5 options for handlingEU: direct-quadraticrefers to handlingEUin timeO(d vocab d model )with either the max row-diff trick (Appendix G.2.2) 26 or the max row-diff trick fused with mean+diff or some other summary statistic (Theorem 24) 27 . Whendirectis not mentioned, this indicates that we handleEUin timeO(d vocab 2 d model )by first multiplying outE q Uand then either taking the maximum row-diff in each row 28 or by taking the maximum row-diff across all rows 29 . The latter is included purely for comparison’s sake, and never gives a tighter bound than the former. There are 10 options for handling the high-rank attention error termEQKE_err: 25 This shows up in the bias towards having larger values (both positive and negative) in the lower-right corner of the plot, indicating that errors are larger for larger query and key values. We hypothesize that this is due to the distribution of data: larger values are more likely to have more space between the maximum and next-most-maximum token, so a bit of noise matters less for larger maxes than for smaller ones. 26 This strategy is labeled “max_diff” in the Python source code. 27 These strategies are labeled “mean_query+max_diff” and “svd_query+max_diff” in the Python source code 28 “max_diff_exact” 29 “global_max_diff_exact” 59 4006008001000 0.0 0.2 0.4 0.6 0.8 1.0 EPQKE Singular Ratio:σ 1 /σ 2 Normalized Accuracy Bound max-diff-exact mean+max-diff-subproduct max-diff-subproduct max-diff mean+max-diff svd mean+max-diff-subproduct-recursive max-diff-subproduct-recursive mean-recursive+max-diff-subproduct-recursive Figure 23:Recreation of Figure 4 for ease of viewing of the legend. attention-quadraticrefers to handling the high-rank attention error termEQKE_errfrom Ap- pendix G.2.5 in timeO(d vocab d model )either with the recursive max row-diff trick (Theorem 23) 30 or with the recursive max row-diff trick fused with the mean+diff trick either just on the query side 31 or throughout 32 (Theorem 26). attention-d vocab d model 2 indicates that we use one of the variousO(d vocab d model 2 )strategies for handlingEQKE_err 1 from Appendix G.2.4 orEQKE_errfrom Appendix G.2.5. These include using √ 2σ 1 —computed via low-rank SVD—as the bound (Equation 11) 33 , considering all ways of multiplying out a subset of the matrices and taking the maximum row-diff of the resulting pair of matrices 34 (Theorem 21), or fusing the max row-diff trick with the mean+diff trick 35 (Theorem 24). Whenattention is not mentioned, this indicates that we handle the attention error term in time O(d vocab 2 d model ) , either by taking the per-row maximum row-diff 36 or by using the full rankEQKE matrix and taking the per-row maximum row diff 37 . Finally, note that in combining the rank one attention computation withEVOU,PVOU, andEU, we may either use the mean+diff trick 38 (Appendix G.1) or not 39 ; this makes up the final factor of2. H.2.2 Proof strategies grouped by attention handling This section slightly reorganizes the information just covered in Appendix H.2.1, for convenience of legend correspondence. Here we group by the strategy used to handle the attention error term. Strategies that involve using the full rankEQKEmatrix are elided. The dashed descriptors here correspond to underscore-joined descriptors in footnotes of Appendix H.2.1. max-diff-exact(O(d vocab 2 d model )) corresponds to taking the full rankEQKE_err 1 term and taking the maximum row-diff in each row. mean+max-diff-subproduct(O(d vocab d model )) corresponds to fusing the max row-diff trick with the mean+diff trick (Theorem 24) and considering all ways of associating the multiplication of EQKE_err. max-diff-subproduct(O(d vocab d model )) corresponds to using the max row-diff trick (Theorem 21) and considering all ways of associating the multiplication ofEQKE_err. max-diff(O(d vocab d model 2 )) corresponds to using the max row-diff trick (Theorem 21) on the factored SVD ofEQKE_err 1 . mean+max-diff(O(d vocab d model 2 )) corresponds to fusing the max row-diff trick with the mean+diff trick (Theorem 24) and applying it on the factored SVD ofEQKE_err 1 . 30 “max_diff_subproduct_recursive” 31 “mean+max_diff_subproduct_recursive” 32 “mean_recursive+max_diff_subproduct_recursive” 33 “svd” 34 “max_diff” forEQKE_err 1 , “max_diff_subproduct” forEQKE_err 35 “mean+max_diff” forEQKE_err 1 , “mean+max_diff_subproduct” forEQKE_err 36 “max_diff_exact” 37 “exact_EQKE+max_diff_exact” 38 “mean_query+diff” 39 “drop_average_query_per_output_logit_reasoning” 60 svd(O(d vocab d model 2 )) corresponds to using √ 2σ 1 —computed via low-rank SVD—as the bound (Equation 11). mean+max-diff-subproduct-recursive(O(d vocab d model )) corresponds to handling the high-rank attention error termEQKE_errfrom Appendix G.2.5 with the recursive max row-diff trick fused with the mean+diff trick on the query-side only (Theorem 26, taking all but the first summary vector to be zero). max-diff-subproduct-recursive(O(d vocab d model )) corresponds to handling the high-rank attention error termEQKE_errfrom Appendix G.2.5 with the recursive max row-diff trick (Theorem 23). mean-recursive+max-diff-subproduct-recursive(O(d vocab d model )) corresponds to handling the high-rank attention error termEQKE_errfrom Appendix G.2.5 with the recursive max row-diff trick recursively fused with the mean+diff trick (Theorem 26). H.2.3 What understanding do we get from each proof strategy? Throughout most of this paper, we talk about doing mechanistic interpretability and using understand- ing to allow more compact proofs to have tighter bounds. We can also look at the reverse problem: we can take a collection of proof strategies, check by brute force which strategies give the tightest bounds for each model, and ask what this implies about how that model works. We do this here. In general, which proof methods perform best is an indication of where structure exists in the model.For example, in quadraticEUproofs, whenmax_diffperforms worse than mean_query+max_diffandsvd_query+max_diff, this indicates thatEhas a relatively strong behavioral component shared across query tokens thatUis not that good at filtering out. Sim- ilarly, when, e.g.,mean_recursive+max_diff_subproduct_recursiveperforms better than max_diff_subproduct_recursive, this indicates that even after removing the first one or two principle components fromE q ,Q,K, and ̄ E , there is still enough common structure that it is worth factoring out the mean behavior. 61