Skip to content

Commit

Permalink
add cond-fip (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
meyerscetbon authored Dec 16, 2024
1 parent 26e0c2c commit 220e47d
Show file tree
Hide file tree
Showing 19 changed files with 11,372 additions and 2,446 deletions.
5,270 changes: 2,825 additions & 2,445 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "causica"
version = "0.4.4"
version = "0.4.5"
description = ""
readme = "README.md"
authors = ["Microsoft Research - Causica"]
Expand Down
61 changes: 61 additions & 0 deletions research_experiments/cond_fip/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Zero-Shot Learning of Causal Models (Cond-FiP)
[![Static Badge](https://img.shields.io/badge/paper-CondFiP-brightgreen?style=plastic&label=Paper&labelColor=yellow)
](https://arxiv.org/pdf/2410.06128)

This repo implements Cond-FiP proposed in the paper "Zero-Shot Learning of Causal Models".

Cond-FiP is a transformer-based approach to infer Structural Causal Models (SCMs) in a zero-shot manner. Rather than learning a specific SCM for each dataset, we enable the Fixed-Point Approach (FiP) proposed in [Scetbon et al. (2024)](https://openreview.net/pdf?id=JpzIGzru5F), to infer the generative SCMs conditionally on their empirical representations. More specifically, we propose to amortize the learning
of a conditional version of FiP to infer generative SCMs from observations and causal structures on synthetically generated datasets.

Cond-FiP is composed of two models: (1) a dataset Encoder that produces embeddings given the empirical representations of SCMs, and (2) a Decoder that conditionnally on the dataset embedding infers the generative functional mechanisms of the associated SCM.

## Dependency
We use [Poetry](https://python-poetry.org/) to manage the project dependencies, they are specified in [pyproject](pyproject.toml) file. To install poetry, run:

```console
curl -sSL https://install.python-poetry.org | python3 -
```
To install the environment, run `poetry install` in the directory of cond_fip project.


## Run experiments
In the [launchers](src/cond_fip/launchers) directory, we provide scripts to run the training of both the encoder and decoder.


### Amortized Learning of the Encoder
To train the Encoder on the synthetically generated datasets of [AVICI](https://arxiv.org/abs/2205.12934), run the following command:
```console
python -m cond_fip.launchers.train_encoder
```
The model as well as the config file will be saved in `src/cond_fip/outputs`.


### Amortized Learning of Cond-FiP
To train the Decoder on the synthetically generated datasets of [AVICI](https://arxiv.org/abs/2205.12934), run the following command:
```console
python -m cond_fip.launchers.train_cond_fip\
--run_id <name_of_the_directory_containing_the_trained_encoder_model>
```
The model as well as the config file will be saved in `src/cond_fip/outputs`. This command assumes that an Encoder model has been trained and saved in a directory located at `src/cond_fip/outputs/<name_of_the_directory_containing_the_trained_encoder_model>`.

### Test Cond-FiP on a new Dataset
To test a trained Cond-FiP, we also provide a [launcher file](src/cond_fip/launchers/inference_cond_fip.py), that enables to infer SCMs with Cond-FiP on new datasets.

To use this file, one needs to provide the path to the data in the [config file](src/cond_fip/config/numpy_tensor_data_module.yaml) by replacing the value of `data_dir`.
The data should respect a specific format. One can generate example of datasets by running:

```console
python -m fip.data_generation.avici_data --func_type linear --graph_type er --noise_type gaussian --dist_case in --seed 1 --data_dim 5 --num_interventions 5
```
The data will be stored in `./data`.

To test a pre-trained Cond-FiP model on a specific dataset, one simply needs to run:
```console
python -m cond_fip.launchers.inference_cond_fip\
--run_id <name_of_the_directory_containing_the_pre_trained_model>\
--path_data <path_to_the_data>
```

This command assumes that a pre-trained Cond-FiP model has been saved in a directory located at `src/cond_fip/outputs/<name_of_the_directory_containing_the_pre_trained_model>`, and the data has been saved at the location `path_to_the_data`.


6,270 changes: 6,270 additions & 0 deletions research_experiments/cond_fip/poetry.lock

Large diffs are not rendered by default.

53 changes: 53 additions & 0 deletions research_experiments/cond_fip/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
[tool.poetry]
name = "cond_fip"
version = "0.1.0"
description = "Zero-Shot Learning of Causal Models"
readme = "README.md"
authors = ["Meyer Scetbon", "Divyat Mahajan"]
packages = [
{ include = "cond_fip", from = "src" },
]
license = "MIT"

[tool.poetry.dependencies]
python = "~3.10"
fip = { path = "../fip"}

[tool.poetry.group.dev.dependencies]
black = {version="^22.6.0", extras=["jupyter"]}
isort = "^5.10.1"
jupyter = "^1.0.0"
jupytext = "^1.13.8"
mypy = "^1.0.0"
pre-commit = "^2.19.0"
pylint = "^2.14.4"
pytest = "^7.1.2"
pytest-cov = "^3.0.0"
seaborn = "^0.12.2"
types-python-dateutil = "^2.8.18"
types-requests = "^2.31.0.10"
ema-pytorch= "^0.6.0"


[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.black]
line-length = 120

[tool.isort]
line_length = 120
profile = "black"
py_version = 310
known_first_party = ["cond_fip"]

# Keep import sorts by code jupytext percent block (https://github.com/PyCQA/isort/issues/1338)
treat_comments_as_code = ["# %%"]

[tool.pytest.ini_options]
addopts = "--durations=200"
junit_family = "xunit1"



Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
seed_everything: 2048

model:
class_path: cond_fip.tasks.cond_fip_inference.CondFiPInference
init_args:
enc_dec_model_path: ./src/cond_fip/outputs/amortized_enc_dec_training_2024-09-09_13-51-00/outputs/best_model.ckpt

trainer:
logger: MLFlowLogger
accelerator: gpu
devices: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
seed_everything: 2048

model:
class_path: cond_fip.tasks.cond_fip_training.CondFiPTraining
init_args:
encoder_model_path: ./src/cond_fip/outputs/amortized_encoder_training_2024-07-02_19-09-00/outputs/best_model.ckpt

learning_rate: 1e-4
beta1: 0.9
beta2: 0.95
weight_decay: 1e-10

use_scheduler: true
linear_warmup_steps: 1000
scheduler_steps: 10_000

d_model: 256
num_heads: 8
num_layers: 4
d_ff: 512
dropout: 0.1
dim_key: 64
num_layers_dataset: 2

distributed: false
with_true_target: true
final_pair_only: true

with_ema: true
ema_beta: 0.99
ema_update_every: 10

trainer:
max_epochs: 7000
logger: MLFlowLogger
accelerator: gpu
check_val_every_n_epoch: 1
log_every_n_steps: 10
accumulate_grad_batches: 16
log_dir: "./src/cond_fip/logging_enc_dec/"
inference_mode: false
devices: 1
num_nodes: 1

early_stopping_callback:
monitor: "val_loss"
min_delta: 0.0001
patience: 500
verbose: False
mode: "min"

best_checkpoint_callback:
dirpath: "./src/cond_fip/logging_enc_dec/"
filename: "best_model"
save_top_k: 1
mode: "min"
monitor: "val_loss"
every_n_epochs: 1

last_checkpoint_callback:
save_last: true
filename: "last_model"
save_top_k: 0 # only the last checkpoint is saved
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
seed_everything: 2048

model:
class_path: cond_fip.tasks.encoder_training.EncoderTraining
init_args:

learning_rate: 1e-4
beta1: 0.9
beta2: 0.95
weight_decay: 5e-4

use_scheduler: true
linear_warmup_steps: 1000
scheduler_steps: 10_000

d_model: 256
num_heads: 8
num_layers: 4
d_ff: 512
dropout: 0.0
dim_key: 32
d_hidden_head: 1024

distributed: false

with_ema: true
ema_beta: 0.99
ema_update_every: 10

trainer:
max_epochs: 5000
logger: MLFlowLogger
accelerator: gpu
check_val_every_n_epoch: 1
log_every_n_steps: 10
log_dir: "./src/cond_fip/logging_enc/"
inference_mode: false
devices: 1
num_nodes: 1

early_stopping_callback:
monitor: "val_loss"
min_delta: 0.0001
patience: 500
verbose: False
mode: "min"

best_checkpoint_callback:
dirpath: "./src/cond_fip/logging_enc/"
filename: "best_model"
save_top_k: 1
mode: "min"
monitor: "val_loss"
every_n_epochs: 1

last_checkpoint_callback:
save_last: true
filename: "last_model"
save_top_k: 0 # only the last checkpoint is saved
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class_path: fip.data_modules.numpy_tensor_data_module.NumpyTensorDataModule
init_args:
data_dir : "./data/er_linear_gaussian_in/total_nodes_5/seed_1/"
train_batch_size: 400
test_batch_size: 400
standardize: false
with_true_graph: true
split_data_noise: true
dod: true
num_workers: 23
shuffle: false
num_interventions: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
class_path: fip.data_modules.synthetic_data_module.SyntheticDataModule
init_args:
sem_samplers:
class_path: fip.data_generation.sem_factory.SemSamplerFactory
init_args:
node_nums: [20]
noises: ['gaussian']
graphs: ['er', 'sf_in', 'sf_out']
funcs: ['linear', 'rff']
config_gaussian:
low: 0.2
high: 2.0
config_er:
edges_per_node: [1,2,3]
config_sf:
edges_per_node: [1,2,3]
attach_power: [1.]
config_linear:
weight_low: 1.
weight_high: 3.
bias_low: -3.
bias_high: 3.
config_rff:
num_rf: 100
length_low: 7.
length_high: 10.
out_low: 10.
out_high: 20.
bias_low: -3.
bias_high: 3.
train_batch_size: 4
test_batch_size: 4
sample_dataset_size: 400
standardize: true
num_samples_used: 400
num_workers: 23
pin_memory: true
persistent_workers: true
prefetch_factor: 2
factor_epoch: 32
num_sems: 0
shuffle: true
num_interventions: 2
num_intervention_samples: 100
proportion_treatment: 0.
sample_counterfactuals: false
20 changes: 20 additions & 0 deletions research_experiments/cond_fip/src/cond_fip/entrypoint_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.cli import LightningCLI


def main():
cli = LightningCLI(
model_class=pl.LightningModule,
datamodule_class=pl.LightningDataModule,
trainer_class=Trainer,
subclass_mode_data=True,
subclass_mode_model=True,
save_config_kwargs={"overwrite": True},
run=False,
)
cli.trainer.test(cli.model, datamodule=cli.datamodule)


if __name__ == "__main__":
main()
Loading

0 comments on commit 220e47d

Please # to comment.