Java创作有关简易状态机函数用于二分预测

import java.util.List;
import java.util.Random;

public class PercepAMFuncmachine {

    private double[] weights;
    private double bias;
    private double learningRate;
    public int numInputsPublic;
    private int epochss;

    PercUi percUi = new PercUi();

    public PercepAMFuncmachine(int numInputsNuberOfWeights, double learningRate, double bias, int epochss) {
        this.learningRate = learningRate;
        this.bias = bias;
        this.weights = new double[numInputsNuberOfWeights];
        numInputsPublic = numInputsNuberOfWeights;
        this.epochss = epochss;
    }

    public void initialWeightsData() {
        weights = percUi.weightss;
    }

    private double activationFunction(double sum) {
        return Math.tanh(sum);
    }

    public double predict(double[] inputs) {
        double sum = 0.0;

        for (int i = 0; i < percUi.numInputsNuberOfWeights; i++) {
            sum += weights[i] * inputs[i];
        }
        sum += bias;
        return activationFunction(sum);
    }

    public void train(List<double[]> trainingData, double[] labels, int epochs) {
        Random random = new Random();
        epochs = epochss;
        for (int epoch = 0; epoch < epochs; epoch++) {
            for (int i = 0; i < trainingData.size(); i++) {
                int randomIndex = random.nextInt(trainingData.size());
                double[] tempData = trainingData.get(i);
                trainingData.set(i, trainingData.get(randomIndex));
                trainingData.set(randomIndex, tempData);

                double tempLabel = labels[i];
                labels[i] = labels[randomIndex];
                labels[randomIndex] = tempLabel;
            }

            for (int i = 0; i < trainingData.size(); i++) {
                double[] inputs = trainingData.get(i);
                double label = labels[i];
                double prediction = predict(inputs);
                double error = label - prediction;

                for (int j = 0; j < percUi.numInputsNuberOfWeights; j++) {
                    weights[j] += learningRate * error * inputs[j];
                }
                bias += learningRate * error;
            }
        }
    }
}

状态机核心函数,用于二分预测,使用tanh函数来避免在某一区域内函数过平的问题。

具体解析:

weights:权重,每一部分在计算中所占的重要性。

bias:偏置率。偏置率用于调整模型输出,提升拟合能力,避免预测偏差。

learningRate:学习率。用于每次进行矫正,此数越低越准确但学习更慢。

numInputsPublic:输入数据。用于输入“每一部分的得分”。

epoches:校正次数。

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.Scanner;

public class PerceptronMiddle {

    private int numInputsNuberOfWeights, epoches, trainTimes;
    private double learningRate, bias;
    private Scanner scanner = new Scanner(System.in);

    public PerceptronMiddle(int numInputsNuberOfWeights, double learningRate, double bias, int epoches, int trainTimes) {
        this.numInputsNuberOfWeights = numInputsNuberOfWeights;
        this.learningRate = learningRate;
        this.bias = bias;
        this.epoches = epoches;
        this.trainTimes = trainTimes;
    }

    public void run() {
        PercepAMFuncmachine percepAMFuncmachine = new PercepAMFuncmachine(numInputsNuberOfWeights, learningRate, bias, epoches);

        percepAMFuncmachine.initialWeightsData();

        Random random = new Random();
        List<double[]> trainingData = new ArrayList<>();
        double[] labels = new double[trainTimes];

        for (int i = 0; i < trainTimes; i++) {
            double[] dataPoint = new double[numInputsNuberOfWeights];

            for (int j = 0; j < numInputsNuberOfWeights; j++) {
                dataPoint[j] = random.nextDouble();
            }
            trainingData.add(dataPoint);

            double label = 2 * (dataPoint[0] + dataPoint[1]) - 1;
            labels[i] = label;
        }

        percepAMFuncmachine.train(trainingData, labels, epoches);

        double[] testData = new double[numInputsNuberOfWeights];
        System.out.println("Enter test data:");

        for (int i = 0; i < numInputsNuberOfWeights; i++) {
            testData[i] = scanner.nextDouble();
        }

        double prediction = percepAMFuncmachine.predict(testData);
        System.out.println("Prediction: " + prediction);
    }
}

中间类,用于更严实的封装,在外部直接填入数字就能用。

import java.util.Scanner;

public class PercUi {

    public int numInputsNuberOfWeights, epoches, trainTimes;
    public double learningRate, bias;
    public double[] weightss;

    public void u() {
        Scanner scanner = new Scanner(System.in);
        System.out.println("the numInputsNuberOfWeights");
        numInputsNuberOfWeights = scanner.nextInt();

        System.out.println("the learningrate");
        learningRate = scanner.nextDouble();
        System.out.println("the bias");
        bias = scanner.nextDouble();
        System.out.println("the epoches");
        epoches = scanner.nextInt();
        System.out.println("the training times");
        trainTimes = scanner.nextInt();

        System.out.println("weights");

        weightss = new double[numInputsNuberOfWeights];

        for (int i = 0; i < numInputsNuberOfWeights; i++) {
            weightss[i] = scanner.nextDouble();
        }
    }

    public static void main(String[] args) {
        PercUi percUi = new PercUi();
        percUi.u();
        PerceptronMiddle trainer = new PerceptronMiddle(percUi.numInputsNuberOfWeights, percUi.learningRate, percUi.bias, percUi.epoches, percUi.trainTimes);
        trainer.run();
    }
}

ui界面(其实只有控制台)。

这些代码实现了严格封装的感知机,用于进行简易预测,也可搭建神经网络。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值