I first install nnet
and open the help page for nnet
function. I see that the function creates a nnet
object.
I trypredict.nnet
but nothing comes up. This means either the package is not loaded, the function doesn't exist or it's hidden. methods("predict")
reveals that the object is actually hidden (indicated by the *
).
> methods("predict")
[1] predict.ar* predict.Arima* predict.arima0* predict.glm
[5] predict.HoltWinters* predict.lm predict.loess* predict.mlm
[9] predict.multinom* predict.nls* predict.nnet* predict.poly
[13] predict.ppr* predict.prcomp* predict.princomp* predict.smooth.spline*
[17] predict.smooth.spline.fit* predict.StructTS*
Calling this function explicitly reveals its code.
> nnet:::predict.nnet
function (object, newdata, type = c("raw", "class"), ...)
{
if (!inherits(object, "nnet"))
stop("object not of class \"nnet\"")
type <- match.arg(type)
if (missing(newdata))
z <- fitted(object)
else {
if (inherits(object, "nnet.formula")) {
newdata <- as.data.frame(newdata)
rn <- row.names(newdata)
Terms <- delete.response(object$terms)
m <- model.frame(Terms, newdata, na.action = na.omit,
xlev = object$xlevels)
if (!is.null(cl <- attr(Terms, "dataClasses")))
.checkMFClasses(cl, m)
keep <- match(row.names(m), rn)
x <- model.matrix(Terms, m, contrasts = object$contrasts)
xint <- match("(Intercept)", colnames(x), nomatch = 0L)
if (xint > 0L)
x <- x[, -xint, drop = FALSE]
}
else {
if (is.null(dim(newdata)))
dim(newdata) <- c(1L, length(newdata))
x <- as.matrix(newdata)
if (any(is.na(x)))
stop("missing values in 'x'")
keep <- 1L:nrow(x)
rn <- rownames(x)
}
ntr <- nrow(x)
nout <- object$n[3L]
.C(VR_set_net, as.integer(object$n), as.integer(object$nconn),
as.integer(object$conn), rep(0, length(object$wts)),
as.integer(object$nsunits), as.integer(0L), as.integer(object$softmax),
as.integer(object$censored))
z <- matrix(NA, nrow(newdata), nout, dimnames = list(rn,
dimnames(object$fitted.values)[[2L]]))
z[keep, ] <- matrix(.C(VR_nntest, as.integer(ntr), as.double(x),
tclass = double(ntr * nout), as.double(object$wts))$tclass,
ntr, nout)
.C(VR_unset_net)
}
switch(type, raw = z, class = {
if (is.null(object$lev)) stop("inappropriate fit for class")
if (ncol(z) > 1L) object$lev[max.col(z)] else object$lev[1L +
(z > 0.5)]
})
}
<bytecode: 0x0000000009305fd8>
<environment: namespace:nnet>