← Back to papers

Paper deep dive

Inducing Causal Structure for Interpretable Neural Networks

Atticus Geiger, Zhengxuan Wu, Hanson Lu, Josh Rozner, Elisa Kreiss, Thomas Icard, Noah D. Goodman, Christopher Potts

Year: 2022Venue: ICML 2022Area: Mechanistic Interp.Type: EmpiricalEmbeddings: 67

Models: BERT, MLP, ResNet

Intelligence

Status: succeeded | Model: google/gemini-3.1-flash-lite-preview | Prompt: intel-v1 | Confidence: 97%

Last extracted: 3/12/2026, 8:09:57 PM

Summary

The paper introduces Interchange Intervention Training (IIT), a method to train neural networks to realize the causal structure of a high-level symbolic causal model. By aligning neural representations with causal variables and performing interchange interventions, the method forces the neural network to match the counterfactual behavior of the causal model, thereby improving interpretability and systematic generalization.

Entities (6)

Interchange Intervention Training · method · 100%MNIST-PVR · datasettask · 99%MQNLI · datasettask · 99%ReaSCAN · datasettask · 99%Causal Abstraction · concept · 95%ResNet · neural-network-architecture · 95%

Relation Signals (4)

Interchange Intervention Training evaluatedon MNIST-PVR

confidence 100% · We evaluate IIT on a structural vision task (MNIST-PVR)

Interchange Intervention Training evaluatedon ReaSCAN

confidence 100% · a navigational language task (ReaSCAN)

Interchange Intervention Training evaluatedon MQNLI

confidence 100% · a natural language inference task (MQNLI)

Interchange Intervention Training improves Interpretability

confidence 95% · IIT achieves the best results and produces neural models that are more interpretable

Cypher Suggestions (2)

Find all tasks where IIT was evaluated · confidence 90% · unvalidated

MATCH (m:Method {name: 'Interchange Intervention Training'})-[:EVALUATED_ON]->(t:Task) RETURN t.name

Identify the relationship between IIT and interpretability · confidence 90% · unvalidated

MATCH (m:Method {name: 'Interchange Intervention Training'})-[r:IMPROVES]->(p:Property {name: 'Interpretability'}) RETURN r

Abstract

Abstract:In many areas, we have well-founded insights about causal structure that would be useful to bring into our trained models while still allowing them to learn in a data-driven fashion. To achieve this, we present the new method of interchange intervention training (IIT). In IIT, we (1) align variables in a causal model (e.g., a deterministic program or Bayesian network) with representations in a neural model and (2) train the neural model to match the counterfactual behavior of the causal model on a base input when aligned representations in both models are set to be the value they would be for a source input. IIT is fully differentiable, flexibly combines with other objectives, and guarantees that the target causal model is a causal abstraction of the neural model when its loss is zero. We evaluate IIT on a structural vision task (MNIST-PVR), a navigational language task (ReaSCAN), and a natural language inference task (MQNLI). We compare IIT against multi-task training objectives and data augmentation. In all our experiments, IIT achieves the best results and produces neural models that are more interpretable in the sense that they more successfully realize the target causal model.

Tags

ai-safety (imported, 100%)empirical (suggested, 88%)mechanistic-interp (suggested, 92%)

Links

Your browser cannot display the PDF inline. Open PDF directly →

Full Text

66,964 characters extracted from source content.

Expand or collapse full text

Inducing Causal Structure for Interpretable Neural Networks Atticus Geiger * 1 Zhengxuan Wu * 1 Hanson Lu * 1 Josh Rozner 1 Elisa Kreiss 1 Thomas Icard 1 Noah D. Goodman 1 Christopher Potts 1 Abstract In many areas, we have well-founded insights about causal structure that would be useful to bring into our trained models while still allowing them to learn in a data-driven fashion. To achieve this, we present the new method ofinterchange intervention training(IIT). In IIT, we (1) align variables in a causal model (e.g., a deterministic program or Bayesian network) with representa- tions in a neural model and (2) train the neural model to match the counterfactual behavior of the causal model on a base input when aligned representations in both models are set to be the value they would be for a source input. IIT is fully differentiable, flexibly combines with other objectives, and guarantees that the target causal model is acausal abstractionof the neural model when its loss is zero. We evaluate IIT on a struc- tural vision task (MNIST-PVR), a navigational language task (ReaSCAN), and a natural language inference task (MQNLI). We compare IIT against multi-task training objectives and data augmen- tation. In all our experiments, IIT achieves the best results and produces neural models that are more interpretable in the sense that they more successfully realize the target causal model. 1. Introduction In many domains, we have well-founded insights about causal structure that we can express in symbolic terms, ranging from commonsense intuitions about how the world works to advanced scientific knowledge. These insights have the potential to make up for gaps in available data, or more generally to provide useful inductive biases. Can we bring these insights into our models while still allowing them to learn in a data-driven fashion? * Equal contribution 1 Stanford University, Stanford, California. Correspondence to: Atticus Geiger<atticusg@stanford.edu>, Zhengxuan Wu<wuzhengx@stanford.edu>. Proceedings of the39 th International Conference on Machine Learning, Baltimore, Maryland, USA, PMLR 162, 2022. Copy- right 2022 by the author(s). In this paper, we presentinterchange intervention training (IIT), a new method that trains a neural network to realize the abstract structure of a causal model. In IIT, we (1) align the variables in a causal modelCwith the representations in a neural modelNand (2) trainNto have the counter- factual behavior ofCby performing alignedinterchange interventions(swapping of internal states created for differ- ent inputs) onNusingC’s counterfactual output as the gold label for the counterfactual prediction ofN. IIT objectives are differentiable and guarantee that, when the loss is zero, the target causal model is acausal abstractionof the neural network in the sense of Beckers & Halpern (2019). IIT is an extension of the causal abstraction analysis of Geiger et al. (2021), which can be placed under the broader rubric ofstructural evaluationsof neural models, which includes probing and many kinds of feature attribution. Our central point of differentiation from this prior work is that we go beyond passive study of static models, by pushing them to learn specific causal structures as part of optimization. This allows for a productive interplay between model anal- ysis and model improvement: we not only assess whether models have systematic, interpretable internal structure but also push them to acquire such structure. We evaluate IIT in three contexts: (1) ResNet trained on a vision task where one part of an image “points” to another (MNIST-PVR), (2) a CNN-LSTM model trained to produce action sequences in a grid world given a natural language command (ReaSCAN), and (3) a pretrained BERT model fine-tuned to label the semantic relation between two sen- tences (MQNLI). For each context, we define a high-level causal model that capture aspects of the task. We then align high-level causal variables to low-level neural representa- tions to define IIT training objectives. For the three case studies, we report two kinds of evaluation: traditional behavioral evaluations using systematic general- ization tasks that assess whether a model has learned a truly general solution, and structural evaluations that directly as- sess the interpretability of our models by evaluating whether they realize the target causal model. We compare IIT against multi-task training objectives and data augmentation meth- ods defined to make use of our causal models, finding that IIT leads to models that both perform better on systematic arXiv:2112.00826v2 [cs.LG] 20 Jul 2022 Inducing Causal Structure for Interpretable Neural Networks generalization benchmarks and are more interpretable. 1 2. Related Work ProbesProbes are supervised or unsupervised models that can be used to gain an understanding of what is encoded in the internal representations of neural networks (Hupkes et al., 2018; Peters et al., 2018; Tenney et al., 2019; Clark et al., 2019). Probes have yielded important insights about what models learn to encode. However, probes are funda- mentally limited in a way that is central to our present goals: there is no guarantee that probed information plays a causal role in the network’s behavior (Ravichander et al., 2020; Elazar et al., 2020; Geiger et al., 2021; 2020). Feature AttributionIn contrast to probes, gradient-based feature attribution methods (Zeiler & Fergus, 2014; Sprin- genberg et al., 2014; Shrikumar et al., 2016; Binder et al., 2016) generally do measure causal properties (Chattopad- hyay et al., 2019). For example, Geiger et al. (2021) note that the integrated gradients method of Sundararajan et al. (2017) computes theindividual causal effectof neurons (Im- bens & Rubin, 2015). In comparison with our proposal, the main limitation of these methods is that (by definition) they passively study trained networks rather than allowing for active improvements of them (though see Erion et al. 2021 for a path from attribution to improved optimization). Intervention-Based AnalysesIn intervention-based analy- sis, one actively changes the values of model representations in systematic ways and studies the effects. Such interven- tions can be applied to input representations in order to measure the effect on the output representation (Feder et al., 2021; Pryzant et al., 2021), or on network internal represen- tations to characterize how these representations mediate the causal relationships between inputs and outputs (Giulianelli et al., 2018; Bau et al., 2019; Vig et al., 2020; Soulos et al., 2020; Ravfogel et al., 2020; Elazar et al., 2020; Besserve et al., 2020; Geiger et al., 2020; Csord ́ as et al., 2021; Geiger et al., 2021; Meng et al., 2022). In the context of neural network analysis, this provides a powerful tool-kit for un- derstanding a model’s causal structure, since an enormous number of diverse and finely controlled intervention ex- periments can be performed. We build on these methods, extending them to the optimization process. Multi-Task TrainingMulti-task training is the practice of jointly training a model against a set of learning tasks to im- prove data efficiency and increase model robustness (Ruder, 2017; Zhang & Yang, 2017; Crawshaw, 2020). This can be thought of in terms of supervised probing. In standard supervised probing, one trains the probe using internal rep- resentations from the target model while keeping the target 1 We release our code athttps://github.com/frankaging/ Interchange-Intervention-Training. model frozen. In multi-task training, we allow the target model’s parameters to be changed by the probing process. This provides a natural point of comparison with our pro- posal for IIT, where we use our target symbolic causal model to define multi-task training objectives. Data AugmentationData augmentation is the practice of enhancing training sets by modifying existing examples to generate new ones (Perez & Wang, 2017; Shorten & Khosh- goftaar, 2019; Kaushik et al., 2019; Liu et al., 2021). For us, data augmentation is another natural comparison point because we can use a target symbolic causal model to gen- erate additional data. Crucially, IIT involves interchanging internal network representations, while data augmentation methods only involve the creation of inputs. 3. Interchange Intervention Training Our goal is to train a neural network to have an internal causal structure that realizes a high-level causal model. To concretize this goal, we draw on two strands of work on causality: (1) formal interventionist theories of causality (Spirtes et al., 2001; Pearl, 2009), in which causal processes are associated with the effect of interventions, and (2) theo- ries of abstraction (Beckers & Halpern, 2019; Beckers et al., 2020; Chalupka et al., 2016; Rubenstein et al., 2017), where relationships between two causal processes are determined by the presence of systematic correspondences between the effects of interventions. The key insight is that having a par- ticular causal structure is a matter of satisfying a number of counterfactual statements about the effect of interventions (Hitchcock, 2001). The present section defines this process formally, and Figure 1 illustrates all the concepts with a self-contained example. Structural Causal Models We introduce a minimal nota- tion for structural causal models here. We define a structural causal modelMto consist of variablesV, and, for each variableV∈ V, a set of valuesVal(V), a set of parents PA V , and a structural equationF V that sets the value ofV based on the setting of its parents. We denote the set of variables with no parents asV In and those with no children V Out . A structural causal modelM= (V,PA,Val,F)can represent both symbolic computations and neural networks. Given a setting of aninput∈Val(V In )and variables V⊆ V, we defineGETVALS(M,input,V)∈Val(V) to be the setting ofVdetermined by the settinginputand modelM. For example,Vcould correspond to a layer in a neural network, andGETVALS(M,input,V)then de- notes the particular values thatVtakes on when the model Mprocessesinput. For a set of variablesVand a setting for those variables v∈Val(V), we defineM V←v to be the causal model identical toM, except that the structural equations forV Inducing Causal Structure for Interpretable Neural Networks are set to constant valuesv. Because we overwrite neurons withvin-place, gradients can back-propagate throughv. This corresponds closely to thedooperator of Pearl (2009), which characterizes interventions on models in the service of exploring hypothetical or counterfactual states. Interchange InterventionsWith the above definitions in place, we can straightforwardly characterize theinterchange interventionsof Geiger et al. (2020), in which a modelM is used to process two different inputs,sourceandbase, and then a particular internal state obtained by processing sourceis used in place of the corresponding internal state obtained bybase. For a given set of variablesV, M V←GETVALS(M,source,V) is a version ofMwith the values ofVset to those obtained by processingsource. In addition, GETVALS(M,base,V Out ) is the setting of the outputsV Out obtained by processing basewith modelM. When we put these two steps together, we obtain the interchange intervention: INTINV(M,base,source,V) def = GETVALS(M V←GETVALS(M,source,V) ,base,V Out )(1) In short, the interchange intervention provides the output of the modelMfor the inputbase, except the variablesVare set to the values they would have ifsourcewere the input. Causal Abstraction RelationshipsSuppose we have a high-level modelM H and a low-level modelM L with identical input spaces and a predetermined mapping of out- put values from the low to high level,κ(for example, if the low level model produces a probability distribution over output classes, thenκcould be thearg maxfunction, which selects the highest probability class). Further suppose we have an alignmentΠmapping intermediate variables inV H to non-overlapping subsets of variables inV L . Consider some intermediate variableV H and defineM ∗ H to beM H with every variable marginalized other thanV In ,V Out , and V H . We can use the definition of interchange interventions to define what it means forM L andM H ∗ to be in a causal abstraction relationship, namely, for allb,s∈V In : INTINV(M ∗ H ,b,s,V H ) = κ ( INTINV(M L ,b,s,Π(V H )) ) (2) This is in fact aconstructiveabstraction relationship in the sense of Beckers & Halpern (2019), in which aligned interventions on the low-level model and high-level model have the same effect. This is especially suited for situations in which we seek to relate small symbolic models with large neural models with high-dimensional representations. Abstraction and InterpretabilityCausal abstraction anal- ysis is not a story about the reasoning a neural networkmight use to achieve its behavior, but instead is an intervention- based method that determines how itdoes, in fact, achieve its behavior. We can interpret the semantic content of neu- ral representations using the high-level variables they are aligned with, and understand how those neural representa- tions are composed using the high-level parenthood relation. Simply put, when a high-level causal model is an abstraction of a neural network, it is a faithful interpretation (Lipton, 2018; Jacovi & Goldberg, 2020) of the network. Interchange Intervention AccuracyTo quantify partial success when it comes to causal abstraction relationships, we measure the percentage of aligned interchange interven- tions that produce the same output, reporting this as the interchange intervention accuracy(INTINVACC): INTINVACC(M H ,M L ,V H ,Π) def = 1 |Val(V In )| 2 ∑ b,s∈Val(V In ) I [ INTINV(M ∗ H ,b,s,V H ) = κ ( INTINV(M L ,b,s,Π(V H )) ) ] (3) Where every pair of inputsbandsis considered, and INTINVACCis1, the two models are in the causal abstrac- tion relationship. However, we often only approximate this by evaluating a set of randomly sampled pairs of inputs, due to the enormous space of input pairs. INTINVACCprovides a natural metric for quantifying the interpretability of a neural network in the following sense: whenINTINVACCis1, the causal model is an explanation of how the network behaves, providing a clear window into the network itself. In practice, we rarely observe perfect INTINVACCin complex networks, but we can still say that the higher the value ofINTINVACC, the more we have li- cense to reason about the high-level causal model instead of reasoning directly about the low-level network. The causal model provides an interpretable proxy for the network itself. IIT Loss FunctionsThe definition of IIT for high-level models with one intermediate variable falls out directly from the causal abstraction definition: ∑ b,s∈V In LOSS ( INTINV(C,b,s,V), INTINV(N θ ,b,s,Π(V)) ) (4) whereCis the high-level causal model,Vis a high-level variable,N θ is the low-level neural network with learned parametersθ,Π(V)is a set of low-level variables (neurons) that are aligned withV, andLOSSis some loss function. Observe that we do not apply the output mapκ, because the loss function takes in the logits directly. The crucial feature of an IIT update is that the interchange intervention intertwines two computation graphs, one gener- Inducing Causal Structure for Interpretable Neural Networks X 1 X 2 H 1 =W 1 [x 1 , x 2 ]H 2 =W 2 [x 1 , x 2 ] Y=w[h 1 ;h 2 ] +b B 1 B 2 V 1 =b 1 V 2 =b 2 O=b 1 ∧b 2 (a) A linear network with unspecified weights (left) and a symbolic causal model that computes boolean conjunction (right). An align- ment between the two is denoted by dashed lines. The causal model is an abstraction of the network when, for bothV 1 andV 2 , aligned interchange interventions on network and causal model result in the same output on all 16 ordered pairs of inputs. (The aligned intervention pair(b,s)in general differs from(s,b).) 00 00 -1 FALSEFALSE FALSE 10 0.450.05 -0.5 TRUEFALSE FALSE 01 0.050.5 -0.45 FALSETRUE FALSE 11 0.50.55 0.05 TRUETRUE TRUE (b) We define a network with initial parametersW 1 = [0.45,0.05], W 2 = [0.05,0.5], output biasb=−1, and output weightsw= [1,1]. Input values are0for False and1for True. With the initial weights, the network has perfect behavioral accuracy, predicting true (red) iff both its inputs are1, otherwise it predicts false (blue). Although correct when run on the four inputs (T, T), (T, F), (F, T), (F, F), theinterchange intervention accuracyis81.25%: between the two high-level variablesV 1 andV 2 , there are six ordered pairs of inputs where performing aligned interchange interventions results in the causal model and neural network producing different outputs (see figure 1c for one such pair). FT FT F TF T T 01 0.050.5 -0.45 FALSETRUE FALSE 10 0.450.5 -0.05 TRUEFALSE FALSE (c) An illustration of an interchange intervention training update, where an intervened network is trained to predict the intervened output of the causal model. It can be seen that the intervention puts the network in a state that could not be achieved with any input representation. Require: High-level and low-level modelsM H andM L with variablesV H andV L , an alignmentΠthat maps a V H ∈V H to aV L ⊆V L , training datasetD 1:M H .eval() 2:M L .train() 3:whilenot convergeddo 4:for(b,s)inenumerate(D×D)do// base and source 5:V H ∼V H // sample a high-level variable 6:V L =Π(V H )// aligned low-level variables 7:withnograd: 8:a H = GETVALS(M H ,s,V H ) 9:o H = GETVALS(M H V H ←a H ,b,V Out )// label 10:a L = GETVALS(M L ,s,V L ) 11:o L = GETVALS(M L V L ←a L ,b,V Out )// pred 12:L IIT = LOSS(o H ,o L ) 13:L=L IIT +L Others // combine with other losses 14:L.backward() 15:Update model parameters with gradients (d) Pseudocode for interchange intervention training. 00 00 -0.95 TRUEFALSE FALSE 10 0.50.05 -0.39 TRUEFALSE FALSE 01 0.050.55 -0.33 FALSETRUE FALSE 11 0.550.6 0.23 TRUETRUE TRUE 10 0.50.55 0.13 TRUEFALSE TRUE (e) The network defined in figure 1b after the IIT training update from figure 1c has been applied, resulting in a network with 100% interchange intervention accuracy (though still nonzero loss), while maintaining the same behavior. The new network has parameters W 1 = [0.5012,0.05],W 2 = [0.05,0.5512], biasb=−0.9488, and output weightsw= [1.0231,1.0256]. Figure 1.Interchange intervention training example. NetworkN ∧ performs boolean conjunction with perfect accuracy, or, equivalently, it agrees withC ∧ on the four possible inputs (figure 1b). However,C ∧ is not a causal abstraction ofN ∧ under this alignment, because there are aligned interchange interventions that result inN ∧ andC ∧ producing different outputs, meaning that the internal dynamics of the network do not realize the structure of the causal model. To quantify this, we note that the interchange intervention accuracy (Eqn. 3) is 81.25%. After a single interchange intervention training update (figure 1c, figure 1d), this is fixed: all aligned interchange interventions result in the same output (the interchange intervention accuracy is now1), soC ∧ has become a causal abstraction ofN ∧ (figure 1e). Inducing Causal Structure for Interpretable Neural Networks ated by the forward pass for the base input and one by the forward pass for the source input. This means that when backpropagation is performed with the IIT loss objective, updates are applied as they are in regular training, starting from the output representation and proceeding towards the input representations. However, when the intervention site is reached, this process bifurcates, and weights receive two updates, once fromN θ processing the inputbase, and once fromN θ processingsource. In our toy example (figure 1c), the network is too small to observe this double update, but the networks in our three case studies are not. (See figure 2, which exemplifies such a process.) An important property of IIT is that, if Eqn. 4 is zero, thenC andN θ stand in the causal abstraction relation Eqn. 2. See Appendix A for a brief proof of this result. (The reverse does not hold;Ccan be a causal abstraction ofN θ without the loss being zero. Figure 1 is an example. This is a desirable property of the method, since we do not expect our loss functions to be zero in general.) ExampleFigure 1 provides an example of interchange in- tervention training, in which a causal modelC ∧ of boolean conjunction is aligned with a one-layer linear networkN θ ∧ , whereθ=W 1 ,W 2 ,b,w, as in figure 1b. At the start,N θ ∧ is perfect in terms of its input–output behav- ior but does not conform to the counterfactual behavior of C ∧ . In other words, the regular behavioral learning objective is met, but the interchange intervention training objective is not; interchange intervention accuracy (Eqn. 3) is81.25. One interchange intervention training update (figure 1c) re- sults in a network that satisfies both objectives (figure 1e): N θ ∧ now stands in the causal abstraction relation toC ∧ (in- terchange intervention accuracy is now1). 4. MNIST Pointer-Value Retrieval Our first benchmark is MNIST Pointer-Value Retrieval (MNIST-PVR; Zhang et al. 2021), a visual reasoning task constructed using the MNIST dataset (LeCun et al., 2010). An inputi= (i TL ,i TR ,i BL ,i BR )consists of four MNIST images (handwritten digits) arranged in a grid. The top left imagei TL acts as a pointer that picks out one of the three other images. Symbolic Causal StructureOur target causal model will abstract away from the details of how to identify the hand- written digit in an image, focusing just on the reasoning about pointers. Formally, we define a causal modelC PVR = (V,PA,Val,F)that computes the label for each of the four MNIST images using an oracleO MNIST with a look-up table to select the correct label based on the pointer. The vari- ables areV=I TL ,I TR ,I BL ,I BR ,Y TL ,Y TR ,Y BL ,Y BR ,O and the values assigned byValare the MNIST training images for the four input variablesI TL ,I TR ,I BL ,I BR , and the set of numbers 0–9 for all other variables. The parents are defined such thatPA I w =∅andPA Y w =I w for all w∈ TR,TL,BR,BL, andPA O =Y TL ,Y TR ,Y BL ,Y BR . The structural equations are F Y TL (i TL ) =O MNIST (i TL ) F Y TR (i TR ) =O MNIST (i TR ) F Y BL (i BL ) =O MNIST (i BL ) F Y BR (i BR ) =O MNIST (i BR ) F O (y TL ,y TR ,y BL ,y BR ) =      y TR y TL ∈0,1,2,3 y BL y TL ∈4,5,6 y BR y TL ∈7,8,9 Systematic Generalization The train/test split designed by Zhang et al. (2021) creates a distributional shift between the training and testing data by removing training exam- ples where eitherO MNIST (i TR )∈1,2,3,O MNIST (i BL )∈ 4,5,6, orO MNIST (i BR )∈ 0,7,8,9. This evaluates where models can systematically generalize, learning the general structure of the problem rather than memorizing many special cases. Neural NetworkWe trained ResNet18 from PyTorch vi- sion. This is the deep residual network (He et al., 2016) baseline used by Zhang et al. (2021) on the MNIST-PVR dataset, and we adopt their hyperparameters. We call this modelN θ PVR , whereθabbreviates the parameters. AlignmentsIn our experiments, we align the neural repre- sentations ofN θ PVR with the symbolic variables ofC PVR by partitioning the layer resulting from the first application of max-pooling into quadrantsQ TL ,Q TR ,Q BL ,Q BR which are aligned with the variablesY TL ,Y TR ,Y BL ,Y BR . In initial ex- perimentation, we found that the layers must be partitioned such that each quadrant is directly above its corresponding input. This is likely due to the locality of convolution op- erators. We also found that aligning layers closer to the classifier head was ineffective. Interchange Intervention TrainingFor each intermediate variableY w ∈ Y TL ,Y TR ,Y BL ,Y BR , we introduce an IIT objective that optimizes forN θ PVR implementingC w PVR the submodel ofC PVR where the three intermediate variables that aren’tY w are marginalized out: ∑ b,s∈MNIST-PVR CE ( INTINV(C w PVR ,b,s,Y w ), INTINV(N θ PVR ,b,s,Q w )) ) (5) whereCEis the cross-entropy loss andMNIST-PVRis the dataset. We visualize an IIT update toN θ PVR in figure 2. Typed Interchange Intervention TrainingWe make fur- ther use of the causal model by observing that the intermedi- ate variablesY TL ,Y TR ,Y BL ,Y BR can be treated as the same type. They all share a value space, as do the neural represen- tationsQ TL ,Q TR ,Q BL ,Q BR . This means we can perform interchange interventions between different variables and Inducing Causal Structure for Interpretable Neural Networks 99 9 7 7 7 0 7 2 2 LOGITSLOGITS Figure 2.An illustration of an IIT update where a neural network (right) is trained to realize a causal model (left) that solves the PVR- MNIST task. Solid lines are feed-forward connections, dashed lines are interchange interventions, red lines are the flow of backpropagation. Observe that when backpropagation reaches the interchange intervention, it flows into both the source input’s computation graph and the base input’s graph, updating the weights below the interchange intervention twice. extend our training objective to these interventions as well: T-INTINV(M,b,s,V,V ′ ) def = GETVALS(M V ′ ←GETVALS(M,s,V) ,b,V Out )(6) ∑ w,w ′ ∈TL, TR, BL, BR b,s∈PVR-MNIST CE ( T-INTINV(C PVR ,b,s,Y w ,Y w ′ ), T-INTINV(N θ PVR ,b,s,Q w ,Q w ′ ) ) (7) Multi-Task Objectives To compare against multi-task ob- jectives, we train models to predict the value of intermediate variables from the aligned neural representations, backprop- agating into the weights of the target model. Specifically, we train four linear classifiersP φ w on the loss ∑ input∈PVR-MNIST w∈TL,TR,BL,BR CE(P φ w (GETVALS(N θ PVR ,input,Q w )), GETVALS(C PVR ,input,Y w ))(8) where the trained parameters areθ, the parameters of ResNet andφ w , the parameters of the linear classifiers. Data AugmentationWe perform data augmentation by ran- domly sampling two examples and swapping a random quad- rant of the base input with a random quadrant of the source input to produce a new example that is then labeled with C PVR . This procedure is guided by the same causal structure used by our other models, but it is by definition restricted to input manipulations. ResultsOur results are in Table 1. The behavioral accuracy is the standard metric, while the interchange intervention accuracy captures whether the symbolic causal model is an abstraction of the neural network. Training Behavioral Accuracy IIT Accuracy TrainTestTrainTest STANDARD99.100.00 88.80 20.60 IIT99.60 93.93 99.00 94.85 MULTI99.640.00 89.35 20.50 IIT + MULTI99.6096.0199.1096.64 AUGMENT99.40 90.90 98.90 92.00 NOTYPING99.410.09 99.47 16.88 Table 1.Results forN θ PVR (ResNet18) trained on the PVR-MNIST dataset. Behavioral accuracy is the percentage of inputs thatN θ PVR agrees withC PVR on. Interchange intervention accuracy quantifies the extent to which the interpretable causal model is a proxy for the network (section 3). IIT delivers the best results, especially when combined with multi-task objectives. Neither the standard nor multi-task models learned the be- havioral objective in a way that generalizes, with total failure on the testing data (0%). On the other hand, IIT solves the generalization task (93.93%). However, multi-task training does synergize with IIT, producing the model with the best performance (96.01%). Data augmentation lessens the dis- tributional shift; however, the distributions remain skewed and model performance does not exceed 90.90%. Our interchange intervention test set accuracies tell a similar story. Neither the standard nor multi-task models learned the IIT objective in a way that generalizes, with total failure on the testing data (20.60% and 20.50%, respectively). On the other hand, IIT learns a general solution to the inter- change intervention objectives, achieving accuracy on the Inducing Causal Structure for Interpretable Neural Networks I Command I World T Size T Color T Shape P a P t P x ∆ P y ∆ a 1 a 2 ... a n COMMANDGRID WORLD Bi-LSTM CNN e 1 ... e Shape ... e n e c a 0 h 0 a 1 e c a 1 h 1 a 2 ... ... ... e c a n−1 h n a n (a) The causal model that solves ReaSCAN (left) and the neu- ral CNN-LSTM model trained on ReaSCAN (right). Dashed lines align variables in the causal model with neural repre- sentations. Training Behavioral Exact Match % Novel colorNovel sizeNovel directionNovel length STANDARD55.98 (6.31)41.67 (6.24)0.00 (0.00)5.72 (3.44) MULTI76.91 (5.02)39.46 (7.68)0.00 (0.00)9.05 (5.28) IIT74.12 (6.00)65.65 (4.26)0.26 (0.14)10.20 (6.08) IIT+ MULTI80.37(0.88)74.84(0.04)14.72(3.54)25.82(0.37) Interchange Intervention Exact Match % STANDARD44.26 (2.76)35.57 (2.64)0.00 (0.00)0.30 (0.21) MULTI68.42 (0.20)45.83 (2.45)0.00 (0.00)0.19 (0.05) IIT70.63 (9.33)65.18 (2.84)5.24 (3.07)4.75 (2.06) IIT+ MULTI70.73(6.86)75.34(0.91)11.79(2.57)8.49(1.53) (b) Results for the CNN-LSTM on the ReaSCAN systematic generalization tasks. Only models that use IIT are able to consistently get traction on these tasks, and once again we see that IIT combines effectively with multi-task objectives, in both standard behavioral evaluations and evaluations that seek to quantify the extent to which the high-level causal model serves as an interpretable proxy for the network. Figure 3. test data (94.85%). Again, multi-task training synergizes with IIT, producing the model with the best performance on the IIT objective (96.64%). The causal modelC PVR is a near perfect abstraction of our best model, meaning the seemingly opaque and complex network dynamics have an interpretable and faithful abstract structure given byC PVR . We can see that Resnet has an inherently modular architec- ture from the fact that standard training produces a model with quite high (88.80%) interchange intervention accuracy on the training data. However, without any structural train- ing objectives, ResNet does not generalize this modular solution to test data (20.60%). We believe this modularity is the result of convolutions being operations that preserve lo- cality of information across layers. When the distributional shift between training and testing is lessened by data aug- mentation, the ResNet model produces a model with near perfect (98.90%) interchange intervention accuracy on the training data, which generalizes better to test data (92.00%) (but is still out performed by IIT). Without our typed IIT objectives, behavioral and interchange intervention accuracy plummets on the test data. Typing our variables is crucial for generalization. 5. Navigation and Language (ReaSCAN) Our second benchmark is ReaSCAN (Wu et al., 2021), a synthetic command-based navigation task that builds off the SCAN (Lake & Baroni, 2018) and gSCAN (Ruis et al., 2020) benchmarks. The goal is to predict an action sequence for the agent to reach the referred target and operate on it given a command and a grid world. For simplicity, we experiment with the simplest command structure included in ReaSCAN, which excludes any relative clauses. Symbolic Causal StructureOur causal modelC ReaSCAN = (V,PA,Val,F)(see figure 3a bottom) is an oracle solver for ReaSCAN that (1) parses the language command, identify- ingsize,color, andshapeproperties of the target shape, (2) computes the location of the target object from these prop- erties and the grid world, (3) calculates the horizontal and vertical distances from the agent to the target, and, finally, (4) emits an action sequence that brings the agent to the target (We condense the action sequence to a single output variable). Formally, we define variables and values V=I Com ,I World ,T Size ,T Color ,T Shape ,P t ,P a ,P x ∆ ,P y ∆ ,O Val(T Shape ) =circle,square,cylinder Val(T Color ) =red,green,blue,yellow Val(T Size ) =small,big Val(P t ) =Val(P a ) =Val(P x/y ∆ ) =−5,...,5 with the valuesVal(I Com ),Val(I world ), andVal(O)being equal to the command space, world space, and action se- quence space. The parents are defined according to the topology of directed arrows pointing from parents to chil- dren in figure 3a. The structural equations for object properties,F T Size (i Com ), F T Color (i Com ),andF T Shape (i Com ),are determined by parsing and interpreting the input language com- mand.The structural equations for position look-ups F P t (t Size ,t Color ,t Shape ,i World )andF P a (i World )determine the target object and agent location from the target object proper- ties and the input world. The position deltasF P x ∆ (p t ,p a )and F P y ∆ (p t ,p a ) are determined to be the horizontal and vertical distance between the target object and agent, respectively. Finally, the equation for the outputF O (P x ∆ ,P y ∆ ,i command ) is the action sequence that takes the agent to the target object in the correct manner of movement, as determined by the Inducing Causal Structure for Interpretable Neural Networks vertical and horizontal distances between the two and the adverb in the command. Systematic Generalization ReaSCAN includes testing ex- amples that are systematically different from training exam- ples. Performance on those test sets provides insights into a model’s capabilities to generalize to unseen composites of seen concepts in a zero-shot fashion. In this experiment, we generate four unseen testing splits investigating two dis- tinct generalization patterns by adapting ReaSCAN’s data generation framework. We investigate two splits focusing on novel attribute compositions in input commands (Novel colorandNovel size), and two splits focusing on novel compositions in output action sequences (Novel direction andNovel length). See Appendix B for details on splits. Neural NetworkWe use the original baseline model for ReaSCAN (Wu et al., 2021) as our neural modelN θ CNN-LSTM . N θ CNN-LSTM is a multimodal sequence-to-sequence model which takes in a command and a grid world, and predicts an action sequence as shown in figure 3a. We include details about the model and experimental set-up in Appendix B. AlignmentsIn our experiments, we align neural represen- tations ofN θ CNN-LSTM with the variablesT Size ,T Color ,T Shape , P x ∆ , andP y ∆ , inC ReaSCAN . We choose the neural represen- tatione Shape output by the LSTM encoder above the noun token (e.g., “circle”), which has 75 dimensions, to be evenly partitioned into three chunks of 25 dimensions, which are aligned with the target propertiesT Size ,T Color , andT Shape . For the position deltas, we choose the initial hidden represen- tationh 0 of the decoder LSTM, which has 100 dimensions, to be sliced into two evenly partitioned 50 dimension chunks where the first chunk represents the position difference by rowP y ∆ , and the second chunk represents the position differ- ence by columnP x ∆ . We train the network to derive from the world and the command the horizontal and vertical distances between the target and agent, storing the horizontal distance in one half ofh 0 and the vertical distance in the other. Interchange Intervention TrainingFor each variableV inC ReaSCAN aligned with neuronsN V inN θ CNN-LSTM , we introduce an IIT objective that optimizes forN θ CNN-LSTM implementing the marginalized submodelC V ReaSCAN : ∑ b,s∈REASCAN CE Action ( INTINV(N θ CNN-LSTM ,b,s,N V ), INTINV(C V ReaSCAN ,b,s,V) ) (9) whereCE Action is the cross-entropy loss over each action token prediction over the complete action sequence. Multi-task ObjectivesSimilar to MNIST-PVR, we train small models to predict the position offsets between the target and the agent from the aligned neural representations. Specifically, for eachV∈ T Size ,T Color ,T Shape ,P x ∆ ,P y ∆ , we train a single-layer linear classifierP φ V on the loss ∑ i∈ReaSCAN CE Position ( P φ V (GETVALS(N θ CNN-LSTM ,i,N V ), GETVALS(C ReaSCAN ,i,V) ) (10) where the trained parameters areθ, the parameters of the CNN-LSTM, andφ V , the parameters of the classifiers. Results Our results are shown in Table 3b. We use exact matches of action sequences as our evaluation metric for the behavioral and interchange intervention tasks. We begin with our results on the behavioral task. Standard training produces models that fail to generalize across all four tasks. IIT alone out-performs multi-task training on novel sizes and lengths, and performs similarly on novel colors and lengths. Again, we observe that IIT and multi- task synergize, producing the models that best generalize across all tasks. Overall, IIT is essential to this systematic generalization task. Our interchange intervention accuracy results suggest that IIT delivers models that best conform to the interpretable causal model. Without any IIT objectives, both the standard and multi-task models achieve non-zero interchange inter- vention accuracy only for the two easier splits: novel colors and novel size. IIT achieves significant improvements over these two tasks and gets traction on the two more difficult ones, novel direction and novel length. And, once again, combining IIT with multi-task training delivers the best model by wide margins on all four tasks. 6. Natural Language Inference (MQNLI) Our final benchmark is MQNLI (Geiger et al., 2019), a synthetic natural language inference dataset where the task is to label the semantic relation between two sentences as enailment,contradiction, orneutral. Here is an example: εeveryεbakerε εhappily eatsεsome stale bread contradiction εsome angry baker does notεeatεsomeεbread where anεdenotes the absence of a word and is used to align corresponding words in the two sentences. Geiger et al. (2020) fine-tuned a BERT model on MQNLI, achieving state-of-the-art results (≈90% test accuracy). Their interchange intervention analysis revealed that this model learns to partially represent the relation between aligned subphrases in the two sentences (e.g.stale bread andεbreadin the example above) We hypothesize that if weteachBERT to fully represent this information using IIT, the task will be solved perfectly. Symbolic causal structure Geiger et al. (2020) define a causal model that computes Inducing Causal Structure for Interpretable Neural Networks the relation between aligned phrases in order to compute the relation between two sentences. We narrow our focus to the submodelC QP Obj NatLog , which (1) contains a single interme- diate variableQP Obj that computes the relation between the quantified verb phrase(= adverb + verb + quantified object noun phrase) of each sentence, and (2) uses this to infer the relation between the two sentences. To label the MQNLI example above,C QP Obj NatLog would compute thatQP Obj =@ because “happily eatsεsome stale bread”entails“eatε someεbread”. Then, this information is used infer that the relation between the sentences iscontradiction. Systematic GeneralizationThe train-test split of MQNLI is constructed to be as difficult as possible while still being solved by a compositional memorization-based learning model. This makes the task hard, but fair. Neural NetworkWe fine-tune a pretrained BERT model N θ NLI on MQNLI. The architecture consists of 12 trans- former layers that create a neural representations for each token in the input; the grid of neural representations has a column for each token and a row for each layer. AlignmentsIn our experiments, we alignQP Obj with the neural representations above the verb in the first sentence from BERT layers0, 2, 4, 6, 8, 10. Interchange intervention trainingForQP Obj inC QP Obj NatLog aligned with neuronsNinN θ NLI , we introduce an IIT objec- tive that optimizes forN NLI implementing the marginalized submodelC QP Obj NatLog : ∑ b,s∈MQNLI CE ( INTINV(N θ NLI ,b,s,N), INTINV(C QP Obj NatLog ,b,s,QP Obj ) ) (11) Multi-task Objectives We train a linear classifierP φ to predict the value of QP Obj with loss: ∑ i∈MQNLI CE ( P φ (GETVALS(N θ NLI ,i,N), GETVALS(C QP Obj NatLog ,i,QP Obj ) ) (12) where the trained parameters areθ, the parameters of BERT, andφ, the parameters of the classifier. Data AugmentationWe perform data augmentation by ran- domly sampling two examples and replacing the quantified verb phrase from the first example with those from the sec- ond in order to produce a new example that is then labeled withC QP Obj NatLog . This procedure is guided by the same causal structure used by our other models, but it is by definition restricted to input manipulations. ResultsWe compute the accuracy on the basic behavioral task of predicting MQNLI labels, and interchange interven- tion (IIT) accuracy, where we compute the percentage of Figure 4.Performance of a pretrained BERT natural language in- ference model fine-tuned on the MQNLI dataset with the causal modelC QP Obj NatLog from Geiger et al. (2020). We report the results on the evaluation set. While data augmentation leads to consis- tently excellent behavior accuracy (left) panel, it has very low interchange intervention accuracy. In other words, IIT is necessary for an interpretable model with high-performance. cases where performing an intervention on the neural model produced a same change in output as the submodelC QP Obj NatLog . Our results, shown in figure 4, demonstrate that IIT training onC QP Obj NatLog solves MQNLI with near perfect accuracy and IIT accuracy (≈100%). Furthermore, we see that aligning QP Obj to the first few layers of BERT results in lower ac- curacy on MQNLI, with layer 6 and onward resulting in near perfect accuracy and interchange intervention accuracy. WhenQP Obj is aligned with layer 6 or later, we again see multi-task training synergizing with IIT to produce the best models. Data augmentation results in models with perfect behavioral performance. This is unsurprising, as data augmentation removes the out-of-domain generalization problem. IIT is needed to produce a model with an interpretable solution. 7. Conclusion We introduced interchange intervention training as a method to imbue neural networks with interpretable, systematic causal structure, and we conducted three case studies with IIT: a vision task (MNIST-PVR), a grounded language un- derstanding task (ReaSCAN), and a natural language infer- ence task (MQNLI). In all settings, models trained with IIT perform best in standard (but very challenging) behavioral evaluations and prove to be the most interpretable in the sense that they conform best to our high-level causal models of the tasks. In addition, our results show that IIT is easily combined with multi-task objectives that further strengthen the results. These initial findings suggest that IIT is a flex- ible and powerful way to bring high-level insights about causal structure into a data-driven learning process. Inducing Causal Structure for Interpretable Neural Networks References Bau, D., Zhu, J., Strobelt, H., Zhou, B., Tenenbaum, J. B., Freeman, W. T., and Torralba, A. GAN dissec- tion: Visualizing and understanding generative adver- sarial networks. In7th International Conference on Learning Representations, ICLR 2019, New Orleans, LA, USA, May 6-9, 2019. OpenReview.net, 2019. URL https://openreview.net/forum?id=HygX2C5FX. Beckers, S. and Halpern, J. Y. Abstracting causal mod- els.Proceedings of the AAAI Conference on Artificial Intelligence, 33(01):2678–2685, Jul. 2019. doi: 10.1609/ aaai.v33i01.33012678. URLhttps://ojs.aaai.org/ index.php/AAAI/article/view/4117. Beckers, S., Eberhardt, F., and Halpern, J. Y. Approximate causal abstractions. In Adams, R. P. and Gogate, V. (eds.), Proceedings of The 35th Uncertainty in Artificial Intel- ligence Conference, volume 115 ofProceedings of Ma- chine Learning Research, p. 606–615, Tel Aviv, Israel, 22–25 Jul 2020. PMLR. URLhttp://proceedings. mlr.press/v115/beckers20a.html. Besserve, M., Mehrjou, A., Sun, R., and Sch ̈ olkopf, B. Counterfactuals uncover the modular structure of deep generative models.In8th International Conference on Learning Representations, ICLR 2020, Addis Ababa, Ethiopia, April 26-30, 2020. OpenReview.net, 2020. URL https://openreview.net/forum?id=SJxDDpEKvH. Binder, A., Montavon, G., Bach, S., M ̈ uller, K., and Samek, W. Layer-wise relevance propagation for neural networks with local renormalization layers.CoRR, abs/1604.00825, 2016. URLhttp://arxiv.org/abs/1604.00825. Chalupka, K., Eberhardt, F., and Perona, P. Multi-level cause-effect systems. In Gretton, A. and Robert, C. C. (eds.),Proceedings of the 19th International Conference on Artificial Intelligence and Statistics, volume 51 of Proceedings of Machine Learning Research, p. 361– 369, Cadiz, Spain, 09–11 May 2016. PMLR. URLhttp: //proceedings.mlr.press/v51/chalupka16.html. Chattopadhyay, A., Manupriya, P., Sarkar, A., and Bal- asubramanian, V. N. Neural network attributions: A causal perspective.In Chaudhuri, K. and Salakhut- dinov, R. (eds.),Proceedings of the 36th Interna- tional Conference on Machine Learning, volume 97 of Proceedings of Machine Learning Research, p. 981– 990, Long Beach, California, USA, 09–15 Jun 2019. PMLR. URLhttp://proceedings.mlr.press/v97/ chattopadhyay19a.html. Clark, K., Khandelwal, U., Levy, O., and Manning, C. D. What does BERT look at? an analysis of BERT’s at- tention. InProceedings of the 2019 ACL Workshop BlackboxNLP: Analyzing and Interpreting Neural Net- works for NLP, p. 276–286, Florence, Italy, August 2019. Association for Computational Linguistics. doi: 10.18653/v1/W19-4828. URLhttps://w.aclweb. org/anthology/W19-4828. Crawshaw, M. Multi-task learning with deep neural net- works: A survey.CoRR, abs/2009.09796, 2020. URL https://arxiv.org/abs/2009.09796. Csord ́ as, R., van Steenkiste, S., and Schmidhuber, J. Are neural nets modular? inspecting functional modularity through differentiable weight masks. In9th Interna- tional Conference on Learning Representations, ICLR 2021, Virtual Event, Austria, May 3-7, 2021. Open- Review.net, 2021.URLhttps://openreview.net/ forum?id=7uVcpu-gMD. Elazar, Y., Ravfogel, S., Jacovi, A., and Goldberg, Y. Am- nesic probing: Behavioral explanation with amnesic coun- terfactuals. InProceedings of the 2020 EMNLP Work- shop BlackboxNLP: Analyzing and Interpreting Neural Networks for NLP. Association for Computational Lin- guistics, November 2020. doi: 10.18653/v1/W18-5426. Erion, G., Janizek, J. D., Sturmfels, P., Lundberg, S. M., and Lee, S.-I. Improving performance of deep learning models with axiomatic attribution priors and expected gradients.Nature Machine Intelligence, 3(7):620–631, 2021. doi: 10.1038/s42256-021-00343-w. URLhttps: //doi.org/10.1038/s42256-021-00343-w. Feder, A., Oved, N., Shalit, U., and Reichart, R. CausaLM: Causal Model Explanation Through Counterfactual Lan- guage Models.Computational Linguistics, p. 1–54, 05 2021. ISSN 0891-2017. doi: 10.1162/colia00404. URL https://doi.org/10.1162/colia00404. Fukushima, K. and Miyake, S. Neocognitron: A self- organizing neural network model for a mechanism of visual pattern recognition. InCompetition and coopera- tion in neural nets, p. 267–285. Springer, 1982. Geiger, A., Cases, I., Karttunen, L., and Potts, C. Pos- ing fair generalization tasks for natural language infer- ence. InProceedings of the 2019 Conference on Em- pirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Lan- guage Processing (EMNLP-IJCNLP), p. 4475–4485, Stroudsburg, PA, November 2019. Association for Com- putational Linguistics. doi: 10.18653/v1/D19-1456. URL https://w.aclweb.org/anthology/D19-1456. Geiger, A., Richardson, K., and Potts, C. Neural nat- ural language inference models partially embed theo- ries of lexical entailment and negation. InProceed- ings of the Third BlackboxNLP Workshop on Analyz- Inducing Causal Structure for Interpretable Neural Networks ing and Interpreting Neural Networks for NLP, p. 163– 173, Online, November 2020. Association for Computa- tional Linguistics. doi: 10.18653/v1/2020.blackboxnlp-1. 16. URLhttps://w.aclweb.org/anthology/2020. blackboxnlp-1.16. Geiger, A., Lu, H., Icard, T., and Potts, C. Causal ab- stractions of neural networks. InAdvances in Neural Information Processing Systems, 2021. URLhttps: //arxiv.org/abs/2109.08994. Giulianelli, M., Harding, J., Mohnert, F., Hupkes, D., and Zuidema, W. Under the hood: Using diagnostic classi- fiers to investigate and improve how language models track agreement information. InProceedings of the 2018 EMNLP Workshop BlackboxNLP: Analyzing and Inter- preting Neural Networks for NLP, p. 240–248, Brus- sels, Belgium, November 2018. Association for Compu- tational Linguistics. doi: 10.18653/v1/W18-5426. URL https://w.aclweb.org/anthology/W18-5426. He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learning for image recognition. In2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), p. 770–778, 2016. doi: 10.1109/CVPR.2016.90. Hewitt, J. and Liang, P. Designing and interpreting probes with control tasks. InProceedings of the 2019 Con- ference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), p. 2733–2743, Hong Kong, China, November 2019. Association for Computational Linguistics. doi: 10. 18653/v1/D19-1275. URLhttps://w.aclweb.org/ anthology/D19-1275. Hitchcock, C. The intransitivity of causation revealed in equations and graphs.Journal of Philosophy, 98(6):273– 299, 2001. Hudson, D. A. and Manning, C. D. Compositional atten- tion networks for machine reasoning. InInternational Conference on Learning Representations, 2018. Hupkes, D., Veldhoen, S., and Zuidema, W. H. Visu- alisation and ’diagnostic classifiers’ reveal how recur- rent and recursive neural networks process hierarchi- cal structure.J. Artif. Intell. Res., 61:907–926, 2018. doi: 10.1613/jair.1.11196. URLhttps://doi.org/10. 1613/jair.1.11196. Imbens, G. W. and Rubin, D. B.Causal inference in statis- tics, social, and biomedical sciences. Cambridge Univer- sity Press, 2015. Jacovi, A. and Goldberg, Y.Towards faithfully inter- pretable NLP systems: How should we define and eval- uate faithfulness?InProceedings of the 58th An- nual Meeting of the Association for Computational Lin- guistics, p. 4198–4205, Online, July 2020. Associa- tion for Computational Linguistics. doi: 10.18653/v1/ 2020.acl-main.386. URLhttps://w.aclweb.org/ anthology/2020.acl-main.386. Kaushik, D., Hovy, E. H., and Lipton, Z. C. Learning the difference that makes a difference with counterfactually- augmented data.CoRR, abs/1909.12434, 2019. URL http://arxiv.org/abs/1909.12434. Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. InICLR (Poster), 2015. Lake, B. and Baroni, M. Generalization without systematic- ity: On the compositional skills of sequence-to-sequence recurrent networks. InInternational Conference on Ma- chine Learning, p. 2873–2882. PMLR, 2018. LeCun, Y., Cortes, C., and Burges, C.Mnist hand- written digit database.ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist, 2, 2010. Lipton, Z. C. The mythos of model interpretability: In machine learning, the concept of interpretability is both important and slippery.Queue, 16(3):31–57, jun 2018. ISSN 1542-7730. doi: 10.1145/3236386.3241340. URL https://doi.org/10.1145/3236386.3241340. Liu, Q., Kusner, M., and Blunsom, P. Counterfactual data augmentation for neural machine translation. InPro- ceedings of the 2021 Conference of the North Ameri- can Chapter of the Association for Computational Lin- guistics: Human Language Technologies, p. 187–197, Online, June 2021. Association for Computational Lin- guistics. doi: 10.18653/v1/2021.naacl-main.18. URL https://aclanthology.org/2021.naacl-main.18. Meng, K., Bau, D., Andonian, A., and Belinkov, Y. Locating and editing factual associations in gpt, 2022. URLhttps: //arxiv.org/abs/2202.05262. Pearl, J.Causality: Models, Reasoning and Inference. Cam- bridge University Press, USA, 2nd edition, 2009. ISBN 052189560X. Perez, L. and Wang, J. The effectiveness of data augmenta- tion in image classification using deep learning.CoRR, abs/1712.04621, 2017. URLhttp://arxiv.org/abs/ 1712.04621. Peters, M., Neumann, M., Zettlemoyer, L., and Yih, W.-t. Dissecting contextual word embeddings: Architecture Inducing Causal Structure for Interpretable Neural Networks and representation. InProceedings of the 2018 Confer- ence on Empirical Methods in Natural Language Pro- cessing, p. 1499–1509, Brussels, Belgium, October- November 2018. Association for Computational Lin- guistics. doi: 10.18653/v1/D18-1179. URLhttps: //w.aclweb.org/anthology/D18-1179. Pryzant, R., Card, D., Jurafsky, D., Veitch, V., and Sridhar, D. Causal effects of linguistic properties. InNAACL, 2021. Ravfogel, S., Elazar, Y., Gonen, H., Twiton, M., and Gold- berg, Y. Null it out: Guarding protected attributes by iterative nullspace projection. InProceedings of the 58th Annual Meeting of the Association for Computational Linguistics, p. 7237–7256, Online, July 2020. Associ- ation for Computational Linguistics. doi: 10.18653/v1/ 2020.acl-main.647. URLhttps://w.aclweb.org/ anthology/2020.acl-main.647. Ravichander, A., Belinkov, Y., and Hovy, E. Probing the probing paradigm: Does probing accuracy entail task relevance?, 2020. Rubenstein, P. K., Weichwald, S., Bongers, S., Mooij, J. M., Janzing, D., Grosse-Wentrup, M., and Sch ̈ olkopf, B.Causal consistency of structural equation mod- els. InProceedings of the 33rd Conference on Uncer- tainty in Artificial Intelligence (UAI). Association for Uncertainty in Artificial Intelligence (AUAI), August 2017. URLhttp://auai.org/uai2017/proceedings/ papers/11.pdf. *equal contribution. Ruder, S. An overview of multi-task learning in deep neural networks.CoRR, abs/1706.05098, 2017. URLhttp: //arxiv.org/abs/1706.05098. Ruis, L., Andreas, J., Baroni, M., Bouchacourt, D., and Lake, B. M. A benchmark for systematic generalization in grounded language understanding.Advances in Neural Information Processing Systems, 33, 2020. Schuster, M. and Paliwal, K. K. Bidirectional recurrent neural networks.IEEE transactions on Signal Processing, 45(11):2673–2681, 1997. Shorten, C. and Khoshgoftaar, T. M. A survey on image data augmentation for deep learning.Journal of Big Data, 6:1–48, 2019. Shrikumar, A., Greenside, P., Shcherbina, A., and Kun- daje, A. Not just a black box: Learning important fea- tures through propagating activation differences.CoRR, abs/1605.01713, 2016. URLhttp://arxiv.org/abs/ 1605.01713. Soulos, P., McCoy, R. T., Linzen, T., and Smolensky, P. Discovering the compositional structure of vector rep- resentations with role learning networks. InProceed- ings of the Third BlackboxNLP Workshop on Analyz- ing and Interpreting Neural Networks for NLP, p. 238– 254, Online, November 2020. Association for Computa- tional Linguistics. doi: 10.18653/v1/2020.blackboxnlp-1. 23. URLhttps://w.aclweb.org/anthology/2020. blackboxnlp-1.23. Spirtes, P., Glymour, C. N., and Scheines, R.Causation, Prediction, and Search. MIT Press, 2nd edition, 2001. Springenberg, J., Dosovitskiy, A., Brox, T., and Riedmiller, M. Striving for simplicity: The all convolutional net. CoRR, 12 2014. Sundararajan, M., Taly, A., and Yan, Q. Axiomatic attri- bution for deep networks. In Precup, D. and Teh, Y. W. (eds.),Proceedings of the 34th International Conference on Machine Learning, volume 70 ofProceedings of Ma- chine Learning Research, p. 3319–3328, International Convention Centre, Sydney, Australia, 06–11 Aug 2017. PMLR. URLhttp://proceedings.mlr.press/v70/ sundararajan17a.html. Tenney, I., Das, D., and Pavlick, E.BERT rediscov- ers the classical NLP pipeline. InProceedings of the 57th Annual Meeting of the Association for Computa- tional Linguistics, p. 4593–4601, Florence, Italy, July 2019. Association for Computational Linguistics. doi: 10.18653/v1/P19-1452.URLhttps://w.aclweb. org/anthology/P19-1452. Vig, J., Gehrmann, S., Belinkov, Y., Qian, S., Nevo, D., Singer, Y., and Shieber, S. Causal mediation analysis for interpreting neural nlp: The case of gender bias, 2020. Wu, Z., Kreiss, E., Ong, D. C., and Potts, C. ReaS- CAN: Compositional reasoning in language grounding. NeurIPS 2021 Datasets and Benchmarks Track, 2021. URLhttps://arxiv.org/abs/2109.08994. Zeiler, M. D. and Fergus, R. Visualizing and understanding convolutional networks. In Fleet, D., Pajdla, T., Schiele, B., and Tuytelaars, T. (eds.),Computer Vision – ECCV 2014, p. 818–833, Cham, 2014. Springer International Publishing. ISBN 978-3-319-10590-1. Zhang, C., Raghu, M., Kleinberg, J. M., and Bengio, S. Pointer value retrieval: A new benchmark for understand- ing the limits of neural network generalization.CoRR, abs/2107.12580, 2021. URLhttps://arxiv.org/abs/ 2107.12580. Zhang, Y. and Yang, Q. A survey on multi-task learning. CoRR, abs/1707.08114, 2017. URLhttp://arxiv.org/ abs/1707.08114. Inducing Causal Structure for Interpretable Neural Networks A. Zero Loss Entails Causal Abstraction Claim Suppose we have a loss functionLOSSthat outputs a non-negative value. IfLOSS(x,y) = 0⇒x=κ(y), then the interchange intervention loss being zero guarantees that then causal modelCis a causal abstraction of the neural networkN θ . ProofSuppose that ∑ b,s∈V In LOSS ( INTINV(C,b,s,V), INTINV(N θ ,b,s,Π(V)) ) = 0(13) Because our loss function outputs non-negative numbers, we know that, if the sum Eqn. 4 is0, then each addend in the sum is0: ∀b,s∈V In :LOSS ( INTINV(C,b,s,V), INTINV(N θ ,b,s,Π(V)) ) = 0(14) Because our loss function is such thatLOSS(x,y) = 0⇒ x=κ(y), we conclude: ∀b,s∈V In :INTINV(C,b,s,V) = κ(INTINV(N θ ,b,s,Π(V)))(15) This is exactly the condition for abstraction in Eqn. 2. B. ReaSCAN Dataset Generation Table 2 shows dataset statistics. For the novel color and novel size splits, we only train a single model which uses the same training set, but test on different testing sets, as discussed in section 5. For the novel color, novel size and novel length splits, we use the ReaSCAN framework 2 to generateSimplecommands without any relative clause as discussed in its original paper (Wu et al., 2021). For the novel color and novel size splits, we have allowed verbs = “walk to”, “push”, “pull”, and allowed adverbs =“while zigzagging”, “while spinning”, “cautiously”, “hesitantly”. For the novel direction and novel length splits, we have al- lowed verbs =“walk to”, and we disallow adverbs, as we are focusing on action length generalization, not command generalization. Our split B1 is derived from gSCAN with its novel direction testing split (Ruis et al., 2020), as the ReaSCAN framework cannot partition splits by relative agent-to-target direction. We set 200 grid worlds per command for the novel color and novel size splits, and we set 1200 grid worlds per command 2 The implementation is adapted from ReaSCAN’s public code repository:https://github.com/frankaging/Reason-SCAN. Split#Train#Dev#Test#Zero-shot A1: novel color76,1023,8163,7747,195 A2: novel size76,1023,8163,7747,227 B1: novel direction34,3431,2013578,282 B2: novel length52,6624,2504,2501,338 Table 2.Statistics of all splits in our ReaSCAN dataset. for the novel length split, as the allowed command pattern is much smaller for this split, since we exclude all other verbs except “walk to”. The ReaSCAN dataset generation procedure leads to some artifacts, which are discussed in its original paper in detail. These are not especially relevant for our experiments. The data generation process takes approximately 30 min- utes on a multi-CPU cluster. Although we generate our own datasets from an existing data generation engine, our training paradigm can be extended to solve existing datasets. Experiment Set-up For ourCNN-LSTM, we adapt code from the original repos- itory. 3 For all training objectives, we optimize for cross- entropy loss using Adam with default parameters (Kingma & Ba, 2015). The learning rate starts at1e −4 and decays by0.9every 20,000 steps. We train the model for a fixed number of epochs (100,000) before stopping. The best model is picked by performance on a smaller development set of 2,000 examples, which is consistent with the training pipeline proposed in Ruis et al. (2020) for gSCAN. The training time is about 1 day on a Standard GeForce RTX 2080 Ti GPU with 11GB memory. To foster reproducibility, we release our adapted evaluation scripts in our code reposi- tory. We repeat each experiment with three distinct random seeds to ensure a fair comparison. Training Procedure We release implementations for our neural models with our symbolic causal structures in our code repository. Our released symbolic causal structures for solving ReaSCAN is not unique, and may not be the optimal one for improving generalizability. Additionally, our variable mappings between two models are not unique. Ideally, a chosen casual variable can be mapped into any hidden states in the neural model. However, we find that the specific mapping chosen substantially affects model performance and generalizability. In contrast to a standard training pipeline, which takes in a single input, our IIT takes pairs of examples as inputs. We found that the formulation of the pairs affects performance. We leave this for future research into the effects of example pairing on model performance. 3 https://github.com/LauraRuis/multimodalseq2seq gSCAN Inducing Causal Structure for Interpretable Neural Networks Generalization SplitsTo evaluate the generalization power of models, ReaSCAN includes testing examples that are systematically different from training examples. Specifi- cally, ReaSCAN generates unseen testing patterns to assess whether models can generalize to unseen composites of seen concepts in a zero-shot setting. We now describe each split in detail. Novel Color Attribute (Novel color)allows models to see “yellow circle” (6,127 examples) and “red square” (6,111 examples) during training but never allows any commands containing “yellow square” during training, and evaluates models with commands containing “yellow square” during test time. Novel Size Attribute (Novel size)holds out all commands referring to small cylinders in any color, meaning that mod- els have not seen commands containing phrases such as “small cylinder” or “small yellow cylinder” during training. On the other hand, models have seen commands containing “big cylinder” (2,020 examples) or “small square” (2,093 examples). At test time, models need to generalize to the hold-out examples. Novel Direction (Novel direction)holds out any command and grid world pair where the referred target is initially located at the south west (SW) of the agent. At test time, models need to generate action sequence to reach to the target which is located SW of the agent. As our agent is always facing to the right (i.e., east) in the beginning, models need to generate action sequences containing three “turn left” actions in order to reach any target positioning at SW of the agent. Novel Action Sequence Length (Novel length)holds out any command and grid world pair that requires models to predict action sequence that contains more than 10 actions. At test time, models need to generalize to examples that require 11, 12, or 13 actions to reach at the target. CNN-LSTMThe encoder contains two parts, a convolu- tional network (CNN) (Fukushima & Miyake, 1982) for en- coding the grid world and a bi-directional LSTM (Schuster & Paliwal, 1997; Hudson & Manning, 2018) for encoding the command. The decoder is a LSTM with cross-modalitiy attention weights over the command and the grid world. C. Experiment Details of Multiply Quantified Natural Language Inference (MQNLI) DatasetFor our experiments, we used a train set with 500K examples, a dev set with 60k examples, and a test set with 10K examples – the most difficult generalization scheme of Geiger et al. (2019). Unlike Geiger et al. (2020) we do not include extra augmented examples consisting of subphrases with labeled relations, but only include complete sentences. Only theQP Obj labels are used to introduce IIT and multi-task training objectives. Models For BERT, we use the same model architecture for theuncased-basevariant, with 12 layers and a hidden size of 768. For the linear classifier that predicts the value ofQP Obj given a neural representation, we use a similar model as the probes used in Geiger et al. (2020). It is a single-layer soft- max classifier:y∝softmax(Ah+b)wherehis a hidden representation andyis the predicted probability distribution of each class ofQP Obj . Following Hewitt & Liang (2019), to control the dimensionality ofA, we factorize it in the formA=LRwhereL∈R |QP Obj |×` andR∈R `×d where dis the dimensionality ofh. We choose`= 32. Training ProcedureWe use a batch size of 32. We use 5.0×10 −5 as our learning rate, and useadamwoptimization. We train for a maximum of 5 epochs. We warm up the learning rate linearly from 0 to the specified value in the first 50% of steps of the first epoch, and linearly decrease the learning rate to0following that until the end of training. To construct each batch sample during training, we ran- domly pick two samples from the MQNLI dataset. We construct the augmented input by replacing the quantified verb phrase in the first example with that in the second one. For counterfactual training, we useantrapackage, which is built off of the implementation in Geiger et al. (2020), which we construct an intervention with the first example as the base, and the second example as the intervention source. For multi-task training, we only use the hidden representation of the first example as input to the linear classifier. We sample 50000 base examples from the train set of MQNLI. For each base example we sample 20 intervention sources, and try to ensure that 10 of them will be able to constructimpactful interventions with the base example, i.e. the logical model C QP Obj NatLog computes that doing an intervention will change the final label. We use the same50000×20 =1M pairs for each training epoch. We use the logical modelC QP Obj NatLog as our oracle to obtain all of the labels for the base output, counterfactual output, augmented output, and the value ofQP Obj used for the linear classifier. For each of the experiment settings (base, IIT, IIT+multitask, augment, multitask), we always assign a weight of 1.0 to the base objective plus each additional objective, e.g. IIT + multitask uses a weight of 1.0 for each of base, IIT and multitask. Inducing Causal Structure for Interpretable Neural Networks “push the small green cylinder hesitantly” 'turn left', 'turn left', 'walk', 'stay', 'walk', 'stay', 'walk', 'stay' “pull to the red cylinder while zigzagging” 'walk', 'turn left', 'walk', 'turn right', 'walk', 'turn left', 'walk', 'walk', 'pull', 'pull' “walk the cylinder” 'turn left', 'turn left', 'walk' Figure 5.ReaSCAN examples with varying command patterns. The navigation commands and the target action sequences are in the grey boxes and green boxes respectively.