Online Training

A variation of SBI is Sequential Neural Posterior Estimation (SNPE), where the model is trained online, i.e., in multiple rounds. In each round, simulations are generated from the current posterior estimate, and the model is updated with these new simulations. This approach can be more efficient in terms of simulation budget, especially when the prior is broad and the posterior is narrow.

However, the model is no longer amortized, as it is specialized to the specific observation after training. This means that a new model must be trained for each new observation, which can be computationally expensive if many observations need to be analyzed.

[1]:
from synference import SBI_Fitter, test_data_dir

fitter = SBI_Fitter.init_from_hdf5(
    model_name="test", hdf5_path=f"{test_data_dir}/example_model_library.hdf5"
)

fitter.create_feature_array();
/opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
2026-05-01 18:06:21,920 | synference | INFO     | ---------------------------------------------
2026-05-01 18:06:21,921 | synference | INFO     | Features: 8 features over 100 samples
2026-05-01 18:06:21,922 | synference | INFO     | ---------------------------------------------
2026-05-01 18:06:21,922 | synference | INFO     | Feature: Min - Max
2026-05-01 18:06:21,923 | synference | INFO     | ---------------------------------------------
2026-05-01 18:06:21,923 | synference | INFO     | JWST/NIRCam.F070W: 7.131974 - 42.758 AB
2026-05-01 18:06:21,924 | synference | INFO     | JWST/NIRCam.F090W: 7.108530 - 39.933 AB
2026-05-01 18:06:21,924 | synference | INFO     | JWST/NIRCam.F115W: 7.012560 - 38.354 AB
2026-05-01 18:06:21,925 | synference | INFO     | JWST/NIRCam.F150W: 6.969396 - 36.997 AB
2026-05-01 18:06:21,926 | synference | INFO     | JWST/NIRCam.F200W: 7.133157 - 35.470 AB
2026-05-01 18:06:21,926 | synference | INFO     | JWST/NIRCam.F277W: 7.670149 - 33.243 AB
2026-05-01 18:06:21,927 | synference | INFO     | JWST/NIRCam.F356W: 8.072730 - 32.490 AB
2026-05-01 18:06:21,927 | synference | INFO     | JWST/NIRCam.F444W: 8.353975 - 31.965 AB
2026-05-01 18:06:21,928 | synference | INFO     | ---------------------------------------------

Now we will recreate the simulator from the library data stored in the HDF5 file.

[2]:
fitter.recreate_simulator_from_library(
    override_library_path=f"{test_data_dir}/example_model_library.hdf5",
    override_grid_path="test_grid.hdf5",
);
2026-05-01 18:06:21,939 | synference | INFO     | Overriding internal library name from provided file path.
2026-05-01 18:06:22,180 | synference | INFO     | Simulator recreated from library at /home/runner/.local/share/Synthesizer/data/synference/example_model_library.hdf5.

Now we can choose an observation for our multiple rounds of online training. Here, we will randomly select one of the simulations from our library as the observation.

[3]:
index = 20
sample = fitter.feature_array[index]
true_params = fitter.fitted_parameter_array[index]

sample
[3]:
array([28.922905, 27.99003 , 27.05276 , 26.493101, 25.965355, 25.390919,
       24.99644 , 24.839458], dtype=float32)

Now we can run our online SBI model - to do this we set learning_type to ‘online’, specify the number of online rounds with ‘num_online_rounds’, and provide our chosen observation with ‘online_training_xobs’. We also set the number of simulations per round with ‘num_simulations’. The engine is set to ‘SNPE’ to use Sequential Neural Posterior Estimation, but SNLE and SNRE are also available for online training.

[4]:
fitter.run_single_sbi(
    online_training_xobs=sample,
    learning_type="online",
    engine="SNPE",
    num_simulations=1000,
    num_online_rounds=4,
    override_prior_ranges={"peak_age": (10, 1000)},
    evaluate_model=False,
    plot=False,
);
2026-05-01 18:06:22,197 | synference | INFO     | ---------------------------------------------
2026-05-01 18:06:22,198 | synference | INFO     | Prior ranges:
2026-05-01 18:06:22,199 | synference | INFO     | ---------------------------------------------
2026-05-01 18:06:22,200 | synference | INFO     | redshift: 0.00 - 4.98 [dimensionless]
2026-05-01 18:06:22,200 | synference | INFO     | log_mass: 8.01 - 11.99 [log10_Msun]
2026-05-01 18:06:22,201 | synference | INFO     | tau_v: 0.01 - 3.00 [mag]
2026-05-01 18:06:22,201 | synference | INFO     | tau: 0.11 - 1.98 [dimensionless]
2026-05-01 18:06:22,202 | synference | INFO     | peak_age: 10.00 - 1000.00 [Myr]
2026-05-01 18:06:22,202 | synference | INFO     | log10metallicity: -3.98 - -1.41 [log10(Zmet)]
2026-05-01 18:06:22,203 | synference | INFO     | ---------------------------------------------
2026-05-01 18:06:22,204 | synference | INFO     | Processing prior...
2026-05-01 18:06:22,208 | synference | INFO     | Using provided xobs for online training: (8,)
Wrapping GalaxySimulator for SBI...
2026-05-01 18:06:22,209 | synference | INFO     | Creating mdn network with SNPE engine and sbi backend.
2026-05-01 18:06:22,210 | synference | INFO     |      hidden_features: 50
2026-05-01 18:06:22,211 | synference | INFO     |      num_components: 4
2026-05-01 18:06:22,212 | synference | INFO     | Training on cpu.
INFO:root:MODEL INFERENCE CLASS: SNPE
INFO:root:The first round of inference will simulate from the given proposal or prior.
Running 1000 simulations.:   0%|          | 0/1000 [00:00<?, ?it/s]
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/site-packages/synference/sbi_runner.py:4936, in SBI_Fitter.run_single_sbi(self, train_test_fraction, random_seed, backend, engine, train_indices, test_indices, n_nets, model_type, hidden_features, num_components, num_transforms, training_batch_size, learning_rate, validation_fraction, stop_after_epochs, clip_max_norm, additional_model_args, save_model, verbose, prior_method, out_dir, plot, name_append, feature_scalar, target_scalar, set_self, learning_type, simulator, num_simulations, num_online_rounds, initial_training_from_library, override_prior_ranges, online_training_xobs, load_existing_model, use_existing_indices, evaluate_model, save_method, num_posterior_draws_per_sample, embedding_net, custom_config_yaml, sql_db_path)
   4935 # Train with normal output
-> 4936 posteriors, stats = trainer(loader, **extra_args)
   4938 if posteriors is None:

File /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/site-packages/ili/inference/runner_sbi.py:403, in SBIRunnerSequential.__call__(self, loader, seed)
    400 logging.info(
    401     "The first round of inference will simulate from the given "
    402     "proposal or prior.")
--> 403 theta, x = loader.simulate(self.proposal)
    404 x = torch.Tensor(x).to(self.device)

File /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/site-packages/ili/dataloaders/loaders.py:345, in SBISimulator.simulate(self, proposal)
    344 theta = proposal.sample((self.num_simulations,)).cpu()
--> 345 x = simulate_in_batches(self.simulator, theta)
    347 # Get device returns -1 for cpu, integers for CUDA tensors

File /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/site-packages/sbi/simulators/simutils.py:88, in simulate_in_batches(simulator, theta, sim_batch_size, num_workers, seed, show_progress_bars)
     87 for batch in batches:
---> 88     simulation_outputs.append(simulator(batch))
     89     pbar.update(sim_batch_size)

File /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/site-packages/synference/sbi_runner.py:4782, in SBI_Fitter.run_single_sbi.<locals>.run_simulator(params, return_type)
   4778     params = {
   4779         self.fitted_parameter_names[i]: params[i]
   4780         for i in range(len(self.fitted_parameter_names))
   4781     }
-> 4782 phot = simulator(params)
   4783 if return_type == "tensor":  #

File /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/site-packages/synference/library.py:6001, in GalaxySimulator.__call__(self, params)
   6000 """Call the simulator with parameters to get photometry."""
-> 6001 return self.simulate(params)

File /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/site-packages/synference/library.py:5788, in GalaxySimulator.simulate(self, params)
   5787 if "photo_fnu" in self.output_type:
-> 5788     fluxes = convert(outputs["photo_fnu"])
   5789     outputs["photo_fnu"] = copy.deepcopy(fluxes)

File /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/site-packages/synference/library.py:5785, in GalaxySimulator.simulate.<locals>.convert(f)
   5784 def convert(f):
-> 5785     return -2.5 * np.log10(f.to(Jy).value) + 8.9

AttributeError: 'NoneType' object has no attribute 'to'

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In[4], line 1
----> 1 fitter.run_single_sbi(
      2     online_training_xobs=sample,
      3     learning_type="online",
      4     engine="SNPE",
      5     num_simulations=1000,
      6     num_online_rounds=4,
      7     override_prior_ranges={"peak_age": (10, 1000)},
      8     evaluate_model=False,
      9     plot=False,
     10 );

File /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/site-packages/synference/sbi_runner.py:4941, in SBI_Fitter.run_single_sbi(self, train_test_fraction, random_seed, backend, engine, train_indices, test_indices, n_nets, model_type, hidden_features, num_components, num_transforms, training_batch_size, learning_rate, validation_fraction, stop_after_epochs, clip_max_norm, additional_model_args, save_model, verbose, prior_method, out_dir, plot, name_append, feature_scalar, target_scalar, set_self, learning_type, simulator, num_simulations, num_online_rounds, initial_training_from_library, override_prior_ranges, online_training_xobs, load_existing_model, use_existing_indices, evaluate_model, save_method, num_posterior_draws_per_sample, embedding_net, custom_config_yaml, sql_db_path)
   4939             logger.warning("Exiting training as posteriors are None.")
   4940 except Exception as e:
-> 4941     raise RuntimeError(f"Error during SBI training: {str(e)}")
   4943 if set_self:
   4944     self._prior = prior

RuntimeError: Error during SBI training: 'NoneType' object has no attribute 'to'

Now we can specifically see how the model performs on the conditioned observation.

[5]:
samples = fitter.sample_posterior(X_test=sample)

fitter.plot_posterior(
    X=sample,
    y=samples,
    num_samples=1000,
)
Sampling from posterior:   0%|          | 0/1 [00:00<?, ?it/s]
2026-05-01 18:06:24,463 | synference | ERROR    | Error occurred while sampling for sample 0: 'NoneType' object has no attribute 'sample'
Sampling from posterior: 100%|██████████| 1/1 [00:00<00:00, 1436.90it/s]
2026-05-01 18:06:24,465 | synference | INFO     | [nan nan nan nan nan nan]

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[5], line 3
      1 samples = fitter.sample_posterior(X_test=sample)
----> 3 fitter.plot_posterior(
      4     X=sample,
      5     y=samples,
      6     num_samples=1000,
      7 )

File /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/site-packages/synference/sbi_runner.py:7047, in SBI_Fitter.plot_posterior(self, ind, X, y, seed, num_samples, sample_method, sample_kwargs, plots_dir, plot_kwargs, posteriors, **kwargs)
   7044 # give xobs a batch dimension
   7045 x_obs = torch.tensor(np.array([X[ind]]), dtype=torch.float32, device=self.device)
-> 7047 fig = metric(
   7048     posterior=posteriors,
   7049     x_obs=x_obs,
   7050     theta_fid=draw,
   7051     plot_kws=plot_kwargs,
   7052     signature=f"{self.name}_{ind}_",
   7053     **kwargs,
   7054 )
   7055 if self.observation_type == "photometry":
   7056     text = "\n".join(
   7057         [
   7058             f"{self.feature_names[i]}: {X[ind][i]:.3f} {self.feature_units[i]}"
   7059             for i in range(len(X[ind]))
   7060         ]
   7061     )

File /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/site-packages/ili/validation/metrics.py:191, in PlotSinglePosterior.__call__(self, posterior, x, theta, x_obs, theta_fid, signature, lower, upper, plot_kws, grid, name, **grid_kws)
    188     theta_fid = theta[ind]
    190 # sample from the posterior
--> 191 sampler = self._build_sampler(posterior)
    192 samples = sampler.sample(self.num_samples, x=x_obs, progress=True)
    193 ndim = samples.shape[-1]

File /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/site-packages/ili/validation/metrics.py:115, in _SampleBasedMetric._build_sampler(self, posterior)
    113         return DirectSampler(posterior)
    114     else:
--> 115         raise ValueError(
    116             'Direct sampling is only available for DirectPosteriors')
    117 elif self.sample_method == 'vi':
    118     return VISampler(posterior, **self.sample_params)

ValueError: Direct sampling is only available for DirectPosteriors