How to build a neural network in Java
Artificial neural networks are a form of deep learning and one of the pillars of modern-day AI. The best way to really get a grip on how these things work is to build one. This article will be a hands-on introduction to building and training a neural network in Java.
See my previous article, Styles of machine learning: Intro to neural networks for an overview of how artificial neural networks operate. Our example for this article is by no means a production-grade system; instead, it shows all the main components in a demo that is designed to be easy to understand.
A basic neural network
A neural network is a graph of nodes called neurons. The neuron is the basic unit of computation. It receives inputs and processes them using a weight-per-input, bias-per-node, and final function processor (known as the activation function) algorithm. You can see a two-input neuron illustrated in Figure 1.
This model has a wide range of variability, but we’ll use this exact configuration for the demo.
Our first step is to model a Neuron
class that will hold these values. You can see the Neuron
class in Listing 1. Note that this is a first version of the class. It will change as we add functionality.
Listing 1. A simple Neuron class
class Neuron
Random random = new Random();
private Double bias = random.nextDouble(-1, 1);
public Double weight1 = random.nextDouble(-1, 1);
private Double weight2 = random.nextDouble(-1, 1);
public double compute(double input1, double input2)
double preActivation = (this.weight1 * input1) + (this.weight2 * input2) + this.bias;
double output = Util.sigmoid(preActivation);
return output;
You can see that the Neuron
class is quite simple, with three members: bias
, weight1
, and weight2
. Each member is initialized to a random double between -1 and 1.
When we compute the output for the neuron, we follow the algorithm shown in Figure 1: multiply each input by its weight, plus the bias: input1 * weight1 + input2 * weight2 + bias. This gives us the unprocessed calculation (i.e., preActivation
) that we run through the activation function. In this case, we use the Sigmoid activation function, which compresses values into a -1 to 1 range. Listing 2 shows the Util.sigmoid()
static method.
Listing 2. Sigmoid activation function
public class Util
public static double sigmoid(double in)
return 1 / (1 + Math.exp(-in));
Now that we’ve seen how neurons work, let’s put some neurons into a network. We’ll use a Network
class with a list of neurons as shown in Listing 3.
Listing 3. The neural network class
class Network
List<Neuron> neurons = Arrays.asList(
new Neuron(), new Neuron(), new Neuron(), /* input nodes */
new Neuron(), new Neuron(), /* hidden nodes */
new Neuron()); /* output node */
}
Although the list of neurons is one-dimensional, we’ll connect them during usage so that they form a network. The first three neurons are inputs, the second and third are hidden, and the last one is the output node.
Make a prediction
Now, let’s use the network to make a prediction. We’re going to use a simple data set of two input integers and an answer format of 0 to 1. My example uses a weight-height combination to guess a person’s gender based on the assumption that more weight and height indicate a person is male. We could use the same formula for any two-factor, single-output probability. We could think of the input as a vector and therefore the overall function of the neurons as transforming a vector to a scalar value.
The prediction phase of the network looks like Listing 4.
Listing 4. Network prediction
public Double predict(Integer input1, Integer input2)
return neurons.get(5).compute(
neurons.get(4).compute(
neurons.get(2).compute(input1, input2),
neurons.get(1).compute(input1, input2)
),
neurons.get(3).compute(
neurons.get(1).compute(input1, input2),
neurons.get(0).compute(input1, input2)
)
);
Listing 4 shows that the two inputs are fed into the first three neurons, whose output is then piped into neurons 4 and 5, which in turn feed into the output neuron. This process is known as a feedforward.
Now, we could ask the network to make a prediction, as shown in Listing 5.
Listing 5. Get a prediction
Network network = new Network();
Double prediction = network.predict(Arrays.asList(115, 66));
System.out.println(“prediction: “ + prediction);
We’d get something, for sure, but it would be the result of the random weights and biases. For a real prediction, we need to first train the network.
Train the network
Training a neural network follows a process known as backpropagation, which I will introduce in more depth in my next article. Backpropagation is basically pushing changes backward through the network to make the output move toward a desired target.
We can perform backpropagation using function differentiation, but for our example, we’re going to do something different. We will give every neuron the capacity to “mutate.” On each round of training (known as an epoch), we pick a different neuron to make a small, random adjustment to one of its properties (weight1
, weight2
, or bias
) and then check to see if the results improved. If the results improved, we’ll keep that change with a remember()
method. If the results worsened, we’ll abandon the change with a forget()
method.
We’ll add class members (old*
versions of weights and bias) to track the changes. You can see the mutate()
, remember()
, and forget()
methods in Listing 6.
Listing 6. mutate(), remember(), forget()
public class Neuron()
private Double oldBias = random.nextDouble(-1, 1), bias = random.nextDouble(-1, 1);
public Double oldWeight1 = random.nextDouble(-1, 1), weight1 = random.nextDouble(-1, 1);
private Double oldWeight2 = random.nextDouble(-1, 1), weight2 = random.nextDouble(-1, 1);
public void mutate()
int propertyToChange = random.nextInt(0, 3);
Double changeFactor = random.nextDouble(-1, 1);
if (propertyToChange == 0)
this.bias += changeFactor;
else if (propertyToChange == 1)
this.weight1 += changeFactor;
else
this.weight2 += changeFactor;
;
public void forget()
bias = oldBias;
weight1 = oldWeight1;
weight2 = oldWeight2;
public void remember()
oldBias = bias;
oldWeight1 = weight1;
oldWeight2 = weight2;
Pretty simple: The mutate()
method picks a property at random and a value between -1 and 1 at random and then changes the property. The forget()
method rolls that change back to the old value. The remember()
method copies the new value to the buffer.
Now, to make use of our Neuron
’s new capabilities, we add a train()
method to Network
, as shown in Listing 7.
Listing 7. The Network.train() method
public void train(List<List<Integer>> data, List<Double> answers){
Double bestEpochLoss = null;
for (int epoch = 0; epoch < 1000; epoch++)
// adapt neuron
Neuron epochNeuron = neurons.get(epoch % 6);
epochNeuron.mutate(this.learnFactor);
List<Double> predictions = new ArrayList<Double>();
for (int i = 0; i < data.size(); i++)
predictions.add(i, this.predict(data.get(i).get(0), data.get(i).get(1)));
Double thisEpochLoss = Util.meanSquareLoss(answers, predictions);
if (bestEpochLoss == null)
bestEpochLoss = thisEpochLoss;
epochNeuron.remember();
else
if (thisEpochLoss < bestEpochLoss)
bestEpochLoss = thisEpochLoss;
epochNeuron.remember();
else
epochNeuron.forget();
The train()
method iterates one thousand times over the data
and answers
List
s in the argument. These are training sets of the same size; data
holds input values and answers
holds their known, good answers. The method then iterates over them and gets a value for how well the network guessed the result compared to the known, correct answers. Then, it mutates a random neuron, keeping the change if a new test reveals it was a better prediction.
Check the results
We can check the results using the mean squared error (MSE) formula, a common way to test a set of results in a neural network. You can see our MSE function in Listing 8.
Listing 8. MSE function
public static Double meanSquareLoss(List<Double> correctAnswers, List<Double> predictedAnswers)
double sumSquare = 0;
for (int i = 0; i < correctAnswers.size(); i++)
double error = correctAnswers.get(i) - predictedAnswers.get(i);
sumSquare += (error * error);
return sumSquare / (correctAnswers.size());
Fine-tune the system
Now all that remains is to put some training data into the network and try it out with more predictions. Listing 9 show how we provide training data.
Listing 9. Training data
List<List<Integer>> data = new ArrayList<List<Integer>>();
data.add(Arrays.asList(115, 66));
data.add(Arrays.asList(175, 78));
data.add(Arrays.asList(205, 72));
data.add(Arrays.asList(120, 67));
List<Double> answers = Arrays.asList(1.0,0.0,0.0,1.0);
Network network = new Network();
network.train(data, answers);
In Listing 9 our training data is a list of two dimensional integer sets (we could think of them as weight and height) and then a list of answers (with 1.0 being female and 0.0 being male).
If we add a bit of logging to the training algorithm, running it will give output similar to Listing 10.
Listing 10. Logging the trainer
// Logging:
if (epoch % 10 == 0) System.out.println(String.format("Epoch: %s | bestEpochLoss: %.15f | thisEpochLoss: %.15f", epoch, bestEpochLoss, thisEpochLoss));
// output:
Epoch: 910 | bestEpochLoss: 0.034404863820424 | thisEpochLoss: 0.034437939546120
Epoch: 920 | bestEpochLoss: 0.033875954196897 | thisEpochLoss: 0.431451026477016
Epoch: 930 | bestEpochLoss: 0.032509260025490 | thisEpochLoss: 0.032509260025490
Epoch: 940 | bestEpochLoss: 0.003092720117159 | thisEpochLoss: 0.003098025397281
Epoch: 950 | bestEpochLoss: 0.002990128276146 | thisEpochLoss: 0.431062364628853
Epoch: 960 | bestEpochLoss: 0.001651762688346 | thisEpochLoss: 0.001651762688346
Epoch: 970 | bestEpochLoss: 0.001637709485751 | thisEpochLoss: 0.001636810460399
Epoch: 980 | bestEpochLoss: 0.001083365453009 | thisEpochLoss: 0.391527869500699
Epoch: 990 | bestEpochLoss: 0.001078338540452 | thisEpochLoss: 0.001078338540452
Listing 10 shows the loss (error divergence from exactly right) slowly declining; that is, it’s getting closer to making accurate predictions. All that remains is to see how well our model predicts with real data, as shown in Listing 11.
Listing 11. Predicting
System.out.println("");
System.out.println(String.format(" male, 167, 73: %.10f", network.predict(167, 73)));
System.out.println(String.format("female, 105, 67: %.10", network.predict(105, 67)));
System.out.println(String.format("female, 120, 72: %.10f | network1000: %.10f", network.predict(120, 72)));
System.out.println(String.format(" male, 143, 67: %.10f | network1000: %.10f", network.predict(143, 67)));
System.out.println(String.format(" male', 130, 66: %.10f | network: %.10f", network.predict(130, 66)));
In Listing 11, we take our trained network and feed it some data, outputting the predictions. We get something like Listing 12.
Listing 12. Trained predictions
male, 167, 73: 0.0279697143
female, 105, 67: 0.9075809407
female, 120, 72: 0.9075808235
male, 143, 67: 0.0305401413
male, 130, 66: network: 0.9009811922
In Listing 12, we see the network has done a pretty good job with most value pairs (aka vectors). It gives the female data sets an estimate around .907, which is pretty close to one. Two males show .027 and .030—approaching 0. The outlier male data set (130, 67) is seen as probably female, but with less confidence at .900.
Conclusion
There are a number of ways to adjust the dials on this system. For one, the number of epochs in a training run is a major factor. The more epochs, the more tuned to the data the model becomes. Running more epochs can improve the accuracy of live data that conforms to the training sets, but it can also result in over-training; that is, a model that confidently predicts wrong outcomes for edge cases.
Visit my GitHub repository for the complete code for this tutorial, along with some extra bells and whistles.
Copyright © 2023 IDG Communications, Inc.