Cross Validation in Keras
-
16-10-2019 - |
Question
Suppose I would like to train and test the MNIST dataset in Keras.
The required data can be loaded as follows:
from keras.datasets import mnist
digits_data = mnist.load_data()
Is there any way in keras to split this data into three sets namely: training_data
, test_data
, and cross_validation_data
?
Solution
From the Keras documentation, you can load the data into Train and Test sets like this
(X_train, y_train), (X_test, y_test) = mnist.load_data()
As for cross validation, you could follow this example from https://github.com/fchollet/keras/issues/1711
from sklearn.model_selection import StratifiedKFold
def load_data():
# load your data using this function
def create model():
# create your model using this function
def train_and_evaluate__model(model, data_train, labels_train, data_test, labels_test):
model.fit...
# fit and evaluate here.
if __name__ == "__main__":
n_folds = 10
data, labels, header_info = load_data()
skf = StratifiedKFold(labels, n_folds=n_folds, shuffle=True)
for i, (train, test) in enumerate(skf):
print "Running Fold", i+1, "/", n_folds
model = None # Clearing the NN.
model = create_model()
train_and_evaluate_model(model, data[train], labels[train], data[test], labels[test])
OTHER TIPS
Not in keras. I normally just use sklearn's train_test_split function:
from sklearn.model_selection import train_test_split
train, test = train_test_split(data, train_size=0.8)
Keras also has sklearn wrappers that might be useful later on:
Licensed under: CC-BY-SA with attribution
Not affiliated with datascience.stackexchange