Machine Learning - deeplearning4j
Overview
Create a neural network with deeplearning4j library.
Machine Learning
Given a labelled data set, a machine learning algorithm will determine mathematical function based on the relationship between label and data, this can later be used to predict the label for any new input data. Neural networks are computational models that consist of interconnected layers of nodes. A neural network can derive information from new data, even if it has not seen these particular data items before.
Iris Flower classification problem states that given measurements of a flower, we should be able to predict which type of flower it is.
Flower classification
Class Number (Type) | Class |
---|---|
0 | Iris Setosa |
1 | Iris Versicolour |
2 | Iris Virginica |
Input Data
Sepal Length | Sepal Width | Petal Length | Petal Width | Class (Type) |
---|---|---|---|---|
5.1 | 3.5 | 1.4 | 0.2 | 0 (Iris Setosa) |
Training Step - First we train our model by using an existing data set of flower measurements.
- Load data
- Normalize data
- Split data set to training and test data
- Configure model - Creates neural network
- Train model
- Evaluate Model
- Export Model
Prediction Step - Using the model we created above we predict which flower type an input belongs to.
- Load Model
- Format Data
- Normal Data
- Feed Data
- Get Label
Code
1package com.demo.neural;
2
3import java.io.File;
4import java.io.IOException;
5import java.util.Arrays;
6import java.util.List;
7
8import lombok.extern.slf4j.Slf4j;
9import org.datavec.api.records.reader.RecordReader;
10import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
11import org.datavec.api.split.FileSplit;
12import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
13import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
14import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
15import org.deeplearning4j.nn.conf.layers.DenseLayer;
16import org.deeplearning4j.nn.conf.layers.OutputLayer;
17import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
18import org.deeplearning4j.nn.weights.WeightInit;
19import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
20import org.nd4j.evaluation.classification.Evaluation;
21import org.nd4j.linalg.activations.Activation;
22import org.nd4j.linalg.api.buffer.DataBuffer;
23import org.nd4j.linalg.api.ndarray.INDArray;
24import org.nd4j.linalg.cpu.nativecpu.NDArray;
25import org.nd4j.linalg.cpu.nativecpu.buffer.FloatBuffer;
26import org.nd4j.linalg.dataset.SplitTestAndTrain;
27import org.nd4j.linalg.dataset.api.DataSet;
28import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
29import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
30import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
31import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializer;
32import org.nd4j.linalg.learning.config.Sgd;
33import org.nd4j.linalg.lossfunctions.LossFunctions;
34import org.springframework.boot.CommandLineRunner;
35import org.springframework.boot.SpringApplication;
36import org.springframework.boot.autoconfigure.SpringBootApplication;
37import org.springframework.context.annotation.Bean;
38
39@Slf4j
40@SpringBootApplication
41public class Main {
42
43 //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
44 final int LABEL_INDEX = 4;
45 //3 classes (types of iris flowers) in the iris data set. Classes have integer values
46 // 0, 1 or 2
47 final int NUM_CLASS = 3;
48 //Iris data set: 150 examples total. We are loading all of them into one DataSet
49 // (not recommended for large data sets)
50 final int BATCH_SIZE = 150;
51
52 final List<String> flowerType = Arrays.asList("Iris Setosa", "Iris Versicolour", "Iris Virginica");
53
54 public static void main(String[] args) {
55 SpringApplication.run(Main.class, args);
56 }
57
58 @Bean
59 public CommandLineRunner sendData() {
60 return args -> {
61 generateModel();
62 log.info("Flower type is {}", predictForInput(new float[]{5.1f, 3.5f, 1.4f, 0.2f}));
63 log.info("Flower type is {}", predictForInput(new float[]{6.5f, 3.0f, 5.5f, 1.8f}));
64 };
65 }
66
67 private DataSet loadDataSet() throws IOException, InterruptedException {
68 int numLinesToSkip = 0;
69 char delimiter = ',';
70 RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
71 recordReader.initialize(new FileSplit(new File("src/main/resources/iris.txt")));
72
73 //RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
74 DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, BATCH_SIZE, LABEL_INDEX, NUM_CLASS);
75 DataSet allData = iterator.next();
76 allData.shuffle();
77 return allData;
78 }
79
80 private void generateModel() throws IOException, InterruptedException {
81 DataSet allData = loadDataSet();
82 //Use 65% of data for training
83 SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
84
85 DataSet trainingData = testAndTrain.getTrain();
86 DataSet testData = testAndTrain.getTest();
87
88 //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
89 DataNormalization normalizer = new NormalizerStandardize();
90 //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
91 normalizer.fit(trainingData);
92 //Apply normalization to the training data
93 normalizer.transform(trainingData);
94 //Apply normalization to the test data. This is using statistics calculated from the *training* set
95 normalizer.transform(testData);
96
97 final int numInputs = 4;
98 int outputNum = 3;
99 long seed = 6;
100
101 log.info("Build model....");
102 MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
103 .seed(seed)
104 .activation(Activation.TANH)
105 .weightInit(WeightInit.XAVIER)
106 .updater(new Sgd(0.1))
107 .l2(1e-4)
108 .list()
109 .layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)
110 .build())
111 .layer(new DenseLayer.Builder().nIn(3).nOut(3)
112 .build())
113 .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
114 .activation(Activation.SOFTMAX) //Override the global TANH activation with softmax for this layer
115 .nIn(3).nOut(outputNum).build())
116 .build();
117
118 //run the model
119 MultiLayerNetwork model = new MultiLayerNetwork(conf);
120 model.init();
121 //record score once every 100 iterations
122 model.setListeners(new ScoreIterationListener(100));
123
124 for (int i = 0; i < 1000; i++) {
125 model.fit(trainingData);
126 }
127
128 //evaluate the model on the test set
129 Evaluation eval = new Evaluation(3);
130 INDArray output = model.output(testData.getFeatures());
131 eval.eval(testData.getLabels(), output);
132 log.info(eval.stats());
133 saveModelAndNormalizer(model, normalizer);
134
135 }
136
137 private void saveModelAndNormalizer(MultiLayerNetwork model, DataNormalization normalizer) throws IOException {
138 log.info("Saving model & normalizer!");
139 File modelFile = new File("model.file");
140 model.save(modelFile, false);
141
142 File normalizerFile = new File("normalize.file");
143 NormalizerSerializer.getDefault().write(normalizer, normalizerFile);
144 }
145
146 private String predictForInput(float[] input) throws Exception {
147 log.info("Loading model & normalizer!");
148 File modelFile = new File("model.file");
149 MultiLayerNetwork model = MultiLayerNetwork.load(modelFile, false);
150 File normalizerFile = new File("normalize.file");
151 DataNormalization normalizer = NormalizerSerializer.getDefault().restore(normalizerFile);
152
153 DataBuffer dataBuffer = new FloatBuffer(input);
154 NDArray ndArray = new NDArray(1, 4);
155 ndArray.setData(dataBuffer);
156
157 normalizer.transform(ndArray);
158 INDArray result = model.output(ndArray, false);
159 getBestPredicationIndex(result);
160
161 return flowerType.get(getBestPredicationIndex(result));
162 }
163
164 private int getBestPredicationIndex(INDArray predictions) {
165 int maxIndex = 0;
166 for (int i = 0; i < 3; i++) {
167 if (predictions.getFloat(i) > predictions.getFloat(maxIndex)) {
168 maxIndex = i;
169 }
170 }
171 return maxIndex;
172 }
173
174}
Setup
1# Project102
2
3deeplearning4j - Supervised classification (Neural Networks)