synference.simformer

Functions for simformer tasks and model saving/loading.

Functions

synference.simformer.load_full_model(dir_path, model_id, simulator=None)[source]

Load a full model from the specified directory and model ID.

Classes

class synference.simformer.GalaxyPhotometryTask(name='galaxy_photometry_task', backend='jax', prior_dict=None, param_names_ordered=None, run_simulator_fn=None, num_filters=None, test_X_data=None, test_theta_data=None, attention_mask_type='full')[source]

A Simformer InferenceTask for the GalaxySimulator.

get_base_mask_fn()[source]

Defines the base attention mask for the transformer.

Return type:

Callable

get_data(num_samples, **kwargs)[source]

Returns data for the task.

Parameters:
  • num_samples (int) – The number of samples to generate.

  • **kwargs – Additional keyword arguments for the prior sampling.

Return type:

Dict[str, Array]

Returns:

A dictionary with keys ‘theta’ and ‘x’, containing the sampled parameters and simulated photometry, respectively.

get_node_id()[source]

Returns an array identifying the nodes (dimensions) of theta and x.

Return type:

Array

get_prior()[source]

Returns the prior distribution object.

get_simulator()[source]

Gets a batched simulator function.

Returns:

A callable that takes a batch of thetas and returns a batch of xs.

get_theta_dim()[source]

Returns the dimension of the theta vector.

Return type:

int

get_x_dim()[source]

Returns the dimension of the x vector (photometry).

Return type:

int

class synference.simformer.GalaxyPrior(prior_ranges, param_order)[source]

A prior distribution for galaxy parameters.

This class uses uniform distributions for each parameter defined in prior_ranges. It can sample from the prior and compute log probabilities.

log_prob(theta)[source]

Calculates the log probability of theta under the prior.

Parameters:

theta (Tensor) – A PyTorch tensor of shape (num_samples, theta_dim).

Return type:

Tensor

Returns:

A PyTorch tensor of shape (num_samples,) containing log probabilities.

sample(sample_shape, sample_lhc=False, rng=None)[source]

Generates samples from the prior.

Parameters:
  • sample_shape (Tuple[int]) – A tuple containing the number of samples, e.g., (num_samples,).

  • sample_lhc – If True, samples using Latin Hypercube sampling.

  • rng – Optional random number generator for reproducibility.

Return type:

Tensor

Returns:

A PyTorch tensor of shape (num_samples, theta_dim).

class synference.simformer.InferenceTask[source]

Dummy InferenceTask class for compatibility.

class synference.simformer.UncertainityModelTask(magnitudes, log_uncertainties)[source]

Condtional uncertainty model task for galaxy magnitudes and log-uncertainties.

get_base_mask_fn()[source]

Defines that log-uncertainty ‘x’ depends on magnitude ‘theta’.

get_data(num_samples, rng=None)[source]

Returns a random subset of the provided catalog data.

Return type:

dict[str, Array]

get_theta_dim()[source]

Returns the dimension of the theta vector (magnitude).

Return type:

int

get_x_dim()[source]

Returns the dimension of the x vector (log-uncertainty).

Return type:

int