← Back to papers

Paper deep dive

AtP*: An efficient and scalable method for localizing LLM behaviour to components

János Kramár, Tom Lieberum, Rohin Shah, Neel Nanda

Year: 2024Venue: arXiv preprintArea: Mechanistic Interp.Type: EmpiricalEmbeddings: 251

Models: Pythia-12B, Pythia-1B, Pythia-2.8B, Pythia-410M, Pythia-6.9B

Intelligence

Status: succeeded | Model: google/gemini-3.1-flash-lite-preview | Prompt: intel-v1 | Confidence: 94%

Last extracted: 3/12/2026, 8:09:54 PM

Summary

The paper introduces AtP*, an improved, scalable method for localizing Large Language Model (LLM) behavior to specific components. It addresses failure modes in the existing Attribution Patching (AtP) method, specifically false negatives caused by attention saturation and cancellation effects, by recomputing attention softmax for queries/keys and using dropout in the backward pass. AtP* demonstrates superior performance in identifying causally important nodes compared to existing approximation methods.

Entities (5)

Activation Patching · method · 100%AtP* · method · 100%Attribution Patching · method · 100%LLMs · technology · 95%Mechanistic Interpretability · field · 95%

Relation Signals (3)

Attribution Patching approximates Activation Patching

confidence 95% · Attribution Patching (AtP) (Nanda, 2022), a fast gradient-based approximation to Activation Patching

AtP* improves Attribution Patching

confidence 95% · We propose a variant of AtP called AtP*, with two changes to address these failure modes while retaining scalability.

AtP* localizes LLM behavior

confidence 90% · An efficient and scalable method for localizing LLM behaviour to components

Cypher Suggestions (2)

Identify improvements made to Attribution Patching · confidence 90% · unvalidated

MATCH (a:Method {name: 'AtP*'})-[:IMPROVES]->(b:Method {name: 'Attribution Patching'}) RETURN a, b

Find all methods related to causal attribution in LLMs · confidence 80% · unvalidated

MATCH (m:Method)-[:USED_FOR]->(a:Task {name: 'Causal Attribution'}) RETURN m.name

Abstract

Abstract:Activation Patching is a method of directly computing causal attributions of behavior to model components. However, applying it exhaustively requires a sweep with cost scaling linearly in the number of model components, which can be prohibitively expensive for SoTA Large Language Models (LLMs). We investigate Attribution Patching (AtP), a fast gradient-based approximation to Activation Patching and find two classes of failure modes of AtP which lead to significant false negatives. We propose a variant of AtP called AtP*, with two changes to address these failure modes while retaining scalability. We present the first systematic study of AtP and alternative methods for faster activation patching and show that AtP significantly outperforms all other investigated methods, with AtP* providing further significant improvement. Finally, we provide a method to bound the probability of remaining false negatives of AtP* estimates.

Tags

ai-safety (imported, 100%)empirical (suggested, 88%)mechanistic-interp (suggested, 92%)

Links

Your browser cannot display the PDF inline. Open PDF directly →

Full Text

250,353 characters extracted from source content.

Expand or collapse full text

janosk@google.com AtP*^*start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT: An efficient and scalable method for localizing LLM behaviour to components János Kramár Google DeepMind Tom Lieberum Google DeepMind Rohin Shah Google DeepMind Neel Nanda Google DeepMind Abstract Activation Patching is a method of directly computing causal attributions of behavior to model components. However, applying it exhaustively requires a sweep with cost scaling linearly in the number of model components, which can be prohibitively expensive for SoTA Large Language Models (LLMs). We investigate Attribution Patching (AtP) (Nanda, 2022), a fast gradient-based approximation to Activation Patching and find two classes of failure modes of AtP which lead to significant false negatives. We propose a variant of AtP called AtP*^*start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT, with two changes to address these failure modes while retaining scalability. We present the first systematic study of AtP and alternative methods for faster activation patching and show that AtP significantly outperforms all other investigated methods, with AtP*^*start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT providing further significant improvement. Finally, we provide a method to bound the probability of remaining false negatives of AtP*^*start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT estimates. 1 Introduction As LLMs become ubiquitous and integrated into numerous digital applications, it’s an increasingly pressing research problem to understand the internal mechanisms that underlie their behaviour – this is the problem of mechanistic interpretability. A fundamental subproblem is to causally attribute particular behaviours to individual parts of the transformer forward pass, corresponding to specific components (such as attention heads, neurons, layer contributions, or residual streams), often at specific positions in the input token sequence. This is important because in numerous case studies of complex behaviours, they are found to be driven by sparse subgraphs within the model (Olsson et al., 2022; Wang et al., 2022; Meng et al., 2023). A classic form of causal attribution uses zero-ablation, or knock-out, where a component is deleted and we see if this negatively affects a model’s output – a negative effect implies the component was causally important. More recent work has generalised this to replacing a component’s activations with samples from some baseline distribution (with zero-ablation being a special case where activations are resampled to be zero). We focus on the popular and widely used method of Activation Patching (also known as causal mediation analysis) (Geiger et al., 2022; Meng et al., 2023; Chan et al., 2022) where the baseline distribution is a component’s activations on some corrupted input, such as an alternate string with a different answer (Pearl, 2001; Robins and Greenland, 1992). Given a causal attribution method, it is common to sweep across all model components, directly evaluating the effect of intervening on each of them via resampling (Meng et al., 2023). However, when working with SoTA models it can be expensive to attribute behaviour especially to small components (e.g. heads or neurons) – each intervention requires a separate forward pass, and so the number of forward passes can easily climb into the millions or billions. For example, on a prompt of length 1024, there are 2.7⋅109⋅2.7superscript1092.7· 10^92.7 ⋅ 109 neuron nodes in Chinchilla 70B (Hoffmann et al., 2022). We propose to accelerate this process by using Attribution Patching (AtP) (Nanda, 2022), a faster, approximate, causal attribution method, as a prefiltering step: after running AtP, we iterate through the nodes in decreasing order of absolute value of the AtP estimate, then use Activation Patching to more reliably evaluate these nodes and filter out false positives – we call this verification. We typically care about a small set of top contributing nodes, so verification is far cheaper than iterating over all nodes. Our contributions: • We investigate the performance of AtP, finding two classes of failure modes which produce false negatives. We propose a variant of AtP called AtP*^*start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT, with two changes to address these failure modes while retaining scalability: – When patching queries and keys, recomputing the attention softmax and using a gradient based approximation from then on, as gradients are a poor approximation to saturated attention. – Using dropout on the backwards pass to fix brittle false negatives, where significant positive and negative effects cancel out. • We introduce several alternative methods to approximate Activation Patching as baselines to AtP which outperform brute force Activation Patching. • We present the first systematic study of AtP and these alternatives and show that AtP significantly outperforms all other investigated methods, with AtP*^*start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT providing further significant improvement. • To estimate the residual error of AtP*^*start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT and statistically bound the sizes of any remaining false negatives we provide a diagnostic method, based on using AtP to filter out high impact nodes, and then patching random subsets of the remainder. Good diagnostics mean that practitioners may still gauge whether AtP is reliable in relevant domains without the costs of exhaustive verification. Finally, we provide some guidance in Section 5.4 on how to successfully perform causal attribution in practice and what attribution methods are likely to be useful and under what circumstances. (a) MLP neurons, on CITY-P. (b) Attention nodes, on IOI-P. Figure 1: Costs of finding the most causally-important nodes in Pythia-12B using different methods, on sample prompt pairs (see Table 1). The shading indicates geometric standard deviation. Cost is measured in forward passes, thus each point’s y-coordinate gives the number of forward passes required to find the top x nodes. Note that each node must be verified, thus y≥xy≥ xy ≥ x, so all lines are above the diagonal, and an oracle for the verification order would produce the diagonal line. For a detailed description see Section 4.3. (a) MLP neurons, on CITY-P. (b) Attention nodes, on IOI-P. Figure 2: Relative costs of methods across models, on sample prompt pairs. The costs are relative to having an oracle, which would verify nodes in decreasing order of true contribution size. Costs are aggregated using an inverse-rank-weighted geometric mean. This means they correspond to the area above the diagonal for each curve in Figure 1 and are relative to the area under the dotted (oracle) line. See Section 4.2 for more details on this metric. Note that GradDrop (difference between AtP+QKfix and AtP*^*start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT) comes with a noticeable upfront cost and so looks worse in this comparison while still helping avoid false negatives as shown inFigure 1. 2 Background 2.1 Problem Statement Our goal is to identify the contributions to model behavior by individual model components. We first formalize model components, then formalize model behaviour, and finally state the contribution problem in causal language. While we state the formalism in terms of a decoder-only transformer language model (Vaswani et al., 2017; Radford et al., 2018), and conduct all our experiments on models of that class, the formalism is also straightforwardly applicable to other model classes. Model components. We are given a model ℳ:X→ℝV:ℳ→superscriptℝM:X ^VM : X → blackboard_RV that maps a prompt (token sequence) x∈X:=1,…,VTassignsuperscript1…x∈ X:=\1,…,V\^Tx ∈ X := 1 , … , V T to output logits over a set of V tokens, aiming to predict the next token in the sequence. We will view the model ℳMM as a computational graph (N,E)(N,E)( N , E ) where the node set N is the set of model components, and a directed edge e=(n1,n2)∈Esubscript1subscript2e=(n_1,n_2)∈ Ee = ( n1 , n2 ) ∈ E is present iff the output of n1subscript1n_1n1 is a direct input into the computation of n2subscript2n_2n2. We will use n⁢(x)n(x)n ( x ) to represent the activation (intermediate computation result) of n when computing ℳ⁢(x)ℳM(x)M ( x ). The choice of N determines how fine-grained the attribution will be. For example, for transformer models, we could have a relatively coarse-grained attribution where each layer is considered a single node. In this paper we will primarily consider more fine-grained attributions that are more expensive to compute (see Section 4 for details); we revisit this issue in Section 5. Model behaviour. Following past work (Geiger et al., 2022; Chan et al., 2022; Wang et al., 2022), we assume a distribution DD over pairs of inputs xclean,xnoisesuperscriptcleansuperscriptnoisex^clean,x^noisexclean , xnoise, where xcleansuperscriptcleanx^cleanxclean is a prompt on which the behaviour occurs, and xnoisesuperscriptnoisex^noisexnoise is a reference prompt which we use as a source of noise to intervene with111This precludes interventions which use activation values that are never actually realized, such as zero-ablation or mean ablation. An alternative formulation via distributions of activation values is also possible.. We are also given a metric222Common metrics in language models are next token prediction loss, difference in log prob between a correct and incorrect next token, probability of the correct next token, etc. ℒ:ℝV→ℝ:ℒ→superscriptℝL:R^V : blackboard_RV → blackboard_R, which quantifies the behaviour of interest. Contribution of a component. Similarly to the work referenced above we define the contribution c⁢(n)c(n)c ( n ) of a node n to the model’s behaviour as the counterfactual absolute333The sign of the impact may be of interest, but in this work we’l focus on the magnitude, as a measure of causal importance. expected impact of replacing that node on the clean prompt with its value on the reference prompt xnoisesuperscriptnoisex^noisexnoise. Using do-calculus notation (Pearl, 2000) this can be expressed as c⁢(n):=|ℐ⁢(n)|assignℐc(n):=|I(n)|c ( n ) := | I ( n ) |, where ℐ⁢(n)ℐ (n)I ( n ) :=(xclean,xnoise)∼⁢[ℐ⁢(n;xclean,xnoise)],assignabsentsubscriptsimilar-tosuperscriptcleansuperscriptnoisedelimited-[]ℐsuperscriptcleansuperscriptnoise :=E_(x^clean,x^noise) % [I(n;x^clean,x^noise) ],:= blackboard_E( xclean , xnoise ) ∼ D [ I ( n ; xclean , xnoise ) ] , (1) where we define the intervention effect ℐII for xclean,xnoisesuperscriptcleansuperscriptnoisex^clean,x^noisexclean , xnoise as ℐ⁢(n;xclean,xnoise)ℐsuperscriptcleansuperscriptnoise (n;x^clean,x^noise)I ( n ; xclean , xnoise ) :=ℒ⁢(ℳ⁢(xclean∣do⁡(n←n⁢(xnoise))))−ℒ⁢(ℳ⁢(xclean)).assignabsentℒℳconditionalsuperscriptcleando←superscriptnoiseℒℳsuperscriptclean :=L(M(x^clean (n% ← n(x^noise))))-L(M(x^clean)).:= L ( M ( xclean ∣ do ( n ← n ( xnoise ) ) ) ) - L ( M ( xclean ) ) . (2) Note that the need to average the effect across a distribution adds a potentially large multiplicative factor to the cost of computing c⁢(n)c(n)c ( n ), further motivating this work. We can also intervene on a set of nodes η=nisubscriptη=\n_i\η = nitalic_i . To do so, we overwrite the values of all nodes in η with their values from a reference prompt. Abusing notation, we write η⁢(x)η(x)η ( x ) as the set of activations of the nodes in η, when computing ℳ⁢(x)ℳM(x)M ( x ). ℐ⁢(η;xclean,xnoise)ℐsuperscriptcleansuperscriptnoise (η;x^clean,x^noise)I ( η ; xclean , xnoise ) :=ℒ⁢(ℳ⁢(xclean∣do⁡(η←η⁢(xnoise))))−ℒ⁢(ℳ⁢(xclean))assignabsentℒℳconditionalsuperscriptcleando←superscriptnoiseℒℳsuperscriptclean :=L(M(x^clean (% η←η(x^noise))))-L(M(x^clean% )):= L ( M ( xclean ∣ do ( η ← η ( xnoise ) ) ) ) - L ( M ( xclean ) ) (3) We note that it is also valid to define contribution as the expected impact of replacing a node on the reference prompt with its value on the clean prompt, also known as denoising or knock-in. We follow Chan et al. (2022); Wang et al. (2022) in using noising, however denoising is also widely used in the literature (Meng et al., 2023; Lieberum et al., 2023). We briefly consider how this choice affects AtP in Section 5.2. 2.2 Attribution Patching On state of the art models, computing c⁢(n)c(n)c ( n ) for all n can be prohibitively expensive as there may be billions or more nodes. Furthermore, to compute this value precisely requires evaluating it on all prompt pairs, thus the runtime cost of Equation 1 for each n scales with the size of the support of DD. We thus turn to a fast approximation of Equation 1. As suggested by Nanda (2022); Figurnov et al. (2016); Molchanov et al. (2017), we can make a first-order Taylor expansion to ℐ⁢(n;xclean,xnoise)ℐsuperscriptcleansuperscriptnoiseI(n;x^clean,x^noise)I ( n ; xclean , xnoise ) around n⁢(xnoise)≈n⁢(xclean)superscriptnoisesuperscriptcleann(x^noise)≈ n(x^clean)n ( xnoise ) ≈ n ( xclean ): ℐ^AtP⁢(n;xclean,xnoise)subscript^ℐAtPsuperscriptcleansuperscriptnoise I_AtP(n;x^clean,x^noise)over start_ARG I end_ARGAtP ( n ; xclean , xnoise ) :=(n⁢(xnoise)−n⁢(xclean))⊺⁢∂ℒ⁢(ℳ⁢(xclean))∂n|n=n⁢(xclean)assignabsentevaluated-atsuperscriptsuperscriptnoisesuperscriptclean⊺ℒℳsuperscriptcleansuperscriptclean :=(n(x^noise)-n(x^clean)) % (M(x^clean))∂ n |_n=n(x^% clean):= ( n ( xnoise ) - n ( xclean ) )⊺ divide start_ARG ∂ L ( M ( xclean ) ) end_ARG start_ARG ∂ n end_ARG |n = n ( xclean ) (4) Then, similarly to Syed et al. (2023), we apply this to a distribution by taking the absolute value inside the expectation in Equation 1 rather than outside; this decreases the chance that estimates across prompt pairs with positive and negative effects might erroneously lead to a significantly smaller estimate. (We briefly explore the amount of cancellation behaviour in the true effect distribution in Section B.2.) As a result, we get an estimate c^AtP⁢(n)subscript^AtP c_AtP(n)over start_ARG c end_ARGAtP ( n ) :=xclean,xnoise⁢[|ℐ^AtP⁢(n;xclean,xnoise)|].assignabsentsubscriptsuperscriptcleansuperscriptnoisedelimited-[]subscript^ℐAtPsuperscriptcleansuperscriptnoise :=E_x^clean,x^noise [ | % I_AtP(n;x^clean,x^noise) | ].:= blackboard_Exclean , xnoise [ | over start_ARG I end_ARGAtP ( n ; xclean , xnoise ) | ] . (5) This procedure is also called Attribution Patching (Nanda, 2022) or AtP. AtP requires two forward passes and one backward pass to compute an estimate score for all nodes on a given prompt pair, and so provides a very significant speedup over brute force activation patching. 3 Methods We now describe some failure modes of AtP and address them, yielding an improved method AtP*. We then discuss some alternative methods for estimating c⁢(n)c(n)c ( n ), to put AtP(*)’s performance in context. Finally we discuss how to combine Subsampling, one such alternative method described in Section 3.3, and AtP* to give a diagnostic to statistically test whether AtP* may have missed important false negatives. 3.1 AtP improvements We identify two common classes of false negatives occurring when using AtP. The first failure mode occurs when the preactivation on xcleansuperscriptcleanx^cleanxclean is in a flat region of the activation function (e.g. produces a saturated attention weight), but the preactivation on xnoisesuperscriptnoisex^noisexnoise is not in that region. As is apparent from Equation 4, AtP uses a linear approximation to the ground truth in Equation 1, so if the non-linear function is badly approximated by the local gradient, AtP ceases to be accurate – see Figure 3 for an illustration and Figure 4 which denotes in color the maximal difference in attention observed between prompt pairs, suggesting that this failure mode occurs in practice. Figure 3: A linear approximation to the attention probability is a particularly poor approximation in cases where one or both of the endpoints are in a saturated region of the softmax. Note that when varying only a single key, the softmax becomes a sigmoid of the dot product of that key and the query. Another, unrelated failure mode occurs due to cancellation between direct and indirect effects: roughly, if the total effect (on some prompt pair) is a sum of direct and indirect effects (Pearl, 2001) ℐ⁢(n)=ℐdirect⁢(n)+ℐindirect⁢(n)ℐsuperscriptℐdirectsuperscriptℐindirectI(n)=I^direct(n)+I^indirect(n)I ( n ) = Idirect ( n ) + Iindirect ( n ), and these are close to cancelling, then a small multiplicative approximation error in ℐ^AtPindirect⁢(n)superscriptsubscript^ℐAtPindirect I_AtP^indirect(n)over start_ARG I end_ARGAtPindirect ( n ), due to non-linearities such as GELU and softmax, can accidentally cause |ℐ^AtPdirect⁢(n)+ℐ^AtPindirect⁢(n)|superscriptsubscript^ℐAtPdirectsuperscriptsubscript^ℐAtPindirect| I_AtP^direct(n)+ I_% AtP^indirect(n)|| over start_ARG I end_ARGAtPdirect ( n ) + over start_ARG I end_ARGAtPindirect ( n ) | to be orders of magnitude smaller than |ℐ⁢(n)|ℐ|I(n)|| I ( n ) |. 3.1.1 False negatives from attention saturation AtP relies on the gradient at each activation being reflective of the true behaviour of the function with respect to intervention at that activation. In some cases, though, a node may immediately feed into a non-linearity whose effect may not be adequately predicted by the gradient; for example, attention key and query nodes feeding into the attention softmax non-linearity. To showcase this, we plot the true rank of each node’s effect against its rank assigned by AtP in Figure 4 (left). The plot shows that there are many pronounced false negatives (below the dashed line), especially among keys and queries. Normal activation patching for queries and keys involves changing a query or key and then re-running the rest of the model, keeping all else the same. AtP takes a linear approximation to the entire rest of the model rather than re-running it. We propose explicitly re-computing the first step of the rest of the model, i.e. the attention softmax, and then taking a linear approximation to the rest. Formally, for attention key and query nodes, instead of using the gradient on those nodes directly, we take the difference in attention weight caused by that key or query, multiplied by the gradient on the attention weights themselves. This requires finding the change in attention weights from each key and query patch — but that can be done efficiently using (for all keys and queries in total) less compute than two transformer forward passes. This correction avoids the problem of saturated attention, while otherwise retaining the performance of AtP. Queries For the queries, we can easily compute the adjusted effect by running the model on xnoisesuperscriptnoisex^noisexnoise and caching the noise queries. We then run the model on xcleansuperscriptcleanx^cleanxclean and cache the attention keys and weights. Finally, we compute the attention weights that result from combining all the keys from the xcleansuperscriptcleanx^cleanxclean forward pass with the queries from the xnoisesuperscriptnoisex^noisexnoise forward pass. This costs approximately as much as the unperturbed attention computation of the transformer forward pass. For each query node n we refer to the resulting weight vector as attn(n)patchattn(n)_patchattn ( n )patch, in contrast with the weights attn⁡(n)⁢(xclean)attnsuperscriptcleanattn(n)(x^clean)attn ( n ) ( xclean ) from the clean forward pass. The improved attribution estimate for n is then ℐ^AtPfixQ⁢(n;xclean,xnoise):=assignsubscriptsuperscript^ℐAtPfixsuperscriptcleansuperscriptnoiseabsent I^Q_AtPfix(n;x^clean,x % noise):=over start_ARG I end_ARGQAtPfix ( n ; xclean , xnoise ) := ∑kℐ^AtP⁢(attn⁢(n)k;xclean,xnoise)subscriptsubscript^ℐAtPattnsubscriptsuperscriptcleansuperscriptnoise _k I_AtP(attn(n)_k;x^% clean,x^noise)∑k over start_ARG I end_ARGAtP ( attn ( n )k ; xclean , xnoise ) (6) = == (attn(n)patch−attn(n)(xclean))⊺∂ℒ⁢(ℳ⁢(xclean))∂attn⁡(n)|attn⁡(n)=attn⁡(n)⁢(xclean) (attn(n)_patch-attn(n)(x^% clean)) (M(x^% clean)) (n) |_attn(n)=% attn(n)(x^clean)( attn ( n )patch - attn ( n ) ( xclean ) )⊺ divide start_ARG ∂ L ( M ( xclean ) ) end_ARG start_ARG ∂ attn ( n ) end_ARG |attn ( n ) = attn ( n ) ( xclean ) (7) Keys For the keys we first describe a simple but inefficient method. We again run the model on xnoisesuperscriptnoisex^noisexnoise, caching the noise keys. We also run it on xcleansuperscriptcleanx^cleanxclean, caching the clean queries and attention probabilities. Let key nodes for a single attention head be n1k,…,nTksubscriptsuperscript1…subscriptsuperscriptn^k_1,…,n^k_Tnitalic_k1 , … , nitalic_kitalic_T and let queries⁡(ntk)=n1q,…,nTqqueriessubscriptsuperscriptsubscriptsuperscript1…subscriptsuperscriptqueries(n^k_t)=\n^q_1,…,n^q_T\queries ( nitalic_kitalic_t ) = nitalic_q1 , … , nitalic_qitalic_T be the set of query nodes for the same head as node ntksubscriptsuperscriptn^k_tnitalic_kitalic_t. We then define attnpatcht⁡(nq)superscriptsubscriptattnpatchsuperscript _patch^t(n^q)attnpatchitalic_t ( nitalic_q ) :=attn⁡(nq)⁢(xclean∣do⁡(ntk←ntk⁢(xnoise)))assignabsentattnsuperscriptconditionalsuperscriptcleando←subscriptsuperscriptsubscriptsuperscriptsuperscriptnoise :=attn(n^q)(x^clean % (n^k_t← n^k_t(x^noise))):= attn ( nitalic_q ) ( xclean ∣ do ( nitalic_kitalic_t ← nitalic_kitalic_t ( xnoise ) ) ) (8) Δt⁢attn⁡(nq)subscriptΔattnsuperscript _tattn(n^q)Δitalic_t attn ( nitalic_q ) :=attnpatcht⁡(nq)−attn⁡(nq)⁢(xclean)assignabsentsuperscriptsubscriptattnpatchsuperscriptattnsuperscriptsuperscriptclean :=attn_patch^t(n^q)-% attn(n^q)(x^clean):= attnpatchitalic_t ( nitalic_q ) - attn ( nitalic_q ) ( xclean ) (9) The improved attribution estimate for ntksubscriptsuperscriptn^k_tnitalic_kitalic_t is then ℐ^AtPfixK⁢(ntk;xclean,xnoise)superscriptsubscript^ℐAtPfixsubscriptsuperscriptsuperscriptcleansuperscriptnoise I_AtPfix^K(n^k_t;x^clean,% x^noise)over start_ARG I end_ARGAtPfixK ( nitalic_kitalic_t ; xclean , xnoise ) :=∑nq∈queries⁡(ntk)Δtattn(nq)⊺∂ℒ⁢(ℳ⁢(xclean))∂attn⁡(nq)|attn⁡(nq)=attn⁡(nq)⁢(xclean) := _n^q (n^k_t) _t% attn(n^q) (M(x^% clean)) (n^q) |_% attn(n^q)=attn(n^q)(x^clean):= ∑nitalic_q ∈ queries ( nitalic_k start_POSTSUBSCRIPT t ) end_POSTSUBSCRIPT Δitalic_t attn ( nitalic_q )⊺ divide start_ARG ∂ L ( M ( xclean ) ) end_ARG start_ARG ∂ attn ( nitalic_q ) end_ARG |attn ( nitalic_q ) = attn ( nitalic_q ) ( xclean ) (10) However, the procedure we just described is costly to execute as it requires O⁡(T3)Osuperscript3O(T^3)O ( T3 ) flops to naively compute Equation 9 for all T keys. In Section A.2.1 we describe a more efficient variant that takes no more compute than the forward pass attention computation itself (requiring O⁡(T2)Osuperscript2O(T^2)O ( T2 ) flops). Since Equation 6 is also cheaper to compute than a forward pass, the full QK fix requires less than two transformer forward passes (since the latter also includes MLP computations). For attention nodes we show the effects of applying the query and key fixes in Figure 4 (middle). We observe that the propagation of Q/K effects has a major impact on reducing the false negative rate. Figure 4: Ranks of c⁢(n)c(n)c ( n ) against ranks of c^AtP⁢(n)subscript^AtP c_AtP(n)over start_ARG c end_ARGAtP ( n ), on Pythia-12B on CITY-P. Both improvements to AtP reduce the number of false negatives (bottom right triangle area), where in this case most improvements come from the QK fix. Coloration indicates the maximum absolute difference in attention probability when comparing xcleansuperscriptcleanx^cleanxclean and patching a given query or key. Many false negatives are keys and queries with significant maximum difference in attention probability, suggesting they are due to attention saturation as illustrated in Figure 3. Output and value nodes are colored in grey as they do not contribute to the attention probability. 3.1.2 False negatives from cancellation This form of cancellation occurs when the backpropagated gradient from indirect effects is combined with the gradient from the direct effect. We propose a way to modify the backpropagation within the attribution patching to reduce this issue. If we artificially zero out the gradient at a downstream layer that contributes to the indirect effect, the cancellation is disrupted. (This is also equivalent to patching in clean activations at the outputs of the layer.) Thus we propose to do this iteratively, sweeping across the layers. Any node whose effect does not route through the layer being gradient-zeroed will have its estimate unaffected. We call this method GradDrop. For every layer ℓ∈1,…,Lℓ1… ∈\1,…,L\ℓ ∈ 1 , … , L in the model, GradDrop computes an AtP estimate for all nodes, where gradients on the residual contribution from ℓ ℓ are set to 0, including the propagation to earlier layers. This provides a different estimate for all nodes, for each layer that was dropped. We call the so-modified gradient ∂ℒℓ∂n=∂ℒ∂n⁢(ℳ⁢(xclean∣do⁡(nℓout←nℓout⁢(xclean))))superscriptℒℓℒℳconditionalsuperscriptcleando←subscriptsuperscriptoutℓsubscriptsuperscriptoutℓsuperscriptclean ∂ n= % ∂ n(M(x^clean (n^out_% ← n^out_ (x^clean))))divide start_ARG ∂ Lroman_ℓ end_ARG start_ARG ∂ n end_ARG = divide start_ARG ∂ L end_ARG start_ARG ∂ n end_ARG ( M ( xclean ∣ do ( noutroman_ℓ ← noutroman_ℓ ( xclean ) ) ) ) when dropping layer ℓ ℓ, where nℓoutsubscriptsuperscriptoutℓn^out_ noutroman_ℓ is the contribution to the residual stream across all positions. Using ∂ℒℓ∂nsuperscriptℒℓ ∂ ndivide start_ARG ∂ Lroman_ℓ end_ARG start_ARG ∂ n end_ARG in place of ∂ℒℓ∂nsuperscriptℒℓ ∂ ndivide start_ARG ∂ Lroman_ℓ end_ARG start_ARG ∂ n end_ARG in the AtP formula produces an estimate ℐ^AtP+GDℓ⁢(n)subscript^ℐsubscriptAtP+GDℓ I_AtP+GD_ (n)over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ). Then, the estimates are aggregated by averaging their absolute values, and then scaling by L−11 LL-1divide start_ARG L end_ARG start_ARG L - 1 end_ARG to avoid changing the direct-effect path’s contribution (which is otherwise zeroed out when dropping the layer the node is in). c^AtP+GD⁢(n)subscript^AtP+GD c_AtP+GD(n)over start_ARG c end_ARGAtP+GD ( n ) :=xclean,xnoise⁢[1L−1⁢∑ℓ=1L|ℐ^AtP+GDℓ⁢(n;xclean,xnoise)|]assignabsentsubscriptsuperscriptcleansuperscriptnoisedelimited-[]11superscriptsubscriptℓ1subscript^ℐsubscriptAtP+GDℓsuperscriptcleansuperscriptnoise :=E_x^clean,x^noise [ 1L-% 1 _ =1^L | I_AtP+GD_ (n;x^% clean,x^noise) | ]:= blackboard_Exclean , xnoise [ divide start_ARG 1 end_ARG start_ARG L - 1 end_ARG ∑ℓ = 1L | over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ; xclean , xnoise ) | ] (11) Note that the forward passes required for computing ℐ^AtP+GDℓ⁢(n;xclean,xnoise)subscript^ℐsubscriptAtP+GDℓsuperscriptcleansuperscriptnoise I_AtP+GD_ (n;x^clean,x^noise)over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ; xclean , xnoise ) don’t depend on ℓ ℓ, so the extra compute needed for GradDrop is L backwards passes from the same intermediate activations on a clean forward pass. This is also the case with the QK fix: the corrected attributions ℐ^AtPfixsubscript^ℐAtPfix I_AtPfixover start_ARG I end_ARGAtPfix are dot products with the attention weight gradients, so the only thing that needs to be recomputed for ℐ^AtPfix+GDℓ⁢(n)subscript^ℐsubscriptAtPfix+GDℓ I_AtPfix+GD_ (n)over start_ARG I end_ARGAtPfix+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ) is the modified gradient ∂ℒℓ∂attn⁡(n)superscriptℒℓattn (n)divide start_ARG ∂ Lroman_ℓ end_ARG start_ARG ∂ attn ( n ) end_ARG. Thus, computing Equation 11 takes L backwards passes444This can be reduced to (L+1)/212(L+1)/2( L + 1 ) / 2 by reusing intermediate results. on top of the costs for AtP. We show the result of applying GradDrop on attention nodes in Figure 4 (right) and on MLP nodes in Figure 5. In Figure 5, we show the true effect magnitude rank against the AtP+GradDrop rank, while highlighting nodes which improved drastically by applying GradDrop. We give some arguments and intuitions on the benefit of GradDrop in Section A.2.2. Direct Effect Ratio To provide some evidence that the observed false negatives are due to cancellation, we compute the ratio between the direct effect cdirect⁢(n)superscriptdirectc^direct(n)cdirect ( n ) and the total effect c⁢(n)c(n)c ( n ). A higher direct effect ratio indicates more cancellation. We observe that the most significant false negatives corrected by GradDrop in Figure 5 (highlighted) have high direct effect ratios of 5.355.355.355.35, 12.212.212.212.2, and 00 (no direct effect) , while the median direct effect ratio of all nodes is 00 (if counting all nodes) or 0.770.770.770.77 (if only counting nodes that have direct effect). Note that direct effect ratio is only applicable to nodes which in fact have a direct connection to the output, and not e.g. to MLP nodes at non-final token positions, since all disconnected nodes have a direct effect of 0 by definition. Figure 5: True rank and rank of AtP estimates with and without GradDrop, using Pythia-12B on the CITY-P distribution with NeuronNodes. GradDrop provides a significant improvement to the largest neuron false negatives (red circles) relative to Default AtP (orange crosses). 3.2 Diagnostics Despite the improvements we have proposed in Section 3.1, there is no guarantee that AtP* produces no false negatives. Thus, it is desirable to obtain an upper confidence bound on the effect size of nodes that might be missed by AtP*, i.e. that aren’t in the top K AtP* estimates, for some K. Let the top K nodes be TopA⁢t⁢P⁣*KsubscriptsuperscriptTopTop^K_AtP*TopKitalic_A t P *. It so happens that we can use subset sampling to obtain such a bound. As described in Algorithm 1 and Section 3.3, the subset sampling algorithm returns summary statistics: i¯±nsubscriptsuperscript¯plus-or-minus i^n_±over¯ start_ARG i end_ARGn±, s±nsubscriptsuperscriptplus-or-minuss^n_±sitalic_n± and count±nsubscriptsuperscriptcountplus-or-minuscount^n_±countn± for each node n: the average effect size i¯±nsubscriptsuperscript¯plus-or-minus i^n_±over¯ start_ARG i end_ARGn± of a subset conditional on the node being contained in that subset (+++) or not (−--), the sample standard deviations s±nsubscriptsuperscriptplus-or-minuss^n_±sitalic_n±, and the sample sizes count±nsubscriptsuperscriptcountplus-or-minuscount^n_±countn±. Given these, consider a null hypothesis555This is an unconventional form of H0subscript0H_0H0 – typically a null hypothesis will say that an effect is insignificant. However, the framework of statistical hypothesis testing is based on determining whether the data let us reject the null hypothesis, and in this case the hypothesis we want to reject is the presence, rather than the absence, of a significant false negative. H0nsuperscriptsubscript0H_0^nH0italic_n that |ℐ⁢(n)|≥θℐ|I(n)|≥θ| I ( n ) | ≥ θ, for some threshold θ, versus the alternative hypothesis H1nsuperscriptsubscript1H_1^nH1italic_n that |ℐ⁢(n)|<θℐ|I(n)|<θ| I ( n ) | < θ. We use a one-sided Welch’s t-test666This relies on the populations being approximately unbiased and normally distributed, and not skewed. This tended to be true on inspection, and it’s what the additivity assumption (see Section 3.3) predicts for a single prompt pair — but a nonparametric bootstrap test may be more reliable, at the cost of additional compute. to test this hypothesis; the general practice with a compound null hypothesis is to select the simple sub-hypothesis that gives the greatest p-value, so to be conservative, the simple null hypothesis is that ℐ⁢(n)=θ⁢sign⁡(i¯+n−i¯−n)ℐsignsubscriptsuperscript¯subscriptsuperscript¯I(n)= ( i^n_+- i^n_-)I ( n ) = θ sign ( over¯ start_ARG i end_ARGn+ - over¯ start_ARG i end_ARGn- ), giving a test statistic of tn=(θ−|i¯+n−i¯−n|)/sWelchnsuperscriptsubscriptsuperscript¯subscriptsuperscript¯subscriptsuperscriptWelcht^n=(θ-| i^n_+- i^n_-|)/s^n_Welchtitalic_n = ( θ - | over¯ start_ARG i end_ARGn+ - over¯ start_ARG i end_ARGn- | ) / sitalic_nWelch, which gives a p-value of pn=ℙT∼tνWelchn⁢(T>tn)superscriptsubscriptℙsimilar-tosubscriptsubscriptsuperscriptWelchsuperscriptp^n=P_T t_ν^n_Welch(T>t^n)pitalic_n = blackboard_PT ∼ t start_POSTSUBSCRIPT νitalic_n start_POSTSUBSCRIPT Welch end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( T > titalic_n ). To get a combined conclusion across all nodes in N∖TopA⁢t⁢P⁣*KsubscriptsuperscriptTopN ^K_AtP*N ∖ TopKitalic_A t P *, let’s consider the hypothesis H0=⋁n∈N∖TopA⁢t⁢P⁣*KH0nsubscript0subscriptsubscriptsuperscriptTopsuperscriptsubscript0H_0= _n∈ N ^K_AtP*H_0^nH0 = ⋁n ∈ N ∖ TopK start_POSTSUBSCRIPT A t P * end_POSTSUBSCRIPT H0italic_n that any of those nodes has true effect |ℐ⁢(n)|>θℐ|I(n)|>θ| I ( n ) | > θ. Since this is also a compound null hypothesis, maxn⁡pnsubscriptsuperscript _np^nmaxitalic_n pitalic_n is the corresponding p-value. Then, to find an upper confidence bound with specified confidence level 1−p11-p1 - p, we invert this procedure to find the lowest θ for which we still have at least that level of confidence. We repeat this for various settings of the sample size m in Algorithm 1. The exact algorithm is described in Section A.3. In Figure 6, we report the upper confidence bounds at confidence levels 90%, 99%, 99.9% from running Algorithm 1 with a given m (right subplots), as well as the number of nodes that have a true contribution c⁢(n)c(n)c ( n ) greater than θ (left subplots). (a) IOI-P (b) IOI Figure 6: Upper confidence bounds on effect magnitudes of false negatives (i.e. nodes not in the top 1024 nodes according to AtP*^*start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT), at 3 confidence levels, varying the sampling budget. On the left we show in red the true effect of the nodes which are ranked highest by AtP*^*start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT. We also show the true effect magnitude at various ranks of the remaining nodes in orange. We can see that the bound for (a) finds the true biggest false negative reasonably early, while for (b), where there is no large false negative, we progressively keep gaining confidence with more data. Note that the costs involved per prompt pair are substantially different between the subplots, and in particular this diagnostic for the distributional case (b) is substantially cheaper to compute than the verification cost of 1024 samples per prompt pair. 3.3 Baselines Iterative The most straightforward method is to directly do Activation Patching to find the true effect c⁢(n)c(n)c ( n ) of each node, in some uninformed random order. This is necessarily inefficient. However, if we are scaling to a distribution, it is possible to improve on this, by alternating between phases of (i) for each unverified node, picking a not-yet-measured prompt pair on which to patch it, (i) ranking the not-yet-verified nodes by the average observed patch effect magnitudes, taking the top |N|/|||N|/|D|| N | / | D | nodes, and verifying them. This balances the computational expenditure on the two tasks, and allows us to find large nodes sooner, at least as long as their large effect shows up on many prompt pairs. Our remaining baseline methods rely on an approximate node additivity assumption: that when intervening on a set of nodes η, the measured effect ℐ⁢(η;xclean,xnoise)ℐsuperscriptcleansuperscriptnoiseI(η;x^clean,x^noise)I ( η ; xclean , xnoise ) is approximately equal to ∑n∈ηℐ⁢(n;xclean,xnoise)subscriptℐsuperscriptcleansuperscriptnoise _n∈ηI(n;x^clean,x^noise)∑n ∈ η I ( n ; xclean , xnoise ). Subsampling Under the approximate node additivity assumption, we can construct an approximately unbiased estimator of c⁢(n)c(n)c ( n ). We select the sets ηksubscript _kηitalic_k to contain each node independently with some probability p, and additionally sample prompt pairs xkclean,xknoise∼similar-tosubscriptsuperscriptcleansubscriptsuperscriptnoisex^clean_k,x^noise_k _k , xnoiseitalic_k ∼ D. For any node n, and sets of nodes ηk⊂Nsubscript _k⊂ Nηitalic_k ⊂ N, let η+⁢(n)superscriptη^+(n)η+ ( n ) be the collection of all those that contain n, and η−⁢(n)superscriptη^-(n)η- ( n ) be the collection of those that don’t contain n; we’l write these node sets as ηk+⁢(n)subscriptsuperscriptη^+_k(n)η+k ( n ) and ηk−⁢(n)subscriptsuperscriptη^-_k(n)η-k ( n ), and the corresponding prompt pairs as xkclean+⁢(n),xknoise+⁢(n)superscriptsubscriptsuperscriptcleansuperscriptsubscriptsuperscriptnoisex^clean_k^+(n),x^noise_k^+(n)xcleanitalic_k+ ( n ) , xnoiseitalic_k+ ( n ) and xkclean−⁢(n),xknoise−⁢(n)superscriptsubscriptsuperscriptcleansuperscriptsubscriptsuperscriptnoisex^clean_k^-(n),x^noise_k^-(n)xcleanitalic_k- ( n ) , xnoiseitalic_k- ( n ). The subsampling (or subset sampling) estimator is then given by ℐ^S⁢(n)subscript^ℐS I_S(n)over start_ARG I end_ARGSS ( n ) :=1|η+⁢(n)|⁢∑k=1|η+⁢(n)|ℐ⁢(ηk+⁢(n);xkclean+⁢(n),xknoise+⁢(n))−1|η−⁢(n)|⁢∑k=1|η−⁢(n)|ℐ⁢(ηk−⁢(n);xkclean−⁢(n),xknoise−⁢(n))assignabsent1superscriptsuperscriptsubscript1superscriptℐsubscriptsuperscriptsuperscriptsubscriptsuperscriptcleansuperscriptsubscriptsuperscriptnoise1superscriptsuperscriptsubscript1superscriptℐsubscriptsuperscriptsuperscriptsubscriptsuperscriptcleansuperscriptsubscriptsuperscriptnoise := 1|η^+(n)| _k=1^|η^+(n)|I(% η^+_k(n);x^clean_k^+(n),x^noise_k^+(n))-% 1|η^-(n)| _k=1^|η^-(n)|I(η^-_k(n);% x^clean_k^-(n),x^noise_k^-(n)):= divide start_ARG 1 end_ARG start_ARG | η+ ( n ) | end_ARG ∑k = 1| η start_POSTSUPERSCRIPT + ( n ) | end_POSTSUPERSCRIPT I ( η+k ( n ) ; xcleanitalic_k+ ( n ) , xnoiseitalic_k+ ( n ) ) - divide start_ARG 1 end_ARG start_ARG | η- ( n ) | end_ARG ∑k = 1| η start_POSTSUPERSCRIPT - ( n ) | end_POSTSUPERSCRIPT I ( η-k ( n ) ; xcleanitalic_k- ( n ) , xnoiseitalic_k- ( n ) ) (12) c^S⁢(n)subscript^S c_S(n)over start_ARG c end_ARGSS ( n ) :=|ℐ^S⁢(n)|assignabsentsubscript^ℐS :=| I_S(n)|:= | over start_ARG I end_ARGSS ( n ) | (13) The estimator ℐ^S⁢(n)subscript^ℐS I_S(n)over start_ARG I end_ARGSS ( n ) is unbiased if there are no interaction effects, and has a small bias proportional to p under a simple interaction model (see Section A.1.1 for proof). In practice, we compute all the estimates c^S⁢(n)subscript^S c_S(n)over start_ARG c end_ARGSS ( n ) by sampling a binary mask over all nodes from i.i.d. Bernoulli(p)|N|^|N|(p)start_FLOATSUPERSCRIPT | N | end_FLOATSUPERSCRIPT ( p ) – each binary mask can be identified with a node set η. In Algorithm 1, we describe how to compute summary statistics related to Equation 13 efficiently for all nodes n∈Nn∈ Nn ∈ N. The means i¯±superscript¯plus-or-minus i^±over¯ start_ARG i end_ARG± are enough to compute c^S⁢(n)subscript^S c_S(n)over start_ARG c end_ARGSS ( n ), while other summary statistics are involved in bounding the magnitude of a false negative (cf. Section 3.2). (Note, countn±subscriptsuperscriptcountplus-or-minuscount^±_ncount±n is just an alternate notation for |η±⁢(n)|superscriptplus-or-minus|η^±(n)|| η± ( n ) |.) Algorithm 1 Subsampling 1:p∈(0,1)01p∈(0,1)p ∈ ( 0 , 1 ), model ℳMM, metric ℒLL, prompt pair distribution DD, num samples m 2:count±superscriptcountplus-or-minuscount^±count±, runSum±superscriptrunSumplus-or-minusrunSum^±runSum±, runSquaredSum±superscriptrunSquaredSumplus-or-minusrunSquaredSum^±runSquaredSum± ←0|N|←absentsuperscript0← 0^|N|← 0| N | ▷ ▷ Init counts and running sums to 0 vectors 3:for i←1⁢ to ⁢m←1 to i← 1 to mi ← 1 to m do 4: xclean,xnoise∼similar-tosuperscriptcleansuperscriptnoisex^clean,x^noise , xnoise ∼ D 5: mask+←Bernoulli|N|⁢(p)←superscriptmasksuperscriptBernoullimask^+ ^|N|(p)mask+ ← Bernoulli| N | ( p ) ▷ ▷ Sample binary mask for patching 6: mask−←1−mask+←superscriptmask1superscriptmaskmask^-← 1-mask^+mask- ← 1 - mask+ 7: i←ℐ⁢(n∈N:maskn+=1;xclean,xnoise)←ℐconditional-setsubscriptsuperscriptmask1superscriptcleansuperscriptnoisei (\n∈ N:mask^+_n=1\;x^clean,x^% noise)i ← I ( n ∈ N : mask+n = 1 ; xclean , xnoise )▷ ▷ η+=n∈N:maskn+=1superscriptconditional-setsubscriptsuperscriptmask1η^+=\n∈ N:mask^+_n=1\η+ = n ∈ N : mask+n = 1 8: count±←count±+mask±←superscriptcountplus-or-minussuperscriptcountplus-or-minussuperscriptmaskplus-or-minuscount^±\,←\,count^±+mask^±count± ← count± + mask± 9: runSum±←runSum±+i⋅mask±←superscriptrunSumplus-or-minussuperscriptrunSumplus-or-minus⋅superscriptmaskplus-or-minusrunSum^±\,←\,runSum^±+i·mask^±runSum± ← runSum± + i ⋅ mask± 10: runSquaredSum±←runSquaredSum±+i2⋅mask±←superscriptrunSquaredSumplus-or-minussuperscriptrunSquaredSumplus-or-minus⋅superscript2superscriptmaskplus-or-minusrunSquaredSum^±\,←\,runSquaredSum^±+i^2·% mask^±runSquaredSum± ← runSquaredSum± + i2 ⋅ mask± 11:i¯±←runSum±/count±←superscript¯plus-or-minussuperscriptrunSumplus-or-minussuperscriptcountplus-or-minus i^± ^±/count^±over¯ start_ARG i end_ARG± ← runSum± / count± 12:s±←(runSquaredSum±−(i¯±)2)/(count±−1)←superscriptplus-or-minussuperscriptrunSquaredSumplus-or-minussuperscriptsuperscript¯plus-or-minus2superscriptcountplus-or-minus1s^±← (runSquaredSum^±-( i^±)^2)/( % count^±-1)s± ← square-root start_ARG ( runSquaredSum± - ( over¯ start_ARG i end_ARG± )2 ) / ( count± - 1 ) end_ARG 13:return count±superscriptcountplus-or-minuscount^±count±, i¯±superscript¯plus-or-minus i^±over¯ start_ARG i end_ARG±, s±superscriptplus-or-minuss^±s±▷ ▷ If diagnostics are not required, i¯±superscript¯plus-or-minus i^±over¯ start_ARG i end_ARG± is sufficient. Blocks & Hierarchical Instead of sampling each η independently, we can group nodes into fixed “blocks” η of some size, and patch each block to find its aggregated contribution c⁢(η)c(η)c ( η ); we can then traverse the nodes, starting with high-contribution blocks and proceeding from there. There is a tradeoff in terms of the block size: using large blocks increases the compute required to traverse a high-contribution block, but using small blocks increases the compute required to finish traversing all of the blocks. We refer to the fixed block size setting as Blocks. Another way to handle this tradeoff is to add recursion: the blocks can be grouped into higher-level blocks, and so forth. We call this method Hierarchical. We present results from both methods in our comparison plots, but relegate details to Section A.1.2. Relative to subsampling, these grouping-based methods have the disadvantage that on distributions, their cost scales linearly with size of DD’s support, in addition to scaling with the number of nodes777AtP* also scales linearly in the same way, but with far fewer forward passes per prompt pair.. 4 Experiments 4.1 Setup Nodes When attributing model behavior to components, an important choice is the partition of the model’s computational graph into units of analysis or ‘nodes’ N∋nN nN ∋ n (cf. Section 2.1). We investigate two settings for the choice of N, AttentionNodes and NeuronNodes. For NeuronNodes, each MLP neuron888We use the neuron post-activation for the node; this makes no difference when causally intervening, but for AtP it’s beneficial, because it makes the n↦ℒ⁢(n)maps-toℒn (n)n ↦ L ( n ) function more linear. is a separate node. For AttentionNodes, we consider the query, key, and value vector for each head as distinct nodes, as well as the pre-linear per-head attention output999We include the output node because it provides additional information about what function an attention head is serving, particularly in the case where its queries have negligible patch effects relative to its keys and/or values. This may happen as a result of choosing xclean,xnoisesuperscriptcleansuperscriptnoisex^clean,\,x^noisexclean , xnoise such that the query does not differ across the prompts.. We also refer to these units as ‘sites’. For each site, we consider each copy of that site at different token positions as a separate node. As a result, we can identify each node n∈Nn∈ Nn ∈ N with a pair (T,S)(T,S)( T , S ) from the product TokenPosition × Site. Since our two settings for N are using a different level of granularity and are expected to have different per-node effect magnitudes, we present results on them separately. Models We investigate transformer language models from the Pythia suite (Biderman et al., 2023) of sizes between 410M and 12B parameters. This allows us to demonstrate that our methods are applicable across scale. Our cost-of-verified-recall plots in Figures 1, 7 and 8 refer to Pythia-12B. Results for other model sizes are presented via the relative-cost (cf. Section 4.2) plots in the main body Figure 9 and disaggregated via cost-of-verified recall in Section B.3. Effect Metric ℒLL All reported results use the negative log probability101010Another popular metric is the difference in logits between the clean and noise target. As opposed to the negative logprob, the logit difference is linear in the final logits and thus might favor AtP. A downside of logit difference is that it is sensitive to the noise target, which may not be meaningful if there are multiple plausible completions, such as in IOI. as their loss function ℒLL. We compute ℒLL relative to targets from the clean prompt xcleansuperscriptcleanx^cleanxclean. We briefly explore other metrics in Section B.4. 4.2 Measuring Effectiveness and Efficiency Cost of verified recall As mentioned in the introduction, we’re primarily interested in finding the largest-effect nodes – see Appendix D for the distribution of c⁢(n)c(n)c ( n ) across models and distributions. Once we have obtained node estimates via a given method, it is relatively cheap to directly measure true effects of top nodes one at a time; we refer to this as “verification”. Incorporating this into our methodology, we find that false positives are typically not a big issue; they are simply revealed during verification. In contrast, false negatives are not so easy to remedy without verifying all nodes, which is what we were trying to avoid. We compare methods on the basis of total compute cost (in # of forward passes) to verify the K nodes with biggest true effect magnitude, for varying K. The procedure being measured is to first compute estimates (incurring an estimation cost), and then sweep through nodes in decreasing order of estimated magnitude, measuring their individual effects c⁢(n)c(n)c ( n ) (i.e. verifying them), and incurring a verification cost. Then the total cost is the sum of these two costs. Inverse-rank-weighted geometric mean cost Sometimes we find it useful to summarize the method performance with a scalar; this is useful for comparing methods at a glance across different settings (e.g. model sizes, as in Figure 2), or for selecting hyperparameters (cf. Section B.5). The cost of verified recall of the top K nodes is of interest for K at varying orders of magnitude. In order to avoid the performance metric being dominated by small or large K, we assign similar total weight to different orders of magnitude: we use a weighted average with weight 1/K11/K1 / K for the cost of the top K nodes. Similarly, since the costs themselves may have different orders of magnitude, we average them on a log scale – i.e., we take a geometric mean. This metric is also proportional to the area under the curve in plots like Figure 1. To produce a more understandable result, we always report it relative to (i.e. divided by) the oracle verification cost on the same metric; the diagonal line is the oracle, with relative cost 1. We refer to this as the IRWRGM (inverse-rank-weighted relative geometric mean) cost, or the relative cost. Note that the preference of the individual practitioner may be different such that this metric is no longer accurately measuring the important rank regime. For example, AtP* pays a notable upfront cost relative to AtP or AtP+QKfix, which sets it at a disadvantage when it doesn’t manage to find additional false negatives; but this may or may not be practically significant. To understand the performance in more detail we advise to refer to the cost of verified recall plots, like Figure 1 (or many more in Section B.3). 4.3 Single Prompt Pairs versus Distributions We focus many of our experiments on single prompt pairs. This is primarily because it’s easier to set up and get ground truth data. It’s also a simpler setting in which to investigate the question, and one that’s more universally applicable, since a distribution to generalize to is not always available. (a) NeuronNodes on CITY-P (b) AttentionNodes on IOI-P Figure 7: Costs of finding the most causally-important nodes in Pythia-12B using different methods on clean prompt pairs, with 90% target recall. This highlights that the AtP* false negatives in Figure 1 are a small minority of nodes. Clean single prompt pairs As a starting point we report results on single prompt pairs which we expect to have relatively clean circuitry111111Formally, these represent prompt distributions via the delta distribution p⁢(xclean,xnoise)=δx1clean,x1noise⁢(xclean,xnoise)superscriptcleansuperscriptnoisesubscriptsubscriptsuperscriptclean1subscriptsuperscriptnoise1superscriptcleansuperscriptnoisep(x^clean,x^noise)= _x^clean_1,x^% noise_1(x^clean,x^noise)p ( xclean , xnoise ) = δitalic_xclean start_POSTSUBSCRIPT 1 , xnoise1 end_POSTSUBSCRIPT ( xclean , xnoise ) where x1clean,x1noisesubscriptsuperscriptclean1subscriptsuperscriptnoise1x^clean_1,x^noise_1xclean1 , xnoise1 is the singular prompt pair.. All singular prompt pairs are shown in Table 1. IOI-P is chosen to resemble an instance from the indirect object identification (IOI) task (Wang et al., 2022), a task predominantly involving attention heads. CITY-P is chosen to elicit factual recall which previous research suggests involves early MLPs and a small number of late attention heads (Meng et al., 2023; Geva et al., 2023; Nanda et al., 2023). The country/city combinations were chosen such that Pythia-410M achieved low loss on both xcleansuperscriptcleanx^cleanxclean and xnoisesuperscriptnoisex^noisexnoise and such that all places were represented by a single token. Identifier Clean Prompt Noise Source Prompt CITY-P BOSCity:␣Barcelona Country:␣Spain BOSCity:␣Beijing Country:␣China IOI-P BOSWhen␣Michael␣and␣Jessica ␣went␣to␣the␣bar,␣Michael ␣gave␣a␣drink␣to␣Jessica BOSWhen␣Michael␣and␣Jessica ␣went␣to␣the␣bar,␣Ashley ␣gave␣a␣drink␣to␣Michael RAND-P BOSHer␣biggest␣worry␣was␣the ␣festival␣might␣suffer␣and ␣people␣might␣erroneously␣think BOSalso␣think␣that␣there ␣should␣be␣the␣same␣rules ␣or␣regulations␣when␣it Table 1: Clean and noise source prompts for singular prompt pair distributions. Vertical lines denote tokenization boundaries. All prompts are preceded by the BOS (beginning of sequence) token. The last token is not part of the input. The last token of the clean prompt is used as the target in ℒLL. We show the cost of verified 100% recall for various methods in Figure 1, where we focus on NeuronNodes for CITY-P and AttentionNodes for IOI-P. Exhaustive results for smaller Pythia models are shown in Section B.3. Figure 2 shows the aggregated relative costs for all models on CITY-P and IOI-P. Instead of applying the strict criterion of recalling all important nodes, we can also relax this constraint. In Figure 7, we show the cost of verified 90% recall in the two clean prompt pair settings. Random prompt pair The previous prompt pairs may in fact be the best-case scenarios: the interventions they create will be fairly localized to a specific circuit, and this may make it easy for AtP to approximate the contributions. It may thus be informative to see how the methods generalize to settings where the interventions are less surgical. To do this, we also report results in Figure 8 (top) and Figure 9 on a random prompt pair chosen from a non-copyright-protected section of The Pile (Gao et al., 2020) which we refer to as RAND-P. The prompt pair was chosen such that Pythia-410M still achieved low loss on both prompts. (a) RAND-P MLP neurons. (b) RAND-P Attention nodes. (c) A-AN MLP neurons. (d) IOI Attention nodes. Figure 8: Costs of finding the most causally-important nodes in Pythia-12B using different methods, on a random prompt pair (see Table 1) and on distributions. The shading indicates geometric standard deviation. Cost is measured in forward passes, or forward passes per prompt pair in the distributional case. (a) RAND-P MLP neurons. (b) RAND-P Attention nodes. (c) A-AN MLP neurons. (d) IOI Attention nodes. Figure 9: Costs of methods across models, on random prompt pair and on distributions. The costs are relative to having an oracle (and thus verifying nodes in decreasing order of true contribution size); they’re aggregated using an inverse-rank-weighted geometric mean. This means they correspond to the area above the diagonal for each curve in Figure 8. We find that AtP/AtP* is only somewhat less effective here; this provides tentative evidence that the strong performance of AtP/AtP* isn’t reliant on the clean prompt using a particularly crisp circuit, or on the noise prompt being a precise control. Distributions Causal attribution is often of most interest when evaluated across a distribution, as laid out in Section 2. Of the methods, AtP, AtP*, and Subsampling scale reasonably to distributions; the former 2 because they’re inexpensive so running them |||D|| D | times is not prohibitive, and Subsampling because it intrinsically averages across the distribution and thus becomes proportionally cheaper relative to the verification via activation patching. In addition, having a distribution enables a more performant Iterative method, as described in Section 3.3. We present a comparison of these methods on 2 distributional settings. The first is a reduced version of IOI (Wang et al., 2022) on 6 names, resulting in 6×5×4=1206541206× 5× 4=1206 × 5 × 4 = 120 prompt pairs, where we evaluate AttentionNodes. The other distribution prompts the model to output an indefinite article ‘ a’ or ‘ an’, where we evaluate NeuronNodes. See Section B.1 for details on constructing these distributions. Results are shown in Figure 8 for Pythia 12B, and in Figure 9 across models. The results show that AtP continues to perform well, especially with the QK fix; in addition, the cancellation failure mode tends to be sensitive to the particular input prompt pair, and as a result, averaging across a distribution diminishes the benefit of GradDrops. An implication of Subsampling scaling well to this setting is that diagnostics may give reasonable confidence in not missing false negatives with much less overhead than in the single-prompt-pair case; this is illustrated in Figure 6. 5 Discussion 5.1 Limitations Prompt pair distributions We only considered a small set of prompt pair distributions, which often were limited to a single prompt pair, since evaluating the ground truth can be quite costly. While we aimed to evaluate on distributions that are reasonably representative, our results may not generalize to other distributions. Choice of Nodes N In the NeuronNodes setting, we took MLP neurons as our fundamental unit of analysis. However, there is mounting evidence (Bricken et al., 2023) that the decomposition of signals into neuron contributions does not correspond directly to a semantically meaningful decomposition. Instead, achieving such a decomposition seems to require finding the right set of directions in neuron activation space (Bricken et al., 2023; Gurnee et al., 2023) – which we viewed as being out of scope for this paper. In Section 5.2 we further discuss the applicability of AtP to sparse autoencoders, a method of finding these decompositions. More generally, we only considered relatively fine-grained nodes, because this is a case where very exhaustive verification is prohibitively expensive, justifying the need for an approximate, fast method. Nanda (2022) speculate that AtP may perform worse on coarser components like full layers or entire residual streams, as a larger change may have more of a non-linear effect. There may still be benefit in speeding up such an analysis, particularly if the context length is long – our alternative methods may have something to offer here, though we leave investigation of this to future work. It is popular in the literature to do Activation Patching with these larger components, with short contexts – this doesn’t pose a performance issue, and so our work would not provide any benefit here. Caveats of c⁢(n)c(n)c ( n ) as importance measure In this work we took the ground truth of activation patching, as defined in Equation 1, as our evaluation target. As discussed by McGrath et al. (2023), Equation 1 often significantly disagrees with a different evaluation target, the “direct effect”, by putting lower weight on some contributions when later components would shift their behaviour to compensate for the earlier patched component. In the worst case this could be seen as producing additional false negatives not accounted for by our metrics. To some degree this is likely to be mitigated by the GradDrop formula in Eq. 11, which will include a term dropping out the effect of that downstream shift. However, it is also questionable whether we need to concern ourselves with finding high-direct-effect nodes. For example, direct effect is easy to efficiently compute for all nodes, as explored by nostalgebraist (2020) – so there is no need for fast approximations like AtP if direct effect is the quantity of interest. This ease of computation is no free lunch, though, because direct effect is also more limited as a tool for finding causally important nodes: it would not be able to locate any nodes that contribute only instrumentally to the circuit rather than producing its output. For example, there is no direct effect from nodes at non-final token positions. We discuss the direct effect further in Section 3.1.2 and Section A.2.2. Another nuance of our ground–truth definition occurs in the distributional setting. Some nodes may have a real and significant effect, but only on a single clean prompt (e.g. they only respond to a particular name in IOI121212We did observe this particular behavior in a few instances. or object in A-AN). Since the effect is averaged over the distribution, the ground truth will not assign these nodes large causal importance. Depending on the goal of the practitioner this may or may not be desirable. Effect size versus rank estimation When evaluating the performance of various estimators, we focused on evaluating the relative rank of estimates, since our main goal was to identify important components (with effect size only instrumentally useful to this end), and we assumed a further verification step of the nodes with highest estimated effects one at a time, in contexts where knowing effect size is important. Thus, we do not present evidence about how closely the estimated effect magnitudes from AtP or AtP* match the ground truth. Similarly, we did not assess the prevalence of false positives in our analysis, because they can be filtered out via the verification process. Finally, we did not compare to past manual interpretability work to check whether our methods find the same nodes to be causally important as discovered by human researchers, as done in prior work (Conmy et al., 2023; Syed et al., 2023). Other LLMs While we think it likely that our results on the Pythia model family (Biderman et al., 2023) will transfer to other LLM families, we cannot rule out qualitatively different behavior without further evidence, especially on SotA–scale models or models that significantly deviate from the standard decoder-only transformer architecture. 5.2 Extensions/Variants Edge Patching While we focus on computing the effects of individual nodes, edge activation patching can give more fine-grained information about which paths in the computational graph matter. However, it suffers from an even larger blowup in number of forward passes if done naively. Fortunately, AtP is easy to generalize to estimating the effects of edges between nodes (Nanda, 2022; Syed et al., 2023), while AtP* may provide further improvement. We discuss edge-AtP, and how to efficiently carry over the insights from AtP*, in Section C.2. Coarser nodes N We focused on fine-grained attribution, rather than full layers or sliding windows (Meng et al., 2023; Geva et al., 2023). In the latter case there’s less computational blowup to resolve, but for long contexts there may still be benefit in considering speedups like ours; on the other hand, they may be less linear, thus favouring other methods over AtP*. We leave investigation of this to future work. Layer normalization Nanda (2022) observed that AtP’s approximation to layer normalization may be a worse approximation when it comes to patching larger/coarser nodes: on average the patched and clean activations are likely to have similar norm, but may not have high cosine-similarity. They recommend treating the denominator in layer normalization as fixed, e.g. using a stop-gradient operator in the implementation. In Section C.1 we explore the effect of this, and illustrate the behaviour of this alternative form of AtP. It seems likely that this variant would indeed produce better results particularly when patching residual-stream nodes – but we leave empirical investigation of this to future work. Denoising Denoising (Meng et al., 2023; Lieberum et al., 2023) is a different use case for patching, which may produce moderately different results: the difference is that each forward pass is run on xnoisesuperscriptnoisex^noisexnoise with the activation to patch taken from xcleansuperscriptcleanx^cleanxclean — colloquially, this tests whether the patched activation is sufficient to recover model performance on xcleansuperscriptcleanx^cleanxclean, rather than necessary. We provide some preliminary evidence to the effect of this choice in Section B.4 but leave a more thorough investigation to future work. Other forms of ablation Further, in some settings it may be of interest to do mean-ablation, or even zero-ablation, and our tweaks remain applicable there; the random-prompt-pair result suggests AtP* isn’t overly sensitive to the noise distribution, so we speculate the results are likely to carry over. 5.3 Applications Automated Circuit Finding A natural application of the methods we discussed in this work is the automatic identification and localization of sparse subgraphs or ‘circuits’ (Cammarata et al., 2020). A variant of this was already discussed in concurrent work by Syed et al. (2023) who combined edge attribution patching with the ACDC algorithm (Conmy et al., 2023). As we mentioned in the edge patching discussion, AtP* can be generalized to edge attribution patching, which may bring additional benefit for automated circuit discovery. Another approach is to learn a (probabilistic) mask over nodes, similar to Louizos et al. (2018); Cao et al. (2021), where the probability scales with the currently estimated node contribution c⁢(n)c(n)c ( n ). For that approach, a fast method to estimate all node effects given the current mask probabilities could prove vital. Sparse Autoencoders Recently there has been increased interest by the community in using sparse autoencoders (SAEs) to construct disentangled sparse representations with potentially more semantic coherence than transformer-native units such as neurons (Cunningham et al., 2023; Bricken et al., 2023). SAEs usually have a lot more nodes than the corresponding transformer block they are applied to. This could pose a larger problem in terms of the activation patching effects, making the speedup of AtP* more valuable. However, due to the sparseness of the SAE, on a given forward pass the effect of most features will be zero. For example, some successful SAEs by Bricken et al. (2023) have 10-20 active features for 500 neurons for a given token position, which reduces the number of nodes by 20-50x relative to the MLP setting, increasing the scale at which existing iterative methods remain practical. It is still an open research question, however, what degree of sparsity is feasible with tolerable reconstruction error for practically relevant or SOTA–scale models, where the methods discussed in this work may become more important again. Steering LLMs AtP* could be used to discover single nodes in the model that can be leveraged for targeted inference time interventions to control the model’s behavior. In contrast to previous work (Li et al., 2023; Turner et al., 2023; Zou et al., 2023) it might provide more localized interventions with less impact on the rest of the model’s computation. One potential exciting direction would be to use AtP* (or other gradient-based approximations) to see which sparse autoencoder features, if activated, would have a significant effect. 5.4 Recommendation Our results suggest that if a practitioner is trying to do fast causal attribution, there are 2 main factors to consider: (i) the desired granularity of localization, and (i) the confidence vs compute tradeoff. Regarding (i), the desired granularity, smaller components (e.g. MLP neurons or attention heads) are more numerous but more linear, likely yielding better results from gradient-based methods like AtP. We are less sure AtP will be a good approximation if patching layers or sliding windows of layers, and in this case practitioners may want to do normal patching. If the number of forward passes required remains prohibitive (e.g. a long context times many layers, when doing per token × layer patching), our other baselines may be useful. For a single prompt pair we particularly recommend trying Blocks, as it’s easy to make sense of; for a distribution we recommend Subsampling because it scales better to many prompt pairs. Regarding (i), the confidence vs compute tradeoff, depending on the application, it may be desirable to run AtP as an activation patching prefilter followed by running the diagnostic to increase confidence. On the other hand, if false negatives aren’t a big concern then it may be preferable to skip the diagnostic – and if false positives aren’t either, then in certain cases practitioners may want to skip activation patching verification entirely. In addition, if the prompt pair distribution does not adequately highlight the specific circuit/behaviour of interest, this may also limit what can be learned from any localization methods. If AtP is appropriate, our results suggest the best variant to use is probably AtP* for single prompt pairs, AtP+QKFix for AttentionNodes on distributions, and AtP for NeuronNodes (or other sites that aren’t immediately before a nonlinearity) on distributions. Of course, these recommendations are best-substantiated in settings similar to those we studied: focused prompt pairs / distribution, attention node or neuron sites, nodewise attribution, measuring cross-entropy loss on the clean-prompt next token. If departing from these assumptions we recommend looking before you leap. 6 Related work Localization and Mediation Analysis This work is concerned with identifying the effect of all (important) nodes in a causal graph (Pearl, 2000), in the specific case where the graph represents a language model’s computation. A key method for finding important intermediate nodes in a causal graph is intervening on those nodes and observing the effect, which was first discussed under the name of causal mediation analysis by Robins and Greenland (1992); Pearl (2001). Activation Patching In recent years there has been increasing success at applying the ideas of causal mediation analysis to identify causally important nodes in deep neural networks, in particular via the method of activation patching, where the output of a model component is intervened on. This technique has been widely used by the community and successfully applied in a range of contexts (Olsson et al., 2022; Vig et al., 2020; Soulos et al., 2020; Meng et al., 2023; Wang et al., 2022; Hase et al., 2023; Lieberum et al., 2023; Conmy et al., 2023; Hanna et al., 2023; Geva et al., 2023; Huang et al., 2023; Tigges et al., 2023; Merullo et al., 2023; McDougall et al., 2023; Goldowsky-Dill et al., 2023; Stolfo et al., 2023; Feng and Steinhardt, 2023; Hendel et al., 2023; Todd et al., 2023; Cunningham et al., 2023; Finlayson et al., 2021; Nanda et al., 2023). Chan et al. (2022) introduce causal scrubbing, a generalized algorithm to verify a hypothesis about the internal mechanism underlying a model’s behavior, and detail their motivation behind performing noising and resample ablation rather than denoising or using mean or zero ablation – they interpret the hypothesis as implying the computation is invariant to some large set of perturbations, so their starting-point is the clean unperturbed forward pass.131313Our motivation for focusing on noising rather than denoising was a closely related one – we were motivated by automated circuit discovery, where gradually noising more and more of the model is the basic methodology for both of the approaches discussed in Section 5.3. Another line of research concerning formalizing causal abstractions focuses on finding and verifying high-level causal abstractions of low-level variables (Geiger et al., 2020, 2021, 2022, 2023). See Jenner et al. (2022) for more details on how these different frameworks agree and differ. In contrast to those works, we are chiefly concerned with identifying the important low-level variables in the computational graph and are not investigating their semantics or potential groupings of lower-level into higher-level variables. In addition to causal mediation analysis, intervening on node activations in the model forward pass has also been studied as a way of steering models towards desirable behavior (Rimsky et al., 2023; Zou et al., 2023; Turner et al., 2023; Jorgensen et al., 2023; Li et al., 2023; Belrose et al., 2023). Attribution Patching / Gradient-based Masking While we use the resample–ablation variant of AtP as formulated in Nanda (2022), similar formulations have been used in the past to successfully prune deep neural networks (Figurnov et al., 2016; Molchanov et al., 2017; Michel et al., 2019), or even identify causally important nodes for interpretability (Cao et al., 2021). Concurrent work by Syed et al. (2023) also demonstrates AtP can help with automatically finding causally important circuits in a way that agrees with previous manual circuit identification work. In contrast to Syed et al. (2023), we provide further analysis of AtP’s failure modes, give improvements in the form of AtP*^*start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT, and evaluate both methods as well as several baselines on a suite of larger models against a ground truth that is independent of human researchers’ judgement. 7 Conclusion In this paper, we have explored the use of attribution patching for node patch effect evaluation. We have compared attribution patching with alternatives and augmentations, characterized its failure modes, and presented reliability diagnostics. We have also discussed the implications of our contributions for other settings in which patching can be of interest, such as circuit discovery, edge localization, coarse-grained localization, and causal abstraction. Our results show that AtP* can be a more reliable and scalable approach to node patch effect evaluation than alternatives. However, it is important to be aware of the failure modes of attribution patching, such as cancellation and saturation. We explored these in some detail, and provided mitigations, as well as recommendations for diagnostics to ensure that the results are reliable. We believe that our work makes an important contribution to the field of mechanistic interpretability and will help to advance the development of more reliable and scalable methods for understanding the behavior of deep neural networks. 8 Author Contributions János Kramár was research lead, and Tom Lieberum was also a core contributor – both were highly involved in most aspects of the project. Rohin Shah and Neel Nanda served as advisors and gave feedback and guidance throughout. References Belrose et al. (2023) N. Belrose, D. Schneider-Joseph, S. Ravfogel, R. Cotterell, E. Raff, and S. Biderman. Leace: Perfect linear concept erasure in closed form. arXiv preprint arXiv:2306.03819, 2023. Biderman et al. (2023) S. Biderman, H. Schoelkopf, Q. G. Anthony, H. Bradley, K. O’Brien, E. Hallahan, M. A. Khan, S. Purohit, U. S. Prashanth, E. Raff, A. Skowron, L. Sutawika, and O. van der Wal. Pythia: A suite for analyzing large language models across training and scaling. In A. Krause, E. Brunskill, K. Cho, B. Engelhardt, S. Sabato, and J. Scarlett, editors, International Conference on Machine Learning, ICML 2023, 23-29 July 2023, Honolulu, Hawaii, USA, volume 202 of Proceedings of Machine Learning Research, pages 2397–2430. PMLR, 2023. URL https://proceedings.mlr.press/v202/biderman23a.html. Bricken et al. (2023) T. Bricken, A. Templeton, J. Batson, B. Chen, A. Jermyn, T. Conerly, N. Turner, C. Anil, C. Denison, A. Askell, R. Lasenby, Y. Wu, S. Kravec, N. Schiefer, T. Maxwell, N. Joseph, Z. Hatfield-Dodds, A. Tamkin, K. Nguyen, B. McLean, J. E. Burke, T. Hume, S. Carter, T. Henighan, and C. Olah. Towards monosemanticity: Decomposing language models with dictionary learning. Transformer Circuits Thread, 2023. https://transformer-circuits.pub/2023/monosemantic-features/index.html. Cammarata et al. (2020) N. Cammarata, S. Carter, G. Goh, C. Olah, M. Petrov, L. Schubert, C. Voss, B. Egan, and S. K. Lim. Thread: Circuits. Distill, 2020. 10.23915/distill.00024. https://distill.pub/2020/circuits. Cao et al. (2021) N. D. Cao, L. Schmid, D. Hupkes, and I. Titov. Sparse interventions in language models with differentiable masking, 2021. Chan et al. (2022) L. Chan, A. Garriga-Alonso, N. Goldwosky-Dill, R. Greenblatt, J. Nitishinskaya, A. Radhakrishnan, B. Shlegeris, and N. Thomas. Causal scrubbing, a method for rigorously testing interpretability hypotheses. AI Alignment Forum, 2022. https://w.alignmentforum.org/posts/JvZhhzycHu2Yd57RN/causal-scrubbing-a-method-for-rigorously-testing. Conmy et al. (2023) A. Conmy, A. N. Mavor-Parker, A. Lynch, S. Heimersheim, and A. Garriga-Alonso. Towards automated circuit discovery for mechanistic interpretability, 2023. Cunningham et al. (2023) H. Cunningham, A. Ewart, L. Riggs, R. Huben, and L. Sharkey. Sparse autoencoders find highly interpretable features in language models, 2023. Feng and Steinhardt (2023) J. Feng and J. Steinhardt. How do language models bind entities in context?, 2023. Figurnov et al. (2016) M. Figurnov, A. Ibraimova, D. P. Vetrov, and P. Kohli. Perforatedcnns: Acceleration through elimination of redundant convolutions. In D. Lee, M. Sugiyama, U. Luxburg, I. Guyon, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 29. Curran Associates, Inc., 2016. URL https://proceedings.neurips.c/paper_files/paper/2016/file/f0e52b27a7a5d6a1a87373dffa53dbe5-Paper.pdf. Finlayson et al. (2021) M. Finlayson, A. Mueller, S. Gehrmann, S. Shieber, T. Linzen, and Y. Belinkov. Causal analysis of syntactic agreement mechanisms in neural language models. In C. Zong, F. Xia, W. Li, and R. Navigli, editors, Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers), pages 1828–1843, Online, Aug. 2021. Association for Computational Linguistics. 10.18653/v1/2021.acl-long.144. URL https://aclanthology.org/2021.acl-long.144. Gao et al. (2020) L. Gao, S. Biderman, S. Black, L. Golding, T. Hoppe, C. Foster, J. Phang, H. He, A. Thite, N. Nabeshima, S. Presser, and C. Leahy. The Pile: An 800gb dataset of diverse text for language modeling. arXiv preprint arXiv:2101.00027, 2020. Geiger et al. (2020) A. Geiger, K. Richardson, and C. Potts. Neural natural language inference models partially embed theories of lexical entailment and negation, 2020. Geiger et al. (2021) A. Geiger, H. Lu, T. Icard, and C. Potts. Causal abstractions of neural networks, 2021. Geiger et al. (2022) A. Geiger, Z. Wu, H. Lu, J. Rozner, E. Kreiss, T. Icard, N. D. Goodman, and C. Potts. Inducing causal structure for interpretable neural networks, 2022. Geiger et al. (2023) A. Geiger, C. Potts, and T. Icard. Causal abstraction for faithful model interpretation, 2023. Geva et al. (2023) M. Geva, J. Bastings, K. Filippova, and A. Globerson. Dissecting recall of factual associations in auto-regressive language models, 2023. Goldowsky-Dill et al. (2023) N. Goldowsky-Dill, C. MacLeod, L. Sato, and A. Arora. Localizing model behavior with path patching, 2023. Gurnee et al. (2023) W. Gurnee, N. Nanda, M. Pauly, K. Harvey, D. Troitskii, and D. Bertsimas. Finding neurons in a haystack: Case studies with sparse probing, 2023. Hanna et al. (2023) M. Hanna, O. Liu, and A. Variengien. How does gpt-2 compute greater-than?: Interpreting mathematical abilities in a pre-trained language model, 2023. Hase et al. (2023) P. Hase, M. Bansal, B. Kim, and A. Ghandeharioun. Does localization inform editing? surprising differences in causality-based localization vs. knowledge editing in language models, 2023. Hendel et al. (2023) R. Hendel, M. Geva, and A. Globerson. In-context learning creates task vectors, 2023. Hoffmann et al. (2022) J. Hoffmann, S. Borgeaud, A. Mensch, E. Buchatskaya, T. Cai, E. Rutherford, D. de Las Casas, L. A. Hendricks, J. Welbl, A. Clark, T. Hennigan, E. Noland, K. Millican, G. van den Driessche, B. Damoc, A. Guy, S. Osindero, K. Simonyan, E. Elsen, O. Vinyals, J. Rae, and L. Sifre. An empirical analysis of compute-optimal large language model training. In S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh, editors, Advances in Neural Information Processing Systems, volume 35, pages 30016–30030. Curran Associates, Inc., 2022. URL https://proceedings.neurips.c/paper_files/paper/2022/file/c1e2faff6f588870935f114ebe04a3e5-Paper-Conference.pdf. Huang et al. (2023) J. Huang, A. Geiger, K. D’Oosterlinck, Z. Wu, and C. Potts. Rigorously assessing natural language explanations of neurons, 2023. Jenner et al. (2022) E. Jenner, A. Garriga-Alonso, and E. Zverev. A comparison of causal scrubbing, causal abstractions, and related methods. AI Alignment Forum, 2022. https://w.alignmentforum.org/posts/uLMWMeBG3ruoBRhMW/a-comparison-of-causal-scrubbing-causal-abstractions-and. Jorgensen et al. (2023) O. Jorgensen, D. Cope, N. Schoots, and M. Shanahan. Improving activation steering in language models with mean-centring, 2023. Li et al. (2023) K. Li, O. Patel, F. Viégas, H. Pfister, and M. Wattenberg. Inference-time intervention: Eliciting truthful answers from a language model, 2023. Lieberum et al. (2023) T. Lieberum, M. Rahtz, J. Kramár, N. Nanda, G. Irving, R. Shah, and V. Mikulik. Does circuit analysis interpretability scale? evidence from multiple choice capabilities in chinchilla, 2023. Louizos et al. (2018) C. Louizos, M. Welling, and D. P. Kingma. Learning sparse neural networks through l0subscript0l_0l0 regularization, 2018. McDougall et al. (2023) C. McDougall, A. Conmy, C. Rushing, T. McGrath, and N. Nanda. Copy suppression: Comprehensively understanding an attention head, 2023. McGrath et al. (2023) T. McGrath, M. Rahtz, J. Kramár, V. Mikulik, and S. Legg. The hydra effect: Emergent self-repair in language model computations, 2023. Meng et al. (2023) K. Meng, D. Bau, A. Andonian, and Y. Belinkov. Locating and editing factual associations in gpt, 2023. Merullo et al. (2023) J. Merullo, C. Eickhoff, and E. Pavlick. Circuit component reuse across tasks in transformer language models, 2023. Michel et al. (2019) P. Michel, O. Levy, and G. Neubig. Are sixteen heads really better than one? In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019. URL https://proceedings.neurips.c/paper_files/paper/2019/file/2c601ad9d2f9bc8b282670cdd54f69f-Paper.pdf. Molchanov et al. (2017) P. Molchanov, S. Tyree, T. Karras, T. Aila, and J. Kautz. Pruning convolutional neural networks for resource efficient inference. In International Conference on Learning Representations, 2017. URL https://openreview.net/forum?id=SJGCiw5gl. Nanda (2022) N. Nanda. Attribution patching: Activation patching at industrial scale. 2022. URL https://w.neelnanda.io/mechanistic-interpretability/attribution-patching. Nanda et al. (2023) N. Nanda, S. Rajamanoharan, J. Kramár, and R. Shah. Fact finding: Attempting to reverse-engineer factual recall on the neuron level, Dec 2023. URL https://w.alignmentforum.org/posts/iGuwZTHWb6DFY3sKB/fact-finding-attempting-to-reverse-engineer-factual-recall. nostalgebraist (2020) nostalgebraist. interpreting gpt: the logit lens. 2020. URL https://w.alignmentforum.org/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens. Olsson et al. (2022) C. Olsson, N. Elhage, N. Nanda, N. Joseph, N. DasSarma, T. Henighan, B. Mann, A. Askell, Y. Bai, A. Chen, T. Conerly, D. Drain, D. Ganguli, Z. Hatfield-Dodds, D. Hernandez, S. Johnston, A. Jones, J. Kernion, L. Lovitt, K. Ndousse, D. Amodei, T. Brown, J. Clark, J. Kaplan, S. McCandlish, and C. Olah. In-context learning and induction heads. Transformer Circuits Thread, 2022. https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html. Pearl (2000) J. Pearl. Causality: Models, Reasoning and Inference. Cambridge University Press, 2000. Pearl (2001) J. Pearl. Direct and indirect effects, 2001. Radford et al. (2018) A. Radford, K. Narasimhan, T. Salimans, and I. Sutskever. Improving language understanding by generative pre-training, 2018. Rimsky et al. (2023) N. Rimsky, N. Gabrieli, J. Schulz, M. Tong, E. Hubinger, and A. M. Turner. Steering llama 2 via contrastive activation addition, 2023. Robins and Greenland (1992) J. M. Robins and S. Greenland. Identifiability and exchangeability for direct and indirect effects. Epidemiology, 3:143–155, 1992. URL https://api.semanticscholar.org/CorpusID:10757981. Soulos et al. (2020) P. Soulos, R. T. McCoy, T. Linzen, and P. Smolensky. Discovering the compositional structure of vector representations with role learning networks. In A. Alishahi, Y. Belinkov, G. Chrupała, D. Hupkes, Y. Pinter, and H. Sajjad, editors, Proceedings of the Third BlackboxNLP Workshop on Analyzing and Interpreting Neural Networks for NLP, pages 238–254, Online, Nov. 2020. Association for Computational Linguistics. 10.18653/v1/2020.blackboxnlp-1.23. URL https://aclanthology.org/2020.blackboxnlp-1.23. Stolfo et al. (2023) A. Stolfo, Y. Belinkov, and M. Sachan. A mechanistic interpretation of arithmetic reasoning in language models using causal mediation analysis, 2023. Syed et al. (2023) A. Syed, C. Rager, and A. Conmy. Attribution patching outperforms automated circuit discovery, 2023. Tigges et al. (2023) C. Tigges, O. J. Hollinsworth, A. Geiger, and N. Nanda. Linear representations of sentiment in large language models, 2023. Todd et al. (2023) E. Todd, M. L. Li, A. S. Sharma, A. Mueller, B. C. Wallace, and D. Bau. Function vectors in large language models, 2023. Turner et al. (2023) A. M. Turner, L. Thiergart, D. Udell, G. Leech, U. Mini, and M. MacDiarmid. Activation addition: Steering language models without optimization, 2023. Vaswani et al. (2017) A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin. Attention is all you need, 2017. Veit et al. (2016) A. Veit, M. J. Wilber, and S. Belongie. Residual networks behave like ensembles of relatively shallow networks. In D. Lee, M. Sugiyama, U. Luxburg, I. Guyon, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 29. Curran Associates, Inc., 2016. URL https://proceedings.neurips.c/paper_files/paper/2016/file/37bc2f75bf1bcfe8450a1a41c200364c-Paper.pdf. Vig et al. (2020) J. Vig, S. Gehrmann, Y. Belinkov, S. Qian, D. Nevo, Y. Singer, and S. Shieber. Investigating gender bias in language models using causal mediation analysis. In H. Larochelle, M. Ranzato, R. Hadsell, M. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 12388–12401. Curran Associates, Inc., 2020. URL https://proceedings.neurips.c/paper_files/paper/2020/file/92650b2e92217715fe312e6fa7b90d82-Paper.pdf. Wang et al. (2022) K. Wang, A. Variengien, A. Conmy, B. Shlegeris, and J. Steinhardt. Interpretability in the wild: a circuit for indirect object identification in gpt-2 small, 2022. Welch (1947) B. L. Welch. The generalization of ‘Student’s’ problem when several different population variances are involved. Biometrika, 34(1-2):28–35, 01 1947. ISSN 0006-3444. 10.1093/biomet/34.1-2.28. URL https://doi.org/10.1093/biomet/34.1-2.28. Zou et al. (2023) A. Zou, L. Phan, S. Chen, J. Campbell, P. Guo, R. Ren, A. Pan, X. Yin, M. Mazeika, A.-K. Dombrowski, S. Goel, N. Li, M. J. Byun, Z. Wang, A. Mallen, S. Basart, S. Koyejo, D. Song, M. Fredrikson, J. Z. Kolter, and D. Hendrycks. Representation engineering: A top-down approach to ai transparency, 2023. Appendix A Method details A.1 Baselines A.1.1 Properties of Subsampling Here we prove that the subsampling estimator ℐ^S⁢(n)subscript^ℐS I_S(n)over start_ARG I end_ARGSS ( n ) from Section 3.3 is unbiased in the case of no interaction effects. Furthermore, assuming a simple interaction model, we show the bias of ℐ^S⁢(n)subscript^ℐS I_S(n)over start_ARG I end_ARGSS ( n ) is p times the total interaction effect of n with other nodes. We assume a pairwise interaction model. That is, given a set of nodes η, we have ℐ⁢(η;x)ℐ (η;x)I ( η ; x ) =∑n∈ηℐ⁢(n;x)+∑n,n′∈ηn≠nσn,n′⁢(x)absentsubscriptℐsubscriptsuperscript′subscriptsuperscript′ = _n∈ηI(n;x)+ _ subarraycn,n^% ∈η\\ n≠ n subarray _n,n (x)= ∑n ∈ η I ( n ; x ) + ∑start_ARG start_ROW start_CELL n , n′ ∈ η end_CELL end_ROW start_ROW start_CELL n ≠ n end_CELL end_ROW end_ARG σitalic_n , n′ ( x ) (16) with fixed constants σn,n′⁢(x)∈ℝsubscriptsuperscript′ℝ _n,n (x) σitalic_n , n′ ( x ) ∈ blackboard_R for each prompt pair x∈support⁡()supportx (D)x ∈ support ( D ). Let σn,n′=x∼⁢[σn,n′⁢(x)]subscriptsuperscript′subscriptsimilar-todelimited-[]subscriptsuperscript′ _n,n =E_x [ _n,n % (x) ]σitalic_n , n′ = blackboard_Ex ∼ D [ σitalic_n , n′ ( x ) ]. Let p be the probability of including each node in a given η and let M be the number of node masks sampled from Bernoulli|N|⁡(p)superscriptBernoulliBernoulli^|N|(p)Bernoulli| N | ( p ) and prompt pairs x sampled from DD. Then, ⁢[ℐ^S⁢(n)]delimited-[]subscript^ℐS [ I_S(n) ]blackboard_E [ over start_ARG I end_ARGSS ( n ) ] =⁢[1|η+⁢(n)|⁢∑k=1|η+⁢(n)|ℐ⁢(ηk+⁢(n);xk+)−1|η−⁢(n)|⁢∑k=1|η−⁢(n)|ℐ⁢(ηk−⁢(n);xk−)]absentdelimited-[]1superscriptsuperscriptsubscript1superscriptℐsubscriptsuperscriptsuperscriptsubscript1superscriptsuperscriptsubscript1superscriptℐsubscriptsuperscriptsuperscriptsubscript =E [ 1|η^+(n)| _k=1^|η^+(n)|% I(η^+_k(n);x_k^+)- 1|η^-(n)| _k=1^|% η^-(n)|I(η^-_k(n);x_k^-) ]= blackboard_E [ divide start_ARG 1 end_ARG start_ARG | η+ ( n ) | end_ARG ∑k = 1| η start_POSTSUPERSCRIPT + ( n ) | end_POSTSUPERSCRIPT I ( η+k ( n ) ; xitalic_k+ ) - divide start_ARG 1 end_ARG start_ARG | η- ( n ) | end_ARG ∑k = 1| η start_POSTSUPERSCRIPT - ( n ) | end_POSTSUPERSCRIPT I ( η-k ( n ) ; xitalic_k- ) ] (17a) =[[1|η+⁢(n)|∑k=1|η+⁢(n)|ℐ(ηk+(n);xk+)−1|η−⁢(n)|∑k=1|η−⁢(n)|ℐ(ηk−(n);xk−)||η+(n)|]] =E [E [ 1|η^+(n)| _k=1% ^|η^+(n)|I(η^+_k(n);x_k^+)- 1|η^-(n)|% _k=1^|η^-(n)|I(η^-_k(n);x_k^-) ||η% ^+(n)| ] ]= blackboard_E [ blackboard_E [ divide start_ARG 1 end_ARG start_ARG | η+ ( n ) | end_ARG ∑k = 1| η start_POSTSUPERSCRIPT + ( n ) | end_POSTSUPERSCRIPT I ( η+k ( n ) ; xitalic_k+ ) - divide start_ARG 1 end_ARG start_ARG | η- ( n ) | end_ARG ∑k = 1| η start_POSTSUPERSCRIPT - ( n ) | end_POSTSUPERSCRIPT I ( η-k ( n ) ; xitalic_k- ) | | η+ ( n ) | ] ] (17b) =[[|η+⁢(n)||η+⁢(n)|[ℐ(η1;x1)|n∈η1]−|η−⁢(n)||η−⁢(n)|[ℐ(η1;x1)|n∉η1]||η+(n)|]] =E [E [ |η^+(n)||η^+(n% )|E [I( _1;x_1) |n∈ _1 ]-% |η^-(n)||η^-(n)|E [I( _1;x_1)% |n ∈ _1 ] ||η^+(n)| ] ]= blackboard_E [ blackboard_E [ divide start_ARG | η+ ( n ) | end_ARG start_ARG | η+ ( n ) | end_ARG blackboard_E [ I ( η1 ; x1 ) | n ∈ η1 ] - divide start_ARG | η- ( n ) | end_ARG start_ARG | η- ( n ) | end_ARG blackboard_E [ I ( η1 ; x1 ) | n ∉ η1 ] | | η+ ( n ) | ] ] (17c) =[ℐ(η1;x1)|n∈η1]−[ℐ(η1;x1)|n∉η1] =E [I( _1;x_1) |n∈ _1% ]-E [I( _1;x_1) |n ∈ _1 ]= blackboard_E [ I ( η1 ; x1 ) | n ∈ η1 ] - blackboard_E [ I ( η1 ; x1 ) | n ∉ η1 ] (17d) =c(n)+[∑n′≠n[n′∈η1](c(n′)+σn⁢n′+12∑n′∉n′,n[n′∈η1]σn′⁢n′|n∈η1)] =c(n)+E [ _n ≠ n1[n % ∈ _1] (c(n )+ _n + 12 _n^% ∈\n ,n\1[n ∈ _1] _% n n |n∈ _1 ) ]= c ( n ) + blackboard_E [ ∑n′ ≠ n blackboard_1 [ n′ ∈ η1 ] ( c ( n′ ) + σitalic_n n′ + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑n′ ′ ∉ n′ , n blackboard_1 [ n′ ∈ η1 ] σitalic_n′ ′ | n ∈ η1 ) ] (17e) −[∑n′≠n[n′∈η1](c(n′)+12∑n′∉n′,n[n′∈η1]σn′⁢n′)|n∉η1] -E [ _n ≠ n1[n % ∈ _1] (c(n )+ 12 _n ∈\n^% ,n\1[n ∈ _1] _n n % ) |n ∈ _1 ]- blackboard_E [ ∑n′ ≠ n blackboard_1 [ n′ ∈ η1 ] ( c ( n′ ) + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑n′ ′ ∉ n′ , n blackboard_1 [ n′ ∈ η1 ] σitalic_n′ ′ ) | n ∉ η1 ] (17f) =c⁢(n)+p⁢∑n′≠nσn⁢n′absentsubscriptsuperscript′subscriptsuperscript′ =c(n)+p _n ≠ n _n = c ( n ) + p ∑n′ ≠ n σitalic_n n′ (17g) In Equation 17g, we observe that if the interaction terms σn⁢n′subscriptsuperscript′ _n σitalic_n n′ are all zero, the estimator is unbiased. Otherwise, the bias scales both with the sum of interaction effects and with p, as expected. A.1.2 Pseudocode for Blocks and Hierarchical baselines In Algorithm 2 we detail the Blocks baseline algorithm. As explained in Section 3.3, it comes with a tradeoff in its “block size” hyperparameter B: a small block size requires a lot of time to evaluate all the blocks, while a large block size means many irrelevant nodes to evaluate in each high-contribution block. Algorithm 2 Blocks algorithm for causal attribution. 1:block size B, compute budget M, nodes N=nisubscriptN=\n_i\N = nitalic_i , prompts xclean,xnoisesuperscriptcleansuperscriptnoisex^clean,\,x^noisexclean , xnoise, intervention function ℐ~:η↦ℐ⁢(η;xclean,xnoise):~ℐmaps-toℐsuperscriptcleansuperscriptnoise I:η (η;x^clean,x^% noise)over~ start_ARG I end_ARG : η ↦ I ( η ; xclean , xnoise ) 2:numBlocks←⌈|N|/B⌉←numBlocksnumBlocks← |N|/B ← ⌈ | N | / B ⌉ 3:π←shuffle⁡(⌊numBlocks⋅i⁢B/|N|⌋∣i∈0,…,|N|−1)←shuffleconditional⋅numBlocks0…1π ( \ % · iB/|N| i∈\0,…,|N|-1\ \ )π ← shuffle ( ⌊ numBlocks ⋅ i B / | N | ⌋ ∣ i ∈ 0 , … , | N | - 1 )▷ ▷ Assign each node to a block. 4:for i←0⁢ to numBlocks−1←0 to numBlocks1i← 0 to numBlocks-1i ← 0 to numBlocks - 1 do 5: blockContribution⁢[i]←|ℐ~⁢(π−1⁢(i))|←blockContributiondelimited-[]~ℐsuperscript1 blockContribution[i]←| I(π^-1(\i\))|blockContribution [ i ] ← | over~ start_ARG I end_ARG ( π- 1 ( i ) ) | ▷ ▷ π−1(i):=n:π(n)=i∣n∈N)π^-1(\i\):=\n:\,π(n)=i n∈ N\)π- 1 ( i ) := n : π ( n ) = i ∣ n ∈ N ) 6:spentBudget←M−numBlocks←spentBudgetnumBlocksspentBudget← M-numBlocksspentBudget ← M - numBlocks 7:topNodeContribs←CreateEmptyDictionary⁢()←topNodeContribsCreateEmptyDictionarytopNodeContribs ()topNodeContribs ← CreateEmptyDictionary ( ) 8:for all i∈0⁢ to numBlocks−10 to numBlocks1i∈\0 to numBlocks-1\i ∈ 0 to numBlocks - 1 in decreasing order of blockContribution⁢[i]blockContributiondelimited-[] blockContribution[i]blockContribution [ i ] do 9: for all n∈π−1⁢(i)superscript1n∈π^-1(\i\)n ∈ π- 1 ( i ) do ▷ ▷ Eval all nodes in block. 10: if spentBudget<MspentBudgetspentBudget<MspentBudget < M then 11: topNodeContribs⁢[n]←∣ℐ~⁢(n)|←topNodeContribsdelimited-[]delimited-∣|~ℐtopNodeContribs[n]← I(\n\)|topNodeContribs [ n ] ← ∣ over~ start_ARG I end_ARG ( n ) | 12: spentBudget←spentBudget+1←spentBudgetspentBudget1spentBudget +1spentBudget ← spentBudget + 1 13: else 14: return topNodeContribs 15:return topNodeContribs The Hierarchical baseline algorithm aims to resolve this tradeoff, by using small blocks, but grouped into superblocks so it’s not necessary to traverse all the small blocks before finding the key nodes. In Algorithm 3 we detail the hierarchical algorithm in its iterative form, corresponding to batch size 1. One aspect that might be surprising is that on line 22, we ensure a subblock is never added to the priority queue with higher priority than its ancestor superblocks. The reason for doing this is that in practice we use batched inference rather than patching a single block at a time, so depending on the batch size, we do evaluate blocks that aren’t the highest-priority unevaluated blocks, and this might impose a significant delay in when some blocks are evaluated. In order to reduce this dependence on the batch size hyperparameter, line 22 ensures that every block is evaluated at most L batches later than it would be with batch size 1. Algorithm 3 Hierarchical algorithm for causal attribution, in iterative form. In practice we do additional batching rather than evaluating a single block at a time on line 15. 1:branching factor B, num levels L, compute budget M, nodes N=nisubscriptN=\n_i\N = nitalic_i , intervention function ℐII 2:numTopLevelBlocks←⌈|N|/BL⌉←numTopLevelBlockssuperscriptnumTopLevelBlocks← |N|/B^L ← ⌈ | N | / Bitalic_L ⌉ 3:π←shuffle(⌊numTopLevelBlocks⋅iBL/|N|⌋|i∈0,…,|N|−1)π ( \ % numTopLevelBlocks· iB^L/|N| |i∈\0,…,|N|-1\% \ )π ← shuffle ( ⌊ numTopLevelBlocks ⋅ i Bitalic_L / | N | ⌋ | i ∈ 0 , … , | N | - 1 ) 4:for all ni∈Nsubscriptn_i∈ Nnitalic_i ∈ N do 5: (dL−1,dL−2,…,d0)←zero-padded final ⁢L←subscript1subscript2…subscript0zero-padded final (d_L-1,d_L-2,…,d_0) -padded final L( ditalic_L - 1 , ditalic_L - 2 , … , d0 ) ← zero-padded final L base-B digits of πisubscript _iπitalic_i 6: address⁢(ni)=(⌊πi/BL⌋,dL−1,…,d0)addresssubscriptsubscriptsuperscriptsubscript1…subscript0address(n_i)=( _i/B^L ,d_L-1,…,d_0)address ( nitalic_i ) = ( ⌊ πitalic_i / Bitalic_L ⌋ , ditalic_L - 1 , … , d0 ) 7:Q←CreateEmptyPriorityQueue⁢()←CreateEmptyPriorityQueueQ ()Q ← CreateEmptyPriorityQueue ( ) 8:for i←0⁢ to numTopLevelBlocks−1←0 to numTopLevelBlocks1i← 0 to numTopLevelBlocks-1i ← 0 to numTopLevelBlocks - 1 do 9: PriorityQueueInsert⁡(Q,[i],∞)PriorityQueueInsertdelimited-[]PriorityQueueInsert(Q,[i],∞)PriorityQueueInsert ( Q , [ i ] , ∞ ) 10:spentBudget←0←spentBudget0spentBudget← 0spentBudget ← 0 11:topNodeContribs←CreateEmptyDictionary⁢()←topNodeContribsCreateEmptyDictionarytopNodeContribs ()topNodeContribs ← CreateEmptyDictionary ( ) 12:repeat 13: (addressPrefix,priority)←PriorityQueuePop⁡(Q)←addressPrefixpriorityPriorityQueuePop(addressPrefix,priority)←% PriorityQueuePop(Q)( addressPrefix , priority ) ← PriorityQueuePop ( Q ) 14: blockNodes←n∈N|StartsWith⁡(address⁢(n),addressPrefix)←blockNodesconditional-setStartsWithaddressaddressPrefixblockNodes← \n∈ N |StartsWith(% address(n),addressPrefix) \blockNodes ← n ∈ N | StartsWith ( address ( n ) , addressPrefix ) 15: blockContribution←|ℐ⁢(blockNodes)|←blockContributionℐblockNodesblockContribution←|I (blockNodes% )|blockContribution ← | I ( blockNodes ) | 16: spentBudget←spentBudget+1←spentBudgetspentBudget1spentBudget +1spentBudget ← spentBudget + 1 17: if blockNodes=nblockNodesblockNodes=\n\blockNodes = n for some n∈Nn∈ Nn ∈ N then 18: topNodeContribs⁢[n]←blockContribution←topNodeContribsdelimited-[]blockContributiontopNodeContribs[n] [ n ] ← blockContribution 19: else 20: for i←0⁢ to ⁢B−1←0 to 1i← 0 to B-1i ← 0 to B - 1 do 21: if n∈blockNodes|StartsWith(address(n),addressPrefix+[i]≠∅\n |StartsWith(address(n),% addressPrefix+[i]\ = n ∈ blockNodes | StartsWith ( address ( n ) , addressPrefix + [ i ] ≠ ∅ then 22: PriorityQueueInsert⁡(Q,addressPrefix+[i],min⁡(blockContribution,priority))PriorityQueueInsertaddressPrefixdelimited-[]blockContributionpriorityPriorityQueueInsert(Q,addressPrefix+[i], (% blockContribution,priority))PriorityQueueInsert ( Q , addressPrefix + [ i ] , min ( blockContribution , priority ) ) 23:until spentBudget=MspentBudgetspentBudget=MspentBudget = M or PriorityQueueEmpty⁡(Q)PriorityQueueEmptyPriorityQueueEmpty(Q)PriorityQueueEmpty ( Q ) 24:return topNodeContribs A.2 AtP improvements A.2.1 Pseudocode for corrected AtP on attention keys As described in Section 3.1.1, computing Equation 10 naïvely for all nodes requires O⁡(T3)Osuperscript3O(T^3)O ( T3 ) flops at each attention head and prompt pair. Here we give a more efficient algorithm running in O⁡(T2)Osuperscript2O(T^2)O ( T2 ). In addition to keys, queries and attention probabilities, we now also cache attention logits (pre-softmax scaled key-query dot products). We define attnLogitspatcht⁡(nq)superscriptsubscriptattnLogitspatchsuperscriptattnLogits_patch^t(n^q)attnLogitspatchitalic_t ( nitalic_q ) and Δt⁢attnLogits⁡(nq)subscriptΔattnLogitssuperscript _tattnLogits(n^q)Δitalic_t attnLogits ( nitalic_q ) analogously to Equations 8 and 9. For brevity we can also define attnLogitspatch(nq)t:=attnLogitspatcht(nq)tattnLogits_patch(n^q)_t:=attnLogits% ^t_patch(n^q)_tattnLogitspatch ( nitalic_q )t := attnLogitsitalic_tpatch ( nitalic_q )t and ΔattnLogits(nq)t:=ΔtattnLogits(nq)t (n^q)_t:= _tattnLogits% (n^q)_tΔ attnLogits ( nitalic_q )t := Δitalic_t attnLogits ( nitalic_q )t, since the aim with this algorithm is to avoid having to separately compute effects of do⁡(ntk←ntk⁢(xnoise))do←subscriptsuperscriptsubscriptsuperscriptsuperscriptnoisedo(n^k_t← n^k_t(x^noise))do ( nitalic_kitalic_t ← nitalic_kitalic_t ( xnoise ) ) on any other component of attnLogitsattnLogitsattnLogitsattnLogits than the one for key node ntksubscriptsuperscriptn^k_tnitalic_kitalic_t. Note that, for a key ntksubscriptsuperscriptn^k_tnitalic_kitalic_t at position t in the sequence, the proportions of the non-t components of attn(nq)tattn(n^q)_tattn ( nitalic_q )t do not change when attnLogits(nq)tattnLogits(n^q)_tattnLogits ( nitalic_q )t is changed, so Δt⁢attn⁡(nq)subscriptΔattnsuperscript _tattn(n^q)Δitalic_t attn ( nitalic_q ) is actually onehot⁢(t)−attn⁡(nq)onehotattnsuperscriptonehot(t)-attn(n^q)onehot ( t ) - attn ( nitalic_q ) multiplied by some scalar stsubscripts_tsitalic_t; specifically, to get the right attention weight on ntksubscriptsuperscriptn^k_tnitalic_kitalic_t, the scalar must be st:=Δattn(nq)t1−attn(nq)ts_t:= (n^q)_t1-attn(n^q% )_tsitalic_t := divide start_ARG Δ attn ( nitalic_q )t end_ARG start_ARG 1 - attn ( nitalic_q )t end_ARG. Additionally, we have log(attnpatcht(nq)t1−attnpatcht(nq)t)=log(attn(nq)t1−attn(nq)t)+ΔattnLogits(nq)t ( attn_patch^t(n^q)_t1-% attn_patch^t(n^q)_t )= ( % attn(n^q)_t1-attn(n^q)_t )+% (n^q)_tlog ( divide start_ARG attnpatchitalic_t ( nitalic_q )t end_ARG start_ARG 1 - attnpatchitalic_t ( nitalic_q )t end_ARG ) = log ( divide start_ARG attn ( nitalic_q )t end_ARG start_ARG 1 - attn ( nitalic_q )t end_ARG ) + Δ attnLogits ( nitalic_q )t; note that the logodds function p↦log⁡(p1−p)maps-to1p ( p1-p )p ↦ log ( divide start_ARG p end_ARG start_ARG 1 - p end_ARG ) is the inverse of the sigmoid function, so attnpatcht⁡(nq)=σ⁡(log⁡(attnpatcht(nq)t1−attnpatcht(nq)t))attn_patch^t(n^q)=σ ( % ( attn_patch^t(n^q)_t1- % attn_patch^t(n^q)_t ) )attnpatchitalic_t ( nitalic_q ) = σ ( log ( divide start_ARG attnpatchitalic_t ( nitalic_q )t end_ARG start_ARG 1 - attnpatchitalic_t ( nitalic_q )t end_ARG ) ). Putting this together, we can compute all attnLogitspatch⁡(nq)subscriptattnLogitspatchsuperscriptattnLogits_patch(n^q)attnLogitspatch ( nitalic_q ) by combining all keys from the xnoisesuperscriptnoisex^noisexnoise forward pass with all queries from the xcleansuperscriptcleanx^cleanxclean forward pass, and proceed to compute Δ⁢attnLogits⁡(nq)ΔattnLogitssuperscript (n^q)Δ attnLogits ( nitalic_q ), and all Δtattn(nq)t _tattn(n^q)_tΔitalic_t attn ( nitalic_q )t, and thus all ℐ^AtPfixK⁢(nt;xclean,xnoise)superscriptsubscript^ℐAtPfixsubscriptsuperscriptcleansuperscriptnoise I_AtPfix^K(n_t;x^clean,x^noise)over start_ARG I end_ARGAtPfixK ( nitalic_t ; xclean , xnoise ), using O⁡(T2)Osuperscript2O(T^2)O ( T2 ) flops per attention head. Algorithm 4 computes the contribution of some query node nqsuperscriptn^qnitalic_q and prompt pair xclean,xnoisesuperscriptcleansuperscriptnoisex^clean,x^noisexclean , xnoise to the corrected AtP estimates c^AtPfixK⁢(ntk)superscriptsubscript^AtPfixsubscriptsuperscript c_AtPfix^K(n^k_t)over start_ARG c end_ARGAtPfixK ( nitalic_kitalic_t ) for key nodes n1k,…,nTksubscriptsuperscript1…subscriptsuperscriptn^k_1,…,n^k_Tnitalic_k1 , … , nitalic_kitalic_T from a single attention head, using O⁢(T)O(T)O ( T ) flops, while avoiding numerical overflows. We reuse the notation attn⁡(nq)attnsuperscriptattn(n^q)attn ( nitalic_q ), attnpatcht⁡(nq)superscriptsubscriptattnpatchsuperscriptattn_patch^t(n^q)attnpatchitalic_t ( nitalic_q ), Δt⁢attn⁡(nq)subscriptΔattnsuperscript _tattn(n^q)Δitalic_t attn ( nitalic_q ), attnLogits⁡(nq)attnLogitssuperscriptattnLogits(n^q)attnLogits ( nitalic_q ), attnLogitspatch⁡(nq)subscriptattnLogitspatchsuperscriptattnLogits_patch(n^q)attnLogitspatch ( nitalic_q ), and stsubscripts_tsitalic_t from Section 3.1.1, leaving the prompt pair implicit. Algorithm 4 AtP correction for attention keys 1::=attnLogits⁡(nq)assignattnLogitssuperscripta:=attnLogits(n^q)a := attnLogits ( nitalic_q ), patch:=attnLogitspatch⁡(nq)assignsuperscriptpatchsubscriptattnLogitspatchsuperscripta^patch:=attnLogits_patch(n^q)apatch := attnLogitspatch ( nitalic_q ), :=∂ℒ⁢(ℳ⁢(xclean))∂attn⁡(nq)assignℒℳsuperscriptcleanattnsuperscriptg:= (M(x^clean))∂% attn(n^q)g := divide start_ARG ∂ L ( M ( xclean ) ) end_ARG start_ARG ∂ attn ( nitalic_q ) end_ARG 2:t*←argmaxt⁡(at)←superscriptsubscriptargmaxsubscriptt^* _t(a_t)t* ← argmaxitalic_t ( aitalic_t ) 3:ℓ←−at*−log⁡(∑teat−at*)←ℓsubscriptsuperscriptsubscriptsuperscriptsubscriptsubscriptsuperscript -a_t^*- ( _te^a_t-a_t^* )ℓ ← a - aitalic_t* - log ( ∑t eitalic_aitalic_t - aitalic_t start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ▷ ▷ Clean log attn weights, ℓ=log⁡(attn⁡(nq))ℓattnsuperscript = (attn(n^q))ℓ = log ( attn ( nitalic_q ) ) 4:←ℓ−log⁡(1−eℓ)←ℓ1superscriptℓd← - (1-e )d ← ℓ - log ( 1 - eroman_ℓ ) ▷ ▷ Clean logodds, dt=log⁡(attn(nq)t1−attn(nq)t)d_t= ( attn(n^q)_t1-attn(n^% q)_t )ditalic_t = log ( divide start_ARG attn ( nitalic_q )t end_ARG start_ARG 1 - attn ( nitalic_q )t end_ARG ) 5:dt*←at*−maxt≠t*⁡at−log⁡(∑t′≠t*eat′−maxt≠t*⁡at)←subscriptsuperscriptsubscriptsuperscriptsubscriptsuperscriptsubscriptsubscriptsuperscript′superscriptsubscriptsuperscript′subscriptsuperscriptsubscriptd_t^*← a_t^*- _t =t^*a_t- ( _t^% =t^*e^a_t - _t =t^*a_t )ditalic_t* ← aitalic_t* - maxitalic_t ≠ t* aitalic_t - log ( ∑t′ ≠ t* eitalic_aitalic_t start_POSTSUPERSCRIPT ′ - maxitalic_t ≠ t* aitalic_t end_POSTSUPERSCRIPT ) ▷ ▷ Adjust dd; more stable for at*≫maxt≠t*⁡atmuch-greater-thansubscriptsuperscriptsubscriptsuperscriptsubscripta_t^* _t =t^*a_taitalic_t* ≫ maxitalic_t ≠ t* aitalic_t 6:ℓpatch←logsigmoid⁡(+patch−)←superscriptℓpatchlogsigmoidsuperscriptpatch ^patch (d+a^% patch-a)ℓpatch ← logsigmoid ( d + apatch - a ) ▷ ▷ Patched log attn weights, ℓtpatch=log(attnpatcht(nq)t) ^patch_t= (attn_patch^t(n^q)_t)ℓpatchitalic_t = log ( attnpatchitalic_t ( nitalic_q )t ) 7:Δ⁢ℓ←ℓpatch−ℓ←Δℓsuperscriptℓpatchℓ ← ^patch- Δ ℓ ← ℓpatch - ℓ ▷ ▷ Δ⁢ℓt=log⁡(attnpatcht(nq)tattn(nq)t) _t= ( attn_patch^t(n^q)_% tattn(n^q)_t )Δ ℓitalic_t = log ( divide start_ARG attnpatchitalic_t ( nitalic_q )t end_ARG start_ARG attn ( nitalic_q )t end_ARG ) 8:b←softmax()⊺b (a) gb ← softmax ( a )⊺ g ▷ ▷ b=attn(nq)⊺b=attn(n^q) gb = attn ( nitalic_q )⊺ g 9:for t←1⁢ to ⁢T←1 to t← 1 to Tt ← 1 to T do 10: ▷ ▷ Compute scaling factor st:=Δtattn(nq)t1−attn(nq)ts_t:= _tattn(n^q)_t1-attn(n% ^q)_tsitalic_t := divide start_ARG Δitalic_t attn ( nitalic_q )t end_ARG start_ARG 1 - attn ( nitalic_q )t end_ARG 11: if ℓtpatch>ℓtsubscriptsuperscriptℓpatchsubscriptℓ ^patch_t> _tℓpatchitalic_t > ℓitalic_t then▷normal-▷ ▷ Avoid overflow when ℓtpatch≫ℓtmuch-greater-thansubscriptsuperscriptnormal-ℓpatchsubscriptnormal-ℓ ^patch_t _tℓpatchitalic_t ≫ ℓitalic_t 12: st←edt+Δ⁢ℓt+log⁡(1−e−Δ⁢ℓt)←subscriptsuperscriptsubscriptΔsubscriptℓ1superscriptΔsubscriptℓs_t← e^d_t+ _t+ (1-e^- _t)sitalic_t ← eitalic_ditalic_t + Δ ℓitalic_t + log ( 1 - e start_POSTSUPERSCRIPT - Δ ℓitalic_t ) end_POSTSUPERSCRIPT ▷ ▷ st=attn(nq)t1−attn(nq)t⁢attnpatcht(nq)tattn(nq)t⁢(1−attn(nq)tattnpatcht(nq)t)s_t= attn(n^q)_t1-attn(n^q)_t% attn_patch^t(n^q)_tattn(n% ^q)_t (1- attn(n^q)_tattn_% patch^t(n^q)_t )sitalic_t = divide start_ARG attn ( nitalic_q )t end_ARG start_ARG 1 - attn ( nitalic_q )t end_ARG divide start_ARG attnpatchitalic_t ( nitalic_q )t end_ARG start_ARG attn ( nitalic_q )t end_ARG ( 1 - divide start_ARG attn ( nitalic_q )t end_ARG start_ARG attnpatchitalic_t ( nitalic_q )t end_ARG ) 13: else▷normal-▷ ▷ Avoid overflow when ℓtpatch≪ℓtmuch-less-thansubscriptsuperscriptnormal-ℓpatchsubscriptnormal-ℓ ^patch_t _tℓpatchitalic_t ≪ ℓitalic_t 14: st←−edt+log⁡(1−eΔ⁢ℓt)←subscriptsuperscriptsubscript1superscriptΔsubscriptℓs_t←-e^d_t+ (1-e _t)sitalic_t ← - eitalic_ditalic_t + log ( 1 - e start_POSTSUPERSCRIPT Δ ℓitalic_t ) end_POSTSUPERSCRIPT ▷ ▷ st=−attn(nq)t1−attn(nq)t⁢(1−attnpatcht(nq)tattn(nq)t)s_t=- attn(n^q)_t1-attn(n^q)_t% (1- attn_patch^t(n^q)_t % attn(n^q)_t )sitalic_t = - divide start_ARG attn ( nitalic_q )t end_ARG start_ARG 1 - attn ( nitalic_q )t end_ARG ( 1 - divide start_ARG attnpatchitalic_t ( nitalic_q )t end_ARG start_ARG attn ( nitalic_q )t end_ARG ) 15: rt←st⁢(gt−b)←subscriptsubscriptsubscriptr_t← s_t(g_t-b)ritalic_t ← sitalic_t ( gitalic_t - b ) ▷ ▷ rt=st⁢(onehot⁢(t)−attn⁡(nq))⊺⁢=Δt⁢attn⁡(nq)⋅∂ℒ⁢(ℳ⁢(xclean))∂attn⁡(nq)subscriptsubscriptsuperscriptonehotattnsuperscript⊺⋅subscriptΔattnsuperscriptℒℳsuperscriptcleanattnsuperscriptr_t=s_t(onehot(t)-attn(n^q)) % g= _tattn(n^q)· ( % M(x^clean)) (n^q)ritalic_t = sitalic_t ( onehot ( t ) - attn ( nitalic_q ) )⊺ g = Δitalic_t attn ( nitalic_q ) ⋅ divide start_ARG ∂ L ( M ( xclean ) ) end_ARG start_ARG ∂ attn ( nitalic_q ) end_ARG 16:return rr The corrected AtP estimates c^AtPfixK⁢(ntk)superscriptsubscript^AtPfixsubscriptsuperscript c_AtPfix^K(n^k_t)over start_ARG c end_ARGAtPfixK ( nitalic_kitalic_t ) can then be computed using Equation 10; in other words, by summing the returned rtsubscriptr_tritalic_t from Algorithm 4 over queries nqsuperscriptn^qnitalic_q for this attention head, and averaging over xclean,xnoise∼similar-tosuperscriptcleansuperscriptnoisex^clean,x^noise , xnoise ∼ D. A.2.2 Properties of GradDrop In Section 3.1.2 we introduced GradDrop to address an AtP failure mode arising from cancellation between direct and indirect effects: roughly, if the total effect (on some prompt pair) is ℐ⁢(n)=ℐdirect⁢(n)+ℐindirect⁢(n)ℐsuperscriptℐdirectsuperscriptℐindirectI(n)=I^direct(n)+I^indirect(n)I ( n ) = Idirect ( n ) + Iindirect ( n ), and these are close to cancelling, then a small multiplicative approximation error in ℐ^AtPindirect⁢(n)superscriptsubscript^ℐAtPindirect I_AtP^indirect(n)over start_ARG I end_ARGAtPindirect ( n ), due to nonlinearities, can accidentally cause |ℐ^AtPdirect⁢(n)+ℐ^AtPindirect⁢(n)|superscriptsubscript^ℐAtPdirectsuperscriptsubscript^ℐAtPindirect| I_AtP^direct(n)+ I_% AtP^indirect(n)|| over start_ARG I end_ARGAtPdirect ( n ) + over start_ARG I end_ARGAtPindirect ( n ) | to be orders of magnitude smaller than |ℐ⁢(n)|ℐ|I(n)|| I ( n ) |. To address this failure mode with an improved estimator c^AtP+GD⁢(n)subscript^AtP+GD c_AtP+GD(n)over start_ARG c end_ARGAtP+GD ( n ), there’s 3 desiderata for GradDrop: 1. c^AtP+GD⁢(n)subscript^AtP+GD c_AtP+GD(n)over start_ARG c end_ARGAtP+GD ( n ) shouldn’t be much smaller than c^AtP⁢(n)subscript^AtP c_AtP(n)over start_ARG c end_ARGAtP ( n ), because that would risk creating more false negatives. 2. c^AtP+GD⁢(n)subscript^AtP+GD c_AtP+GD(n)over start_ARG c end_ARGAtP+GD ( n ) should usually not be much larger than c^AtP⁢(n)subscript^AtP c_AtP(n)over start_ARG c end_ARGAtP ( n ), because that would create false positives, which also slows down verification and can effectively create false negatives at a given budget. 3. If c^AtP⁢(n)subscript^AtP c_AtP(n)over start_ARG c end_ARGAtP ( n ) is suffering from the cancellation failure mode, then c^AtP+GD⁢(n)subscript^AtP+GD c_AtP+GD(n)over start_ARG c end_ARGAtP+GD ( n ) should be significantly larger than c^AtP⁢(n)subscript^AtP c_AtP(n)over start_ARG c end_ARGAtP ( n ). Let’s recall how GradDrop was defined in Section 3.1.2, using a virtual node nℓoutsuperscriptsubscriptℓoutn_ ^outnroman_ℓout to represent the residual-stream contributions of layer ℓ ℓ: c^AtP+GD⁢(n):=assignsubscript^AtP+GDabsent c_AtP+GD(n):=over start_ARG c end_ARGAtP+GD ( n ) := xclean,xnoise⁢[1L−1⁢∑ℓ=1L|ℐ^AtP+GDℓ⁢(n;xclean,xnoise)|]subscriptsuperscriptcleansuperscriptnoisedelimited-[]11superscriptsubscriptℓ1subscript^ℐsubscriptAtP+GDℓsuperscriptcleansuperscriptnoise _x^clean,x^noise [ 1L-1% _ =1^L | I_AtP+GD_ (n;x^% clean,x^noise) | ]blackboard_Exclean , xnoise [ divide start_ARG 1 end_ARG start_ARG L - 1 end_ARG ∑ℓ = 1L | over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ; xclean , xnoise ) | ] = == xclean,xnoise⁢[1L−1⁢∑ℓ=1L|(n⁢(xnoise)−n⁢(xclean))⊺⁢∂ℒℓ∂n|]subscriptsuperscriptcleansuperscriptnoisedelimited-[]11superscriptsubscriptℓ1superscriptsuperscriptnoisesuperscriptclean⊺superscriptℒℓ _x^clean,x^noise [ 1L-1% _ =1^L |(n(x^noise)-n(x^clean)) % ∂ n | ]blackboard_Exclean , xnoise [ divide start_ARG 1 end_ARG start_ARG L - 1 end_ARG ∑ℓ = 1L | ( n ( xnoise ) - n ( xclean ) )⊺ divide start_ARG ∂ Lroman_ℓ end_ARG start_ARG ∂ n end_ARG | ] = == xclean,xnoise[1L−1∑ℓ=1L|(n(xnoise)−n(xclean))⊺∂ℒ∂n(ℳ(xclean∣do(nℓout←nℓout(xclean))))|] _x^clean,x^noise [ 1L-1% _ =1^L |(n(x^noise)-n(x^clean)) % ∂ n(M(x^clean % do(n^out_ ← n^out_ (x^% clean)))) | ]blackboard_Exclean , xnoise [ divide start_ARG 1 end_ARG start_ARG L - 1 end_ARG ∑ℓ = 1L | ( n ( xnoise ) - n ( xclean ) )⊺ divide start_ARG ∂ L end_ARG start_ARG ∂ n end_ARG ( M ( xclean ∣ do ( noutroman_ℓ ← noutroman_ℓ ( xclean ) ) ) ) | ] To better understand the behaviour of GradDrop, let’s look more carefully at the gradient ∂ℒ∂nℒ ∂ ndivide start_ARG ∂ L end_ARG start_ARG ∂ n end_ARG. The total gradient ∂ℒ∂nℒ ∂ ndivide start_ARG ∂ L end_ARG start_ARG ∂ n end_ARG can be expressed as a sum of all path gradients from the node n to the output. Each path is characterized by the set of layers s it goes through (in contrast to routing via the skip connection). We write the gradient along a path s as ∂ℒs∂nsubscriptℒ _s∂ ndivide start_ARG ∂ Litalic_s end_ARG start_ARG ∂ n end_ARG. Let SS be the set of all subsets of layers after the layer n is in. For example, the direct-effect path is given by ∅∈ ∅ ∈ S. Then the total gradient can be expressed as ∂ℒ∂nℒ ∂ ndivide start_ARG ∂ L end_ARG start_ARG ∂ n end_ARG =∑s∈∂ℒs∂n.absentsubscriptsubscriptℒ = _s _s∂ n.= ∑s ∈ S divide start_ARG ∂ Litalic_s end_ARG start_ARG ∂ n end_ARG . (18) We can analogously define ℐ^AtPs⁢(n)=(n⁢(xnoise)−n⁢(xclean))⊺⁢∂ℒs∂nsuperscriptsubscript^ℐAtPsuperscriptsuperscriptnoisesuperscriptclean⊺subscriptℒ I_AtP^s(n)=(n(x^noise)-n(x^clean)% ) _s∂ nover start_ARG I end_ARGAtPs ( n ) = ( n ( xnoise ) - n ( xclean ) )⊺ divide start_ARG ∂ Litalic_s end_ARG start_ARG ∂ n end_ARG, and break down ℐ^AtP⁢(n)=∑s∈ℐ^AtPs⁢(n)subscript^ℐAtPsubscriptsuperscriptsubscript^ℐAtP I_AtP(n)= _s I_% AtP^s(n)over start_ARG I end_ARGAtP ( n ) = ∑s ∈ S over start_ARG I end_ARGAtPs ( n ). The effect of doing GradDrop at some layer ℓ ℓ is then to drop all terms ℐ^AtPs⁢(n)superscriptsubscript^ℐAtP I_AtP^s(n)over start_ARG I end_ARGAtPs ( n ) with ℓ∈sℓ ∈ sℓ ∈ s: in other words, ℐ^AtP+GDℓ⁢(n)subscript^ℐsubscriptAtP+GDℓ I_AtP+GD_ (n)over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ) =∑s∈ℓ∉sℐ^AtPs⁢(n).absentsubscriptℓsuperscriptsubscript^ℐAtP = _ subarraycs \\ ∈ s subarray I_AtP^s(n).= ∑start_ARG start_ROW start_CELL s ∈ S end_CELL end_ROW start_ROW start_CELL ℓ ∉ s end_CELL end_ROW end_ARG over start_ARG I end_ARGAtPs ( n ) . (21) Now we’l use this understanding to discuss the 3 desiderata. Firstly, most node effects are approximately independent of most layers (see e.g. Veit et al. (2016)); for any layer ℓ ℓ that n’s effect is independent of, we’l have ℐ^AtP+GDℓ⁢(n)=ℐ^AtP⁢(n)subscript^ℐsubscriptAtP+GDℓsubscript^ℐAtP I_AtP+GD_ (n)= I_AtP(n)over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ) = over start_ARG I end_ARGAtP ( n ). Letting K be the set of downstream layers that matter, this guarantees 1L−1⁢∑ℓ=1L|ℐ^AtP+GDℓ⁢(n;xclean,xnoise)|≥L−|K|−1L−1⁢|ℐ^AtP⁢(n;xclean,xnoise)|11superscriptsubscriptℓ1subscript^ℐsubscriptAtP+GDℓsuperscriptcleansuperscriptnoise11subscript^ℐAtPsuperscriptcleansuperscriptnoise 1L-1 _ =1^L | I_AtP+GD_ (n% ;x^clean,x^noise) |≥ L-|K|-1L-1 | % I_AtP(n;x^clean,x^noise) |divide start_ARG 1 end_ARG start_ARG L - 1 end_ARG ∑ℓ = 1L | over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ; xclean , xnoise ) | ≥ divide start_ARG L - | K | - 1 end_ARG start_ARG L - 1 end_ARG | over start_ARG I end_ARGAtP ( n ; xclean , xnoise ) |, which meets the first desideratum. Regarding the second desideratum: for each ℓ ℓ we have |ℐ^AtP+GDℓ⁢(n)|≤∑s∈|ℐ^AtPs⁢(n)|subscript^ℐsubscriptAtP+GDℓsubscriptsuperscriptsubscript^ℐAtP | I_AtP+GD_ (n) |≤ _s∈ % S | I_AtP^s(n) || over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ) | ≤ ∑s ∈ S | over start_ARG I end_ARGAtPs ( n ) |, so overall we have 1L−1⁢∑ℓ=1L|ℐ^AtP+GDℓ⁢(n)|≤L−|K|−1L−1⁢|ℐ^AtP⁢(n)|+|K|L−1⁢∑s∈|ℐ^AtPs⁢(n)|11superscriptsubscriptℓ1subscript^ℐsubscriptAtP+GDℓ11subscript^ℐAtP1subscriptsuperscriptsubscript^ℐAtP 1L-1 _ =1^L | I_AtP+GD_ (n% ) |≤ L-|K|-1L-1 | I_AtP(n) |+% |K|L-1 _s | I_AtP^s(% n) |divide start_ARG 1 end_ARG start_ARG L - 1 end_ARG ∑ℓ = 1L | over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ) | ≤ divide start_ARG L - | K | - 1 end_ARG start_ARG L - 1 end_ARG | over start_ARG I end_ARGAtP ( n ) | + divide start_ARG | K | end_ARG start_ARG L - 1 end_ARG ∑s ∈ S | over start_ARG I end_ARGAtPs ( n ) |. For the RHS to be much larger (e.g. α times larger) than |∑s∈ℐ^AtPs⁢(n)|=|ℐ^AtP⁢(n)|subscriptsuperscriptsubscript^ℐAtPsubscript^ℐAtP | _s I_AtP^s(n) |=| % I_AtP(n)|| ∑s ∈ S over start_ARG I end_ARGAtPs ( n ) | = | over start_ARG I end_ARGAtP ( n ) |, there must be quite a lot of cancellation between different paths, enough so that ∑s∈|ℐ^AtPs⁢(n)|≥(L−1)⁢α|K|⁢|∑s∈ℐ^AtPs⁢(n)|subscriptsuperscriptsubscript^ℐAtP1subscriptsuperscriptsubscript^ℐAtP _s | I_AtP^s(n) |≥% (L-1)α|K| | _s I_% AtP^s(n) |∑s ∈ S | over start_ARG I end_ARGAtPs ( n ) | ≥ divide start_ARG ( L - 1 ) α end_ARG start_ARG | K | end_ARG | ∑s ∈ S over start_ARG I end_ARGAtPs ( n ) |. This is possible, but seems generally unlikely for e.g. α>33α>3α > 3. Now let’s consider the third desideratum, i.e. suppose n is a cancellation false negative, with |ℐ^AtP⁢(n)|≪|ℐ⁢(n)|≪|ℐdirect⁢(n)|≈|ℐ^AtPdirect⁢(n)|much-less-thansubscript^ℐAtPℐmuch-less-thansuperscriptℐdirectsuperscriptsubscript^ℐAtPdirect| I_AtP(n)| |I(n)| |I^% direct(n)|≈| I_AtP^direct(n)|| over start_ARG I end_ARGAtP ( n ) | ≪ | I ( n ) | ≪ | Idirect ( n ) | ≈ | over start_ARG I end_ARGAtPdirect ( n ) |. Then, |∑s∈∖∅ℐ^AtPs⁢(n)|=|ℐ^AtP⁢(n)−ℐ^AtPdirect⁢(n)|≫|ℐ⁢(n)|subscriptsuperscriptsubscript^ℐAtPsubscript^ℐAtPsuperscriptsubscript^ℐAtPdirectmuch-greater-thanℐ | _s I_AtP^% s(n) |= | I_AtP(n)- I_% AtP^direct(n) | |I(n)|| ∑s ∈ S ∖ ∅ over start_ARG I end_ARGAtPs ( n ) | = | over start_ARG I end_ARGAtP ( n ) - over start_ARG I end_ARGAtPdirect ( n ) | ≫ | I ( n ) |. The summands in ∑s∈∖∅ℐ^AtPs⁢(n)subscriptsuperscriptsubscript^ℐAtP _s I_AtP^s(n)∑s ∈ S ∖ ∅ over start_ARG I end_ARGAtPs ( n ) are the union of the summands in ∑s∈ℓ∈sℐ^AtPs⁢(n)=ℐ^AtP⁢(n)−ℐ^AtP+GDℓ⁢(n)subscriptℓsuperscriptsubscript^ℐAtPsubscript^ℐAtPsubscript^ℐsubscriptAtP+GDℓ _ subarraycs \\ ∈ s subarray I_AtP^s(n)= I% _AtP(n)- I_AtP+GD_ (n)∑start_ARG start_ROW start_CELL s ∈ S end_CELL end_ROW start_ROW start_CELL ℓ ∈ s end_CELL end_ROW end_ARG over start_ARG I end_ARGAtPs ( n ) = over start_ARG I end_ARGAtP ( n ) - over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ) across layers ℓ ℓ. It’s then possible but intuitively unlikely that ∑ℓ|ℐ^AtP⁢(n)−ℐ^AtP+GDℓ⁢(n)|subscriptℓsubscript^ℐAtPsubscript^ℐsubscriptAtP+GDℓ _ | I_AtP(n)- I_% AtP+GD_ (n) |∑ℓ | over start_ARG I end_ARGAtP ( n ) - over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ) | would be much smaller than |ℐ^AtP⁢(n)−ℐ^AtPdirect⁢(n)|subscript^ℐAtPsuperscriptsubscript^ℐAtPdirect | I_AtP(n)- I_AtP^% direct(n) || over start_ARG I end_ARGAtP ( n ) - over start_ARG I end_ARGAtPdirect ( n ) |. Suppose the ratio is α, i.e. suppose ∑ℓ|ℐ^AtP⁢(n)−ℐ^AtP+GDℓ⁢(n)|=α⁢|ℐ^AtP⁢(n)−ℐ^AtPdirect⁢(n)|subscriptℓsubscript^ℐAtPsubscript^ℐsubscriptAtP+GDℓsubscript^ℐAtPsuperscriptsubscript^ℐAtPdirect _ | I_AtP(n)- I_% AtP+GD_ (n) |=α | I_AtP(n)- % I_AtP^direct(n) |∑ℓ | over start_ARG I end_ARGAtP ( n ) - over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ) | = α | over start_ARG I end_ARGAtP ( n ) - over start_ARG I end_ARGAtPdirect ( n ) |. For example, if all indirect effects use paths of length 1 then the union is a disjoint union, so ∑ℓ|ℐ^AtP⁢(n)−ℐ^AtP+GDℓ⁢(n)|≥|∑ℓ(ℐ^AtP⁢(n)−ℐ^AtP+GDℓ⁢(n))|=|ℐ^AtP⁢(n)−ℐ^AtPdirect⁢(n)|subscriptℓsubscript^ℐAtPsubscript^ℐsubscriptAtP+GDℓsubscriptℓsubscript^ℐAtPsubscript^ℐsubscriptAtP+GDℓsubscript^ℐAtPsuperscriptsubscript^ℐAtPdirect _ | I_AtP(n)- I_% AtP+GD_ (n) |≥ | _ ( I_% AtP(n)- I_AtP+GD_ (n) ) |= | % I_AtP(n)- I_AtP^direct(n) |∑ℓ | over start_ARG I end_ARGAtP ( n ) - over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ) | ≥ | ∑ℓ ( over start_ARG I end_ARGAtP ( n ) - over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ) ) | = | over start_ARG I end_ARGAtP ( n ) - over start_ARG I end_ARGAtPdirect ( n ) |, so α≥11α≥ 1α ≥ 1. Now: ∑ℓ∈K|ℐ^AtP+GDℓ⁢(n)|subscriptℓsubscript^ℐsubscriptAtP+GDℓ _ ∈ K | I_AtP+GD_ (n) |∑ℓ ∈ K | over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ) | ≥∑ℓ∈K|ℐ^AtP⁢(n)−ℐ^AtP+GDℓ⁢(n)|−|K|⁢|ℐ^AtP⁢(n)|absentsubscriptℓsubscript^ℐAtPsubscript^ℐsubscriptAtP+GDℓsubscript^ℐAtP ≥ _ ∈ K | I_AtP(n)- % I_AtP+GD_ (n) |-|K| | I_% AtP(n) |≥ ∑ℓ ∈ K | over start_ARG I end_ARGAtP ( n ) - over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ) | - | K | | over start_ARG I end_ARGAtP ( n ) | (22) =α⁢|ℐ^AtP⁢(n)−ℐ^AtPdirect⁢(n)|−|K|⁢|ℐ^AtP⁢(n)|absentsubscript^ℐAtPsuperscriptsubscript^ℐAtPdirectsubscript^ℐAtP =α | I_AtP(n)- I_% AtP^direct(n) |-|K| | I_AtP% (n) |= α | over start_ARG I end_ARGAtP ( n ) - over start_ARG I end_ARGAtPdirect ( n ) | - | K | | over start_ARG I end_ARGAtP ( n ) | (23) ≥α⁢|ℐ^AtPdirect⁢(n)|−(|K|+α)⁢|ℐ^AtP⁢(n)|absentsuperscriptsubscript^ℐAtPdirectsubscript^ℐAtP ≥α | I_AtP^direct(n)% |-(|K|+α) | I_AtP(n) |≥ α | over start_ARG I end_ARGAtPdirect ( n ) | - ( | K | + α ) | over start_ARG I end_ARGAtP ( n ) | (24) ∴1L−1⁢∑ℓ=1L|ℐ^AtP+GDℓ⁢(n)|thereforeabsent11superscriptsubscriptℓ1subscript^ℐsubscriptAtP+GDℓ 1L-1 _ =1^L | I_% AtP+GD_ (n) |∴ divide start_ARG 1 end_ARG start_ARG L - 1 end_ARG ∑ℓ = 1L | over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ) | =1L−1⁢∑ℓ∈K|ℐ^AtP+GDℓ⁢(n)|+L−|K|−1L−1⁢|ℐ^AtP⁢(n)|absent11subscriptℓsubscript^ℐsubscriptAtP+GDℓ11subscript^ℐAtP = 1L-1 _ ∈ K | I_AtP+% GD_ (n) |+ L-|K|-1L-1 | I_AtP(% n) |= divide start_ARG 1 end_ARG start_ARG L - 1 end_ARG ∑ℓ ∈ K | over start_ARG I end_ARGAtP+GD start_POSTSUBSCRIPT ℓ end_POSTSUBSCRIPT ( n ) | + divide start_ARG L - | K | - 1 end_ARG start_ARG L - 1 end_ARG | over start_ARG I end_ARGAtP ( n ) | (25) ≥αL−1⁢|ℐ^AtPdirect⁢(n)|+L−2⁢|K|−1−αL−1⁢|ℐ^AtP⁢(n)|absent1superscriptsubscript^ℐAtPdirect211subscript^ℐAtP ≥ αL-1 | I_AtP % direct(n) |+ L-2|K|-1-αL-1 | I_% AtP(n) |≥ divide start_ARG α end_ARG start_ARG L - 1 end_ARG | over start_ARG I end_ARGAtPdirect ( n ) | + divide start_ARG L - 2 | K | - 1 - α end_ARG start_ARG L - 1 end_ARG | over start_ARG I end_ARGAtP ( n ) | (26) And the RHS is an improvement over |ℐ^AtP⁢(n)|subscript^ℐAtP | I_AtP(n) || over start_ARG I end_ARGAtP ( n ) | so long as α⁢|ℐ^AtPdirect⁢(n)|>(2⁢|K|+α)⁢|ℐ^AtP⁢(n)|superscriptsubscript^ℐAtPdirect2subscript^ℐAtPα | I_AtP^direct(n) |>(2|K|+% α) | I_AtP(n) |α | over start_ARG I end_ARGAtPdirect ( n ) | > ( 2 | K | + α ) | over start_ARG I end_ARGAtP ( n ) |, which is likely given the assumptions. Ultimately, though, the desiderata are validated by the experiments, which consistently show GradDrops either decreasing or leaving untouched the number of false negatives, and thus improving performance apart from the initial upfront cost of the extra backwards passes. A.3 Algorithm for computing diagnostics Given summary statistics i¯±subscript¯plus-or-minus i_±over¯ start_ARG i end_ARG±, s±subscriptplus-or-minuss_±s± and count±subscriptcountplus-or-minuscount_±count± for every node n, obtained from Algorithm 1, and a threshold θ>00θ>0θ > 0 we can use Welch’s t-test Welch (1947) to test the hypothesis that |i¯+−i¯−|≥θsubscript¯subscript¯| i_+- i_-|≥θ| over¯ start_ARG i end_ARG+ - over¯ start_ARG i end_ARG- | ≥ θ. Concretely we compute the t-statistic via si¯±subscriptsubscript¯plus-or-minus s_ i_±sover¯ start_ARG i end_ARG start_POSTSUBSCRIPT ± end_POSTSUBSCRIPT =s±count±absentsubscriptplus-or-minussubscriptcountplus-or-minus = s_± count_±= divide start_ARG s± end_ARG start_ARG square-root start_ARG count± end_ARG end_ARG (28) t t =θ−|i¯+−i¯−|si¯+2+si¯−2.absentsubscript¯subscript¯superscriptsubscriptsubscript¯2superscriptsubscriptsubscript¯2 = θ-| i_+- i_-| s_ i_+^2% +s_ i_-^2.= divide start_ARG θ - | over¯ start_ARG i end_ARG+ - over¯ start_ARG i end_ARG- | end_ARG start_ARG square-root start_ARG sover¯ start_ARG i end_ARG start_POSTSUBSCRIPT + end_POSTSUBSCRIPT2 + sover¯ start_ARG i end_ARG start_POSTSUBSCRIPT - end_POSTSUBSCRIPT2 end_ARG end_ARG . (29) The effective degrees of freedom ν can be approximated with the Welch–Satterthwaite equation νWelch=(s+2count++s−2count−)2s+4count+2⁢(count+−1)+s−4count−2⁢(count−1)subscriptWelchsuperscriptsuperscriptsubscript2subscriptcountsuperscriptsubscript2subscriptcount2superscriptsubscript4subscriptsuperscriptcount2subscriptcount1superscriptsubscript4subscriptsuperscriptcount2subscriptcount1 _Welch= ( s_+^2count_+% + s_-^2count_- )^2 s_+^4count^% 2_+(count_+-1)+ s_-^4count^2_-(count% _--1)νWelch = divide start_ARG ( divide start_ARG s+2 end_ARG start_ARG count+ end_ARG + divide start_ARG s-2 end_ARG start_ARG count- end_ARG )2 end_ARG start_ARG divide start_ARG s+4 end_ARG start_ARG count2+ ( count+ - 1 ) end_ARG + divide start_ARG s-4 end_ARG start_ARG count2- ( count- - 1 ) end_ARG end_ARG (30) We then compute the probability (p-value) of obtaining a t at least as large as observed, using the cumulative distribution function of Student’s t⁢(x;νWelch)subscriptWelcht (x; _Welch )t ( x ; νWelch ) at the appropriate points. We take the max of the individual p-values of all nodes to obtain an aggregate upper bound. Finally, we use binary search to find the largest threshold θ that still has an aggregate p-value smaller than a given target p value. We show multiple such diagnostic curves in Section B.3, for different confidence levels (1−ptarget1subscripttarget1-p_target1 - ptarget). Appendix B Experiments B.1 Prompt Distributions B.1.1 IOI We use the following prompt template: BOSWhen␣[A]␣and␣[B]␣went␣to␣the␣bar,␣[A/C]␣gave␣a␣drink␣to␣[B/A] Each clean prompt xcleansuperscriptcleanx^cleanxclean uses two names A and B with completion B, while a noise prompt xnoisesuperscriptnoisex^noisexnoise uses names A, B, and C with completion A. We construct all possible such assignments where names are chosen from the set of Michael, Jessica, Ashley, Joshua, David, Sarah, resulting in 120 prompt pairs. B.1.2 A-AN We use the following prompt template to induce the prediction of an indefinite article. BOSI␣want␣one␣pear.␣Can␣you␣pick␣up␣a␣pear␣for␣me? ␣I␣want␣one␣orange.␣Can␣you␣pick␣up␣an␣orange␣for␣me? ␣I␣want␣one␣[OBJECT].␣Can␣you␣pick␣up␣[a/an] We found that zero shot performance of small models was relatively low, but performance improved drastically when providing a single example of each case. Model performance was sensitive to the ordering of the two examples but was better than random in all cases. The magnitude and sign of the impact of the few-shot ordering was inconsistent. Clean prompts xcleansuperscriptcleanx^cleanxclean contain objects inducing ‘␣a’, one of boat, coat, drum, horn, map, pipe, screw, stamp, tent, wall. Noise prompts xnoisesuperscriptnoisex^noisexnoise contain objects inducing ‘␣an’, one of apple, ant, axe, award, elephant, egg, orange, oven, onion, umbrella. This results in a total of 100 prompt pairs. B.2 Cancellation across a distribution As mention in Section 2, we average the magnitudes of effects across a distribution, rather than taking the magnitude of the average effect. We do this because cancellation of effects is happening frequently across a distribution, which, together with imprecise estimates, could lead to significant false negatives. A proper ablation study to quantify this effect exactly is beyond the scope of this work. In Figure 10, we show the degree of cancellation across the IOI distribution for various model sizes. For this we define the Cancellation Ratio of node n as 1−|∑xclean,xnoiseℐ⁢(n;xclean,xnoise)|∑xclean,xnoise|ℐ⁢(n;xclean,xnoise)|.1subscriptsuperscriptcleansuperscriptnoiseℐsuperscriptcleansuperscriptnoisesubscriptsuperscriptcleansuperscriptnoiseℐsuperscriptcleansuperscriptnoise 1- | _x^clean,x^noiseI% (n;x^clean,x^noise) | _x^clean,x % noise |I(n;x^clean,x^noise) |.1 - divide start_ARG | ∑xclean , xnoise I ( n ; xclean , xnoise ) | end_ARG start_ARG ∑xclean , xnoise | I ( n ; xclean , xnoise ) | end_ARG . (a) Pythia-410M (b) Pythia-1B (c) Pythia-2.8B (d) Pythia-12B Figure 10: Cancellation ratio across IOI for various model sizes. A ratio of 1 means positive and negative effects cancel out across the distribution, whereas a ratio of 0 means only either negative or positive effects exist across the distribution. We report cancellation ratio for different percentiles of nodes based on ∑xclean,xnoise|ℐ⁢(n;xclean,xnoise)|subscriptsuperscriptcleansuperscriptnoiseℐsuperscriptcleansuperscriptnoise _x^clean,x^noise |I(n;x^clean,x% ^noise) |∑xclean , xnoise | I ( n ; xclean , xnoise ) |. B.3 Additional detailed results We show the diagnostic measurements for Pythia-12B across all investigated distributions in Figure 11(b), and cost of verified 100% recall curves for all models and settings in Figures 12(c) and 13(c). Figure 11: Diagnostic of false negatives for 12B across distributions. a.i IOI-P a.i RAND-P a.i IOI (a) AttentionNodes b.i CITY-P b.i RAND-P b.i A-AN (b) NeuronNodes Figure 12: Cost of verified 100% recall curves, sweeping across models and settings for NeuronNodes a.i Pythia 410M a.i Pythia 1B a.i Pythia 2.8B a.iv Pythia 12B (a) CITY-P b.i Pythia 410M b.i Pythia 1B b.i Pythia 2.8B b.iv Pythia 12B (b) RAND-P c.i Pythia 410M c.i Pythia 1B c.i Pythia 2.8B c.iv Pythia 12B (c) A-AN distribution Figure 13: Cost of verified 100% recall curves, sweeping across models and settings for AttentionNodes a.i Pythia 410M a.i Pythia 1B a.i Pythia 2.8B a.iv Pythia 12B (a) IOI-P b.i Pythia 410M b.i Pythia 1B b.i Pythia 2.8B b.iv Pythia 12B (b) RAND-P c.i Pythia 410M c.i Pythia 1B c.i Pythia 2.8B c.iv Pythia 12B (c) IOI distribution B.4 Metrics In this paper we focus on the difference in loss (negative log probability) as the metric ℒLL. We provide some evidence that AtP(*^*start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT) is not sensitive to the choice of ℒLL. For Pythia-12B, on IOI-P and IOI, we show the rank scatter plots in Figure 14 for three different metrics. For IOI, we also show that performance of AtP*^*start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT looks notably worse when effects are evaluated via denoising instead of noising (cf. Section 2.1). As of now we do not have a satisfactory explanation for this observation. Figure 14: True ranks against AtP*^*start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT ranks on Pythia-12B using various metrics ℒLL. The last row shows the effect in the denoising (rather than noising) setting; we speculate that the lower-right subplot (log-odds denoising) is similar to the lower-middle one (logit-diff denoising) because IOI produces a bimodal distribution over the correct and alternate next token. B.5 Hyperparameter selection The iterative baseline, and the AtP-based methods, have no hyperparameters. In general, we used 5 random seeds for each hyperparameter setting, and selected the setting that produced the lowest IRWRGM cost (see Section 4.2). For Subsampling, the two hyperparameters are the Bernoulli sampling probability p, and the number of samples to collect before verifying nodes in decreasing order of c^SSsubscript^S c_Sover start_ARG c end_ARGSS. p was chosen from 0.01, 0.03141414We observed early on that larger values of p were consistently underperforming. We leave it to future work to investigate more granular and smaller values for p.. The number of steps was chosen among power-of-2 numbers of batches, where the batch size depended on the setting. For Blocks, we swept across block sizes 2, 6, 20, 60, 250. For Hierarchical, we used a branching factor of B=33B=3B = 3, because of the following heuristic argument. If all but one node had zero effect, then discovering that node would be a matter of iterating through the hierarchy levels. We’d have number of levels logB⁡|N|subscript _B|N|logitalic_B | N |, and at each level, B forward passes would be required to find which lower-level block the special node is in – and thus the cost of finding the node would be B⁢logB⁡|N|=Blog⁡B⁢log⁡|N|subscriptB _B|N|= B B |N|B logitalic_B | N | = divide start_ARG B end_ARG start_ARG log B end_ARG log | N |. Blog⁡B B Bdivide start_ARG B end_ARG start_ARG log B end_ARG is minimized at B=eB=eB = e, or at B=33B=3B = 3 if B must be an integer. The other hyperparameter is the number of levels; we swept this from 2 to 12. Appendix C AtP variants C.1 Residual-site AtP and Layer normalization Let’s consider the behaviour of AtP on sites that contain much or all of the total signal in the residual stream, such as residual-stream sites. Nanda (2022) described a concern about this behaviour: that linear approximation of the layer normalization would do poorly if the patched value is significantly different than the clean one, but with a similar norm. The proposed modification to AtP to account for this was to hold the scaling factors (in the denominators) fixed when computing the backwards pass. Here we’l present an analysis of how this modification would affect the approximation error of AtP. (Empirical investigation of this issue is beyond the scope of this paper.) Concretely, let the node under consideration be n, with clean and alternate values ncleansuperscriptcleann^cleannroman_clean and nnoisesuperscriptnoisen^noisenroman_noise; and for simplicity, let’s assume the model does nothing more than an unparametrized RMSNorm ℳ⁢(n):=n/|n|assignℳM(n):=n/|n|M ( n ) := n / | n |. Let’s now consider how well ℳ⁢(nnoise)ℳsuperscriptnoiseM(n^noise)M ( nroman_noise ) is approximated, both by its first-order approximation ℳ^AtP⁢(nnoise):=ℳ⁢(nclean)+ℳ⁢(nclean)⟂⁢(nnoise−nclean)assignsubscript^ℳAtPsuperscriptnoiseℳsuperscriptcleanℳsuperscriptsuperscriptcleanperpendicular-tosuperscriptnoisesuperscriptclean M_AtP(n^noise):=M(n^% clean)+M(n^clean) (n^noise-n % clean)over start_ARG M end_ARGAtP ( nroman_noise ) := M ( nroman_clean ) + M ( nroman_clean )⟂ ( nroman_noise - nroman_clean ) where ℳ⁢(nclean)⟂=I−ℳ⁢(nclean)⁢ℳ⁢(nclean)⊺ℳsuperscriptsuperscriptcleanperpendicular-toℳsuperscriptcleanℳsuperscriptsuperscriptclean⊺M(n^clean) =I-M(n^clean)% M(n^clean) M ( nroman_clean )⟂ = I - M ( nroman_clean ) M ( nroman_clean )⊺ is the projection to the hyperplane orthogonal to ℳ⁢(nclean)ℳsuperscriptcleanM(n^clean)M ( nroman_clean ), and by the variant that fixes the denominator: ℳ^AtP+frozenLN⁢(nnoise):=nnoise/|nclean|assignsubscript^ℳAtP+frozenLNsuperscriptnoisesuperscriptnoisesuperscriptclean M_AtP+frozenLN(n^noise):=n^noise% /|n^clean|over start_ARG M end_ARGAtP+frozenLN ( nroman_noise ) := nroman_noise / | nroman_clean |. To quantify the error in the above, we’l measure the error ϵitalic-ϵεϵ in terms of Euclidean distance. Let’s also assume, without loss of generality, that |nclean|=1superscriptclean1|n^clean|=1| nroman_clean | = 1. Geometrically, then, ℳ⁢(n)ℳM(n)M ( n ) is a projection onto the unit hypersphere, ℳAtP⁢(n)subscriptℳAtPM_AtP(n)MAtP ( n ) is a projection onto the tangent hyperplane at ncleansuperscriptcleann^cleannroman_clean, and ℳAtP+frozenLNsubscriptℳAtP+frozenLNM_AtP+frozenLNMAtP+frozenLN is the identity function. Now, let’s define orthogonal coordinates (x,y)(x,y)( x , y ) on the plane spanned by nclean,nnoisesuperscriptcleansuperscriptnoisen^clean,n^noisenroman_clean , nroman_noise, such that ncleansuperscriptcleann^cleannroman_clean is mapped to (1,0)10(1,0)( 1 , 0 ) and nnoisesuperscriptnoisen^noisenroman_noise is mapped to (x,y)(x,y)( x , y ), with y≥00y≥ 0y ≥ 0. Then, ϵAtP:=|ℳ^⁢(nnoise)−ℳ⁢(nnoise)|=2+y2−2⁢x+y2x2+y2assignsubscriptitalic-ϵAtP^ℳsuperscriptnoiseℳsuperscriptnoise2superscript22superscript2superscript2superscript2 _AtP:= | M(n^noise)-M% (n^noise) |= 2+y^2-2 x+y^2 x^2+y^2ϵAtP := | over start_ARG M end_ARG ( nroman_noise ) - M ( nroman_noise ) | = square-root start_ARG 2 + y2 - 2 divide start_ARG x + y2 end_ARG start_ARG square-root start_ARG x2 + y2 end_ARG end_ARG end_ARG, while ϵAtP+frozenLN:=|ℳ^fix⁢(nnoise)−ℳ⁢(nnoise)|=|x2+y2−1|assignsubscriptitalic-ϵAtP+frozenLNsubscript^ℳfixsuperscriptnoiseℳsuperscriptnoisesuperscript2superscript21 _AtP+frozenLN:= | M_fix(n^% noise)-M(n^noise) |= | x^2+y^2% -1 |ϵAtP+frozenLN := | over start_ARG M end_ARGfix ( nroman_noise ) - M ( nroman_noise ) | = | square-root start_ARG x2 + y2 end_ARG - 1 |. Plotting the error in Figure 15, we can see that, as might be expected, freezing the layer norm denominators helps whenever nnoisesuperscriptnoisen^noisenroman_noise indeed has the same norm as ncleansuperscriptcleann^cleannroman_clean, and (barring weird cases with x>11x>1x > 1) whenever the cosine-similarity is less than 1212 12divide start_ARG 1 end_ARG start_ARG 2 end_ARG; but largely hurts if nnoisesuperscriptnoisen^noisenroman_noise is close to ncleansuperscriptcleann^cleannroman_clean. This illustrates that, while freezing the denominators will generally be unhelpful when patch distances are small relative to the full residual signal (as with almost all nodes considered in this paper), it will likely be helpful in a different setting of patching residual streams, which could be quite unaligned but have similar norm. (a) ϵAtPsubscriptitalic-ϵAtP _AtPϵAtP (b) ϵAtP+frozenLNsubscriptitalic-ϵAtP+frozenLN _AtP+frozenLNϵAtP+frozenLN (c) ϵAtP+frozenLN−ϵAtPsubscriptitalic-ϵAtP+frozenLNsubscriptitalic-ϵAtP _AtP+frozenLN- _AtPϵAtP+frozenLN - ϵAtP Figure 15: A comparison of how AtP and AtP with frozen layernorm scaling behave in a toy setting where the model we’re trying to approximate is just ℳ⁢(n):=n/|n|assignℳM(n):=n/|n|M ( n ) := n / | n |. The red region is where frozen layernorm scaling helps; the blue region is where it hurts. We find that unless x>11x>1x > 1, frozen layernorm scaling always has lower error when the cosine-similarity between nnoisesuperscriptnoisen^noisenroman_noise and ncleansuperscriptcleann^cleannroman_clean is <12absent12< 12< divide start_ARG 1 end_ARG start_ARG 2 end_ARG (in other words the angle >60∘absentsuperscript60>60 > 60∘), but often has higher error otherwise. C.2 Edge AtP and AtP* Here we will investigate edge attribution patching, and how the cost scales if we use GradDrop and/or QK fix. (For this section we’l focus on a single prompt pair.) First, let’s review what edge attribution patching is trying to approximate, and how it works. C.2.1 Edge intervention effects Given nodes n1,n2subscript1subscript2n_1,n_2n1 , n2 where n1subscript1n_1n1 is upstream of n2subscript2n_2n2, if we were to patch in an alternate value for n1subscript1n_1n1, this could impact n2subscript2n_2n2 in a complicated nonlinear way. As discussed in 3.1.2, because LLMs have a residual stream, the “direct effect” can be understood as the one holding all other possible intermediate nodes between n1subscript1n_1n1 and n2subscript2n_2n2 fixed – and it’s a relatively simple function, composed of transforming the alternate value n1⁢(xnoise)subscript1superscriptnoisen_1(x^noise)n1 ( xnoise ) to a residual stream contribution rout,ℓ1⁢(xclean|do⁡(n1←n1⁢(xnoise)))subscriptoutsubscriptℓ1conditionalsuperscriptcleando←subscript1subscript1superscriptnoiser_out, _1(x^clean|do(n_1← n% _1(x^noise)))rout , ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( xclean | do ( n1 ← n1 ( xnoise ) ) ), then carrying it along the residual stream to an input rin,ℓ2=rin,ℓ2⁢(xclean)+(rout,ℓ1−rout,ℓ1⁢(xclean))subscriptinsubscriptℓ2subscriptinsubscriptℓ2superscriptcleansubscriptoutsubscriptℓ1subscriptoutsubscriptℓ1superscriptcleanr_in, _2=r_in, _2(x^clean)+(r_% out, _1-r_out, _1(x^clean))rin , ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = rin , ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( xclean ) + ( rout , ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - rout , ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( xclean ) ), and transforming that into a value n2directsuperscriptsubscript2directn_2^directn2direct. In the above, ℓ1subscriptℓ1 _1ℓ1 and ℓ2subscriptℓ2 _2ℓ2 are the semilayers containing n1subscript1n_1n1 and n2subscript2n_2n2, respectively. Let’s define (ℓ1,ℓ2)subscriptsubscriptℓ1subscriptℓ2n_( _1, _2)n( ℓ start_POSTSUBSCRIPT 1 , ℓ2 ) end_POSTSUBSCRIPT to be the set of non-residual nodes between semilayers ℓ1subscriptℓ1 _1ℓ1 and ℓ2subscriptℓ2 _2ℓ2. Then, we can define the resulting n2directsuperscriptsubscript2directn_2^directn2direct as: n2directℓ1⁢(xclean|do⁡(n1←n1⁢(xnoise))):=n2⁢(xclean|do⁡(n1←n1⁢(xnoise)),do⁡((ℓ1,ℓ2)←(ℓ1,ℓ2)⁢(xclean))).assignsuperscriptsubscript2superscriptdirectsubscriptℓ1conditionalsuperscriptcleando←subscript1subscript1superscriptnoisesubscript2conditionalsuperscriptcleando←subscript1subscript1superscriptnoisedo←subscriptsubscriptℓ1subscriptℓ2subscriptsubscriptℓ1subscriptℓ2superscriptcleann_2^direct _1(x^clean|do(n_1% ← n_1(x^noise))):=n_2(x^clean|do% (n_1← n_1(x^noise)),do(n_( _% 1, _2) _( _1, _2)(x^clean))).n2direct start_POSTSUPERSCRIPT ℓ1 end_POSTSUPERSCRIPT ( xclean | do ( n1 ← n1 ( xnoise ) ) ) := n2 ( xclean | do ( n1 ← n1 ( xnoise ) ) , do ( n( ℓ start_POSTSUBSCRIPT 1 , ℓ2 ) end_POSTSUBSCRIPT ← n( ℓ start_POSTSUBSCRIPT 1 , ℓ2 ) end_POSTSUBSCRIPT ( xclean ) ) ) . The residual-stream input rin,ℓ2directℓ1⁢(xclean|do⁡(n1←n1⁢(xnoise)))superscriptsubscriptinsubscriptℓ2superscriptdirectsubscriptℓ1conditionalsuperscriptcleando←subscript1subscript1superscriptnoiser_in, _2^direct _1(x^clean|% do(n_1← n_1(x^noise)))rin , ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPTdirect start_POSTSUPERSCRIPT ℓ1 end_POSTSUPERSCRIPT ( xclean | do ( n1 ← n1 ( xnoise ) ) ) is defined similarly. Finally, n2subscript2n_2n2 itself isn’t enough to compute the metric ℒLL – for that we also need to let the forward pass ℳ⁢(xclean)ℳsuperscriptcleanM(x^clean)M ( xclean ) run using the modified n2directℓ1⁢(xclean|do⁡(n1←n1⁢(xnoise)))superscriptsubscript2superscriptdirectsubscriptℓ1conditionalsuperscriptcleando←subscript1subscript1superscriptnoisen_2^direct _1(x^clean|do(n_1% ← n_1(x^noise)))n2direct start_POSTSUPERSCRIPT ℓ1 end_POSTSUPERSCRIPT ( xclean | do ( n1 ← n1 ( xnoise ) ) ), while removing all other effects of n1subscript1n_1n1 (i.e. not patching it). Writing this out, we have edge intervention effect ℐ⁢(n1→n2;xclean,xnoise)ℐ→subscript1subscript2superscriptcleansuperscriptnoise (n_1→ n_2;x^clean,x^% noise)I ( n1 → n2 ; xclean , xnoise ) :=ℒ(ℳ(xclean|do(n2←n2directℓ1(xclean|do(n1←n1(xnoise)))))) :=L(M(x^clean|do(n_2% ← n_2^direct _1(x^clean|do% (n_1← n_1(x^noise)))))):= L ( M ( xclean | do ( n2 ← n2direct start_POSTSUPERSCRIPT ℓ1 end_POSTSUPERSCRIPT ( xclean | do ( n1 ← n1 ( xnoise ) ) ) ) ) ) −ℒ⁢(ℳ⁢(xclean)).ℒℳsuperscriptclean :=-L(M(x^clean)).- L ( M ( xclean ) ) . (31) C.2.2 Nodes and Edges Let’s briefly consider what edges we’d want to be evaluating this on. In Section 4.1, we were able to conveniently separate attention nodes from MLP neurons, knowing that to handle both kinds of nodes, we’d just need to be able handle each kind of node on its own, and then combine the results. For edge interventions this of course isn’t true, because edges can go from MLP neurons to attention nodes, and vice versa. For the purposes of this section, we’l assume that the node set N contains the attention nodes, and for MLPs either a node per layer (as in Syed et al. (2023)), or a node per neuron (as in the NeuronNodes setting). Regarding the edges, the MLP nodes can reasonably be connected with any upstream or downstream node, but this isn’t true for the attention nodes, which have more of a structure amongst themselves: the key, query, and value nodes for an attention head can only affect downstream nodes via the attention output nodes for that head, and vice versa. As a result, on edges between different semilayers, upstream attention nodes must be attention head outputs, and downstream attention nodes must be keys, queries, or values. In addition, there are some within-attention-head edges, connecting each query node to the output node in the same position, and each key and value node to output nodes in causally affectable positions. C.2.3 Edge AtP As with node activation patching, the edge intervention effect ℐ⁢(n1→n2;xclean,xnoise)ℐ→subscript1subscript2superscriptcleansuperscriptnoiseI(n_1→ n_2;x^clean,x^noise)I ( n1 → n2 ; xclean , xnoise ) is costly to evaluate directly for every edge, since a forward pass is required each time. However, as with AtP, we can apply first-order approximations: we define ℐ^AtP⁢(n1→n2;xclean,xnoise)subscript^ℐAtP→subscript1subscript2superscriptcleansuperscriptnoise I_AtP(n_1→ n_2;x^% clean,x^noise)over start_ARG I end_ARGAtP ( n1 → n2 ; xclean , xnoise ) :=(Δ⁢rn1AtP⁢(xclean,xnoise))⊺⁢∇rn2AtPℒ⁢(ℳ⁢(xclean)),assignabsentsuperscriptΔsuperscriptsubscriptsubscript1AtPsuperscriptcleansuperscriptnoise⊺superscriptsubscript∇subscriptsubscript2AtPℒℳsuperscriptclean := ( r_n_1^AtP(x^clean,x^% noise) ) _r_n_2^AtPL( % M(x^clean)),:= ( Δ ritalic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTAtP ( xclean , xnoise ) )⊺ ∇r start_POSTSUBSCRIPT n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPTAtP L ( M ( xclean ) ) , (32) where ⁢Δ⁢rn1AtP⁢(xclean,xnoise)where Δsuperscriptsubscriptsubscript1AtPsuperscriptcleansuperscriptnoise r_n_1^AtP(x^clean,x^% noise)where Δ ritalic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTAtP ( xclean , xnoise ) :=Jacn1⁡(rout,ℓ1)⁢(n1⁢(xclean))⁢(n1⁢(xnoise)−n1⁢(xclean))assignabsentsubscriptJacsubscript1subscriptoutsubscriptℓ1subscript1superscriptcleansubscript1superscriptnoisesubscript1superscriptclean :=Jac_n_1(r_out, _1)(n_1(x^% clean))(n_1(x^noise)-n_1(x^clean)):= Jacitalic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( rout , ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ( n1 ( xclean ) ) ( n1 ( xnoise ) - n1 ( xclean ) ) (33) and ⁢∇rn2AtPℒ⁢(ℳ⁢(xclean))and superscriptsubscript∇subscriptsubscript2AtPℒℳsuperscriptclean _r_n_2^AtPL(M% (x^clean))and ∇r start_POSTSUBSCRIPT n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPTAtP L ( M ( xclean ) ) :=(Jacrin,ℓ2⁡(n2)⁢(rin,ℓ2⁢(xclean)))⊺⁢∇n2(ℒ⁢(ℳ⁢(xclean)))⁡(n2⁢(xclean)),assignabsentsuperscriptsubscriptJacsubscriptinsubscriptℓ2subscript2subscriptinsubscriptℓ2superscriptclean⊺subscript∇subscript2ℒℳsuperscriptcleansubscript2superscriptclean := (Jac_r_in, _2(n_2)(r_% in, _2(x^clean)) ) _n_2(% L(M(x^clean)))(n_2(x^clean)),:= ( Jacitalic_r start_POSTSUBSCRIPT in , ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( n2 ) ( rin , ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( xclean ) ) )⊺ ∇n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( L ( M ( xclean ) ) ) ( n2 ( xclean ) ) , (34) and this is a close approximation when n1⁢(xnoise)≈n1⁢(xclean)subscript1superscriptnoisesubscript1superscriptcleann_1(x^noise)≈ n_1(x^clean)n1 ( xnoise ) ≈ n1 ( xclean ). A key benefit of this decomposition is that the first term depends only on n1subscript1n_1n1, and the second term depends only on n2subscript2n_2n2; and they’re both easy to compute from a forward and backward pass on xcleansuperscriptcleanx^cleanxclean and a forward pass on xnoisesuperscriptnoisex^noisexnoise, just like AtP itself. Then, to complete the edge-AtP evaluation, what remains computationally is to evaluate all the dot products between nodes in different semilayers, at each token position. This requires dresid⁢T⁢(1−1L)⁢|N|2/2subscriptresid11superscript22d_residT(1- 1L)|N|^2/2droman_resid T ( 1 - divide start_ARG 1 end_ARG start_ARG L end_ARG ) | N |2 / 2 multiplications in total151515This formula omits edges within a single layer, for simplicity – but those are a small minority., where L is the number of layers, T is the number of tokens, and |N||N|| N | is the total number of nodes. This cost exceeds the cost of computing all Δ⁢rn1AtP⁢(xclean,xnoise)Δsuperscriptsubscriptsubscript1AtPsuperscriptcleansuperscriptnoise r_n_1^AtP(x^clean,x^noise)Δ ritalic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTAtP ( xclean , xnoise ) and ∇rn2AtPℒ⁢(ℳ⁢(xclean))superscriptsubscript∇subscriptsubscript2AtPℒℳsuperscriptclean _r_n_2^AtPL(M(x^clean))∇r start_POSTSUBSCRIPT n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPTAtP L ( M ( xclean ) ) on Pythia 2.8B even with a single node per MLP layer; if we look at a larger model, or especially if we consider single-neuron nodes even for small models, the gap grows significantly. Due to this observation, we’l focus our attention on the quadratic part of the compute cost, pertaining to two nodes rather than just one – i.e. the number of multiplications in computing all (Δ⁢rn1AtP⁢(xclean,xnoise))⊺⁢∇rn2AtPℒ⁢(ℳ⁢(xclean))superscriptΔsuperscriptsubscriptsubscript1AtPsuperscriptcleansuperscriptnoise⊺superscriptsubscript∇subscriptsubscript2AtPℒℳsuperscriptclean( r_n_1^AtP(x^clean,x^noise)) % _r_n_2^AtPL(M(x^clean))( Δ ritalic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTAtP ( xclean , xnoise ) )⊺ ∇r start_POSTSUBSCRIPT n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPTAtP L ( M ( xclean ) ). Notably, we’l also exclude within-attention-head edges from the “quadratic cost”: these edges, from some key, query, or value node to an attention output node can be handled by minor variations of the nodewise AtP or AtP* methods for the corresponding key, query, or value node. C.2.4 MLPs There are a couple of issues that can come up around the MLP nodes. One is that, similarly to the attention saturation issue described in Section 3.1.1, the linear approximation to the MLP may be fairly bad in some cases, creating significant false negatives if n2subscript2n_2n2 is an MLP node. Another issue is that if we use single-neuron nodes, then those are very numerous, making the dresidsubscriptresidd_residdroman_resid-dimensional dot product per edge quite costly. MLP saturation and fix Just as clean activations that saturate the attention probability may have small gradients that lead to strongly underestimated effects, the same is true of the MLP nonlinearity. A similar fix is applicable: instead of using a linear approximation to the function from n1subscript1n_1n1 to n2subscript2n_2n2, we can linearly approximate the function from n1subscript1n_1n1 to the preactivation n2,presubscript2pren_2,pren2 , pre, and then recompute n2subscript2n_2n2 using that, before multiplying by the gradient. This kind of rearrangement, where the gradient-delta-activation dot product is computed in dn2subscriptsubscript2d_n_2ditalic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT dimensions rather than dresidsubscriptresidd_residdroman_resid, will come up again – we’l call it the factored form of AtP. If the nodes are neurons then the factored form requires no change to the number of multiplications; however, if they’re MLP layers then there’s a large increase in cost, by a factor of dneuronssubscriptneuronsd_neuronsdroman_neurons. This increase is mitigated by two factors: one is that this is a small minority of edges, outnumbered by the number of edges ending in attention nodes by 3×(# heads per layer)3# heads per layer3×(\# heads per layer)3 × ( # heads per layer ); the other is the potential for parameter sharing. Neuron edges and parameter sharing A useful observation is that each edge, across different token161616Also across different batch entries, if we do this on more than one prompt pair. positions, reuses the same parameter matrices in Jacn1⁡(rout,ℓ1)⁢(n1⁢(xclean))subscriptJacsubscript1subscriptoutsubscriptℓ1subscript1superscriptcleanJac_n_1(r_out, _1)(n_1(x^clean))Jacitalic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( rout , ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ( n1 ( xclean ) ) and Jacrin,ℓ2⁡(n2)⁢(rin,ℓ2⁢(xclean))subscriptJacsubscriptinsubscriptℓ2subscript2subscriptinsubscriptℓ2superscriptcleanJac_r_in, _2(n_2)(r_in, _2(% x^clean))Jacitalic_r start_POSTSUBSCRIPT in , ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( n2 ) ( rin , ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( xclean ) ). Indeed, setting aside the MLP activation function, the only other nonlinearity in those functions is a layer normalization; if we freeze the scaling factor at its clean value as in Section C.1, the Jacobians are equal to the product of the corresponding parameter matrices, divided by the clean scaling factor. Thus if we premultiply the parameter matrices then we eliminate the need to do so at each token, which reduces the per-token quadratic cost by dresidsubscriptresidd_residdroman_resid (i.e. to a scalar multiplication) for neuron-neuron edges, or by dresid/dsitesubscriptresidsubscriptsited_resid/d_sitedroman_resid / droman_site (i.e. to a dsitesubscriptsited_sitedroman_site-dimensional dot product) for edges between neurons and some attention site. It’s worth noting, though, that these premultiplied parameter matrices (or, indeed, the edge-AtP estimates if we use neuron sites) will in total be many times (specifically, (L−1)⁢dneurons4⁢dresid1subscriptneurons4subscriptresid(L-1) d_neurons4d_resid( L - 1 ) divide start_ARG droman_neurons end_ARG start_ARG 4 droman_resid end_ARG times) larger than the MLP weights themselves, so storage may need to be considered carefully. It may be worth considering ways to only find the largest estimates, or the estimates over some threshold, rather than full estimates for all edges. C.2.5 Edge AtP* costs Let’s now consider how to adapt the AtP* proposals from Section 3.1 to this setting. We’ve already seen that the MLP fix, which is similarly motivated to the QK fix, has negligible cost in the neuron-nodes case, but comes with a dneurons/dresidsubscriptneuronssubscriptresidd_neurons/d_residdroman_neurons / droman_resid overhead in quadratic cost in the case of using an MLP layer per node, at least on edges into those MLP nodes. We’l consider the MLP fix to be part of edge-AtP*. Now let’s investigate the two corrections in regular AtP*: GradDrops, and the QK fix. GradDrops GradDrops works by replacing the single backward pass in the AtP formula with L backward passes; this in effect means L values for the multiplicand ∇rn2AtPℒ⁢(ℳ⁢(xclean))superscriptsubscript∇subscriptsubscript2AtPℒℳsuperscriptclean _r_n_2^AtPL(M(x^clean))∇r start_POSTSUBSCRIPT n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPTAtP L ( M ( xclean ) ), so this is a multiplicative factor of L on the quadratic cost (though in fact some of these will be duplicates, and taking this into account lets us drive the multiplicative factor down to (L+1)/212(L+1)/2( L + 1 ) / 2). Notably this works equally well with “factored AtP”, as used for neuron edges; and in particular, if n2subscript2n_2n2 is a neuron, the gradients can easily be combined and shared across n1subscript1n_1n1s, eliminating the (L+1)/212(L+1)/2( L + 1 ) / 2 quadratic-cost overhead. However, the motivation for GradDrops was to account for multiple paths whose effects may cancel; in the edge-interventions setting, these can already be discovered in a different way (by identifying the responsible edges out of n2subscript2n_2n2), so the benefit of GradDrops is lessened. At the same time, the cost remains substantial. Thus, we’l omit GradDrops from our recommended procedure edge-AtP*. QK fix The QK fix applies to the ∇n2(ℒ⁢(ℳ⁢(xclean)))⁡(n2⁢(xclean))subscript∇subscript2ℒℳsuperscriptcleansubscript2superscriptclean _n_2(L(M(x^clean)))(n_2(x^% clean))∇n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( L ( M ( xclean ) ) ) ( n2 ( xclean ) ) term, i.e. to replacing the linear approximation to the softmax with a correct calculation to the change in softmax, for each different input Δ⁢rn1AtP⁢(xclean,xnoise)Δsuperscriptsubscriptsubscript1AtPsuperscriptcleansuperscriptnoise r_n_1^AtP(x^clean,x^noise)Δ ritalic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTAtP ( xclean , xnoise ). As in Section 3.1.1, there’s the simpler case of accounting for n2subscript2n_2n2s that are query nodes, and the more complicated case of n2subscript2n_2n2s that are key nodes using Algorithm 4 – but these are both cheap to do after computing the Δ⁢attnLogitsΔattnLogits Δ attnLogits corresponding to n2subscript2n_2n2. The “factored AtP” way is to matrix-multiply Δ⁢rn1AtP⁢(xclean,xnoise)Δsuperscriptsubscriptsubscript1AtPsuperscriptcleansuperscriptnoise r_n_1^AtP(x^clean,x^noise)Δ ritalic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTAtP ( xclean , xnoise ) with key or query weights and with the clean queries or keys, respectively. This means instead of the dresidsubscriptresidd_residdroman_resid multiplications required for each edge n1→n2→subscript1subscript2n_1→ n_2n1 → n2 with AtP, we need dresid⁢dkey+T⁢dkeysubscriptresidsubscriptkeysubscriptkeyd_residd_key+Td_keydroman_resid droman_key + T droman_key multiplications (which, thanks to the causal mask, can be reduced to an average of dkey⁢(dresid+(T+1)/2)subscriptkeysubscriptresid12d_key(d_resid+(T+1)/2)droman_key ( droman_resid + ( T + 1 ) / 2 )). The “unfactored” option is to stay in the rin,ℓ2subscriptinsubscriptℓ2r_in, _2rin , ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT space: pre-multiply the clean queries or keys with the respective key or query weight matrices, and then take the dot product of Δ⁢rn1AtP⁢(xclean,xnoise)Δsuperscriptsubscriptsubscript1AtPsuperscriptcleansuperscriptnoise r_n_1^AtP(x^clean,x^noise)Δ ritalic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTAtP ( xclean , xnoise ) with each one. This way, the quadratic part of the compute cost contains dresid⁢(T+1)/2subscriptresid12d_resid(T+1)/2droman_resid ( T + 1 ) / 2 multiplications; this will be more efficient for short sequence lengths. This means that for edges into key and query nodes, the overhead of doing AtP+QKfix on the quadratic cost is a multiplicative factor of min⁡(T+12,dkey⁢(1+T+12⁢dresid))12subscriptkey112subscriptresid ( T+12,d_key (1+ T+12d_resid% ) )min ( divide start_ARG T + 1 end_ARG start_ARG 2 end_ARG , droman_key ( 1 + divide start_ARG T + 1 end_ARG start_ARG 2 droman_resid end_ARG ) ). QK fix + GradDrops If the QK fix is being combined with GradDrops, then the first multiplication by the dresid×dkeysubscriptresidsubscriptkeyd_resid× d_keydroman_resid × droman_key matrix can be shared between the different gradients; so the overhead on the quadratic cost of QKfix + GradDrops for edges into queries and keys, using the factored method, is dkey⁢(1+(T+1)⁢(L+1)4⁢dresid)subscriptkey1114subscriptresidd_key (1+ (T+1)(L+1)4d_resid )droman_key ( 1 + divide start_ARG ( T + 1 ) ( L + 1 ) end_ARG start_ARG 4 droman_resid end_ARG ). C.3 Conclusion Considering all the above possibilities, it’s not obvious where the best tradeoff is between correctness and compute cost in all situations. In Table 2 we provide formulas measuring the number of multiplications in the quadratic cost for each kind of edge, across the variations we’ve mentioned. In Figure 16 we plug in the 4 sizes of Pythia model used elsewhere in the paper, such as Figure 2, to enable numerical comparison. AtP variant O→V O→Q,K O→MLP MLP→V MLP→Q,K MLP→MLP MLP layers D⁢H2superscript2DH^2D H2 2⁢D⁢H22superscript22DH^22 D H2 D⁢HDHD H D⁢HDHD H 2⁢D⁢H22DH2 D H D QKfix D⁢H2superscript2DH^2D H2 (T+1)⁢D⁢H21superscript2(T+1)DH^2( T + 1 ) D H2 D⁢HDHD H D⁢HDHD H (T+1)⁢D⁢H1(T+1)DH( T + 1 ) D H D QKfix+GD L+12⁢D⁢H212superscript2 L+12DH^2divide start_ARG L + 1 end_ARG start_ARG 2 end_ARG D H2 (L+1)⁢(T+1)2⁢D⁢H2112superscript2 (L+1)(T+1)2DH^2divide start_ARG ( L + 1 ) ( T + 1 ) end_ARG start_ARG 2 end_ARG D H2 L+12⁢D⁢H12 L+12DHdivide start_ARG L + 1 end_ARG start_ARG 2 end_ARG D H L+12⁢D⁢H12 L+12DHdivide start_ARG L + 1 end_ARG start_ARG 2 end_ARG D H (L+1)⁢(T+1)2⁢D⁢H112 (L+1)(T+1)2DHdivide start_ARG ( L + 1 ) ( T + 1 ) end_ARG start_ARG 2 end_ARG D H L+12⁢D12 L+12Ddivide start_ARG L + 1 end_ARG start_ARG 2 end_ARG D AtP* D⁢H2superscript2DH^2D H2 (T+1)⁢D⁢H21superscript2(T+1)DH^2( T + 1 ) D H2 V⁢N⁢HVNHV N H D⁢HDHD H (T+1)⁢D⁢H1(T+1)DH( T + 1 ) D H N⁢DNDN D AtP*+GD L+12⁢D⁢H212superscript2 L+12DH^2divide start_ARG L + 1 end_ARG start_ARG 2 end_ARG D H2 (L+1)⁢(T+1)2⁢D⁢H2112superscript2 (L+1)(T+1)2DH^2divide start_ARG ( L + 1 ) ( T + 1 ) end_ARG start_ARG 2 end_ARG D H2 V⁢N⁢HVNHV N H L+12⁢D⁢H12 L+12DHdivide start_ARG L + 1 end_ARG start_ARG 2 end_ARG D H (L+1)⁢(T+1)2⁢D⁢H112 (L+1)(T+1)2DHdivide start_ARG ( L + 1 ) ( T + 1 ) end_ARG start_ARG 2 end_ARG D H N⁢DNDN D QKfix (long) D⁢H2superscript2DH^2D H2 (2⁢D+T+1)⁢K⁢H221superscript2(2D+T+1)KH^2( 2 D + T + 1 ) K H2 D⁢HDHD H D⁢HDHD H (2⁢D+T+1)⁢K⁢H21(2D+T+1)KH( 2 D + T + 1 ) K H D QKfix+GD L+12⁢D⁢H212superscript2 L+12DH^2divide start_ARG L + 1 end_ARG start_ARG 2 end_ARG D H2 L+12⁢(2⁢D+T+1)⁢K⁢H21221superscript2 L+12(2D+T+1)KH^2divide start_ARG L + 1 end_ARG start_ARG 2 end_ARG ( 2 D + T + 1 ) K H2 L+12⁢D⁢H12 L+12DHdivide start_ARG L + 1 end_ARG start_ARG 2 end_ARG D H L+12⁢D⁢H12 L+12DHdivide start_ARG L + 1 end_ARG start_ARG 2 end_ARG D H L+12⁢(2⁢D+T+1)⁢K⁢H1221 L+12(2D+T+1)KHdivide start_ARG L + 1 end_ARG start_ARG 2 end_ARG ( 2 D + T + 1 ) K H L+12⁢D12 L+12Ddivide start_ARG L + 1 end_ARG start_ARG 2 end_ARG D ATP* D⁢H2superscript2DH^2D H2 (2⁢D+T+1)⁢K⁢H221superscript2(2D+T+1)KH^2( 2 D + T + 1 ) K H2 V⁢N⁢HVNHV N H D⁢HDHD H (2⁢D+T+1)⁢K⁢H21(2D+T+1)KH( 2 D + T + 1 ) K H N⁢DNDN D AtP*+GD L+12⁢D⁢H212superscript2 L+12DH^2divide start_ARG L + 1 end_ARG start_ARG 2 end_ARG D H2 L+12⁢(2⁢D+T+1)⁢K⁢H21221superscript2 L+12(2D+T+1)KH^2divide start_ARG L + 1 end_ARG start_ARG 2 end_ARG ( 2 D + T + 1 ) K H2 V⁢N⁢HVNHV N H L+12⁢D⁢H12 L+12DHdivide start_ARG L + 1 end_ARG start_ARG 2 end_ARG D H L+12⁢(2⁢D+T+1)⁢K⁢H1221 L+12(2D+T+1)KHdivide start_ARG L + 1 end_ARG start_ARG 2 end_ARG ( 2 D + T + 1 ) K H N⁢DNDN D Neurons D⁢H2superscript2DH^2D H2 2⁢D⁢H22superscript22DH^22 D H2 V⁢N⁢HVNHV N H V⁢N⁢HVNHV N H 2⁢K⁢N⁢H22KNH2 K N H N2superscript2N^2N2 MLPfix D⁢H2superscript2DH^2D H2 2⁢D⁢H22superscript22DH^22 D H2 V⁢N⁢HVNHV N H V⁢N⁢HVNHV N H 2⁢K⁢N⁢H22KNH2 K N H N2superscript2N^2N2 AtP* D⁢H2superscript2DH^2D H2 (T+1)⁢D⁢H21superscript2(T+1)DH^2( T + 1 ) D H2 V⁢N⁢HVNHV N H V⁢N⁢HVNHV N H (T+1)⁢K⁢N⁢H1(T+1)KNH( T + 1 ) K N H N2superscript2N^2N2 AtP*+GD L+12⁢D⁢H212superscript2 L+12DH^2divide start_ARG L + 1 end_ARG start_ARG 2 end_ARG D H2 L+12⁢(T+1)⁢D⁢H2121superscript2 L+12(T+1)DH^2divide start_ARG L + 1 end_ARG start_ARG 2 end_ARG ( T + 1 ) D H2 V⁢N⁢HVNHV N H L+12⁢V⁢N⁢H12 L+12VNHdivide start_ARG L + 1 end_ARG start_ARG 2 end_ARG V N H (L+1)⁢(T+1)2⁢K⁢N⁢H112 (L+1)(T+1)2KNHdivide start_ARG ( L + 1 ) ( T + 1 ) end_ARG start_ARG 2 end_ARG K N H N2superscript2N^2N2 ATP* (long) D⁢H2superscript2DH^2D H2 (2⁢D+T+1)⁢K⁢H221superscript2(2D+T+1)KH^2( 2 D + T + 1 ) K H2 V⁢N⁢HVNHV N H V⁢N⁢HVNHV N H (T+1)⁢K⁢N⁢H1(T+1)KNH( T + 1 ) K N H N2superscript2N^2N2 AtP*+GD L+12⁢D⁢H212superscript2 L+12DH^2divide start_ARG L + 1 end_ARG start_ARG 2 end_ARG D H2 L+12⁢(2⁢D+T+1)⁢K⁢H21221superscript2 L+12(2D+T+1)KH^2divide start_ARG L + 1 end_ARG start_ARG 2 end_ARG ( 2 D + T + 1 ) K H2 V⁢N⁢HVNHV N H L+12⁢V⁢N⁢H12 L+12VNHdivide start_ARG L + 1 end_ARG start_ARG 2 end_ARG V N H (L+1)⁢(T+1)2⁢K⁢N⁢H112 (L+1)(T+1)2KNHdivide start_ARG ( L + 1 ) ( T + 1 ) end_ARG start_ARG 2 end_ARG K N H N2superscript2N^2N2 Table 2: Per-token per-layer-pair total quadratic cost of each kind of between-layers edge, across edge-AtP variants. For brevity, we omit the layer-pair (L2)binomial2 L2( FRACOP start_ARG L end_ARG start_ARG 2 end_ARG ) factor that would otherwise be in every cell, and use D:=dresid,H:=# heads per layer,K:=dkey,V:=dvalue,N:=dneuronsformulae-sequenceassignsubscriptresidformulae-sequenceassign# heads per layerformulae-sequenceassignsubscriptkeyformulae-sequenceassignsubscriptvalueassignsubscriptneuronsD:=d_resid,H:=\# heads per layer,K:=d_key,V:=d_% value,N:=d_neuronsD := droman_resid , H := # heads per layer , K := droman_key , V := droman_value , N := droman_neurons. Figure 16: A comparison of edge-AtP variants across model sizes and prompt lengths. AtP* here is defined to include QKfix and MLPfix, but not GradDrops. The costs vary across several orders of magnitude for each setting. In the setting with full-MLP nodes, MLPfix carries substantial cost for short prompts, but barely matters for long prompts. In the neuron-nodes setting, MLPfix is costless. But GradDrops in that setting continues to impose a large cost; even though it doesn’t affect MLP→MLP edges, it does affect MLP→Q,K edges, which come out dominating the cost with QKfix. Appendix D Distribution of true effects In Figure 17, we show the distribution of c⁢(n)c(n)c ( n ) across models and distributions. Figure 17: Distribution of true effects across models and prompt pair distributions AttentionNodes NeuronNodes a.i a.i (a) Pythia-410M b.i b.i (b) Pythia-1B c.i c.i (c) Pythia-2.8B d.i d.i (d) Pythia-12B