synference.simformer¶
Functions for simformer tasks and model saving/loading.
Functions
- synference.simformer.load_full_model(dir_path, model_id, simulator=None)[source]¶
Load a full model from the specified directory and model ID.
Classes
- class synference.simformer.GalaxyPhotometryTask(name='galaxy_photometry_task', backend='jax', prior_dict=None, param_names_ordered=None, run_simulator_fn=None, num_filters=None, test_X_data=None, test_theta_data=None, attention_mask_type='full')[source]¶
A Simformer InferenceTask for the GalaxySimulator.
- get_base_mask_fn()[source]¶
Defines the base attention mask for the transformer.
- Return type:
Callable
- get_data(num_samples, **kwargs)[source]¶
Returns data for the task.
- Parameters:
num_samples (
int) – The number of samples to generate.**kwargs – Additional keyword arguments for the prior sampling.
- Return type:
Dict[str,Array]- Returns:
A dictionary with keys ‘theta’ and ‘x’, containing the sampled parameters and simulated photometry, respectively.
- get_node_id()[source]¶
Returns an array identifying the nodes (dimensions) of theta and x.
- Return type:
Array
- class synference.simformer.GalaxyPrior(prior_ranges, param_order)[source]¶
A prior distribution for galaxy parameters.
This class uses uniform distributions for each parameter defined in prior_ranges. It can sample from the prior and compute log probabilities.
- log_prob(theta)[source]¶
Calculates the log probability of theta under the prior.
- Parameters:
theta (
Tensor) – A PyTorch tensor of shape (num_samples, theta_dim).- Return type:
Tensor- Returns:
A PyTorch tensor of shape (num_samples,) containing log probabilities.
- sample(sample_shape, sample_lhc=False, rng=None)[source]¶
Generates samples from the prior.
- Parameters:
sample_shape (
Tuple[int]) – A tuple containing the number of samples, e.g., (num_samples,).sample_lhc – If True, samples using Latin Hypercube sampling.
rng – Optional random number generator for reproducibility.
- Return type:
Tensor- Returns:
A PyTorch tensor of shape (num_samples, theta_dim).