Sampling from a multivariate von Mises-Fisher distribution in Python
-
16-10-2019 - |
Pergunta
I am looking for a simple way to sample from a multivariate von Mises-Fisher distribution in Python. I have looked in the stats module in scipy and the numpy module but only found the univariate von Mises distribution. Is there any code available? I have not found yet.
-- edit. Apparently, Wood (1994) has designed an algorithm for sampling from the vMF distribution according to this link, but I can't find the paper.
Solução 2
Thanks to your help. I finally got my code working, plus some bibliography.
I put my hands on Directional Statistics (Mardia and Jupp, 1999) and on the Ulrich-Wood's algorithm for sampling. I post here what I understood from it, i.e. my code (in Python), with a 'movMF' flavour.
The rejection sampling scheme:
def rW(n,kappa,m):
dim = m-1
b = dim / (np.sqrt(4*kappa*kappa + dim*dim) + 2*kappa)
x = (1-b) / (1+b)
c = kappa*x + dim*np.log(1-x*x)
y = []
for i in range(0,n):
done = False
while not done:
z = sc.stats.beta.rvs(dim/2,dim/2)
w = (1 - (1+b)*z) / (1 - (1-b)*z)
u = sc.stats.uniform.rvs()
if kappa*w + dim*np.log(1-x*w) - c >= np.log(u):
done = True
y.append(w)
return y
Then, the desired sampling is $v \sqrt{1-w^2} + w \mu$, where $w$ is the result from the rejection sampling scheme, and $v$ is uniformly sampled over the hypersphere.
def rvMF(n,theta):
dim = len(theta)
kappa = np.linalg.norm(theta)
mu = theta / kappa
result = []
for sample in range(0,n):
w = rW(kappa,dim)
v = np.random.randn(dim)
v = v / np.linalg.norm(v)
result.append(np.sqrt(1-w**2)*v + w*mu)
return result
And, for effectively sampling with this code, here is an example:
import numpy as np
import scipy as sc
import scipy.stats
n = 10
kappa = 100000
direction = np.array([1,-1,1])
direction = direction / np.linalg.norm(direction)
res_sampling = rvMF(n, kappa * direction)
Outras dicas
It looks like you can sample the von Mises-Fisher distribution with that R package. Have you thought about calling R from within Python using the rpy2 package? I haven't tried this for myself, but could you try the following?
from numpy import *
import scipy as sp
from pandas import *
from rpy2.robjects.packages import importr
import rpy2.robjects as ro
import pandas.rpy.common as com
from rpy2.robjects.packages import importr
# import the movMF R package
movMF = importr('movMF')
# call the rmovMF sampling function from the R package
print(movMF.rmovMF(10, 3 * c(1, -1) / sqrt(2)))