Machine Learning - deeplearning4j

Overview

Create a neural network with deeplearning4j library.

Machine Learning

Given a labelled data set, a machine learning algorithm will determine mathematical function based on the relationship between label and data, this can later be used to predict the label for any new input data. Neural networks are computational models that consist of interconnected layers of nodes. A neural network can derive information from new data, even if it has not seen these particular data items before.

Iris Flower classification problem states that given measurements of a flower, we should be able to predict which type of flower it is.

Flower classification

Class Number (Type)Class
0Iris Setosa
1Iris Versicolour
2Iris Virginica

Input Data

Sepal LengthSepal WidthPetal LengthPetal WidthClass (Type)
5.13.51.40.20 (Iris Setosa)

Training Step - First we train our model by using an existing data set of flower measurements.

  1. Load data
  2. Normalize data
  3. Split data set to training and test data
  4. Configure model - Creates neural network
  5. Train model
  6. Evaluate Model
  7. Export Model

Prediction Step - Using the model we created above we predict which flower type an input belongs to.

  1. Load Model
  2. Format Data
  3. Normal Data
  4. Feed Data
  5. Get Label

Code

  1package com.demo.neural;
  2
  3import java.io.File;
  4import java.io.IOException;
  5import java.util.Arrays;
  6import java.util.List;
  7
  8import lombok.extern.slf4j.Slf4j;
  9import org.datavec.api.records.reader.RecordReader;
 10import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
 11import org.datavec.api.split.FileSplit;
 12import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
 13import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
 14import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 15import org.deeplearning4j.nn.conf.layers.DenseLayer;
 16import org.deeplearning4j.nn.conf.layers.OutputLayer;
 17import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 18import org.deeplearning4j.nn.weights.WeightInit;
 19import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
 20import org.nd4j.evaluation.classification.Evaluation;
 21import org.nd4j.linalg.activations.Activation;
 22import org.nd4j.linalg.api.buffer.DataBuffer;
 23import org.nd4j.linalg.api.ndarray.INDArray;
 24import org.nd4j.linalg.cpu.nativecpu.NDArray;
 25import org.nd4j.linalg.cpu.nativecpu.buffer.FloatBuffer;
 26import org.nd4j.linalg.dataset.SplitTestAndTrain;
 27import org.nd4j.linalg.dataset.api.DataSet;
 28import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 29import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
 30import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
 31import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializer;
 32import org.nd4j.linalg.learning.config.Sgd;
 33import org.nd4j.linalg.lossfunctions.LossFunctions;
 34import org.springframework.boot.CommandLineRunner;
 35import org.springframework.boot.SpringApplication;
 36import org.springframework.boot.autoconfigure.SpringBootApplication;
 37import org.springframework.context.annotation.Bean;
 38
 39@Slf4j
 40@SpringBootApplication
 41public class Main {
 42
 43    //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
 44    final int LABEL_INDEX = 4;
 45    //3 classes (types of iris flowers) in the iris data set. Classes have integer values
 46    // 0, 1 or 2
 47    final int NUM_CLASS = 3;
 48    //Iris data set: 150 examples total. We are loading all of them into one DataSet
 49    // (not recommended for large data sets)
 50    final int BATCH_SIZE = 150;
 51
 52    final List<String> flowerType = Arrays.asList("Iris Setosa", "Iris Versicolour", "Iris Virginica");
 53
 54    public static void main(String[] args) {
 55        SpringApplication.run(Main.class, args);
 56    }
 57
 58    @Bean
 59    public CommandLineRunner sendData() {
 60        return args -> {
 61            generateModel();
 62            log.info("Flower type is {}", predictForInput(new float[]{5.1f, 3.5f, 1.4f, 0.2f}));
 63            log.info("Flower type is {}", predictForInput(new float[]{6.5f, 3.0f, 5.5f, 1.8f}));
 64        };
 65    }
 66
 67    private DataSet loadDataSet() throws IOException, InterruptedException {
 68        int numLinesToSkip = 0;
 69        char delimiter = ',';
 70        RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
 71        recordReader.initialize(new FileSplit(new File("src/main/resources/iris.txt")));
 72
 73        //RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
 74        DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, BATCH_SIZE, LABEL_INDEX, NUM_CLASS);
 75        DataSet allData = iterator.next();
 76        allData.shuffle();
 77        return allData;
 78    }
 79
 80    private void generateModel() throws IOException, InterruptedException {
 81        DataSet allData = loadDataSet();
 82        //Use 65% of data for training
 83        SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
 84
 85        DataSet trainingData = testAndTrain.getTrain();
 86        DataSet testData = testAndTrain.getTest();
 87
 88        //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
 89        DataNormalization normalizer = new NormalizerStandardize();
 90        //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
 91        normalizer.fit(trainingData);
 92        //Apply normalization to the training data
 93        normalizer.transform(trainingData);
 94        //Apply normalization to the test data. This is using statistics calculated from the *training* set
 95        normalizer.transform(testData);
 96
 97        final int numInputs = 4;
 98        int outputNum = 3;
 99        long seed = 6;
100
101        log.info("Build model....");
102        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
103                .seed(seed)
104                .activation(Activation.TANH)
105                .weightInit(WeightInit.XAVIER)
106                .updater(new Sgd(0.1))
107                .l2(1e-4)
108                .list()
109                .layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)
110                        .build())
111                .layer(new DenseLayer.Builder().nIn(3).nOut(3)
112                        .build())
113                .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
114                        .activation(Activation.SOFTMAX) //Override the global TANH activation with softmax for this layer
115                        .nIn(3).nOut(outputNum).build())
116                .build();
117
118        //run the model
119        MultiLayerNetwork model = new MultiLayerNetwork(conf);
120        model.init();
121        //record score once every 100 iterations
122        model.setListeners(new ScoreIterationListener(100));
123
124        for (int i = 0; i < 1000; i++) {
125            model.fit(trainingData);
126        }
127
128        //evaluate the model on the test set
129        Evaluation eval = new Evaluation(3);
130        INDArray output = model.output(testData.getFeatures());
131        eval.eval(testData.getLabels(), output);
132        log.info(eval.stats());
133        saveModelAndNormalizer(model, normalizer);
134
135    }
136
137    private void saveModelAndNormalizer(MultiLayerNetwork model, DataNormalization normalizer) throws IOException {
138        log.info("Saving model & normalizer!");
139        File modelFile = new File("model.file");
140        model.save(modelFile, false);
141
142        File normalizerFile = new File("normalize.file");
143        NormalizerSerializer.getDefault().write(normalizer, normalizerFile);
144    }
145
146    private String predictForInput(float[] input) throws Exception {
147        log.info("Loading model & normalizer!");
148        File modelFile = new File("model.file");
149        MultiLayerNetwork model = MultiLayerNetwork.load(modelFile, false);
150        File normalizerFile = new File("normalize.file");
151        DataNormalization normalizer = NormalizerSerializer.getDefault().restore(normalizerFile);
152
153        DataBuffer dataBuffer = new FloatBuffer(input);
154        NDArray ndArray = new NDArray(1, 4);
155        ndArray.setData(dataBuffer);
156
157        normalizer.transform(ndArray);
158        INDArray result = model.output(ndArray, false);
159        getBestPredicationIndex(result);
160
161        return flowerType.get(getBestPredicationIndex(result));
162    }
163
164    private int getBestPredicationIndex(INDArray predictions) {
165        int maxIndex = 0;
166        for (int i = 0; i < 3; i++) {
167            if (predictions.getFloat(i) > predictions.getFloat(maxIndex)) {
168                maxIndex = i;
169            }
170        }
171        return maxIndex;
172    }
173
174}

Setup

1# Project102
2
3deeplearning4j - Supervised classification (Neural Networks)

References

https://deeplearning4j.konduit.ai/

comments powered by Disqus