How does the function multinom from R package nnet compute the multinomial probability weights?

StackOverflow https://stackoverflow.com/questions/22905807

  •  28-06-2023
  •  | 
  •  

Pregunta

I know the theoretical answer to the question of my title, which is discussed here or in this previous question on Stack Overflow. My problem is that, even considering some numerical roundings, the probability weights I compute using the coefficients fitted in the R function multinom are quite different from the weights directly obtained from the same function (through predict(fit, newdata = dat, "probs")). I tried to numerically compute these weights in Java and R, and in both implementations I obtain the same results, which are in fact different from the values returned from predict.

Do you know how I may discover the implementation of the function predict(..., "probs") for the R function multinom?

¿Fue útil?

Solución

I first install nnet and open the help page for nnet function. I see that the function creates a nnet object.

I trypredict.nnet but nothing comes up. This means either the package is not loaded, the function doesn't exist or it's hidden. methods("predict") reveals that the object is actually hidden (indicated by the *).

> methods("predict")
 [1] predict.ar*                predict.Arima*             predict.arima0*            predict.glm               
 [5] predict.HoltWinters*       predict.lm                 predict.loess*             predict.mlm               
 [9] predict.multinom*          predict.nls*               predict.nnet*              predict.poly              
[13] predict.ppr*               predict.prcomp*            predict.princomp*          predict.smooth.spline*    
[17] predict.smooth.spline.fit* predict.StructTS*    

Calling this function explicitly reveals its code.

> nnet:::predict.nnet
function (object, newdata, type = c("raw", "class"), ...) 
{
    if (!inherits(object, "nnet")) 
        stop("object not of class \"nnet\"")
    type <- match.arg(type)
    if (missing(newdata)) 
        z <- fitted(object)
    else {
        if (inherits(object, "nnet.formula")) {
            newdata <- as.data.frame(newdata)
            rn <- row.names(newdata)
            Terms <- delete.response(object$terms)
            m <- model.frame(Terms, newdata, na.action = na.omit, 
                xlev = object$xlevels)
            if (!is.null(cl <- attr(Terms, "dataClasses"))) 
                .checkMFClasses(cl, m)
            keep <- match(row.names(m), rn)
            x <- model.matrix(Terms, m, contrasts = object$contrasts)
            xint <- match("(Intercept)", colnames(x), nomatch = 0L)
            if (xint > 0L) 
                x <- x[, -xint, drop = FALSE]
        }
        else {
            if (is.null(dim(newdata))) 
                dim(newdata) <- c(1L, length(newdata))
            x <- as.matrix(newdata)
            if (any(is.na(x))) 
                stop("missing values in 'x'")
            keep <- 1L:nrow(x)
            rn <- rownames(x)
        }
        ntr <- nrow(x)
        nout <- object$n[3L]
        .C(VR_set_net, as.integer(object$n), as.integer(object$nconn), 
            as.integer(object$conn), rep(0, length(object$wts)), 
            as.integer(object$nsunits), as.integer(0L), as.integer(object$softmax), 
            as.integer(object$censored))
        z <- matrix(NA, nrow(newdata), nout, dimnames = list(rn, 
            dimnames(object$fitted.values)[[2L]]))
        z[keep, ] <- matrix(.C(VR_nntest, as.integer(ntr), as.double(x), 
            tclass = double(ntr * nout), as.double(object$wts))$tclass, 
            ntr, nout)
        .C(VR_unset_net)
    }
    switch(type, raw = z, class = {
        if (is.null(object$lev)) stop("inappropriate fit for class")
        if (ncol(z) > 1L) object$lev[max.col(z)] else object$lev[1L + 
            (z > 0.5)]
    })
}
<bytecode: 0x0000000009305fd8>
<environment: namespace:nnet>    
Licenciado bajo: CC-BY-SA con atribución
No afiliado a StackOverflow
scroll top