Amortized Inference of Causal Models via Conditional Fixed-Point Iterations
- URL: http://arxiv.org/abs/2410.06128v3
- Date: Tue, 10 Jun 2025 22:20:54 GMT
- Title: Amortized Inference of Causal Models via Conditional Fixed-Point Iterations
- Authors: Divyat Mahajan, Jannes Gladrow, Agrin Hilmkil, Cheng Zhang, Meyer Scetbon,
- Abstract summary: We propose amortized inference of Structural Causal Models (SCMs) by training a single model on multiple datasets sampled from different SCMs.<n>We first use a transformer-based architecture for amortized learning of dataset embeddings, and then extend the Fixed-Point Approach (FiP) to infer SCMs conditionally on their dataset embeddings.<n>As a byproduct, our method can generate observational and interventional data from novel SCMs at inference time, without updating parameters.
- Score: 17.427722515310606
- License: http://creativecommons.org/licenses/by/4.0/
- Abstract: Structural Causal Models (SCMs) offer a principled framework to reason about interventions and support out-of-distribution generalization, which are key goals in scientific discovery. However, the task of learning SCMs from observed data poses formidable challenges, and often requires training a separate model for each dataset. In this work, we propose amortized inference of SCMs by training a single model on multiple datasets sampled from different SCMs. We first use a transformer-based architecture for amortized learning of dataset embeddings, and then extend the Fixed-Point Approach (FiP) (Scetbon et al.) to infer SCMs conditionally on their dataset embeddings. As a byproduct, our method can generate observational and interventional data from novel SCMs at inference time, without updating parameters. Empirical results show that our amortized procedure performs on par with baselines trained specifically for each dataset on both in and out-of-distribution problems, and also outperforms them in scare data regimes.
Related papers
- Factor Analysis with Correlated Topic Model for Multi-Modal Data [0.0]
Multimodal factor analysis (FA) uncovers shared axes of variation underlying simple data modalities.<n>FA is not suited for structured data modalities, such as text or single cell sequencing data.<n>We introduce FACTM, a novel, multi-view and multi-structure Bayesian model that combines FA with correlated topic modeling and is optimized using variational inference.
arXiv Detail & Related papers (2025-04-26T13:02:53Z) - AdvKT: An Adversarial Multi-Step Training Framework for Knowledge Tracing [64.79967583649407]
Knowledge Tracing (KT) monitors students' knowledge states and simulates their responses to question sequences.
Existing KT models typically follow a single-step training paradigm, which leads to significant error accumulation.
We propose a novel Adversarial Multi-Step Training Framework for Knowledge Tracing (AdvKT) which focuses on the multi-step KT task.
arXiv Detail & Related papers (2025-04-07T03:31:57Z) - Meta-Statistical Learning: Supervised Learning of Statistical Inference [59.463430294611626]
This work demonstrates that the tools and principles driving the success of large language models (LLMs) can be repurposed to tackle distribution-level tasks.<n>We propose meta-statistical learning, a framework inspired by multi-instance learning that reformulates statistical inference tasks as supervised learning problems.
arXiv Detail & Related papers (2025-02-17T18:04:39Z) - DISCO: DISCovering Overfittings as Causal Rules for Text Classification Models [6.369258625916601]
Post-hoc interpretability methods fail to capture the models' decision-making process fully.
Our paper introduces DISCO, a novel method for discovering global, rule-based explanations.
DISCO supports interactive explanations, enabling human inspectors to distinguish spurious causes in the rule-based output.
arXiv Detail & Related papers (2024-11-07T12:12:44Z) - An Information Criterion for Controlled Disentanglement of Multimodal Data [39.601584166020274]
Multimodal representation learning seeks to relate and decompose information inherent in multiple modalities.
Disentangled Self-Supervised Learning (DisentangledSSL) is a novel self-supervised approach for learning disentangled representations.
arXiv Detail & Related papers (2024-10-31T14:57:31Z) - Continual Learning for Multimodal Data Fusion of a Soft Gripper [1.0589208420411014]
A model trained on one data modality often fails when tested with a different modality.
We introduce a continual learning algorithm capable of incrementally learning different data modalities.
We evaluate the algorithm's effectiveness on a challenging custom multimodal dataset.
arXiv Detail & Related papers (2024-09-20T09:53:27Z) - Disperse-Then-Merge: Pushing the Limits of Instruction Tuning via Alignment Tax Reduction [75.25114727856861]
Large language models (LLMs) tend to suffer from deterioration at the latter stage ofSupervised fine-tuning process.
We introduce a simple disperse-then-merge framework to address the issue.
Our framework outperforms various sophisticated methods such as data curation and training regularization on a series of standard knowledge and reasoning benchmarks.
arXiv Detail & Related papers (2024-05-22T08:18:19Z) - FiP: a Fixed-Point Approach for Causal Generative Modeling [20.88890689294816]
We propose a new and equivalent formalism that does not require DAGs to describe fixed-point problems on the causally ordered variables.
We show three important cases where they can be uniquely recovered given the topological ordering (TO)
arXiv Detail & Related papers (2024-04-10T12:29:05Z) - Heat Death of Generative Models in Closed-Loop Learning [63.83608300361159]
We study the learning dynamics of generative models that are fed back their own produced content in addition to their original training dataset.
We show that, unless a sufficient amount of external data is introduced at each iteration, any non-trivial temperature leads the model to degenerate.
arXiv Detail & Related papers (2024-04-02T21:51:39Z) - Images in Discrete Choice Modeling: Addressing Data Isomorphism in
Multi-Modality Inputs [77.54052164713394]
This paper explores the intersection of Discrete Choice Modeling (DCM) and machine learning.
We investigate the consequences of embedding high-dimensional image data that shares isomorphic information with traditional tabular inputs within a DCM framework.
arXiv Detail & Related papers (2023-12-22T14:33:54Z) - Causal Optimal Transport of Abstractions [8.642152250082368]
Causal abstraction (CA) theory establishes formal criteria for relating multiple structural causal models (SCMs) at different levels of granularity.
We propose COTA, the first method to learn abstraction maps from observational and interventional data without assuming complete knowledge of the underlying SCMs.
We extensively evaluate COTA on synthetic and real world problems, and showcase its advantages over non-causal, independent and aggregated COTA formulations.
arXiv Detail & Related papers (2023-12-13T12:54:34Z) - Discovering Mixtures of Structural Causal Models from Time Series Data [23.18511951330646]
We propose a general variational inference-based framework called MCD to infer the underlying causal models.
Our approach employs an end-to-end training process that maximizes an evidence-lower bound for the data likelihood.
We demonstrate that our method surpasses state-of-the-art benchmarks in causal discovery tasks.
arXiv Detail & Related papers (2023-10-10T05:13:10Z) - Learning Unified Distance Metric Across Diverse Data Distributions with Parameter-Efficient Transfer Learning [36.349282242221065]
A common practice in metric learning is to train and test an embedding model for each dataset.<n>This dataset-specific approach fails to simulate real-world scenarios that involve multiple heterogeneous distributions of data.<n>We explore a new metric learning paradigm, called Unified Metric Learning (UML), which learns a unified distance metric.
arXiv Detail & Related papers (2023-09-16T10:34:01Z) - MADS: Modulated Auto-Decoding SIREN for time series imputation [9.673093148930874]
We propose MADS, a novel auto-decoding framework for time series imputation, built upon implicit neural representations.
We evaluate our model on two real-world datasets, and show that it outperforms state-of-the-art methods for time series imputation.
arXiv Detail & Related papers (2023-07-03T09:08:47Z) - Mutual Exclusivity Training and Primitive Augmentation to Induce
Compositionality [84.94877848357896]
Recent datasets expose the lack of the systematic generalization ability in standard sequence-to-sequence models.
We analyze this behavior of seq2seq models and identify two contributing factors: a lack of mutual exclusivity bias and the tendency to memorize whole examples.
We show substantial empirical improvements using standard sequence-to-sequence models on two widely-used compositionality datasets.
arXiv Detail & Related papers (2022-11-28T17:36:41Z) - SimSCOOD: Systematic Analysis of Out-of-Distribution Generalization in
Fine-tuned Source Code Models [58.78043959556283]
We study the behaviors of models under different fine-tuning methodologies, including full fine-tuning and Low-Rank Adaptation (LoRA) fine-tuning methods.
Our analysis uncovers that LoRA fine-tuning consistently exhibits significantly better OOD generalization performance than full fine-tuning across various scenarios.
arXiv Detail & Related papers (2022-10-10T16:07:24Z) - Learning from aggregated data with a maximum entropy model [73.63512438583375]
We show how a new model, similar to a logistic regression, may be learned from aggregated data only by approximating the unobserved feature distribution with a maximum entropy hypothesis.
We present empirical evidence on several public datasets that the model learned this way can achieve performances comparable to those of a logistic model trained with the full unaggregated data.
arXiv Detail & Related papers (2022-10-05T09:17:27Z) - MRCLens: an MRC Dataset Bias Detection Toolkit [82.44296974850639]
We introduce MRCLens, a toolkit that detects whether biases exist before users train the full model.
For the convenience of introducing the toolkit, we also provide a categorization of common biases in MRC.
arXiv Detail & Related papers (2022-07-18T21:05:39Z) - Learning from Temporal Spatial Cubism for Cross-Dataset Skeleton-based
Action Recognition [88.34182299496074]
Action labels are only available on a source dataset, but unavailable on a target dataset in the training stage.
We utilize a self-supervision scheme to reduce the domain shift between two skeleton-based action datasets.
By segmenting and permuting temporal segments or human body parts, we design two self-supervised learning classification tasks.
arXiv Detail & Related papers (2022-07-17T07:05:39Z) - On Continual Model Refinement in Out-of-Distribution Data Streams [64.62569873799096]
Real-world natural language processing (NLP) models need to be continually updated to fix the prediction errors in out-of-distribution (OOD) data streams.
Existing continual learning (CL) problem setups cannot cover such a realistic and complex scenario.
We propose a new CL problem formulation dubbed continual model refinement (CMR)
arXiv Detail & Related papers (2022-05-04T11:54:44Z) - DRFLM: Distributionally Robust Federated Learning with Inter-client
Noise via Local Mixup [58.894901088797376]
federated learning has emerged as a promising approach for training a global model using data from multiple organizations without leaking their raw data.
We propose a general framework to solve the above two challenges simultaneously.
We provide comprehensive theoretical analysis including robustness analysis, convergence analysis, and generalization ability.
arXiv Detail & Related papers (2022-04-16T08:08:29Z) - Causal Inference Through the Structural Causal Marginal Problem [17.91174054672512]
We introduce an approach to counterfactual inference based on merging information from multiple datasets.
We formalise this approach for categorical SCMs using the response function formulation and show that it reduces the space of allowed marginal and joint SCMs.
arXiv Detail & Related papers (2022-02-02T21:45:10Z) - Harmonization with Flow-based Causal Inference [12.739380441313022]
This paper presents a normalizing-flow-based method to perform counterfactual inference upon a structural causal model (SCM) to harmonize medical data.
We evaluate on multiple, large, real-world medical datasets to observe that this method leads to better cross-domain generalization compared to state-of-the-art algorithms.
arXiv Detail & Related papers (2021-06-12T19:57:35Z) - Continual Learning with Fully Probabilistic Models [70.3497683558609]
We present an approach for continual learning based on fully probabilistic (or generative) models of machine learning.
We propose a pseudo-rehearsal approach using a Gaussian Mixture Model (GMM) instance for both generator and classifier functionalities.
We show that GMR achieves state-of-the-art performance on common class-incremental learning problems at very competitive time and memory complexity.
arXiv Detail & Related papers (2021-04-19T12:26:26Z) - On Disentanglement in Gaussian Process Variational Autoencoders [3.403279506246879]
We introduce a class of models recently introduced that have been successful in different tasks on time series data.
Our model exploits the temporal structure of the data by modeling each latent channel with a GP prior and employing a structured variational distribution.
We provide evidence that we can learn meaningful disentangled representations on real-world medical time series data.
arXiv Detail & Related papers (2021-02-10T15:49:27Z) - Understanding Self-supervised Learning with Dual Deep Networks [74.92916579635336]
We propose a novel framework to understand contrastive self-supervised learning (SSL) methods that employ dual pairs of deep ReLU networks.
We prove that in each SGD update of SimCLR with various loss functions, the weights at each layer are updated by a emphcovariance operator.
To further study what role the covariance operator plays and which features are learned in such a process, we model data generation and augmentation processes through a emphhierarchical latent tree model (HLTM)
arXiv Detail & Related papers (2020-10-01T17:51:49Z) - Partially Conditioned Generative Adversarial Networks [75.08725392017698]
Generative Adversarial Networks (GANs) let one synthesise artificial datasets by implicitly modelling the underlying probability distribution of a real-world training dataset.
With the introduction of Conditional GANs and their variants, these methods were extended to generating samples conditioned on ancillary information available for each sample within the dataset.
In this work, we argue that standard Conditional GANs are not suitable for such a task and propose a new Adversarial Network architecture and training strategy.
arXiv Detail & Related papers (2020-07-06T15:59:28Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.