No-U-Turn Sampler (NUTS)

class sampyl.NUTS(logp, start, step_size=0.25, adapt_steps=100, Emax=1000.0, target_accept=0.65, gamma=0.05, k=0.75, t0=10.0, **kwargs)

No-U-Turn sampler (Hoffman & Gelman, 2014) for sampling from a probability distribution defined by a log P(theta) function.

For technical details, see the paper:

  • logp – log P(X) function for sampling distribution
  • start – Dictionary of starting state for the sampler. Should have one element for each argument of logp.
  • grad_logp

    (optional) Function or list of functions that calculate grad log P(theta). Pass functions here if you don’t want to use autograd for the gradients. If logp has multiple parameters, grad_logp must be a list of gradient functions w.r.t. each parameter in logp.

    If you wish to use a logp function that returns both the logp value and the gradient, set grad_logp = True.

  • scale – (optional) Dictionary with same format as start. Scaling for initial momentum in Hamiltonian step.
  • step_size – (optional) float. Initial step size for the deterministic proposals.
  • adapt_steps – (optional) int. Integer number of steps used for adapting the step size to achieve a target acceptance rate.
  • Emax – (optional) float. Maximum energy.
  • target_accept – (optional) float. Target acceptance rate.
  • gamma – (optional) float.
  • k – (optional) float. Scales the speed of step size adaptation.
  • t0 – (optional) float. Slows initial step size adaptation.


def logp(x, y):

start = {'x': x_start, 'y': y_start}
nuts = sampyl.NUTS(logp, start)
chain = nuts.sample(1000)
sample(num, burn=0, thin=1, n_chains=1, progress_bar=True)

Sample from \(P(X)\)

  • numint. Number of samples to draw from \(P(X)\).
  • burn – (optional) int. Number of samples to discard from the beginning of the chain.
  • thin – (optional) float. Thin the samples by this factor.
  • n_chains – (optional) int. Number of chains to return. Each chain is given its own process and the OS decides how to distribute the processes.
  • progress_bar – (optional) boolean. Show the progress bar, default = True.

Record array with fields taken from arguments of logp function.


Perform one NUTS step.