pysersic.multiband
Classes
Base class for Pysersic Fitters |
|
Class used to fit multiple sources within a single image |
|
Class used to fit a single source |
|
Base class for multi-band fitter, for new classes only need to add a sample_param_at_bands function specifying the linking of parameters between bands. |
|
Base class for multi-band fitter, for new classes only need to add a sample_param_at_bands function specifying the linking of parameters between bands. |
|
Base class for multi-band fitter, for new classes only need to add a sample_param_at_bands function specifying the linking of parameters between bands. |
Functions
|
Change the suffix of a pysersic prior |
Module Contents
- class pysersic.multiband.BaseFitter(data: jax.typing.ArrayLike, rms: jax.typing.ArrayLike, psf: jax.typing.ArrayLike, mask: jax.typing.ArrayLike | None = None, loss_func: Callable | None = gaussian_loss, renderer: pysersic.rendering.BaseRenderer | None = HybridRenderer, renderer_kwargs: dict | None = {})
Bases:
abc.ABCBase class for Pysersic Fitters
- set_loss_func(loss_func: Callable) None
Set loss function to be used for inference
- Parameters:
loss_func (Callable) – Functions which takes samples the loss function, see utils/loss.py for some examples.
- set_prior(parameter: str, distribution: numpyro.distributions.Distribution) None
Set the prior for a specific parameter
- Parameters:
parameter (str) – Parameter to be set
distribution (numpyro.distributions.Distribution) – Numpyro distribution object corresponding to the prior
- sample(rkey: jax.random.PRNGKey, num_samples: int = 1000, num_warmup: int = 1000, num_chains: int = 2, init_strategy: Callable | None = infer.init_to_sample, sampler_kwargs: dict | None = {}, mcmc_kwargs: dict | None = {}, return_model: bool | None = True, reparam_func: Callable | None = identity) pandas.DataFrame
Perform inference using a NUTS sampler
- Parameters:
num_samples (int, optional) – Number of samples to draw, by default 1000
num_warmup (int, optional) – Number of warmup samples, by default 1000
num_chains (int, optional) – Number of chains to run, by default 2
init_strategy (Optional[Callable], optional) – Initialization strategy for the sampler, by default infer.init_to_sample. See numpyro.infer.initialization for more options
sampler_kwargs (Optional[dict], optional) – Arguments to pass to the numpyro NUTS kernel
mcmc_kwargs (Optional[dict], optional) – Arguments to pass to the numpyro MCMC sampler
return_model (Optional[bool]) – Whether to return the model images but adds a small memory/time overhead, by default True
rkey (Optional[jax.random.PRNGKey], optional) – PRNG key to use, by default jax.random.PRNGKey(3)
- Returns:
ArviZ summary of posterior
- Return type:
pandas.DataFrame
- _train_SVI(autoguide: numpyro.infer.autoguide.AutoContinuous, method: str, rkey: jax.random.PRNGKey, ELBO_loss: Callable | None = infer.Trace_ELBO(5), lr_init: int | None = 0.01, num_round: int | None = 3, SVI_kwargs: dict | None = {}, train_kwargs: dict | None = {}, return_model: bool | None = True, num_sample: int | None = 1000) pandas.DataFrame
Internal function to perform inference using stochastic variational inference.
- Parameters:
autoguide (numpyro.infer.autoguide.AutoContinuous) – Function to build guide
method (str) – name of method being used; for saving results
Elbo_loss (Optional[Callable], optional) – Loss function to use, by default infer.Trace_ELBO(1), see numpyro.infer.elbo for more options
lr_init (Optional[int], optional) – Initial learning rate, by default 1e-2
num_round (Optional[int], optional) – Number of rounds for training, lr decreases each round, by default 3
SVI_kwargs (Optional[dict], optional) – Additional arguments to pass to numpyro.infer.SVI, by default {}
train_kwargs (Optional[dict], optional) – Additional arguments to pass to utils.train_numpyro_svi_early_stop, by default {}
return_model (Optional[bool]) – Whether to return the model images but adds a small memory/time overhead, by default True
num_sample (Optional[int]) – Number of samples to draw from trained SVI posterior
rkey (Optional[jax.random.PRNGKey], optional) – PRNG key, by default jax.random.PRNGKey(6)
- Returns:
ArviZ summary of posterior
- Return type:
pandas.DataFrame
- find_MAP(rkey: jax.random.PRNGKey, return_model: bool | None = True, purge_extra: bool | None = True)
Find the “best-fit” parameters as the maximum a-posteriori and return a dictionary with values for the parameters.
- Parameters:
return_model (Optional[bool]) – Whether to return the model images but adds a small memory/time overhead, by default True
rkey (Optional[jax.random.PRNGKey], optional) – rng key, by default jax.random.PRNGKey(3)
- Returns:
dictionary with fit parameters and their values.
- Return type:
dict
- estimate_posterior(rkey: jax.random.PRNGKey, method: str = 'laplace', return_model: bool = True, num_sample: int | None = 1000) pandas.DataFrame
Estimate the posterior using a method other than MCMC sampling. Generally faster than MCMC, but could be less accurate. Current Options are: - ‘laplace’
Uses the Laplace approximation, which finds the MAP and then uses a Gaussian approximation to the posterior. The covariance matrix is calculated using the Hessian of the log posterior at the MAP. Generally the fastest method but can lead to numerical problems, especially for fitting several (~> 5) sources since this involves inverting a large matrix.
- ‘svi-mvn’
Use variational inference to fit a multivariate normal distribution to the posterior. In practice should give similar results to ‘laplace’, as both assume the posterior is a multivariate Gaussian but is trained differently and generally less susceptible to the numerical issues discussed above.
- ‘svi-flow’
Uses a neural flow (currently a BNAF, https://arxiv.org/abs/1904.04676) to approximate the posterior. This is more flexible than the two methods above as the flow is a more flexible representation of the posterior which can capture non-Gaussian behavior.However it is slower to train. Also optimization can be inconsistent so use and interpret with caution. Best to cross-reference with sample on tests cases.
- Parameters:
method (str, optional) – method to use, by default ‘laplace’
return_model (Optional[bool]) –
- Whether to return the model images but adds a small memory/time overhead, by default True
num_sample: Optional[int]
posterior (Number of samples to draw from trained SVI) – Number of samples to draw from trained SVI posterior
rkey (Optional[jax.random.PRNGKey], optional) – rng key, by default jax.random.PRNGKey(6)
- abstract build_model(return_model: bool = True)
- class pysersic.multiband.FitMulti(data: jax.typing.ArrayLike, rms: jax.typing.ArrayLike, psf: jax.typing.ArrayLike, prior: pysersic.priors.PySersicMultiPrior, mask: jax.typing.ArrayLike | None = None, loss_func: Callable | None = gaussian_loss, renderer: pysersic.rendering.BaseRenderer | None = HybridRenderer, renderer_kwargs: dict | None = {})
Bases:
BaseFitterClass used to fit multiple sources within a single image
- build_model(return_model: bool = True) Callable
Generate Numpyro model for the specified image, profile and priors
- Returns:
model – Function specifying the current model in Numpyro, can be passed to inference algorithms
- Return type:
Callable
- find_MAP(return_model: bool | None = True, rkey: jax.random.PRNGKey | None = jax.random.PRNGKey(3), purge_extra: bool | None = True)
Find the “best-fit” parameters as the maximum a-posteriori and return a dictionary with values for the parameters.
- Parameters:
return_model (Optional[bool], optional) – whether to return the model image, adds a small time and memory overhead, by default True
rkey (Optional[jax.random.PRNGKey], optional) – rng key, by default jax.random.PRNGKey(3)
- Returns:
dictionary with fit parameters and their values.
- Return type:
dict
- class pysersic.multiband.FitSingle(data: jax.typing.ArrayLike, rms: jax.typing.ArrayLike, psf: jax.typing.ArrayLike, prior: pysersic.priors.PySersicSourcePrior, mask: jax.typing.ArrayLike | None = None, loss_func: Callable | None = gaussian_loss, renderer: pysersic.rendering.BaseRenderer | None = HybridRenderer, renderer_kwargs: dict | None = {})
Bases:
BaseFitterClass used to fit a single source
- build_model(return_model: bool = True) Callable
Generate Numpyro model for the specified image, profile and priors
- Returns:
model – Function specifying the current model in Numpyro, can be passed to inference algorithms
- Return type:
Callable
- pysersic.multiband.update_prior_suffix(prior: BasePrior, new_suffix: str) BasePrior
Change the suffix of a pysersic prior
- class pysersic.multiband.BaseMultiBandFitter(fitter_list: List[pysersic.pysersic.FitSingle] | List[pysersic.pysersic.FitMulti], wavelengths: jax.Array, linked_params: List[str], const_params: List[str] | None = [], band_names: List[str] | None = None, linked_params_range: dict | None = {}, wv_to_save: jax.Array | None = None, rescale_unlinked_priors: bool | None = False)
Bases:
pysersic.pysersic.BaseFitterBase class for multi-band fitter, for new classes only need to add a sample_param_at_bands function specifying the linking of parameters between bands.
- abstract sample_param_at_bands(name: str) jax.Array
Function used to sample linked parameters at each band
- Parameters:
name (str) – parameter name
- Returns:
params_at_bands – parameter values sampled at bands
- Return type:
jax.Array
- build_model(return_model: bool = False) callable
build numpyro model for multi-band inference
- Parameters:
return_model (bool, optional) – wether or not to save and return observed images, by default False
- Returns:
numpyro model
- Return type:
callable
- class pysersic.multiband.FitMultiBandPoly(fitter_list: List[pysersic.pysersic.FitSingle], wavelengths: jax.Array, linked_params: List[str], const_params: List[str] | None = [], band_names: List[str] | None = None, linked_params_range: dict | None = {}, wv_to_save: jax.Array | None = None, rescale_unlinked_priors: bool | None = False, poly_order: int | None = 2)
Bases:
BaseMultiBandFitterBase class for multi-band fitter, for new classes only need to add a sample_param_at_bands function specifying the linking of parameters between bands.
- restrict_func(x, hi, low)
- sample_param_at_bands(name)
Function used to sample linked parameters at each band
- Parameters:
name (str) – parameter name
- Returns:
params_at_bands – parameter values sampled at bands
- Return type:
jax.Array
- class pysersic.multiband.FitMultiBandBSpline(fitter_list: List[pysersic.pysersic.FitSingle], wavelengths: jax.Array, linked_params: List[str], const_params: List[str] | None = [], band_names: List[str] | None = None, linked_params_range: dict | None = {}, wv_to_save: jax.Array | None = None, rescale_unlinked_priors: bool | None = False, N_knots: int | None = 4, spline_k: int | None = 2, pad_knots: bool | None = True)
Bases:
BaseMultiBandFitterBase class for multi-band fitter, for new classes only need to add a sample_param_at_bands function specifying the linking of parameters between bands.
- sample_param_at_bands(name)
Function used to sample linked parameters at each band
- Parameters:
name (str) – parameter name
- Returns:
params_at_bands – parameter values sampled at bands
- Return type:
jax.Array