The Simformer

Gloeckler et al. 2024 introduced the Simformer, for ‘all in one simulation based inference’.

They use a novel probablistic diffusion model with a transformer architecture which learns the full joint distribution of parameters and data, allowing for fast, amortized Bayesian inference, without specifying beforehand which parameters are of interest. This makes the simformer approach particularly well suited to missing data, as the use of an attention mechanism allows the model sample from an arbitrary conditional distribution excluding any missing data.

The Simformer is currently implemented in two ways. The first way, is a seperate class called Simformer_Fitter, which requires the user to install a fork of the original simformer repo, which requires quite specific versions of CUDA, PyTorch and jax to work.

The second way uses the new simformer implementation in the sbi package, which is currently only available in a pull request, but should be merged into the main branch soon. For now we will deal with this implementation, as it is much easier to install and use. There are examples of using the original approach in the examples/simformer folder of the synference repo.

There are some limitations to the current sbi simformer implementation. Currently it doesn’t seem to support serialization (due to the use of lambda functions), so models cannot be saved and loaded like other synference SBI models.

[1]:
from synference import SBI_Fitter, test_data_dir

fitter = SBI_Fitter.init_from_hdf5(
    model_name="test", hdf5_path=f"{test_data_dir}/example_model_library.hdf5"
)
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

We will create our training array as normal using the SBI_Fitter class.

[2]:
fitter.create_feature_array();
2025-11-13 20:04:28,265 | synference | INFO     | ---------------------------------------------
2025-11-13 20:04:28,266 | synference | INFO     | Features: 8 features over 100 samples
2025-11-13 20:04:28,268 | synference | INFO     | ---------------------------------------------
2025-11-13 20:04:28,269 | synference | INFO     | Feature: Min - Max
2025-11-13 20:04:28,271 | synference | INFO     | ---------------------------------------------
2025-11-13 20:04:28,273 | synference | INFO     | JWST/NIRCam.F070W: 7.131974 - 42.758 AB
2025-11-13 20:04:28,274 | synference | INFO     | JWST/NIRCam.F090W: 7.108530 - 39.933 AB
2025-11-13 20:04:28,277 | synference | INFO     | JWST/NIRCam.F115W: 7.012560 - 38.354 AB
2025-11-13 20:04:28,279 | synference | INFO     | JWST/NIRCam.F150W: 6.969396 - 36.997 AB
2025-11-13 20:04:28,280 | synference | INFO     | JWST/NIRCam.F200W: 7.133157 - 35.470 AB
2025-11-13 20:04:28,281 | synference | INFO     | JWST/NIRCam.F277W: 7.670149 - 33.243 AB
2025-11-13 20:04:28,283 | synference | INFO     | JWST/NIRCam.F356W: 8.072730 - 32.490 AB
2025-11-13 20:04:28,285 | synference | INFO     | JWST/NIRCam.F444W: 8.353975 - 31.965 AB
2025-11-13 20:04:28,286 | synference | INFO     | ---------------------------------------------

There are a few parameters to be aware of:

  1. sde_type : The type of SDE to use. Options are ‘ve` (variance exploding), ‘vp’ (variance preserving) or ‘subvp’ (sub variance preserving). This doesn’t do anything for the flow based simformer.

  2. simformer_type: ‘score’ or ‘flow’- whether to use a score based or flow based simformer. 3 learning_rate: The learning rate to use for training.

  3. model_kwargs: A dictionary of additional keyword arguments to pass to the simformer model. These can include:

    • num_layers: The number of transformer layers to use.

    • num_heads: The number of attention heads to use.

    • dim_val: The dimension of the value vectors in the attention mechanism.

    • dim_id: The dimension of the identity vectors in the attention mechanism.

    • mlp_ratio : The ratio of the hidden dimension to the input dimension in the MLP layers.

    • hidden_features: The number of hidden features to use in the MLP layers.

    • time_embedding_dim: The dimension of the time embedding.

Like all other synference models we can also set the training_batch_size, validation_fraction, stop_after_epochs and clip_max_norm parameters.

fitter.run_single_simformer(
    name_append="simformer_test",
    sde_type="ve",
    simformer_type="score",
    learning_rate=1e-5,
    training_batch_size=64,
    model_kwargs={
        "hidden_features": 128,
        "n_layers": 6,
        "dim_val": 64,
        "dim_id": 64,
        "mlp_ratio": 4,
        "time_embedding_dim": 32,
        "num_heads": 4,
    },
    load_existing_model=False,
    validation_fraction=0.1,
    stop_after_epochs=30,
    plot=False,  # Currently the LtU-ILI plotting doesn't work with the simformer
)