diff --git a/abcpmc/sampler.py b/abcpmc/sampler.py index 791a6f1..e324a4d 100644 --- a/abcpmc/sampler.py +++ b/abcpmc/sampler.py @@ -169,11 +169,13 @@ class Sampler(object): particle_proposal_cls = ParticleProposal particle_proposal_kwargs = {} - def __init__(self, N, Y, postfn, dist, threads=1, pool=None): + def __init__(self, N, Y, postfn, dist, threads=1, pool=None, postfn_kwargs={}, dist_kwargs={}): self.N = N self.Y = Y self.postfn = postfn + self.postfn_kwargs = postfn_kwargs # keyword arguments for postfn self.dist = dist + self.dist_kwargs = dist_kwargs # keyword arguments for distance metric self._random = np.random.mtrand.RandomState() if pool is not None: @@ -264,7 +266,9 @@ class _RejectionSamplingWrapper(object): # @DontTrace def __init__(self, sampler, eps, prior): self.postfn = sampler.postfn + self.postfn_kwargs = sampler.postfn_kwargs self.distfn = sampler.dist + self.distfn_kwargs = sampler.dist_kwargs self._random = sampler._random self.Y = sampler.Y self.eps = np.asarray(eps) @@ -280,8 +284,8 @@ def __call__(self, i): cnt = 1 while True: thetai = self.prior() - X = self.postfn(thetai) - p = np.asarray(self.distfn(X, self.Y)) + X = self.postfn(thetai, **self.postfn_kwargs) + p = np.asarray(self.distfn(X, self.Y, **self.distfn_kwargs)) if np.all(p <= self.eps): break cnt+=1