This is a RcppGSL implementation of multi-multinomial. However, it requires you to install gsl independently....which may not be very practical.
// [[Rcpp::depends(RcppGSL)]]
#include <RcppGSL.h>
#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>
#include <unistd.h> // getpid
Rcpp::IntegerVector rmn(unsigned int N, Rcpp::NumericVector p, gsl_rng* r){
size_t K = p.size();
Rcpp::IntegerVector x(K);
gsl_ran_multinomial(r, K, N, p.begin(), (unsigned int *) x.begin());
return x; // return results vector
}
Rcpp::IntegerVector gsl_mmm_1(Rcpp::IntegerVector N, Rcpp::NumericMatrix P, gsl_rng* r){
size_t K = N.size();
int i;
Rcpp::IntegerVector x(K);
for(i=0; i<K; i++){
x += rmn(N[i], P(Rcpp::_, i), r);
}
return x;
}
// [[Rcpp::export]]
Rcpp::IntegerMatrix gsl_mmm(Rcpp::IntegerMatrix X_, Rcpp::NumericMatrix P){
int j;
gsl_rng * r = gsl_rng_alloc (gsl_rng_mt19937);
long seed = rand()/(((double)RAND_MAX + 1)/10000000) * getpid();
gsl_rng_set (r, seed);
Rcpp::IntegerMatrix X(X_.nrow(), X_.ncol());
for(j=0; j<X.ncol(); j++){
X(Rcpp::_, j) = gsl_mmm_1(X_(Rcpp::_,j), P, r);
}
gsl_rng_free (r);
return X;
}
I also compare it with a pure R implementation and jbaums's version
library(Rcpp)
library(microbenchmark)
sourceCpp("gsl.cpp")
P = matrix(c(c(0.1,0.2,0.7),c(0.3,0.3,0.4),c(0.5,0.3,0.2)),nc=3)
X = matrix(c(c(30,40,30),c(20,40,40)), nc=2)
mmm = function(X, P){
n = ncol(X)
p = nrow(X)
Reduce("+", lapply(1:p, function(j) {
Y = matrix(0,p,n)
for(i in 1:n) Y[,i] = rmultinom(1, X[j,i], P[,j])
Y
}))
}
jbaums = function(X,P){
apply(sapply(apply(X, 2, function(x) {
lapply(seq_len(ncol(P)), function(i) {
rmultinom(1, x[i], P[, i])
})
}), function(x) do.call(cbind, x), simplify='array'), nrow(X), rowSums)
}
microbenchmark(jbaums(X,P), mmm(X,P), gsl_mmm(X, P))
and this is the result
> microbenchmark(jbaums(X,P), mmm(X,P), gsl_mmm(X, P))
Unit: microseconds
expr min lq median uq max neval
jbaums(X, P) 165.832 172.8420 179.185 187.2810 339.280 100
mmm(X, P) 60.071 63.5955 67.437 71.5775 92.963 100
gsl_mmm(X, P) 10.529 11.8800 13.671 14.6220 40.857 100
The gsl version is about 6x faster than my pure R version.