Model Validation

Model validation is a crucial step in the machine learning workflow to ensure that the model performs well and assess any biases or overfitting. Here are some common techniques used for model validation.

Synference supports a variety of model validation techniques, including:

  • Train/Test Split: Dividing the dataset into a training set and a test set to evaluate model performance on unseen data.

  • Validation Plots: Visualizations such as training loss, model coverage, and prediction distributions to assess model behavior.

  • Prior Predictive Checks: Evaluating the model’s predictions based on prior distributions to ensure they align with domain knowledge.

  • Evaluation Metrics: Calculating metrics such as Mean Squared Error (MSE), Root Mean Squared Error (RMSE), Mean Absolute Error (MAE), R-squared, and others to quantify model performance.

Splitting Validation Data by SNR

If your model incorporates a noise model, it is often useful to validate the model’s performance across different signal-to-noise ratio (SNR) regimes. This can help identify if the model performs consistently across varying levels of observational uncertainty.

To split your validation data by SNR, you can use this function.

fitter.plot_parameter_deviations(
    snr_bins=[5, 10, 10, 1000]
)

Available Model Validation Functions

Here are brief descriptions of the available model validation functions in Synference:

  1. calculate_TARP: Calculate the Tests of Accuracy with Random Points (TARP) metric.

  2. calculate_PIT: Calculate the Probability Integral Transform (PIT) values.

  3. plot_loss: Plot the training and validation loss over epochs.

  4. plot_coverage: Plot the model coverage.

  5. plot_posterior: Plot the posterior distribution for a given observation.

  6. plot_parameter_deviations: Plot the deviations of predicted parameters from true values.

  7. plot_diagnostics: Plot various diagnostic plots for model validation.

  8. evaluate_model: Evaluate the model using specified metrics.

  9. fit_observation_using_sampler: Fit an observation using nested sampling for comparison to the SBI model.

  10. test_in_distribution: Test either if your observations are in-distribution with respect to your training data or if your training data is in-distribution with respect to your observations.

  11. test_in_distribution_pyod: Test either if your observations are in-distribution with respect to your training data or if your training data is in-distribution with respect to your observations using the PyOD package.

  12. detect_misspecification: Detect model misspecification using the MarginalTrainer from the sbi package. See the sbi documentation for more details.

  13. lc2st: Perform the L-C2ST test for model coverage.