This page contains an example of Gauss Hypergeometric prior model for count data using Stan. For details on the method, please see the following paper:
Inference on High-Dimensional Sparse Count Data: Datta, J and Dunson, D. (2016+)
First we load the necessary libraries:
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
library(ggplot2)
theme_set(theme_bw())
library(plyr)
library(dplyr)
library(reshape2)
The GH model can be implemented in Stan in a stright-forward way using the following:
# setup Stan Gauss-HG sampler
{
library(plyr)
library(rstan)
library(parallel)
library(rbenchmark)
#set_cppo("fast")
stan.gh.code = "
data{
int<lower=0> J;
int<lower=0> Y[J];
real<lower=0> alpha;
real<lower=0> a;
real<lower=0> b;
real<lower=0> gamma;
real<lower=0> phi;
}
parameters{
real<lower=0,upper=1> kappa[J];
real<lower=0> theta[J];
}
model{
for(i in 1:J) {
increment_log_prob((a-1)*log(kappa[i])+(b-1)*log(1-kappa[i])-gamma*log(1-phi*kappa[i]));
theta[i] ~ gamma(a, kappa[i]/(1-kappa[i]));
Y[i] ~ poisson(theta[i]);
}
}
"
stan.gh.fit = stan_model(model_code=stan.gh.code, model_name="GH")
}
Note that, we have fixed the \(\gamma\) parameter her, but it can be easily changed by moving “gamma” from the data block to the parameters block and including a “gamma ~ p(.)” statement in the model, with an approriate distribution p(.) on the positive real line.
Now, generate \(n = 200\) observations from a `two-groups’ model where the underlying parameter vector \(\theta\) is sparse with a small proportion \(\omega\) of \(\theta_i\) drawn from a heavy-tailed distribution (or, fixed at a relativel larger positive constant) and the remaining parameters zero or very close to zero. For ease of illustration we pick the noise and signal observations to have mean 0.1 and 10 respectively, but in general a smaller separation should work.
stan.iters = 10000
n.chains = 2
seed.val = 786
set.seed(seed.val)
n = 200; w = 0.9
y = rep(0,n); idx = rep(1,n)
lambdasparse = rep(0,n)
for (i in 1:n)
{
if(i<=round(n*w)){
lambdasparse[i]<-0.1
idx[i] <- 0}
else {lambdasparse[i] <-10}}
y = rpois(n,lambdasparse);
gamma = mean(kmeans(y,centers=2)$centers)
alpha = 0.01
a = 0.5; b = 0.5
gh.data = list('J'=n,'Y'=y, 'alpha' = alpha,'a' = a, 'b' = b, 'gamma' = gamma, 'phi' = 0.99)
The choice of \(\gamma\) can be done in the way shown above, or it can be used as a hyper-parameter in the full Bayes hierarchy. Some empirical support for this choice of \(\gamma\) is given in the simulation experiments page.
Stan is run with 2 chains of 10^{4} iterations each.
{
gh.res = sampling(stan.gh.fit,
data = gh.data,
iter = stan.iters,
warmup = floor(stan.iters/2),
thin = 2,
pars = c('kappa','theta'),
init = 0,
seed = seed.val,
chains = 1)
gh.theta.smpls = extract(gh.res, pars=c('theta'), permuted=TRUE)[[1]]
gh.kappa.smpls = extract(gh.res, pars=c('kappa'), permuted=TRUE)[[1]]
gh.theta.mean = apply(gh.theta.smpls,2,mean)
gh.kappa.mean = apply(gh.kappa.smpls,2,mean)
gh.sample.data = melt(extract(gh.res, permuted=TRUE))
colnames(gh.sample.data) = c("iteration", "component", "value", "variable")
gh.sample.data= gh.sample.data %>%
filter(variable %in% c("theta","kappa"))
gh.sample.data.2 = gh.sample.data %>% group_by(component, variable) %>%
summarise(upper = quantile(value, prob=0.975),
lower = quantile(value, prob=0.225),
middle = mean(value))
}
We first plot the point estimates, i.e. the posterior mean of \(\theta\) and \(\kappa\) under the Gauss hypergeometric prior and the observations.
post.data = rbind(data.frame(type = "observation",values=y,x=seq(1:n)),
data.frame(type = "posterior mean",values= gh.theta.mean,x=seq(1:n)),
data.frame(type = "shrinkage",values= gh.kappa.mean,x=seq(1:n)))
ggplot(post.data, aes(x=x, y=values, group=type, colour=type)) +
geom_point(aes(colour = type),size = 1) + ylab("Mean/Shrinkage") + xlab(expression(y))+facet_grid(type~.,scales="free_y")
Note that the small non-zero observations are shrunk to zero as \(\kappa \approx 1\) for these observations.
Next we show the 95% credible interval for the shrinkage weights \(\kappa\) and \(\theta\) for each observation:
ggplot(gh.sample.data.2,aes(x=component, y=middle, group=component,colour=rep(idx,2))) + theme_bw() +geom_pointrange(aes(ymin=lower, ymax=upper), size=0.5) +
facet_grid(variable ~ ., scales="free_y") + theme(legend.position="none")+ylab("")+xlab("")