Online Training

A variation of SBI is Sequential Neural Posterior Estimation (SNPE), where the model is trained online, i.e., in multiple rounds. In each round, simulations are generated from the current posterior estimate, and the model is updated with these new simulations. This approach can be more efficient in terms of simulation budget, especially when the prior is broad and the posterior is narrow.

However, the model is no longer amortized, as it is specialized to the specific observation after training. This means that a new model must be trained for each new observation, which can be computationally expensive if many observations need to be analyzed.

[1]:
from synference import SBI_Fitter, test_data_dir

fitter = SBI_Fitter.init_from_hdf5(
    model_name="test", hdf5_path=f"{test_data_dir}/example_model_library.hdf5"
)

fitter.create_feature_array();
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
2025-11-13 20:02:46,634 | synference | INFO     | ---------------------------------------------
2025-11-13 20:02:46,635 | synference | INFO     | Features: 8 features over 100 samples
2025-11-13 20:02:46,636 | synference | INFO     | ---------------------------------------------
2025-11-13 20:02:46,636 | synference | INFO     | Feature: Min - Max
2025-11-13 20:02:46,637 | synference | INFO     | ---------------------------------------------
2025-11-13 20:02:46,638 | synference | INFO     | JWST/NIRCam.F070W: 7.131974 - 42.758 AB
2025-11-13 20:02:46,638 | synference | INFO     | JWST/NIRCam.F090W: 7.108530 - 39.933 AB
2025-11-13 20:02:46,638 | synference | INFO     | JWST/NIRCam.F115W: 7.012560 - 38.354 AB
2025-11-13 20:02:46,639 | synference | INFO     | JWST/NIRCam.F150W: 6.969396 - 36.997 AB
2025-11-13 20:02:46,640 | synference | INFO     | JWST/NIRCam.F200W: 7.133157 - 35.470 AB
2025-11-13 20:02:46,640 | synference | INFO     | JWST/NIRCam.F277W: 7.670149 - 33.243 AB
2025-11-13 20:02:46,641 | synference | INFO     | JWST/NIRCam.F356W: 8.072730 - 32.490 AB
2025-11-13 20:02:46,641 | synference | INFO     | JWST/NIRCam.F444W: 8.353975 - 31.965 AB
2025-11-13 20:02:46,641 | synference | INFO     | ---------------------------------------------

Now we will recreate the simulator from the library data stored in the HDF5 file.

[2]:
fitter.recreate_simulator_from_library(
    override_library_path=f"{test_data_dir}/example_model_library.hdf5",
    override_grid_path="test_grid.hdf5",
);
2025-11-13 20:02:46,653 | synference | INFO     | Overriding internal library name from provided file path.
2025-11-13 20:02:46,988 | synference | INFO     | Simulator recreated from library at /home/runner/.local/share/Synthesizer/data/synference/example_model_library.hdf5.

Now we can choose an observation for our multiple rounds of online training. Here, we will randomly select one of the simulations from our library as the observation.

[3]:
index = 20
sample = fitter.feature_array[index]
true_params = fitter.fitted_parameter_array[index]

sample
[3]:
array([28.922905, 27.99003 , 27.05276 , 26.493101, 25.965355, 25.390919,
       24.99644 , 24.839458], dtype=float32)

Now we can run our online SBI model - to do this we set learning_type to ‘online’, specify the number of online rounds with ‘num_online_rounds’, and provide our chosen observation with ‘online_training_xobs’. We also set the number of simulations per round with ‘num_simulations’. The engine is set to ‘SNPE’ to use Sequential Neural Posterior Estimation, but SNLE and SNRE are also available for online training.

[4]:
fitter.run_single_sbi(
    online_training_xobs=sample,
    learning_type="online",
    engine="SNPE",
    num_simulations=1000,
    num_online_rounds=4,
    override_prior_ranges={"peak_age": (10, 1000)},
    evaluate_model=False,
    plot=False,
);
2025-11-13 20:02:47,005 | synference | INFO     | ---------------------------------------------
2025-11-13 20:02:47,006 | synference | INFO     | Prior ranges:
2025-11-13 20:02:47,007 | synference | INFO     | ---------------------------------------------
2025-11-13 20:02:47,007 | synference | INFO     | redshift: 0.00 - 4.98 [dimensionless]
2025-11-13 20:02:47,008 | synference | INFO     | log_mass: 8.01 - 11.99 [log10_Msun]
2025-11-13 20:02:47,009 | synference | INFO     | tau_v: 0.01 - 3.00 [mag]
2025-11-13 20:02:47,010 | synference | INFO     | tau: 0.11 - 1.98 [dimensionless]
2025-11-13 20:02:47,010 | synference | INFO     | peak_age: 10.00 - 1000.00 [Myr]
2025-11-13 20:02:47,011 | synference | INFO     | log10metallicity: -3.98 - -1.41 [log10(Zmet)]
2025-11-13 20:02:47,012 | synference | INFO     | ---------------------------------------------
2025-11-13 20:02:47,013 | synference | INFO     | Processing prior...
2025-11-13 20:02:47,016 | synference | INFO     | Using provided xobs for online training: (8,)
Wrapping GalaxySimulator for SBI...
2025-11-13 20:02:47,017 | synference | INFO     | Creating mdn network with SNPE engine and sbi backend.
2025-11-13 20:02:47,018 | synference | INFO     |      hidden_features: 50
2025-11-13 20:02:47,018 | synference | INFO     |      num_components: 4
2025-11-13 20:02:47,019 | synference | INFO     | Training on cpu.
INFO:root:MODEL INFERENCE CLASS: SNPE
INFO:root:The first round of inference will simulate from the given proposal or prior.
Running 1000 simulations.: 100%|██████████| 1000/1000 [00:42<00:00, 23.72it/s]
INFO:root:Running round 1 / 4
INFO:root:Training model 1 / 1.
 Neural network successfully converged after 224 epochs.
1106it [00:00, 213371.06it/s]
Running 1000 simulations.: 100%|██████████| 1000/1000 [00:42<00:00, 23.65it/s]
INFO:root:Running round 2 / 4
INFO:root:Training model 1 / 1.
Using SNPE-C with atomic loss
 Neural network successfully converged after 41 epochs.
1495it [00:00, 262165.92it/s]
Running 1000 simulations.: 100%|██████████| 1000/1000 [00:37<00:00, 26.74it/s]
INFO:root:Running round 3 / 4
INFO:root:Training model 1 / 1.
Using SNPE-C with atomic loss
 Neural network successfully converged after 15 epochs.
1432it [00:00, 247028.19it/s]
Running 1000 simulations.: 100%|██████████| 1000/1000 [00:33<00:00, 29.89it/s]
INFO:root:Running round 4 / 4
INFO:root:Training model 1 / 1.
Using SNPE-C with atomic loss
 Training neural network. Epochs trained: 14
INFO:root:It took 174.97990894317627 seconds to train models.
INFO:root:Saving model to /opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/models/test
 Neural network successfully converged after 15 epochs.2025-11-13 20:06:24,182 | synference | INFO     | Time to train model(s): 0:03:37.176911
2025-11-13 20:06:24,194 | synference | INFO     | Saved model parameters to /opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/models/test/test_20251113_200247_params.pkl.

Now we can specifically see how the model performs on the conditioned observation.

[5]:
samples = fitter.sample_posterior(X_test=sample)

fitter.plot_posterior(
    X=sample,
    y=samples,
    num_samples=1000,
)
Sampling from posterior: 100%|██████████| 1/1 [00:00<00:00, 130.08it/s]
2025-11-13 20:06:24,216 | synference | INFO     | [  4.25622272  10.32305241   1.77526617   0.91543102 344.62731934
  -3.95824599]

1419it [00:00, 158286.15it/s]
INFO:root:Saving single posterior plot to /opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/models/test/plots/test_0_plot_single_posterior.jpg...
[5]:
<seaborn.axisgrid.PairGrid at 0x7fe288708b20>
../_images/sbi_train_online_training_9_4.png