pysersic.multiband

Classes

BaseFitter

Base class for Pysersic Fitters

FitMulti

Class used to fit multiple sources within a single image

FitSingle

Class used to fit a single source

BaseMultiBandFitter

Base class for multi-band fitter, for new classes only need to add a sample_param_at_bands function specifying the linking of parameters between bands.

FitMultiBandPoly

Base class for multi-band fitter, for new classes only need to add a sample_param_at_bands function specifying the linking of parameters between bands.

FitMultiBandBSpline

Base class for multi-band fitter, for new classes only need to add a sample_param_at_bands function specifying the linking of parameters between bands.

Functions

update_prior_suffix(→ BasePrior)

Change the suffix of a pysersic prior

Module Contents

class pysersic.multiband.BaseFitter(data: jax.typing.ArrayLike, rms: jax.typing.ArrayLike, psf: jax.typing.ArrayLike, mask: jax.typing.ArrayLike | None = None, loss_func: Callable | None = gaussian_loss, renderer: pysersic.rendering.BaseRenderer | None = HybridRenderer, renderer_kwargs: dict | None = {})

Bases: abc.ABC

Base class for Pysersic Fitters

set_loss_func(loss_func: Callable) None

Set loss function to be used for inference

Parameters:

loss_func (Callable) – Functions which takes samples the loss function, see utils/loss.py for some examples.

set_prior(parameter: str, distribution: numpyro.distributions.Distribution) None

Set the prior for a specific parameter

Parameters:
  • parameter (str) – Parameter to be set

  • distribution (numpyro.distributions.Distribution) – Numpyro distribution object corresponding to the prior

sample(rkey: jax.random.PRNGKey, num_samples: int = 1000, num_warmup: int = 1000, num_chains: int = 2, init_strategy: Callable | None = infer.init_to_sample, sampler_kwargs: dict | None = {}, mcmc_kwargs: dict | None = {}, return_model: bool | None = True, reparam_func: Callable | None = identity) pandas.DataFrame

Perform inference using a NUTS sampler

Parameters:
  • num_samples (int, optional) – Number of samples to draw, by default 1000

  • num_warmup (int, optional) – Number of warmup samples, by default 1000

  • num_chains (int, optional) – Number of chains to run, by default 2

  • init_strategy (Optional[Callable], optional) – Initialization strategy for the sampler, by default infer.init_to_sample. See numpyro.infer.initialization for more options

  • sampler_kwargs (Optional[dict], optional) – Arguments to pass to the numpyro NUTS kernel

  • mcmc_kwargs (Optional[dict], optional) – Arguments to pass to the numpyro MCMC sampler

  • return_model (Optional[bool]) – Whether to return the model images but adds a small memory/time overhead, by default True

  • rkey (Optional[jax.random.PRNGKey], optional) – PRNG key to use, by default jax.random.PRNGKey(3)

Returns:

ArviZ summary of posterior

Return type:

pandas.DataFrame

_train_SVI(autoguide: numpyro.infer.autoguide.AutoContinuous, method: str, rkey: jax.random.PRNGKey, ELBO_loss: Callable | None = infer.Trace_ELBO(5), lr_init: int | None = 0.01, num_round: int | None = 3, SVI_kwargs: dict | None = {}, train_kwargs: dict | None = {}, return_model: bool | None = True, num_sample: int | None = 1000) pandas.DataFrame

Internal function to perform inference using stochastic variational inference.

Parameters:
  • autoguide (numpyro.infer.autoguide.AutoContinuous) – Function to build guide

  • method (str) – name of method being used; for saving results

  • Elbo_loss (Optional[Callable], optional) – Loss function to use, by default infer.Trace_ELBO(1), see numpyro.infer.elbo for more options

  • lr_init (Optional[int], optional) – Initial learning rate, by default 1e-2

  • num_round (Optional[int], optional) – Number of rounds for training, lr decreases each round, by default 3

  • SVI_kwargs (Optional[dict], optional) – Additional arguments to pass to numpyro.infer.SVI, by default {}

  • train_kwargs (Optional[dict], optional) – Additional arguments to pass to utils.train_numpyro_svi_early_stop, by default {}

  • return_model (Optional[bool]) – Whether to return the model images but adds a small memory/time overhead, by default True

  • num_sample (Optional[int]) – Number of samples to draw from trained SVI posterior

  • rkey (Optional[jax.random.PRNGKey], optional) – PRNG key, by default jax.random.PRNGKey(6)

Returns:

ArviZ summary of posterior

Return type:

pandas.DataFrame

find_MAP(rkey: jax.random.PRNGKey, return_model: bool | None = True, purge_extra: bool | None = True)

Find the “best-fit” parameters as the maximum a-posteriori and return a dictionary with values for the parameters.

Parameters:
  • return_model (Optional[bool]) – Whether to return the model images but adds a small memory/time overhead, by default True

  • rkey (Optional[jax.random.PRNGKey], optional) – rng key, by default jax.random.PRNGKey(3)

Returns:

dictionary with fit parameters and their values.

Return type:

dict

estimate_posterior(rkey: jax.random.PRNGKey, method: str = 'laplace', return_model: bool = True, num_sample: int | None = 1000) pandas.DataFrame

Estimate the posterior using a method other than MCMC sampling. Generally faster than MCMC, but could be less accurate. Current Options are: - ‘laplace’

  • Uses the Laplace approximation, which finds the MAP and then uses a Gaussian approximation to the posterior. The covariance matrix is calculated using the Hessian of the log posterior at the MAP. Generally the fastest method but can lead to numerical problems, especially for fitting several (~> 5) sources since this involves inverting a large matrix.

  • ‘svi-mvn’
    • Use variational inference to fit a multivariate normal distribution to the posterior. In practice should give similar results to ‘laplace’, as both assume the posterior is a multivariate Gaussian but is trained differently and generally less susceptible to the numerical issues discussed above.

  • ‘svi-flow’
    • Uses a neural flow (currently a BNAF, https://arxiv.org/abs/1904.04676) to approximate the posterior. This is more flexible than the two methods above as the flow is a more flexible representation of the posterior which can capture non-Gaussian behavior.However it is slower to train. Also optimization can be inconsistent so use and interpret with caution. Best to cross-reference with sample on tests cases.

Parameters:
  • method (str, optional) – method to use, by default ‘laplace’

  • return_model (Optional[bool]) –

    Whether to return the model images but adds a small memory/time overhead, by default True

    num_sample: Optional[int]

  • posterior (Number of samples to draw from trained SVI) – Number of samples to draw from trained SVI posterior

  • rkey (Optional[jax.random.PRNGKey], optional) – rng key, by default jax.random.PRNGKey(6)

abstract build_model(return_model: bool = True)
class pysersic.multiband.FitMulti(data: jax.typing.ArrayLike, rms: jax.typing.ArrayLike, psf: jax.typing.ArrayLike, prior: pysersic.priors.PySersicMultiPrior, mask: jax.typing.ArrayLike | None = None, loss_func: Callable | None = gaussian_loss, renderer: pysersic.rendering.BaseRenderer | None = HybridRenderer, renderer_kwargs: dict | None = {})

Bases: BaseFitter

Class used to fit multiple sources within a single image

build_model(return_model: bool = True) Callable

Generate Numpyro model for the specified image, profile and priors

Returns:

model – Function specifying the current model in Numpyro, can be passed to inference algorithms

Return type:

Callable

find_MAP(return_model: bool | None = True, rkey: jax.random.PRNGKey | None = jax.random.PRNGKey(3), purge_extra: bool | None = True)

Find the “best-fit” parameters as the maximum a-posteriori and return a dictionary with values for the parameters.

Parameters:
  • return_model (Optional[bool], optional) – whether to return the model image, adds a small time and memory overhead, by default True

  • rkey (Optional[jax.random.PRNGKey], optional) – rng key, by default jax.random.PRNGKey(3)

Returns:

dictionary with fit parameters and their values.

Return type:

dict

class pysersic.multiband.FitSingle(data: jax.typing.ArrayLike, rms: jax.typing.ArrayLike, psf: jax.typing.ArrayLike, prior: pysersic.priors.PySersicSourcePrior, mask: jax.typing.ArrayLike | None = None, loss_func: Callable | None = gaussian_loss, renderer: pysersic.rendering.BaseRenderer | None = HybridRenderer, renderer_kwargs: dict | None = {})

Bases: BaseFitter

Class used to fit a single source

build_model(return_model: bool = True) Callable

Generate Numpyro model for the specified image, profile and priors

Returns:

model – Function specifying the current model in Numpyro, can be passed to inference algorithms

Return type:

Callable

pysersic.multiband.update_prior_suffix(prior: BasePrior, new_suffix: str) BasePrior

Change the suffix of a pysersic prior

Parameters:
  • prior (BasePrior) – Either a Source or Multi Prior,

  • new_suffix (str) – new suffix for variables

Returns:

Prior with updated suffix

Return type:

BasePrior

class pysersic.multiband.BaseMultiBandFitter(fitter_list: List[pysersic.pysersic.FitSingle] | List[pysersic.pysersic.FitMulti], wavelengths: jax.Array, linked_params: List[str], const_params: List[str] | None = [], band_names: List[str] | None = None, linked_params_range: dict | None = {}, wv_to_save: jax.Array | None = None, rescale_unlinked_priors: bool | None = False)

Bases: pysersic.pysersic.BaseFitter

Base class for multi-band fitter, for new classes only need to add a sample_param_at_bands function specifying the linking of parameters between bands.

abstract sample_param_at_bands(name: str) jax.Array

Function used to sample linked parameters at each band

Parameters:

name (str) – parameter name

Returns:

params_at_bands – parameter values sampled at bands

Return type:

jax.Array

build_model(return_model: bool = False) callable

build numpyro model for multi-band inference

Parameters:

return_model (bool, optional) – wether or not to save and return observed images, by default False

Returns:

numpyro model

Return type:

callable

class pysersic.multiband.FitMultiBandPoly(fitter_list: List[pysersic.pysersic.FitSingle], wavelengths: jax.Array, linked_params: List[str], const_params: List[str] | None = [], band_names: List[str] | None = None, linked_params_range: dict | None = {}, wv_to_save: jax.Array | None = None, rescale_unlinked_priors: bool | None = False, poly_order: int | None = 2)

Bases: BaseMultiBandFitter

Base class for multi-band fitter, for new classes only need to add a sample_param_at_bands function specifying the linking of parameters between bands.

restrict_func(x, hi, low)
sample_param_at_bands(name)

Function used to sample linked parameters at each band

Parameters:

name (str) – parameter name

Returns:

params_at_bands – parameter values sampled at bands

Return type:

jax.Array

class pysersic.multiband.FitMultiBandBSpline(fitter_list: List[pysersic.pysersic.FitSingle], wavelengths: jax.Array, linked_params: List[str], const_params: List[str] | None = [], band_names: List[str] | None = None, linked_params_range: dict | None = {}, wv_to_save: jax.Array | None = None, rescale_unlinked_priors: bool | None = False, N_knots: int | None = 4, spline_k: int | None = 2, pad_knots: bool | None = True)

Bases: BaseMultiBandFitter

Base class for multi-band fitter, for new classes only need to add a sample_param_at_bands function specifying the linking of parameters between bands.

sample_param_at_bands(name)

Function used to sample linked parameters at each band

Parameters:

name (str) – parameter name

Returns:

params_at_bands – parameter values sampled at bands

Return type:

jax.Array