Paper deep dive
Finding Alignments Between Interpretable Causal Variables and Distributed Neural Representations
Atticus Geiger, Zhengxuan Wu, Christopher Potts, Thomas Icard, Noah D. Goodman
Models: BERT
Intelligence
Status: succeeded | Model: google/gemini-3.1-flash-lite-preview | Prompt: intel-v1 | Confidence: 94%
Last extracted: 3/12/2026, 8:19:46 PM
Summary
The paper introduces Distributed Alignment Search (DAS), a method for causal abstraction that replaces brute-force search with gradient descent to find alignments between high-level causal models and low-level neural networks. DAS allows for distributed representations where individual neurons can play multiple roles, overcoming the limitations of localist approaches in explainable AI.
Entities (4)
Relation Signals (2)
Distributed Alignment Search → improves → Causal Abstraction
confidence 95% · In DAS, we find the alignment between high-level and low-level models using gradient descent rather than conducting a brute-force search
Interchange Intervention → usedin → Causal Abstraction
confidence 90% · Geiger et al. (2021) show that the relevant causal abstraction relation obtains when interchange interventions on aligned high-level variables and low-level variables have equivalent effects.
Cypher Suggestions (2)
Find all methods related to causal abstraction · confidence 90% · unvalidated
MATCH (m:Method)-[:USED_IN]->(f:Framework {name: 'Causal Abstraction'}) RETURN m.nameIdentify techniques used to analyze neural networks · confidence 85% · unvalidated
MATCH (t:Technique)-[:APPLIED_TO]->(n:System {type: 'Neural Network'}) RETURN t.nameAbstract
Abstract:Causal abstraction is a promising theoretical framework for explainable artificial intelligence that defines when an interpretable high-level causal model is a faithful simplification of a low-level deep learning system. However, existing causal abstraction methods have two major limitations: they require a brute-force search over alignments between the high-level model and the low-level one, and they presuppose that variables in the high-level model will align with disjoint sets of neurons in the low-level one. In this paper, we present distributed alignment search (DAS), which overcomes these limitations. In DAS, we find the alignment between high-level and low-level models using gradient descent rather than conducting a brute-force search, and we allow individual neurons to play multiple distinct roles by analyzing representations in non-standard bases-distributed representations. Our experiments show that DAS can discover internal structure that prior approaches miss. Overall, DAS removes previous obstacles to conducting causal abstraction analyses and allows us to find conceptual structure in trained neural nets.
Tags
Links
Full Text
67,695 characters extracted from source content.
Expand or collapse full text
Atticus Geiger∗ start_FLOATSUPERSCRIPT ∗ end_FLOATSUPERSCRIPT♢ start_FLOATSUPERSCRIPT ♢ end_FLOATSUPERSCRIPT, Zhengxuan Wu†thanks: Equal contribution., Christopher Potts, Thomas Icard, and Noah D. Goodman Pr(Ai)22^2start_FLOATSUPERSCRIPT 2 end_FLOATSUPERSCRIPTR Group♢ start_FLOATSUPERSCRIPT ♢ end_FLOATSUPERSCRIPT Stanford University atticusg, wuzhengx, cgpotts, icard, ngoodman@stanford.edu Finding Alignments Between Interpretable Causal Variables and Distributed Neural Representations Abstract Causal abstraction is a promising theoretical framework for explainable artificial intelligence that defines when an interpretable high-level causal model is a faithful simplification of a low-level deep learning system. However, existing causal abstraction methods have two major limitations: they require a brute-force search over alignments between the high-level model and the low-level one, and they presuppose that variables in the high-level model will align with disjoint sets of neurons in the low-level one. In this paper, we present distributed alignment search (DAS), which overcomes these limitations. In DAS, we find the alignment between high-level and low-level models using gradient descent rather than conducting a brute-force search, and we allow individual neurons to play multiple distinct roles by analyzing representations in non-standard bases—distributed representations. Our experiments show that DAS can discover internal structure that prior approaches miss. Overall, DAS removes previous obstacles to uncovering conceptual structure in trained neural nets. 1 Introduction Can an interpretable symbolic algorithm be used to faithfully explain a complex neural network model? This is a key question for interpretability; a positive answer can provide guarantees about how the model will behave, and a negative answer could lead to fundamental concerns about whether the model will be safe and trustworthy. Causal abstraction provides a mathematical framework for precisely characterizing what it means for any complex causal system (e.g., a deep learning model) to implement a simpler causal system (e.g., a symbolic algorithm) (Rubenstein et al., 2017; Beckers et al., 2019; Massidda et al., 2023). For modern AI models, the fundamental operation for assessing whether this relationship holds in practice has been the interchange intervention (also known as activation patching), in which a neural network is provided a ‘base’ input, and sets of neurons are forced to take on the values they would have if different ‘source’ inputs were processed (Geiger et al., 2020; Vig et al., 2020; Finlayson et al., 2021; Meng et al., 2022). The counterfactuals that these interventions create are the basis for causal inferences about model behavior. Geiger et al. (2021) show that the relevant causal abstraction relation obtains when interchange interventions on aligned high-level variables and low-level variables have equivalent effects. This ideal relationship rarely obtains in practice, but the proportion of interchange interventions with the same effect (interchange intervention accuracy; IIA) provides a graded notion, and Geiger et al. (2023) formally ground this metric in the theory of approximate causal abstraction. Geiger et al. also use causal abstraction theory as a unified framework for a wide range of recent intervention-based analysis methods (Vig et al., 2020; Csordás et al., 2021; Feder et al., 2021; Ravfogel et al., 2020; Elazar et al., 2020; De Cao et al., 2021; Abraham et al., 2022; Olah et al., 2020; Olsson et al., 2022; Chan et al., 2022). Causal abstraction techniques have been applied to diverse problems (Geiger et al., 2019, 2020; Li et al., 2021; Huang et al., 2022). However, previous applications have faced two central challenges. First, causal abstraction requires a computationally intensive brute-force search process to find optimal alignments between the variables in the high-level model and the states of the low-level one. Where exhaustive search is intractable, we risk missing the best alignment entirely. Second, these prior methods are localist: they artificially limit the space of possible alignments by presupposing that high-level causal variables will be aligned with disjoint groups of neurons. There is no reason to assume this a priori, and indeed much recent work in model explanation (see especially Ravfogel et al. 2020, 2022; Elazar et al. 2020; Olah et al. 2020; Olsson et al. 2022) is converging on the insight of Smolensky (1986), Rumelhart et al. (1986), and McClelland et al. (1986) that individual neurons can play multiple conceptual roles. Smolensky (1986) identified distributed neural representations as “patterns” consisting of linear combinations of unit vectors. In the current paper, we propose distributed alignment search (DAS), which overcomes the above limitations of prior causal abstraction work. In DAS, we find the best alignment via gradient descent rather than conducting a brute-force search. In addition, we use distributed interchange interventions, which are “soft” interventions in which the causal mechanisms of a group of neurons are edited such that (1) their values are rotated with a change-of-basis matrix, (2) the targeted dimensions of the rotated neural representation are fixed to be the corresponding values in the rotated neural representation created for the source inputs, and (3) the representation is rotated back to the standard neuron-aligned basis. The key insight is that viewing a neural representation through an alternative basis that is not aligned with individual neurons can reveal interpretable dimensions (Smolensky, 1986). In our experiments, we evaluate the capabilities of DAS to provide faithful and interpretable explanations with two tasks that have obvious interpretable high-level algorithmic solutions with two intermediate variables. In both tasks, the distributed alignment learned by DAS is as good or better than both the closest localist alignment and the best localist alignment in a brute-force search. In our first set of experiments, we focus on a hierarchical equality task that has been used extensively in developmental and cognitive psychology as a test of relational reasoning (Premack, 1983; Thompson et al., 1997; Geiger et al., 2022a): the inputs are sequences [w,x,y,z][w,x,y,z][ w , x , y , z ], and the label is given by (w=x)=(y=z)(w=x)=(y=z)( w = x ) = ( y = z ). We train a simple feed-forward neural network on this task and show that it perfectly solves the task. Our key question: does this model implement a program that computes w=xw=xw = x and y=zy=zy = z as intermediate values, as we might hypothesize humans do? Using DAS, we find a distributed alignment with 100% IIA. In other words, the network is perfectly abstracted by the high-level model; the distinction between the learned neural model and the symbolic algorithm is thus one of implementation. Our second task models a natural language inference dataset (Geiger et al., 2020) where the inputs are premise and hypothesis sentences (p,h)ℎ(p,h)( p , h ) that are identical but for the words wpsubscriptw_pwitalic_p and whsubscriptℎw_hwitalic_h; the label is either entails (p makes hℎh true) or contradicts/neutral (p makes hℎh false). We fine-tune a pretrained language model to perfectly solve the task. With DAS, we find a perfect alignment (100% IIA) to a causal model with a binary variable for the entailment relation between the words wpsubscriptw_pwitalic_p and whsubscriptℎw_hwitalic_h (e.g., dog entails mammal). In both our sets of experiments, the DAS analyses reveal perfect abstraction relations. However, we also identify an important difference between them. In the NLI case, the entailment relation can be decomposed into representations of wpsubscriptw_pwitalic_p and whsubscriptℎw_hwitalic_h. What appears to be a representation of lexical entailment is, in this case, a “data structure” containing two representations of word identity, rather than an encoding of their entailment relation. By contrast, the hierarchical equality models learn representations of w=xw=xw = x and y=zy=zy = z that cannot be decomposed into representations of w, x, y and z. In other words, these relations are entirely abstracted from the entities participating in the relation; DAS reveals that the neural network truly implements a symbolic, tree-structured algorithm. 2 Related Work A theory of causal abstraction specifies exactly when a ‘high-level causal model’ can be seen as an abstract characterization of some ‘low-level causal model’ (Iwasaki and Simon, 1994; Chalupka et al., 2017; Rubenstein et al., 2017; Beckers et al., 2019). The basic idea is that high-level variables are associated with (potentially overlapping) sets of low-level variables that summarize their causal mechanisms with respect to a set of hard or soft interventions (Massidda et al., 2023). In practice, a graded notion of approximate causal abstraction is often more useful (Beckers et al., 2019; Rischel and Weichwald, 2021; Geiger et al., 2023). Geiger et al. (2023) argue that causal abstraction is a generic theoretical framework for providing faithful (Jacovi and Goldberg, 2020; Lyu et al., 2022) and interpretable (Lipton, 2018) explanations of AI models and show that LIME (Ribeiro et al., 2016), causal effect estimation (Abraham et al., 2022; Feder et al., 2021), causal mediation analysis (Vig et al., 2020; Csordás et al., 2021; De Cao et al., 2021), iterated nullspace projection (Ravfogel et al., 2020; Elazar et al., 2020), and circuit-based explanations (Olah et al., 2020; Olsson et al., 2022; Wang et al., 2022; Chan et al., 2022) can all be understood as causal abstraction analysis. Interchange intervention training (IIT) objectives are minimized when a high-level causal model is an abstraction of a neural network under a given alignment (Geiger et al., 2022b; Wu et al., 2022; Huang et al., 2022). In this paper, we use IIT objectives to learn an alignment between a high-level causal model and a deep learning model. 3 Methods We focus on acyclic causal models (Pearl, 2001; Spirtes et al., 2000) and seek to provide an intuitive overview of our method. An acyclic causal model consists of input, intermediate, and output variables, where each variable has an associated set of values it can take on and a causal mechanism that determine the value of the variable based on the value of its causal parents. For a simple running example, we modify the boolean conjunction models of Geiger et al. (2022b) to reveal key properties of DAS. A causal model ℬBB for this problem can be defined as below, where the inputs and outputs are booleans t and f. Alongside ℬBB, we also define a causal model NN of a linear feed-forward neural network that solves the task. Here we show ℬBB, NN, and the parameters of NN: PPPQQitalic_QV1=psubscript1V_1=pV1 = pV2=qsubscript2V_2=qV2 = qV3=v1∧v2subscript3subscript1subscript2V_3=v_1 v_2V3 = v1 ∧ v2X1subscript1X_1X1X2subscript2X_2X2H1=[x1;x2]W1subscript1subscript1subscript2subscript1H_1=[x_1;x_2]W_1H1 = [ x1 ; x2 ] W1H2=[x1;x2]W2subscript2subscript1subscript2subscript2H_2=[x_1;x_2]W_2H2 = [ x1 ; x2 ] W2O=[h1;h2]+bsubscriptℎ1subscriptℎ2O=[h_1;h_2]w+bO = [ h1 ; h2 ] w + b W1=[cos(20∘)−sin(20∘)]subscript1delimited-[]superscript20superscript20W_1= [ array[]r (20 )&- (20 ) array ]W1 = [ start_ARRAY start_ROW start_CELL cos ( 20∘ ) end_CELL start_CELL - sin ( 20∘ ) end_CELL end_ROW end_ARRAY ] =[11]delimited-[]11w= [ array[]l1&1 array ]w = [ start_ARRAY start_ROW start_CELL 1 end_CELL start_CELL 1 end_CELL end_ROW end_ARRAY ] W2=[sin(20∘)cos(20∘)]subscript2delimited-[]superscript20superscript20W_2= [ array[]r (20 )& - (20 )% array ]W2 = [ start_ARRAY start_ROW start_CELL sin ( 20∘ ) end_CELL start_CELL cos ( 20∘ ) end_CELL end_ROW end_ARRAY ] b=−1.81.8b=-1.8b = - 1.8 The model NN predicts t if O>00O>0O > 0 and f otherwise. This network solves the boolean conjunction problem perfectly in that all pairs of input boolean values are mapped to the intended output. An input xx of a model ℳMM determines a unique total setting ℳ()ℳM(x)M ( x ) of all the variables in the model. The inputs are fixed to be xx and the causal mechanisms of the model determine the values of the remaining variables. We denote the values that ℳ()ℳM(x)M ( x ) assigns to the variable or variables ZZ as GetValues(ℳ())subscriptGetValuesℳ GetValues_Z(M(x))GetValuesZ ( M ( x ) ). For example, GetValuesV3(ℬ([t,f]))=fsubscriptGetValuessubscript3ℬtff GetValues_V_3(B([ t, f]))= fGetValuesV start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( B ( [ t , f ] ) ) = f. 3.1 Interventions Interventions are a fundamental building block of causal models, and of causal abstraction analysis in particular. An intervention ←I ← i is a setting ii of variables II. Together, an intervention and an input setting xx of a model ℳMM determine a unique total setting that we denote as ℳ←()subscriptℳ←M_I (x)Mbold_I ← i ( x ). The inputs are fixed to be xx, and the causal mechanisms of the model determine the values of the non-intervened variables, with the intervened variables II being fixed to ii. We can define interventions on both our causal model ℬBB and our neural model NN. For example, ℬV1←t([f,t])subscriptℬ←subscript1tftB_V_1← t([ f, t])Bitalic_V start_POSTSUBSCRIPT 1 ← t end_POSTSUBSCRIPT ( [ f , t ] ) is our boolean model when it processes input [f,t]ft[ f, t][ f , t ] but with variable V1subscript1V_1V1 set to t. This has the effect of changing the output value to t. Similarly, whereas ([0,1])01N([0,1])N ( [ 0 , 1 ] ) leads to an intermediate values h1=−0.34subscriptℎ10.34h_1=-0.34h1 = - 0.34 and h2=0.94subscriptℎ20.94h_2=0.94h2 = 0.94 and output value −1.21.2-1.2- 1.2, if we compute h1←1.34([0,1])subscript←subscriptℎ11.3401N_h_1← 1.34([0,1])Nitalic_h start_POSTSUBSCRIPT 1 ← 1.34 end_POSTSUBSCRIPT ( [ 0 , 1 ] ), then the output value is 0.480.480.480.48. This has the effect of changing the predicted value to t, because 0.48>00.4800.48>00.48 > 0. 3.2 Alignment In causal abstraction analysis, we ask whether a specific low-level model like NN implements a high-level algorithm like ℬBB. This is always relative to a specific alignment of variables between the two models. An alignment Π=(ΠXX,τXX)ΠsubscriptsubscriptΠsubscriptsubscript =(\ _X\_X,\ _X\_X)Π = ( Πitalic_X X , τitalic_X X ) assigns to each high-level variable X a set of low-level variables ΠXsubscriptΠ _XΠitalic_X and a function τXsubscript _Xτitalic_X that maps from values of the low-level variables in ΠXsubscriptΠ _XΠitalic_X to values of the aligned high-level variable X. One possible alignment between ℬBB and NN is shown in the diagram above: Π Π is depicted by the dashed lines connecting ℬBB and NN. We immediately know what the functions for high-level input and output variables are. For the inputs, t is encoded as 1111 and f is encoded as 00, meaning τP(1)=τQ(1)=tsubscript1subscript1t _P(1)= _Q(1)= tτitalic_P ( 1 ) = τitalic_Q ( 1 ) = t and τP(0)=τQ(0)=fsubscript0subscript0f _P(0)= _Q(0)= fτitalic_P ( 0 ) = τitalic_Q ( 0 ) = f. For the output, the network only predicts t if y>00y>0y > 0, meaning τV3(x)=tsubscriptsubscript3t _V_3(x)= tτitalic_V start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( x ) = t if x>00x>0x > 0, else f. This is simply a consequence of how a neural network is used and trained. The functions for high-level intermediate variables τV1(x)subscriptsubscript1 _V_1(x)τitalic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( x ) and τV2(x)subscriptsubscript2 _V_2(x)τitalic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( x ) must be discovered and verified experimentally. 3.3 Constructive Causal Abstraction Relative to an alignment like this, we can define abstraction: Definition 3.1. (Constructive Causal Abstraction) A high-level causal model ℋHH is a constructive abstraction of a low-level causal model ℒLL under alignment Πnormal-Π Π exactly when the following holds for every low-level input setting xx and low-level intervention ←normal-←I ← i: τ(ℒ←())=ℋτ(←)(τ())subscriptℒ←subscriptℋ←τ(L_I (x) )=% H_τ(I )(τ(x))τ ( Lbold_I ← i ( x ) ) = Hitalic_τ ( I ← i ) ( τ ( x ) ) ℋHH being a causal abstraction of ℒLL under Π Π guarantees that the causal mechanism for each high-level variable X is a faithful rendering of the causal mechanisms for the low-level variables in ΠXsubscriptΠ _XΠitalic_X. To assess the degree to which a high-level model is a constructive causal abstraction of a low-level model, we perform interchange interventions: Definition 3.2. (Interchange Interventions) Given source input settings j1ksuperscriptsubscriptsubscript1\s_j\_1^k sitalic_j 1k, and non-overlapping sets of intermediate variables j1ksuperscriptsubscriptsubscript1\X_j\_1^k Xitalic_j 1k for model ℳMM, define the interchange intervention as the model I(ℳ,j1k,j1k)=ℳ⋀j=1k⟨j←j(ℳ(sj))⟩Iℳsuperscriptsubscriptsubscript1superscriptsubscriptsubscript1subscriptℳsubscriptsuperscript1delimited-⟨⟩←subscriptsubscriptsubscriptℳsubscript I(M,\s_j\_1^k,\X_j\_1^k)% =M_ ^k_j=1 _j← % GetVals_X_j(M(s_j)) I ( M , sitalic_j 1k , Xitalic_j 1k ) = M⋀italic_k start_POSTSUBSCRIPT j = 1 ⟨ Xitalic_j ← sansserif_GetValsX start_POSTSUBSCRIPT j end_POSTSUBSCRIPT ( M ( sitalic_j ) ) ⟩ end_POSTSUBSCRIPT where ⋀j=1k⟨⋅⟩subscriptsuperscript1delimited-⟨⟩normal-⋅ ^k_j=1 · ⋀kitalic_j = 1 ⟨ ⋅ ⟩ concatenates a set of interventions. A base input setting can be fed into the resulting model to compute the counterfactual output value. Consider the following interchange intervention: I(ℬ,[t,t],V1)=ℬV1←V1(ℬ([t,t]))Iℬttsubscript1subscriptℬ←subscript1subscriptsubscript1ℬt I(B,\[ t, t]\,\\V_1\\)=B% _\V_1\← GetVals_\V_1\(B([ t,% t]))I ( B , [ t , t ] , V1 ) = B V start_POSTSUBSCRIPT 1 ← sansserif_GetVals V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( B ( [ t , t ] ) ) end_POSTSUBSCRIPT We process a base input and a source input, and then we intervene on a target variable, replacing it with the value obtained by processing the source. Our causal model is fully known, and so we know ahead of time that this interchange intervention yields t. For our neural network, the corresponding behavior is not known ahead of time. The interchange intervention corresponding to the above (according to the alignment we are exploring) is as follows I(,[1,1],H1)=V1←H1(([1,1]))I11subscript1subscript1←subscriptsubscript111 I(N,\[1,1]\,\\H_1\\)=N\V_1\% ← GetVals_\H_1\(N([1,1]))I ( N , [ 1 , 1 ] , H1 ) = N V1 ← sansserif_GetVals H start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( N ( [ 1 , 1 ] ) ) And, indeed, the counterfactual behavior of the model and the network NN are unequal: ftV1=tsubscript1tV_1= tV1 = tV2=tsubscript2tV_2= tV2 = tV3=tsubscript3tV_3= tV3 = tttV1=tsubscript1tV_1= tV1 = tV2=tsubscript2tV_2= tV2 = tV3=tsubscript3tV_3= tV3 = t 001111H1=0.6subscript10.6H_1=0.6H1 = 0.6H2=0.94subscript20.94H_2=0.94H2 = 0.94O=−0.260.26O=-0.26O = - 0.26f11111111H1=0.6subscript10.6H_1=0.6H1 = 0.6H2=1.28subscript21.28H_2=1.28H2 = 1.28O=0.080.08O=0.08O = 0.08 Under the given alignment, the interchange interventions at the low and high level have different effects. Thus, we have a counterexample to constructive abstraction as given in Definition 3.1. Although NN has perfect behavioral accuracy, its accuracy under the counterfactuals created by our interventions is not perfect, and thus ℬBB is not a constructive abstraction of NN under this alignment. 3.4 Distributed Interventions The above conclusion is based on the kind of localist causal abstraction explored in the literature to date. As noted in Section 1, there are two risks associated with this conclusion: (1) we may have chosen a suboptimal alignment, and (2) we may be wrong to assume that the relevant structure will be encoded in the standard basis we have implicitly assumed throughout. If we simply rotate the representation [H1,H2]subscript1subscript2[H_1,H_2][ H1 , H2 ] by −20∘superscript20-20 - 20∘ to get a new representation [Y1,Y2]subscript1subscript2[Y_1,Y_2][ Y1 , Y2 ], then the resulting network has perfect behavioral and counterfactual accuracy when we align V1subscript1V_1V1 and V2subscript2V_2V2 with Y1subscript1Y_1Y1 and Y2subscript2Y_2Y2. What this reveals is that there is an alignment, but not in the basis we chose. Since the choice of basis was arbitrary, our negative conclusion about the causal abstraction relation was spurious. This rotation localizes the information about the first and second argument into separate dimensions. To understand this, observe that the weight matrix of the linear network rotates a two dimensional vector by 20∘superscript2020 20∘ and the rotation matrix rotates the representation by 340∘superscript340340 340∘. The two matrices are inverses. Because this network is linear, there is no activation function and so rotating the hidden representation “undoes” the transformation of the input by the weight matrix. Under this non-standard basis, the first hidden dimension is equal to the first input argument and the second hidden dimension is equal to the second input argument. This reveals an essential aspect of distributed neural representations: there is a many-to-many mapping between neurons and concepts, and thus multiple high-level causal variables might be encoded in structures from overlapping groups of neurons (Rumelhart et al., 1986; McClelland et al., 1986). In particular, Smolensky (1986) proposes that viewing a neural representation under a basis that is not aligned with individual neurons can reveal the interpretable distributed structure of the neural representations. 1subscript1X_1X12subscript2X_2X23subscript3X_3X31subscript1Y_1Y12subscript2Y_2Y23subscript3Y_3Y31subscript1Y_1Y12subscript2Y_2Y23subscript3Y_3Y31subscript1Y_1Y12subscript2Y_2Y23subscript3Y_3Y31subscript1Y_1Y12subscript2Y_2Y23subscript3Y_3Y31subscript1X_1X12subscript2X_2X23subscript3X_3X31subscript1X_1X12subscript2X_2X23subscript3X_3X31subscript1X_1X12subscript2X_2X23subscript3X_3X31subscript1X_1X12subscript2X_2X23subscript3X_3X31subscript1X_1X12subscript2X_2X23subscript3X_3X31subscript1X_1X12subscript2X_2X23subscript3X_3X31subscript1Y_1Y12subscript2Y_2Y23subscript3Y_3Y3RR1subscript1Y_1Y12subscript2Y_2Y23subscript3Y_3Y3RR1subscript1Y_1Y12subscript2Y_2Y23subscript3Y_3Y31subscript1Y_1Y12subscript2Y_2Y23subscript3Y_3Y3RR1subscript1X_1X12subscript2X_2X23subscript3X_3X3−1superscript1R^-1R- 1 Figure 1: A generic multi-source distributed interchange intervention. The base input and two source inputs create three total settings of a model. The top left (green) and right (blue) total model settings are determined by two source inputs and the middle total model setting (red) is determined by the base input. Three hidden units from each total setting are rotated with an orthogonal matrix :→:→R:X : X → Y. Then we intervene on the rotated representation for the base input and fix two dimensions to be the value they take on for each source input, respectively. Then we unrotate the representation with −1superscript1R^-1R- 1 and compute a counterfactual total model setting for the base input. In DAS, the orthogonal matrix is found with gradient descent using a high-level causal model to guide the search process. To make good on this intuition we define a distributed intervention, which first transforms a set of variables to a vector space, then does interchange on orthogonal sub-spaces, before transforming back to the original representation space. Definition 3.3. Distributed Interchange Interventions We begin with a causal model ℳMM with input variables SS and source input settings jj=1ksuperscriptsubscriptsubscript1\s_j\_j=1^k sitalic_j j = 1k. Let NN be a subset of variables in ℳMM, the target variables. Let YY be a vector space with subspaces j0ksuperscriptsubscriptsubscript0\Y_j\_0^k Yitalic_j 0k that form an orthogonal decomposition, i.e., =⨁j=0kjsuperscriptsubscriptdirect-sum0subscriptY= _j=0^kY_jY = ⨁j = 0k Yitalic_j. Let RR be an invertible function :→normal-:normal-→R:N : N → Y. Write jsubscriptsubscript Proj_Y_jsansserif_ProjY start_POSTSUBSCRIPT j end_POSTSUBSCRIPT for the orthogonal projection operator of a vector in YY onto subspace jsubscriptY_jYitalic_j.111Thus, Projsansserif_Proj generalizes GetValssansserif_GetVals to arbitrary vector spaces. A distributed interchange intervention yields a new model DII(ℳ,,j1k,j0k)DIIℳsuperscriptsubscriptsubscript1superscriptsubscriptsubscript0 DII(M,R,\s_j\_1^k,\Y_j% \_0^k)DII ( M , R , sitalic_j 1k , Yitalic_j 0k ) which is identical to ℳMM except that the mechanisms FsubscriptF_NFbold_N (which yield values of NN from a total setting) are replaced by: F*()=−1(0((F()))+∑j=1kj((F(ℳ())))).subscriptsuperscriptsuperscript1subscriptsubscript0subscriptsuperscriptsubscript1subscriptsubscriptsubscriptℳsubscriptF^*_N(v)=R^-1 ( Proj_Y% _0 (R (F_N(v) ) )\\ + _j=1^k Proj_Y_j (R (F_% N(M(s_j)) ) ) ).F*N ( v ) = R- 1 ( sansserif_ProjY start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( R ( Fbold_N ( v ) ) ) + ∑j = 1k sansserif_ProjY start_POSTSUBSCRIPT j end_POSTSUBSCRIPT ( R ( Fbold_N ( M ( sbold_j ) ) ) ) ) . Notice that in this definition the base setting is partially preserved through the intervention (in subspace 0subscript0Y_0Y0) and hence this is a soft intervention on NN that rewrites causal mechanisms while maintaining a causal dependence between parent and child. Under this new alignment, the high-level interchange intervention I(ℬ,[t,t],V1)=ℬV1←V1(ℬ([t,t]))Iℬttsubscript1subscriptℬ←subscript1subscriptsubscript1ℬt I(B,\[ t, t]\,\\V_1\\)=B% _\V_1\← GetVals_\V_1\(B([ t,% t]))I ( B , [ t , t ] , V1 ) = B V start_POSTSUBSCRIPT 1 ← sansserif_GetVals V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( B ( [ t , t ] ) ) end_POSTSUBSCRIPT is aligned with the low-level distributed interchange intervention DII(,[cos(−20∘)−sin(−20∘)sin(−20∘)cos(−20∘)],[1,1],Y1)DIIdelimited-[]superscript20superscript20superscript20superscript2011subscript1 DII(N, [ array[]r (-20 )&- (-2% 0 )\\ (-20 )& - (-20 ) array ],\[1,1]\,\% \Y_1\\)DII ( N , [ start_ARRAY start_ROW start_CELL cos ( - 20∘ ) end_CELL start_CELL - sin ( - 20∘ ) end_CELL end_ROW start_ROW start_CELL sin ( - 20∘ ) end_CELL start_CELL cos ( - 20∘ ) end_CELL end_ROW end_ARRAY ] , [ 1 , 1 ] , Y1 ) and the counterfactual output behavior of ℬBB and NN are equal: 11111111H1=0.6subscript10.6H_1=0.6H1 = 0.6H2=1.28subscript21.28H_2=1.28H2 = 1.28O=0.080.08O=0.08O = 0.08001111H1=−0.34subscript10.34H_1=-0.34H1 = - 0.34H2=0.94subscript20.94H_2=0.94H2 = 0.941.01.01.01.01.01.01.01.00.00.00.00.01.01.01.01.0[cos(−20∘)−sin(−20∘)sin(−20∘)cos(−20∘)]delimited-[]superscript20superscript20superscript20superscript20 [ array[]r (-20 )&- (-20 )\\ (-20 )& - (-20 ) array ][ start_ARRAY start_ROW start_CELL cos ( - 20∘ ) end_CELL start_CELL - sin ( - 20∘ ) end_CELL end_ROW start_ROW start_CELL sin ( - 20∘ ) end_CELL start_CELL cos ( - 20∘ ) end_CELL end_ROW end_ARRAY ][cos(−20∘)−sin(−20∘)sin(−20∘)cos(−20∘)]delimited-[]superscript20superscript20superscript20superscript20 [ array[]r (-20 )&- (-20 )\\ (-20 )& - (-20 ) array ][ start_ARRAY start_ROW start_CELL cos ( - 20∘ ) end_CELL start_CELL - sin ( - 20∘ ) end_CELL end_ROW start_ROW start_CELL sin ( - 20∘ ) end_CELL start_CELL cos ( - 20∘ ) end_CELL end_ROW end_ARRAY ][cos(20∘)−sin(20∘)sin(20∘)cos(20∘)]delimited-[]superscript20superscript20superscript20superscript20 [ array[]r (20 )&- (20 )\\ (20 )& - (20 ) array ][ start_ARRAY start_ROW start_CELL cos ( 20∘ ) end_CELL start_CELL - sin ( 20∘ ) end_CELL end_ROW start_ROW start_CELL sin ( 20∘ ) end_CELL start_CELL cos ( 20∘ ) end_CELL end_ROW end_ARRAY ]H1=0.6subscript10.6H_1=0.6H1 = 0.6H2=1.28subscript21.28H_2=1.28H2 = 1.28O=0.080.08O=0.08O = 0.08t In what follows we will assume that XX are already vector spaces (which is true for neural nets) and the functions RR are rotation operators. In this case, the subspaces jsubscriptY_jYitalic_j can be identified without loss of generality with those spanned by the first |0|subscript0|Y_0|| Y0 | basis vectors for 0subscript0Y_0Y0, the next |1|subscript1|Y_1|| Y1 | basis vectors for 1subscript1Y_1Y1, and so on. (The following methods would be well-defined for non-linear transformations, as long as they were invertible and differentiable, but efficient implementation becomes harder.) 3.5 Distributed Alignment Search The question then arises of how to find good rotations. As we discussed above, previous causal abstraction analyses of neural networks have performed brute-force search through a discrete space of hand-picked alignments. In distributed alignment search (DAS), we find an alignment between one or more high-level variables and disjoint sub-spaces (but not necessarily subsets) of a large neural representation. We define a distributed interchange intervention training objective, use differentiable parameterizations for the space of orthogonal matrices (such as provided by PyTorch), and then optimize the objective with stochastic gradient descent. Crucially, the low-level and high-level models are frozen during learning so we are only changing the alignment. In the following definition we assume that a neural network specifies an output distribution for a given input, which can then be pushed forward to a distribution on output values of the high-level model via an alignment function τ. We may similarly interpret even a deterministic high-level model as defining a (e.g., delta) distribution on output values. We make use of these distributions, after interchange intervention, to define a differentiable loss for the rotation matrix which aligns intermediate variables. Definition 3.4. Distributed Interchange Intervention Training Objective Begin with a low-level neural network ℒLL, with low-level input settings LsubscriptInputs_LInputsitalic_L, a high-level algorithm ℋHH, with high-level output settings HsubscriptOut_HOutitalic_H, and an alignment τ for their input and output variables. Suppose we want to align intermediate high level variables Xj∈ℋsubscriptsubscriptℋX_j _HXitalic_j ∈ Varscaligraphic_H with rotated subspaces jsubscriptY_jYitalic_j of a neural representation ⊂ℒsubscriptℒN _LN ⊂ Varscaligraphic_L with learned rotation matrix θ:→normal-:superscriptnormal-→R^θ:N _θ : N → Y. In general, we can define a training objective using any differentiable loss function Losssansserif_Loss that quantifies the distance between two total high-level settings. ∑,1,…,k∈L(DII(ℒ,θ,j1k,j0k)(),I(ℋ,τ(j)1k,j1k)(τ())) _b,s_1,…,s_k _L% Loss ( DII(L,R^θ,\s% _j\_1^k,\Y_j\^k_0)(b),\\ I(H,\τ(s_j)\^k_1,\X_j\_1% ^k)(τ(b)) )∑b , s start_POSTSUBSCRIPT 1 , … , sitalic_k ∈ Inputsitalic_L end_POSTSUBSCRIPT sansserif_Loss ( DII ( L , Ritalic_θ , sitalic_j 1k , Yitalic_j k0 ) ( b ) , I ( H , τ ( sitalic_j ) k1 , Xitalic_j 1k ) ( τ ( b ) ) ) For our experiments, we compute the cross entropy loss (⋅,⋅)normal-⋅normal-⋅ CE(·,·)sansserif_CE ( ⋅ , ⋅ ) between the high-level output distribution ℙ(H|ℋ(τ()))ℙconditionalsubscriptℋP(out_H|H(τ(b)))blackboard_P ( outitalic_H | H ( τ ( b ) ) ) and the push-forward under τ of the low-level output distribution ℙτ(H|ℒ())superscriptℙconditionalsubscriptℒP^τ(out_H|L(b))blackboard_Pτ ( outitalic_H | L ( b ) ). The overall objective is: ∑,1,…,k∈L(ℙ(H|I(ℋ,τ(j)1k,j1k))(τ()),ℙτ(H|DII(ℒ,θ,j1k,j0k)())) _b,s_1,…,s_k _L% CE (P(out_H| I(H,\τ% (s_j)\^k_1,\X_j\_1^k))(τ(b)),% P^τ(out_H| DII(L,R^θ% ,\s_j\_1^k,\Y_j\^k_0)(b)) )start_ROW start_CELL ∑b , s start_POSTSUBSCRIPT 1 , … , sitalic_k ∈ Inputsitalic_L end_POSTSUBSCRIPT sansserif_CE ( blackboard_P ( outitalic_H | I ( H , τ ( sitalic_j ) k1 , Xitalic_j 1k ) ) ( τ ( b ) ) , blackboard_Pτ ( outitalic_H | DII ( L , Ritalic_θ , sitalic_j 1k , Yitalic_j k0 ) ( b ) ) ) end_CELL end_ROW While we still have discrete hyperparameters (,|0|,…,|k|)subscript0…subscript(N,|Y_0|,…,|Y_k|)( N , | Y0 | , … , | Yitalic_k | )—the target population and the dimensionality of the sub-spaces used for each high-level variable—we may use stochastic gradient descent to determine the rotation that minimizes loss, thus yielding the best distributed alignment between ℒLL and ℋHH. 3.6 Approximate Causal Abstraction Perfect causal abstraction relationships are unlikely to arise for neural networks trained to solve complex empirical tasks. We use a graded notion of accuracy: Definition 3.5. Distributed Interchange Intervention Accuracy Given low-level and high-level causal models ℒLL and ℋHH with alignment (Π,τ)normal-Π( ,τ)( Π , τ ), rotation :→normal-:normal-→R:N : N → Y, and orthogonal decomposition j0ksubscriptsuperscriptsubscript0\Y_j\^k_0 Yitalic_j k0. If we let LsubscriptInputs_LInputsitalic_L be low-level input settings and j1ksuperscriptsubscriptsubscript1\X_j\_1^k Xitalic_j 1k be high-level intermediate variables the interchange intervention accuracy (IIA) is as follows ∑,1,…,k∈L1|L|k+1[τ(DII(ℒ,θ,j1k,j0k)())=I(ℋ,τ(j)1k,j1k)(τ())]subscriptsubscript1…subscriptsubscript1superscriptsubscript1delimited-[]DIIℒsuperscriptsuperscriptsubscriptsubscript1subscriptsuperscriptsubscript0Iℋsubscriptsuperscriptsubscript1superscriptsubscriptsubscript1 _b,s_1,…,s_k _L% 1|Inputs_L|^k+1 [τ ( DII(L% ,R^θ,\s_j\_1^k,\Y_j\^k_0)(% b) )= I(H,\τ(s_j)\^k_1,\% X_j\_1^k)(τ(b)) ]∑b , s start_POSTSUBSCRIPT 1 , … , sitalic_k ∈ Inputsitalic_L end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG | Inputsitalic_L |k + 1 end_ARG [ τ ( DII ( L , Ritalic_θ , sitalic_j 1k , Yitalic_j k0 ) ( b ) ) = I ( H , τ ( sitalic_j ) k1 , Xitalic_j 1k ) ( τ ( b ) ) ] IIA is the proportion of aligned interchange interventions that have equivalent high-level and low-level effects. In our example with NN and AA, IIA is 100% and the high-level model is a perfect abstraction of the low-level model (Def. 3.1). When IIA is α<absentα<α <100%, we rely on the graded notion of α-on-average approximate causal abstraction (Geiger et al., 2023), which coincides with IIA. 3.7 General Experimental Setup We illustrate the value of DAS by analyzing feed-forward networks trained on a hierarchical equality and pretrained Transformer-based language models (Vaswani et al., 2017) fine-tuned on a natural language inference task. Our evaluation paradigm is as follows: 1. Train the neural network NN to solve the task. In all experiments, the neural models achieve perfect accuracy on both training and testing data. 2. Create interchange intervention training datasets using a high-level causal model. Each example consists of a base input, one or more source inputs, high-level causal variables targetted for intervention, and a counterfactual gold label that will be output by the network if the interchange intervention has the hypothesized effect on model behavior. This gold label is a counterfactual output of the high-level model we will align with the network. (See Appendix A.1 for details) 3. Optimize an orthogonal matrix to learn a distributed alignment for each high-level model that maximizes IIA using the training objective in Def. 3.4. We experiment with different hidden dimension sizes for our low-level model and different intervention site sizes (dimensionality of low-level subspaces) and locations (the layer where the intervention happens). (See Appendix A.2 for details) 4. Evaluate a baseline that brute-force searches through a discrete space of alignments and selects the alignment with the highest IIA. We search the space of alignments by aligning each high-level variable with groups of neurons in disjoint sliding windows. (See Appendix A.3 for details) 5. Evaluate the localist alignment “closest” to the learned distributed alignment. The rotation matrix for the localist alignment will be axis-aligned with the standard basis, possibly permuting and reflecting unit axes. (See Appendix A.4 for details) 6. Determine whether each distributed representation aligned with high-level variables can be decomposed into multiple representations that encode the identity of the input values to the variable’s causal mechanism. We do this by learning a second rotation matrix that decomposes learned distributed representation, holding the first rotation matrix fixed. (See Appendix A.5 for details) The codebase used to run these experiments is at222https://github.com/atticusg/InterchangeInterventions/tree/zen. We have replicated the hierarchical equality experiment using the Pyvene library at333https://github.com/stanfordnlp/pyvene/blob/main/tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb. 4 Hierarchical Equality Experiment We now illustrate the power of DAS for analyzing networks designed to solve a hierarchical equality task. We concentrate on analyzing a trained feed-forward network. A basic equality task is to determine whether a pair of objects are the same (x=yx=yx = y). A hierarchical equality task is to determine whether a pair of pairs of objects have identical relations: (w=x)=(y=z)(w=x)=(y=z)( w = x ) = ( y = z ). Specifically, the input to the task is two pairs of objects and the output is Truesansserif_True if both pairs are equal or both pairs are unequal and Falsesansserif_False otherwise. For example, (A,A,B,B)(A,A,B,B)( A , A , B , B ) and (A,B,C,D)(A,B,C,D)( A , B , C , D ) are both assigned Truesansserif_True while (A,B,C,C)(A,B,C,C)( A , B , C , C ) is assigned Falsesansserif_False. 4.1 Low-Level Neural Model We train a three-layer feed-forward network with ReLU activations to perform the hierarchical equality task. Each input object is represented by a randomly initialized vector. Specifically, our model has the following architecture where k is the number of layers. h1=([x1;x2;x3;x4]W1+b1)hj−1=(hjWj+bj)y=(hkWk+bk)formulae-sequencesubscriptℎ1subscript1subscript2subscript3subscript4subscript1subscript1formulae-sequencesubscriptℎ1subscriptℎsubscriptsubscriptsubscriptℎsubscriptsubscripth_1= ReLU([x_1;x_2;x_3;x_4]W_1+b_1)\;\;\;\;\;\;\;\;\;h_% j-1= ReLU(h_jW_j+b_j)\;\;\;\;\;\;\;\;\;y=softmax(h_k% W_k+b_k)h1 = sansserif_ReLU ( [ x1 ; x2 ; x3 ; x4 ] W1 + b1 ) hitalic_j - 1 = sansserif_ReLU ( hitalic_j Witalic_j + bitalic_j ) y = softmax ( hitalic_k Witalic_k + bitalic_k ) The input vectors are in ℝnsuperscriptℝR^nblackboard_Rn, the biases are in ℝ4nsuperscriptℝ4R^4nblackboard_R4 n, and the weights are in ℝ4n×4nsuperscriptℝ44R^4n× 4nblackboard_R4 n × 4 n. We evaluate our model on held-out random vectors unseen during training, as in Geiger et al. 2022a. Both Equality Relations Left Equality Relation Identity of First Argument Identity Subspace of Left Equality & Hidden size Intervention size Layer 1 Layer 2 Layer 3 Layer 1 Layer 2 Layer 3 Layer 1 Layer 2 Layer 3 Layer 1 ||=1616|N|=16| N | = 16 1111 0.88 0.51 0.50 0.85 0.54 0.50 0.51 0.52 0.50 0.51 ||=1616|N|=16| N | = 16 2222 0.97 0.54 0.50 0.85 0.55 0.50 0.50 0.52 0.51 0.50 ||=1616|N|=16| N | = 16 8888 1.00 0.57 0.50 0.90 0.56 0.50 0.52 0.53 0.51 0.51 ||=3232|N|=32| N | = 32 2222 0.93 0.63 0.49 0.92 0.65 0.50 0.52 0.55 0.52 0.50 ||=3232|N|=32| N | = 32 4444 0.97 0.63 0.49 0.94 0.65 0.50 0.51 0.55 0.52 0.51 ||=3232|N|=32| N | = 32 16161616 0.99 0.67 0.53 0.99 0.65 0.50 0.49 0.55 0.52 0.51 Brute-Force Search 0.60 0.56 0.52 0.64 0.64 0.57 0.50 0.51 0.54 - Localist Alignment 0.73 0.56 0.48 0.60 0.50 0.49 0.46 0.47 0.48 - Table 1: Hierarchical equality alignment learning results. The table can be read as follows: Layer 1, Layer 2, and Layer 3 indicate which layer of neurons is targeted, |||N|| N | is the number of neurons in a layer, k is the number of neurons aligned with each intermediate variable (red) where our subspace model occupies k22 k2divide start_ARG k end_ARG start_ARG 2 end_ARG with rounding up to the closest integer, and the values in each cell are interchange intervention accuracies for the learned alignment on training data. We report the best results from three runs with distinct random seeds for training the rotation matrix (the same frozen low-level model is used for each seed). 4.2 High-Level Models We use DAS to evaluate whether trained neural networks have achieved the natural solution to the hierarchical equality task where the left and right equality relations are computed and then used to predict the final label (Figure 2). wwwxxitalic_xyyitalic_yzzitalic_zV1←(w=x)←subscript1V_1←(w=x)V1 ← ( w = x )V2←(y=z)←subscript2V_2←(y=z)V2 ← ( y = z )V3←(V1=V2)←subscript3subscript1subscript2V_3←(V_1=V_2)V3 ← ( V1 = V2 ) Figure 2: A causal model that computes the hierarchical equality task. However, evaluating this high-level model alone is insufficient, as there are obviously many other high-level models of this task. To further contextualize our results, we also consider two alternatives: a high-level model where only the equality relation of the first pair is represented and a high-level model where the lone intermediate variable encodes the identity of the first input object (leaving all computation for the final step). These alternative high-level models also solve the task perfectly. 4.3 Discussion The IIA results achieved by the best alignment for each high-level model can be seen in Table 4.1. The best alignments found are with the ‘Both Equality Relations’ model that is widely assumed in the cognitive science literature. For all causal models, DAS learns a more faithful alignment (higher IIA) than a brute-force search through localist alignments. This result is most pronounced for ‘Both Equality Relations’, where DAS learns perfect or near-perfect alignments under a number of settings, whereas the best brute-force alignment achieves only 0.60 and the best localist alignment achieves only 0.73. Finally, the distributed representation of left equality could not be decomposed into a representation of the first argument identity. We see this in the very low performance of the ‘Identity Subspace of Left Equality’ results. This indicates that models are truly learning to encode an abstract equality relation, rather than merely storing the identities of the inputs. 4.4 Analyzing a Randomly Initialized Network Figure 3: DAS on a random network with a 16 dimension input. An oversized hidden dimension allows DAS to manipulate the model behavior by searching through a large space of random mechanisms. Both Equality Relations Hidden size & Intervention size Layer 1 ||=1616|N|=16| N | = 16 k=88k=8k = 8 0.50 ||=6464|N|=64| N | = 64 k=3232k=32k = 32 0.50 ||=256256|N|=256| N | = 256 k=128128k=128k = 128 0.51 ||=10281028|N|=1028| N | = 1028 k=512512k=512k = 512 0.55 ||=40964096|N|=4096| N | = 4096 k=20482048k=2048k = 2048 0.64 To calibrate intuitions about our method, we evaluate the ability of DAS to optimize for interchange intervention accuracy on a frozen randomly initialized networks that achieves chance accuracy (50%) on the hierarchical equality task. This investigates the degree to which random causal structures can be used to systematically manipulate the counterfactual behavior of the network. We evaluate networks with different hidden representation sizes while holding the four input vectors fixed at 4444 dimensions, under the hypothesis that more hidden neurons create more random structure that DAS can search through. These results are summarized in Table 4.4. Observe that, in small networks, there is no ability to increase interchange intervention accuracy. However, as we increase the size of the hidden representation to be orders of magnitude larger than the input dimension of 16, the interchange intervention accuracy increases. This confirms our hypothesis and serves as a check that demonstrates DAS cannot construct entirely new behaviors from random structure. [Two MoNLI examples.] Sentence Pairs Label premise: A man is talking to someone in a taxi. entails hypothesis: A man is talking to someone in a car. premise: The people are not playing sitars. neutral hypothesis: The people are not playing instruments. [A simple program that solves MoNLI.] MoNLI(,)MoNLI MoNLI(p,h)MoNLI ( p , h ) codebox ← get−lexrel(,)← ← get-lexrel(p,h)lexrel ← g e t - l e x r e l ( p , h ) ← contains−not(,)← ← contains-not(p,h)neg ← c o n t a i n s - n o t ( p , h ) neg: reverse() reverse(lexrel)r e v e r s e ( lexrel ) Figure 4: Monotonicity NLI task examples and high-level model. Figure 3: DAS on a random network with a 16 dimension input. An oversized hidden dimension allows DAS to manipulate the model behavior by searching through a large space of random mechanisms.