package edu.agh.eit.neural;

import edu.agh.eit.neural.events.NeuralNetworkEventListener;
import edu.agh.eit.neural.events.NeuralNetworkLearningEvent;
import edu.agh.eit.neural.functions.ActivationFunction;
import java.util.Iterator;
import java.util.Random;
import javax.swing.event.EventListenerList;

/* loaded from: input_file:edu/agh/eit/neural/NeuralNetwork.class */
public class NeuralNetwork {
    private int inputs;
    private int outputs;
    private NeuralLayer inputLayer;
    private NeuralLayer outputLayer;
    private NeuralLayer[] hiddenLayers;
    private double mse = 0.0d;
    private NetworkState networkState = NetworkState.NEWLY_CREATED;
    private boolean useBias = true;
    protected EventListenerList learningListeners = new EventListenerList();

    /* loaded from: input_file:edu/agh/eit/neural/NeuralNetwork$NetworkState.class */
    public enum NetworkState {
        NEWLY_CREATED,
        LEARNING_IN_PROGRESS,
        LEARNING_TERMINATED,
        LEARNED
    }

    public NeuralNetwork(int i, int i2, NeuralLayer[] neuralLayerArr, ActivationFunction activationFunction) {
        this.inputLayer = null;
        this.outputLayer = null;
        this.hiddenLayers = null;
        if (i <= 0 || i2 <= 0) {
            throw new IllegalArgumentException();
        }
        this.inputs = i;
        this.inputLayer = new NeuralLayer(i, activationFunction);
        this.outputs = i2;
        this.outputLayer = new NeuralLayer(i2, activationFunction);
        if (neuralLayerArr == null || neuralLayerArr.length == 0) {
            this.inputLayer.connectWith(this.outputLayer);
            return;
        }
        this.hiddenLayers = neuralLayerArr;
        this.inputLayer.connectWith(neuralLayerArr[0]);
        for (int i3 = 0; i3 < neuralLayerArr.length - 1; i3++) {
            neuralLayerArr[i3].connectWith(neuralLayerArr[i3 + 1]);
        }
        neuralLayerArr[neuralLayerArr.length - 1].connectWith(this.outputLayer);
    }

    public void randomizeWeight(double d, double d2) {
        Random random = new Random();
        if (this.hiddenLayers != null) {
            for (int i = 0; i < this.hiddenLayers.length; i++) {
                for (Neuron neuron : this.hiddenLayers[i].getNeurons()) {
                    Iterator<Synapse> it = neuron.getInputSynapses().iterator();
                    while (it.hasNext()) {
                        it.next().setWeight((random.nextDouble() * (d2 - d)) + d);
                    }
                    if (this.useBias) {
                        neuron.setBiasWeight((random.nextDouble() * (d2 - d)) + d);
                    }
                }
            }
        }
        for (Neuron neuron2 : this.outputLayer.getNeurons()) {
            Iterator<Synapse> it2 = neuron2.getInputSynapses().iterator();
            while (it2.hasNext()) {
                it2.next().setWeight((random.nextDouble() * (d2 - d)) + d);
            }
            if (this.useBias) {
                neuron2.setBiasWeight((random.nextDouble() * (d2 - d)) + d);
            }
        }
    }

    public double[] compute(double[] dArr) {
        double[] dArr2 = new double[this.outputs];
        Neuron[] neurons = this.inputLayer.getNeurons();
        for (int i = 0; i < dArr.length; i++) {
            neurons[i].setOutput(dArr[i]);
        }
        if (this.hiddenLayers != null) {
            for (int i2 = 0; i2 < this.hiddenLayers.length; i2++) {
                for (Neuron neuron : this.hiddenLayers[i2].getNeurons()) {
                    neuron.compute();
                }
            }
        }
        Neuron[] neurons2 = this.outputLayer.getNeurons();
        for (int i3 = 0; i3 < dArr2.length; i3++) {
            neurons2[i3].compute();
            dArr2[i3] = neurons2[i3].getOutput();
        }
        return dArr2;
    }

    public void learn(double[][] dArr, double[][] dArr2, int i, double d) {
        this.networkState = NetworkState.LEARNING_IN_PROGRESS;
        for (int i2 = 0; i2 < i; i2++) {
            learn(dArr, dArr2, d);
            if (i2 % 100 == 0) {
                fireLearningEvent(new NeuralNetworkLearningEvent(this, i2, getMse()));
            }
        }
        this.networkState = NetworkState.LEARNED;
    }

    public void learn(double[][] dArr, double[][] dArr2, double d) {
        for (int i = 0; i < dArr.length; i++) {
            double[] compute = compute(dArr[i]);
            double d2 = 0.0d;
            for (int i2 = 0; i2 < this.outputLayer.getNeuronsCount(); i2++) {
                Neuron neuron = this.outputLayer.getNeurons()[i2];
                double d3 = dArr2[i][i2] - compute[i2];
                neuron.setError(d3);
                d2 += d3 * d3;
                Iterator<Synapse> it = neuron.getInputSynapses().iterator();
                while (it.hasNext()) {
                    Synapse next = it.next();
                    next.getNeuron().setError(next.getNeuron().getError() + (next.getWeight() * neuron.getError()));
                }
            }
            this.mse = Math.sqrt(d2 / (dArr2.length * this.outputs));
            if (this.hiddenLayers != null) {
                for (int length = this.hiddenLayers.length - 1; length > 0; length--) {
                    for (Neuron neuron2 : this.hiddenLayers[length].getNeurons()) {
                        Iterator<Synapse> it2 = neuron2.getInputSynapses().iterator();
                        while (it2.hasNext()) {
                            Synapse next2 = it2.next();
                            next2.getNeuron().setError(next2.getNeuron().getError() + (next2.getWeight() * neuron2.getError()));
                        }
                    }
                }
            }
            for (Neuron neuron3 : this.outputLayer.getNeurons()) {
                Iterator<Synapse> it3 = neuron3.getInputSynapses().iterator();
                while (it3.hasNext()) {
                    Synapse next3 = it3.next();
                    next3.setWeight(next3.getWeight() + (d * neuron3.getError() * neuron3.getActivationFunction().computeDerivative(neuron3.getOutput()) * next3.getNeuron().getOutput()));
                }
                if (this.useBias) {
                    neuron3.setBiasWeight(neuron3.getBiasWeight() + (d * neuron3.getError() * neuron3.getActivationFunction().computeDerivative(neuron3.getOutput())));
                }
                neuron3.setError(0.0d);
            }
            if (this.hiddenLayers != null) {
                for (int i3 = 0; i3 < this.hiddenLayers.length; i3++) {
                    for (Neuron neuron4 : this.hiddenLayers[i3].getNeurons()) {
                        Iterator<Synapse> it4 = neuron4.getInputSynapses().iterator();
                        while (it4.hasNext()) {
                            Synapse next4 = it4.next();
                            next4.setWeight(next4.getWeight() + (d * neuron4.getError() * neuron4.getActivationFunction().computeDerivative(neuron4.getOutput()) * next4.getNeuron().getOutput()));
                        }
                        if (this.useBias) {
                            neuron4.setBiasWeight(neuron4.getBiasWeight() + (d * neuron4.getError() * neuron4.getActivationFunction().computeDerivative(neuron4.getOutput())));
                        }
                        neuron4.setError(0.0d);
                    }
                }
            }
        }
    }

    public double getError() {
        return getMse();
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    public void setUseBias(boolean z) {
        this.useBias = z;
        for (NeuralLayer neuralLayer : this.hiddenLayers) {
            for (Neuron neuron : neuralLayer.getNeurons()) {
                neuron.setUseBias(z);
            }
        }
    }

    public void addLearningListener(NeuralNetworkEventListener neuralNetworkEventListener) {
        this.learningListeners.add(NeuralNetworkEventListener.class, neuralNetworkEventListener);
    }

    public void removeLearningListener(NeuralNetworkEventListener neuralNetworkEventListener) {
        this.learningListeners.remove(NeuralNetworkEventListener.class, neuralNetworkEventListener);
    }

    private void fireLearningEvent(NeuralNetworkLearningEvent neuralNetworkLearningEvent) {
        Object[] listenerList = this.learningListeners.getListenerList();
        for (int i = 0; i < listenerList.length; i++) {
            if (listenerList[i] == NeuralNetworkEventListener.class) {
                ((NeuralNetworkEventListener) listenerList[i + 1]).eventOccured(neuralNetworkLearningEvent);
            }
        }
    }

    public double getMse() {
        return this.mse;
    }

    public NetworkState getNetworkState() {
        return this.networkState;
    }

    public void setNetworkState(NetworkState networkState) {
        this.networkState = networkState;
    }
}
