Basic SBI Model Training¶
In this tutorial, we will walk through the process of training a simulation-based inference (SBI) model using the synference package. We will assume we already have a library of simulations and corresponding parameters.
First let’s consider the training process more generally. The main steps involved in training an SBI model are:
Prepare the Simulation Data: Gather a set of simulations and their corresponding parameters.
Choose a Model Architecture: Select an appropriate neural network architecture for the SBI model.
Define the Training Procedure: Set up the training loop, loss function, and optimization algorithm.
Train the Model: Run the training process and monitor performance.
Evaluate the Model: Assess the trained model’s performance on a validation set.
Now let’s look at how to implement these steps using synference.
[1]:
from synference import SBI_Fitter, test_data_dir
/opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
From the output of the library generation tutorials, we should have a HDF5 file called test_model_library.hdf5 in our ‘libraries/’ directory. If you don’t have this file, please refer to the Library Generation tutorial.
We can directly use this file to instantiate a SBI_Fitter instance, which is the class which handles training and evaluating SBI models in synference.
[2]:
fitter = SBI_Fitter.init_from_hdf5(
model_name="test", hdf5_path=f"{test_data_dir}/example_model_library.hdf5"
)
Feature and Parameter Arrays¶
Now this fitter has loaded the generated observations and parameters from the HDF5 file. Note that the data is not yet normalized or set up with the correct features for training. We will handle that in the next steps.
We can see the names of the observations.
[3]:
print(fitter.raw_observation_names)
['JWST/NIRCam.F070W' 'JWST/NIRCam.F090W' 'JWST/NIRCam.F115W'
'JWST/NIRCam.F150W' 'JWST/NIRCam.F200W' 'JWST/NIRCam.F277W'
'JWST/NIRCam.F356W' 'JWST/NIRCam.F444W']
The names of the features:
[4]:
print(fitter.parameter_names)
['redshift' 'log_mass' 'tau_v' 'tau' 'peak_age' 'log10metallicity']
and any associated parameters units:
[5]:
print(fitter.parameter_units)
['dimensionless' 'log10_Msun' 'mag' 'dimensionless' 'Myr' 'log10(Zmet)']
The actual array itself is stored in the parameter_array attribute, which we will just print the first 10 entries of here:
[6]:
print(fitter.parameter_array[:10])
[[ 2.05416274e+00 1.04035606e+01 1.22474718e+00 9.36861515e-01
1.69158648e+03 -3.13879132e+00]
[ 2.00016856e+00 1.04899187e+01 1.28473628e+00 1.60899329e+00
3.02503141e+03 -3.87390995e+00]
[ 2.18753195e+00 1.15228844e+01 2.18956232e+00 1.20334947e+00
1.86831172e+03 -3.46571136e+00]
[ 2.11274648e+00 9.81967258e+00 5.74459910e-01 8.06611896e-01
2.01560364e+03 -1.89818966e+00]
[ 2.44652700e+00 9.17687321e+00 1.38670945e+00 1.77724314e+00
3.07231417e+02 -2.88550878e+00]
[ 6.84367716e-01 1.14148254e+01 2.61994505e+00 3.89508367e-01
3.59630292e+03 -1.86997616e+00]
[ 1.49105906e-01 1.00042429e+01 1.95352912e+00 1.27394879e+00
1.13156527e+04 -2.40571117e+00]
[ 1.71895242e+00 1.17945662e+01 1.94888246e+00 2.07960412e-01
1.57720040e+03 -3.86293936e+00]
[ 2.84610391e+00 9.41966534e+00 2.85131145e+00 8.91395509e-01
1.19775692e+03 -3.77944398e+00]
[ 2.45635724e+00 1.03770952e+01 1.80910027e+00 2.73230791e-01
2.36855774e+03 -1.71073520e+00]]
And a similar logic for the observations:
[7]:
print(fitter.raw_observation_grid[:, :10])
[[1.55361035e+02 6.62787615e+01 2.93631144e+02 6.01755184e+01
6.08683670e+00 2.52897265e+03 6.31139334e+03 5.55954909e+02
3.01451722e-01 7.46139935e-02]
[2.40530441e+02 1.13258320e+02 5.88940061e+02 7.87135295e+01
9.31069105e+00 6.79137932e+03 1.07562514e+04 1.35413814e+03
7.49976727e-01 5.30070034e-01]
[5.07742016e+02 2.91263414e+02 1.41441540e+03 1.44744857e+02
1.48766062e+01 1.40988387e+04 1.63680472e+04 5.82525117e+03
1.77367012e+00 5.52317547e+00]
[1.06025886e+03 5.80749330e+02 3.79225364e+03 3.03962040e+02
3.73309422e+01 2.65854161e+04 2.33425604e+04 1.08011228e+04
5.90461870e+00 3.40313785e+01]
[1.49439723e+03 8.83105895e+02 6.47315587e+03 4.27806729e+02
5.66841099e+01 4.68630971e+04 3.13794258e+04 1.76892096e+04
1.37485025e+01 1.01136398e+02]
[1.90476605e+03 1.16731581e+03 9.31606212e+03 6.26340520e+02
8.70344677e+01 8.13157865e+04 2.26920392e+04 2.70938604e+04
2.74803669e+01 2.37788360e+02]
[2.36584297e+03 1.41584377e+03 1.28469681e+04 8.12406290e+02
1.16919381e+02 7.99288041e+04 1.74361191e+04 3.40880141e+04
3.78438853e+01 3.99970276e+02]
[2.70394000e+03 1.58965237e+03 1.54602400e+04 9.96181433e+02
1.33611502e+02 5.93842217e+04 1.33736683e+04 4.01845570e+04
5.17631635e+01 5.66103013e+02]]
The first step is to turn this raw library of photometric observations into a set of features that can be used for training. This is done with the fitter.create_feature_array method.
This method handles the following tasks:
Normalizing the observations (e.g., converting magnitudes to fluxes, normalizing by a reference band, etc.)
Creating features from the observations (e.g., colors, ratios, etc.)
Removing photometric bands in the library from the feature array that are not present in the observations.
Handling missing data (e.g., setting features to NaN if any of the required bands are missing)
Adding additional features (e.g., redshift) from the parameter array to the feature array.
Adding realistic noise to the features based on a provided noise model (see the Noise Models tutorial for more details).
Adding photometric uncertainties to the feature array.
The default configuration of this method doesn’t do all of these however. By default, all photometric bands are kept, no additional features are added, and no noise is added. The default normalization is to convert the raw array of photometry to AB magnitudes only.
We call the method below and we can see it prints information about the features it creates.
[8]:
fitter.create_feature_array();
2026-05-01 18:06:43,587 | synference | INFO | ---------------------------------------------
2026-05-01 18:06:43,588 | synference | INFO | Features: 8 features over 100 samples
2026-05-01 18:06:43,589 | synference | INFO | ---------------------------------------------
2026-05-01 18:06:43,590 | synference | INFO | Feature: Min - Max
2026-05-01 18:06:43,591 | synference | INFO | ---------------------------------------------
2026-05-01 18:06:43,592 | synference | INFO | JWST/NIRCam.F070W: 7.131974 - 42.758 AB
2026-05-01 18:06:43,592 | synference | INFO | JWST/NIRCam.F090W: 7.108530 - 39.933 AB
2026-05-01 18:06:43,593 | synference | INFO | JWST/NIRCam.F115W: 7.012560 - 38.354 AB
2026-05-01 18:06:43,593 | synference | INFO | JWST/NIRCam.F150W: 6.969396 - 36.997 AB
2026-05-01 18:06:43,594 | synference | INFO | JWST/NIRCam.F200W: 7.133157 - 35.470 AB
2026-05-01 18:06:43,595 | synference | INFO | JWST/NIRCam.F277W: 7.670149 - 33.243 AB
2026-05-01 18:06:43,595 | synference | INFO | JWST/NIRCam.F356W: 8.072730 - 32.490 AB
2026-05-01 18:06:43,596 | synference | INFO | JWST/NIRCam.F444W: 8.353975 - 31.965 AB
2026-05-01 18:06:43,596 | synference | INFO | ---------------------------------------------
We will proceed with the default configuration for now. More advanced configurations will be covered in later tutorials. Using different normalizations/units or adding additional features can have a significant impact on the performance of the trained SBI model.
Before we do any fitting, we can inspect the feature and parameter arrays to see the distribution of the data.
Firstly we can look at the feature array, and see the distribution of the photometry given our model and feature array configuration. The below figure shows a histogram of each feature in the feature array.
[9]:
fitter.plot_histogram_feature_array(bins=20);
2026-05-01 18:06:44,235 | synference | INFO | saving /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/plots//feature_histogram.png
Secondly we can look at the parameter array, and see the distribution of the parameters given our model and parameter configuration. The below figure shows a histogram of each parameter in the parameter array. We can see that the parameters are uniformly distributed, as expected from our library generation configuration.
[10]:
fitter.plot_histogram_parameter_array();
2026-05-01 18:06:46,131 | synference | INFO | saving /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/plots//param_histogram.png
Training an SBI Model¶
SBI model training is handled with the fitter.train_single_sbi method. This method handles the following tasks:
Creating a prior from the parameter array.
Setting up the neural density estimator (NDE) for the SBI model.
Training the SBI model.
Saving the trained model to disk.
Plotting diagnostics of the trained model.
We will cover the various options for different SBI configurations in later tutorials. For now, we will proceed with the default configuration.
synference is built on top of the LtU-ILI package, which utilizes sbi and lampe for the underlying SBI functionality. The default NDE is a Masked Autoregressive Flow (MAF) from the sbi package. The default prior proposal is a uniform prior over the range of the parameters in the parameter array.
The primary arguments to the fitter.train_single_sbi method are:
train_test_fraction: The fraction of the data to use for training. The rest is used for validation. The default is 0.8.validation_fraction: The fraction of the training data to use for validation during training. The default is 0.2.backend: The backend to use for training. Eithersbiorlampe. The default issbi.hidden_features: The number of hidden features in the NDE. The default is 50.num_components/transforms: The number of components or transforms in the NDE. The default is 4.training_batch_size: The batch size for training. The default is 64.stop_after_epochs: The number of epochs with no improvement to stop training. The default is 15.
There are other methods to turn on or off plotting, model saving, validation, etc. See the docstring for more details.
Now we will run the training, and quite a lot of things will be printed. We are setting name_append to ‘test_1’ so that the trained model is saved with a unique name. If left as the default a timestamp will be used.
[11]:
posterior_model, stats = fitter.run_single_sbi(
name_append="test_1", random_seed=42, hidden_features=256, num_components=64
)
2026-05-01 18:06:47,355 | synference | INFO | Splitting dataset with 100 samples into trainingand testing sets with 0.80 train fraction.
2026-05-01 18:06:47,357 | synference | INFO | ---------------------------------------------
2026-05-01 18:06:47,358 | synference | INFO | Prior ranges:
2026-05-01 18:06:47,360 | synference | INFO | ---------------------------------------------
2026-05-01 18:06:47,362 | synference | INFO | redshift: 0.00 - 4.98 [dimensionless]
2026-05-01 18:06:47,362 | synference | INFO | log_mass: 8.01 - 11.99 [log10_Msun]
2026-05-01 18:06:47,363 | synference | INFO | tau_v: 0.01 - 3.00 [mag]
2026-05-01 18:06:47,365 | synference | INFO | tau: 0.11 - 1.98 [dimensionless]
2026-05-01 18:06:47,365 | synference | INFO | peak_age: 9.03 - 11315.65 [Myr]
2026-05-01 18:06:47,366 | synference | INFO | log10metallicity: -3.98 - -1.41 [log10(Zmet)]
2026-05-01 18:06:47,366 | synference | INFO | ---------------------------------------------
2026-05-01 18:06:47,368 | synference | INFO | Processing prior...
2026-05-01 18:06:47,371 | synference | INFO | Creating mdn network with NPE engine and sbi backend.
2026-05-01 18:06:47,371 | synference | INFO | hidden_features: 256
2026-05-01 18:06:47,372 | synference | INFO | num_components: 64
2026-05-01 18:06:47,373 | synference | INFO | Training on cpu.
INFO:root:MODEL INFERENCE CLASS: NPE
INFO:root:Training model 1 / 1.
Training neural network. Epochs trained: 121
INFO:root:It took 5.473414897918701 seconds to train models.
INFO:root:Saving model to /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test
Neural network successfully converged after 130 epochs.2026-05-01 18:06:52,862 | synference | INFO | Time to train model(s): 0:00:05.506787
2026-05-01 18:06:52,872 | synference | INFO | Saved model parameters to /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/test_test_1_params.pkl.
2026-05-01 18:06:53,082 | synference | INFO | [ 5.59471250e-01 1.16237526e+01 2.68136740e-01 1.74207807e+00
9.19759094e+02 -3.81416941e+00]
1343it [00:00, 242883.33it/s]
INFO:root:Saving single posterior plot to /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/plots/test_1/test_18_plot_single_posterior.jpg...
2026-05-01 18:07:05,627 | synference | INFO | shapes: X:(20, 8), y:(20, 6)
100%|██████████| 20/20 [00:00<00:00, 354.65it/s]
INFO:root:Saving posterior samples to /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/plots/test_1/posterior_samples.npy...
INFO:root:Saving coverage plot to /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/plots/test_1/plot_coverage.jpg...
INFO:root:Saving ranks histogram to /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/plots/test_1/ranks_histogram.jpg...
INFO:root:Mean logprob: -1.4854e+01Median logprob: -1.4163e+01
INFO:root:Saving true logprobs to /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/plots/test_1/true_logprobs.npy...
INFO:root:Saving true logprobs plot to /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/plots/test_1/plot_true_logprobs.jpg...
INFO:matplotlib.mathtext:Substituting symbol E from STIXNonUnicode
100%|██████████| 100/100 [00:00<00:00, 735.81it/s]
2026-05-01 18:07:06,922 | synference | INFO | Evaluating model...
2026-05-01 18:07:06,923 | synference | WARNING | Transposing samples to match shape (num_objects, num_samples, num_parameters).
100%|██████████| 200/200 [00:00<00:00, 716.99it/s]
Log prob: 100%|██████████| 20/20 [00:00<00:00, 103.96it/s]
2026-05-01 18:07:07,404 | synference | INFO | ============================================================
2026-05-01 18:07:07,405 | synference | INFO | MODEL PERFORMANCE METRICS
2026-05-01 18:07:07,406 | synference | INFO | ============================================================
2026-05-01 18:07:07,406 | synference | INFO | Full Model Metrics:
2026-05-01 18:07:07,407 | synference | INFO | ----------------------------------------
2026-05-01 18:07:07,407 | synference | INFO | TARP..................... 0.120250
2026-05-01 18:07:07,407 | synference | INFO | LOG DPIT MAX............. 0.533779
2026-05-01 18:07:07,408 | synference | INFO | MEAN LOG PROB............ -16.228356
2026-05-01 18:07:07,408 | synference | INFO | Parameter-Specific Metrics:
2026-05-01 18:07:07,409 | synference | INFO | ----------------------------------------
2026-05-01 18:07:07,409 | synference | INFO | Metric redshift log_mass tau_v tau peak_age log10metallicity
2026-05-01 18:07:07,409 | synference | INFO | --------------------------------------------------------------------------------------
2026-05-01 18:07:07,410 | synference | INFO | MSE 1.655640 0.635541 0.537653 0.336472 2499187.973318 0.622132
2026-05-01 18:07:07,410 | synference | INFO | RMSE 1.286717 0.797208 0.733248 0.580062 1580.882024 0.788754
2026-05-01 18:07:07,411 | synference | INFO | MEAN AE 0.982032 0.702808 0.621581 0.495737 1418.338433 0.688818
2026-05-01 18:07:07,411 | synference | INFO | MEDIAN AE 0.657609 0.597500 0.597761 0.442729 1248.427232 0.736772
2026-05-01 18:07:07,412 | synference | INFO | R SQUARED 0.999957 0.999982 0.999986 0.999991 -0.094400 0.999985
2026-05-01 18:07:07,412 | synference | INFO | RMSE NORM 0.002003 0.001241 0.001141 0.000903 2.460965 0.001228
2026-05-01 18:07:07,413 | synference | INFO | MEAN AE NORM 0.001529 0.001094 0.000968 0.000772 2.207933 0.001072
2026-05-01 18:07:07,416 | synference | INFO | ============================================================
INFO:matplotlib.mathtext:Substituting symbol E from STIXNonUnicode
The first part of the output shows we split the training data into the training and testing splits, then we create the prior from the parameter library, and show the ranges of each parameter.
The next part shows us creating the neural density estimator (NDE) model, which is a mixture density network (MDN) with 4 components. The model is created using the sbi package, which is built on top of PyTorch.
Then the actual training happens - we see the training epochs increment until the model has stopped improving on the validation set. The training stops after 15 epochs with no improvement, as we set stop_after_epochs=15.
The model is pickled and saved to the output directory for this model, which is models/test_model/ by default. The summary of the training model is saved as a .json file in the same directory. And the configuration of the fitter is also pickled and saved to the same directory, which saves the feature and parameter configuration used for training. A model can be re-loaded later using the fitter.load_model_from_pkl method. We can save in a different format by changing the save_method
argument to e.g. torch or hickle.
Now we have a trained model. The validation metrics run which include:
A posterior corner plot for a random observation from the test set.
A loss plot which shows the training and validation loss over epochs.
A coverage plot which shows how well the credible intervals of the posterior match the true parameters.
A ranks histogram which shows how well the posterior samples match the true parameters.
A log_probabiity plot which shows the log probability of the true parameters under the posterior.
A True vs predicted plot which shows the true parameters vs the maximum a posteriori (MAP) estimate from the posterior.
These plots are shown in the output, and also saved to the plots/ directory in the output folder for this model.
Loading a Trained Model¶
We can load a trained model into an exisiting SBI_Fitter instance using the fitter.load_model_from_pkl method. This method takes the path to the pickled model file as an argument.
If only one model is present in the directory, we can simply provide the directory path and the method will find the model file automatically. If multiple models are present, we can provide the full path to the model file.
[12]:
fitter.load_model_from_pkl("test/test_test_1_posterior.pkl");
2026-05-01 18:07:10,159 | synference | INFO | Loaded model from /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/test_test_1_posterior.pkl.
2026-05-01 18:07:10,160 | synference | INFO | Device: cpu
Alternatively, we can create a new SBI_Fitter instance and load the model into that instance, using the class method load_saved_model. This method takes the path to the pickled model file as an argument, and returns a new SBI_Fitter instance with the model loaded.
[13]:
new_fitter = SBI_Fitter.load_saved_model("test/test_test_1_posterior.pkl")
2026-05-01 18:07:10,203 | synference | INFO | Loaded model from /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/test_test_1_posterior.pkl.
2026-05-01 18:07:10,203 | synference | INFO | Device: cpu
Plotting model loss¶
We can plot the model loss using the fitter.plot_loss method. This method will create a plot of the training and validation loss over epochs, and save it to the plots/ directory in the output folder for this model. By default, it will not overwrite existing plots, but you can change this with the overwrite argument.
[14]:
fitter.plot_loss(overwrite=True);
Plotting validation metrics¶
Whilst it does happen automatically during training, we can also plot the validation metrics of a trained model using the fitter.plot_diagnostics method. You can provide your own validation set, or by default it will use the test set from the last training run. By default, it will not create existing plots in the plots/ directory, but you can change this with the overwrite argument.
[15]:
fitter.plot_diagnostics();
2026-05-01 18:07:10,576 | synference | INFO | [ 1.34024525 9.02346897 1.03011668 0.41331279 247.98370468
-2.60618162]
1147it [00:00, 174149.02it/s]
INFO:root:Saving single posterior plot to /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/plots/test_17_plot_single_posterior.jpg...
2026-05-01 18:07:23,193 | synference | INFO | shapes: X:(20, 8), y:(20, 6)
100%|██████████| 20/20 [00:00<00:00, 435.72it/s]
INFO:root:Saving posterior samples to /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/plots/posterior_samples.npy...
INFO:root:Saving coverage plot to /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/plots/plot_coverage.jpg...
INFO:root:Saving ranks histogram to /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/plots/ranks_histogram.jpg...
INFO:root:Mean logprob: -1.4580e+01Median logprob: -1.4027e+01
INFO:root:Saving true logprobs to /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/plots/true_logprobs.npy...
INFO:root:Saving true logprobs plot to /opt/hostedtoolcache/Python/3.10.20/x64/lib/python3.10/models/test/plots/plot_true_logprobs.jpg...
INFO:matplotlib.mathtext:Substituting symbol E from STIXNonUnicode
100%|██████████| 100/100 [00:00<00:00, 688.50it/s]
INFO:matplotlib.mathtext:Substituting symbol E from STIXNonUnicode
Getting model metrics¶
We can print and save metrics of the trained model using the fitter.evaluate_model method. This method will print the metrics to the console, and also save them to a .json file in the output directory for this model. The metrics include:
TARP (Tests of Accuracy with Random Points)
Log DPIT (Logarithmic Deviation of the Probability Integral Transform)
Mean Log Probability
Parameter-specific metrics (MSE, RMSE, Mean Absolute Error, Median Absolute Error, R-squared, Normalized RMSE)
[16]:
fitter.evaluate_model();
Sampling from posterior: 100%|██████████| 20/20 [00:00<00:00, 440.52it/s]
100%|██████████| 200/200 [00:00<00:00, 698.91it/s]
Log prob: 100%|██████████| 20/20 [00:00<00:00, 127.76it/s]
2026-05-01 18:07:27,397 | synference | INFO | ============================================================
2026-05-01 18:07:27,397 | synference | INFO | MODEL PERFORMANCE METRICS
2026-05-01 18:07:27,399 | synference | INFO | ============================================================
2026-05-01 18:07:27,399 | synference | INFO | Full Model Metrics:
2026-05-01 18:07:27,401 | synference | INFO | ----------------------------------------
2026-05-01 18:07:27,401 | synference | INFO | TARP..................... 0.070750
2026-05-01 18:07:27,402 | synference | INFO | LOG DPIT MAX............. 0.538878
2026-05-01 18:07:27,402 | synference | INFO | MEAN LOG PROB............ -15.547252
2026-05-01 18:07:27,403 | synference | INFO | Parameter-Specific Metrics:
2026-05-01 18:07:27,403 | synference | INFO | ----------------------------------------
2026-05-01 18:07:27,404 | synference | INFO | Metric redshift log_mass tau_v tau peak_age log10metallicity
2026-05-01 18:07:27,404 | synference | INFO | --------------------------------------------------------------------------------------
2026-05-01 18:07:27,405 | synference | INFO | MSE 1.643753 0.625900 0.543064 0.337336 2477278.598774 0.627021
2026-05-01 18:07:27,406 | synference | INFO | RMSE 1.282089 0.791138 0.736929 0.580806 1573.937292 0.791846
2026-05-01 18:07:27,406 | synference | INFO | MEAN AE 0.981237 0.696740 0.623029 0.495283 1417.745846 0.695092
2026-05-01 18:07:27,407 | synference | INFO | MEDIAN AE 0.675242 0.557023 0.586161 0.442108 1308.255525 0.747703
2026-05-01 18:07:27,407 | synference | INFO | R SQUARED 0.999957 0.999982 0.999986 0.999991 -0.084806 0.999985
2026-05-01 18:07:27,408 | synference | INFO | RMSE NORM 0.001996 0.001232 0.001147 0.000904 2.450154 0.001233
2026-05-01 18:07:27,408 | synference | INFO | MEAN AE NORM 0.001527 0.001085 0.000970 0.000771 2.207011 0.001082
2026-05-01 18:07:27,409 | synference | INFO | ============================================================
Posterior Samples¶
We can sample the posterior for a given observation using the fitter.sample_posterior method. This method takes an observation, or a set of observations, as an argument, and returns samples from the posterior distribution. If no observation is provided, it will draw posterior samples for all observations in the test set.
[17]:
fitter.sample_posterior()
Sampling from posterior: 100%|██████████| 20/20 [00:00<00:00, 439.67it/s]
[17]:
array([[[ 2.32180476e+00, 1.00106554e+01, 1.52611160e+00,
1.77157879e+00, 1.82333374e+03, -2.10652447e+00],
[ 3.68806100e+00, 9.98847485e+00, 1.61486506e+00,
8.35572958e-01, 3.18698730e+01, -2.11124182e+00],
[ 1.91590309e+00, 9.29074001e+00, 2.57524562e+00,
1.08890784e+00, 1.01842651e+03, -2.61202860e+00],
...,
[ 2.65891457e+00, 9.48077679e+00, 1.26335382e+00,
1.47971308e+00, 9.58887146e+02, -3.09854603e+00],
[ 2.25324917e+00, 8.82637596e+00, 1.84927106e+00,
9.85716939e-01, 5.48782837e+02, -3.31030941e+00],
[ 2.00221682e+00, 1.04697247e+01, 1.71417546e+00,
6.53493643e-01, 1.45260425e+03, -3.00743222e+00]],
[[ 1.88346219e+00, 8.36991596e+00, 1.81685328e+00,
1.62378883e+00, 2.61582031e+02, -2.25934792e+00],
[ 3.99796915e+00, 9.32114410e+00, 1.33550692e+00,
7.74554014e-01, 1.47964465e+03, -3.09684515e+00],
[ 2.66095734e+00, 1.03075485e+01, 2.15192413e+00,
9.01264548e-01, 2.73966577e+03, -2.25946665e+00],
...,
[ 3.50039148e+00, 9.31750298e+00, 1.77870750e+00,
1.90578723e+00, 1.78029199e+03, -2.41232800e+00],
[ 2.19935727e+00, 9.73857212e+00, 1.78818238e+00,
8.79107058e-01, 1.55718481e+03, -2.61633182e+00],
[ 4.68812418e+00, 8.99368095e+00, 1.07337666e+00,
6.57473981e-01, 5.67124878e+02, -3.00959945e+00]],
[[ 1.72945786e+00, 1.09432764e+01, 2.92831755e+00,
1.23242080e-01, 3.23465527e+03, -1.98474383e+00],
[ 1.04873753e+00, 9.53966713e+00, 1.55538392e+00,
9.76665020e-01, 3.40340332e+03, -3.16616249e+00],
[ 1.90054083e+00, 1.12026911e+01, 1.61182201e+00,
1.64780664e+00, 9.63705444e+02, -2.70382333e+00],
...,
[ 2.02669811e+00, 1.00716934e+01, 7.49588788e-01,
7.62486875e-01, 1.30515381e+02, -2.47872734e+00],
[ 1.90197444e+00, 1.00641356e+01, 5.62777638e-01,
7.88214087e-01, 3.85209106e+02, -2.92040586e+00],
[ 1.44302738e+00, 1.09962435e+01, 2.67149925e+00,
1.75669634e+00, 1.30243359e+03, -2.28525567e+00]],
...,
[[ 1.43647754e+00, 1.03602095e+01, 1.48316634e+00,
1.12074995e+00, 3.66913379e+03, -3.95500922e+00],
[ 4.65780830e+00, 9.05846596e+00, 5.36590695e-01,
7.29956806e-01, 1.74674683e+03, -1.70626330e+00],
[ 2.73609853e+00, 1.00503082e+01, 5.50137997e-01,
7.18721509e-01, 3.87392944e+02, -3.31811190e+00],
...,
[ 3.72873163e+00, 9.43251896e+00, 1.10690629e+00,
6.09999537e-01, 1.40036401e+03, -1.70190930e+00],
[ 3.40128756e+00, 9.56472778e+00, 8.65173459e-01,
1.03878367e+00, 3.26625195e+03, -2.22205114e+00],
[ 4.05167150e+00, 1.05771141e+01, 1.56531119e+00,
6.82531357e-01, 9.53278442e+02, -3.30360198e+00]],
[[ 2.64899254e+00, 1.09797287e+01, 1.36497819e+00,
1.17027855e+00, 1.75180444e+03, -2.99954295e+00],
[ 2.49411464e+00, 1.17604332e+01, 1.13658726e+00,
1.41343296e+00, 1.37542554e+03, -2.70664263e+00],
[ 1.04854596e+00, 1.07168121e+01, 2.09446287e+00,
1.31069255e+00, 2.48833008e+03, -3.33505416e+00],
...,
[ 2.45913053e+00, 1.09919262e+01, 7.26088464e-01,
1.14350319e+00, 2.13377783e+03, -3.36903763e+00],
[ 5.50764799e-01, 1.11392097e+01, 3.16190720e-01,
1.42294228e+00, 1.94386023e+03, -1.51735556e+00],
[ 2.89964676e-03, 9.05051613e+00, 7.43790507e-01,
7.02521324e-01, 3.37239600e+03, -3.44996953e+00]],
[[ 3.89272213e-01, 1.02992315e+01, 9.13109183e-01,
1.92569947e+00, 1.50332642e+03, -2.73387575e+00],
[ 1.33064032e-01, 1.08225365e+01, 2.39495826e+00,
1.31859303e+00, 2.77387939e+02, -1.53810120e+00],
[ 2.73090839e+00, 1.11313858e+01, 1.21820283e+00,
1.41048288e+00, 1.33004456e+03, -3.24739742e+00],
...,
[ 8.95622969e-01, 9.71197891e+00, 1.31146443e+00,
1.37779570e+00, 1.01892070e+04, -3.92445731e+00],
[ 3.32483530e+00, 8.74147415e+00, 2.32919550e+00,
1.43509173e+00, 9.54132996e+02, -2.79255795e+00],
[ 3.53684139e+00, 1.12156153e+01, 1.64917552e+00,
1.60190797e+00, 1.33316113e+03, -2.43094301e+00]]],
shape=(20, 1000, 6))
Next Steps¶
In the next tutorials, we will cover more advanced configurations for training SBI models, including different feature and parameter configurations, different NDEs, and different prior proposals. We will also cover how to use the trained models for inference on real data.