import java.util.List;
import java.util.Random;
public class PercepAMFuncmachine {
private double[] weights;
private double bias;
private double learningRate;
public int numInputsPublic;
private int epochss;
PercUi percUi = new PercUi();
public PercepAMFuncmachine(int numInputsNuberOfWeights, double learningRate, double bias, int epochss) {
this.learningRate = learningRate;
this.bias = bias;
this.weights = new double[numInputsNuberOfWeights];
numInputsPublic = numInputsNuberOfWeights;
this.epochss = epochss;
}
public void initialWeightsData() {
weights = percUi.weightss;
}
private double activationFunction(double sum) {
return Math.tanh(sum);
}
public double predict(double[] inputs) {
double sum = 0.0;
for (int i = 0; i < percUi.numInputsNuberOfWeights; i++) {
sum += weights[i] * inputs[i];
}
sum += bias;
return activationFunction(sum);
}
public void train(List<double[]> trainingData, double[] labels, int epochs) {
Random random = new Random();
epochs = epochss;
for (int epoch = 0; epoch < epochs; epoch++) {
for (int i = 0; i < trainingData.size(); i++) {
int randomIndex = random.nextInt(trainingData.size());
double[] tempData = trainingData.get(i);
trainingData.set(i, trainingData.get(randomIndex));
trainingData.set(randomIndex, tempData);
double tempLabel = labels[i];
labels[i] = labels[randomIndex];
labels[randomIndex] = tempLabel;
}
for (int i = 0; i < trainingData.size(); i++) {
double[] inputs = trainingData.get(i);
double label = labels[i];
double prediction = predict(inputs);
double error = label - prediction;
for (int j = 0; j < percUi.numInputsNuberOfWeights; j++) {
weights[j] += learningRate * error * inputs[j];
}
bias += learningRate * error;
}
}
}
}
状态机核心函数,用于二分预测,使用tanh函数来避免在某一区域内函数过平的问题。
具体解析:
weights:权重,每一部分在计算中所占的重要性。
bias:偏置率。偏置率用于调整模型输出,提升拟合能力,避免预测偏差。
learningRate:学习率。用于每次进行矫正,此数越低越准确但学习更慢。
numInputsPublic:输入数据。用于输入“每一部分的得分”。
epoches:校正次数。
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.Scanner;
public class PerceptronMiddle {
private int numInputsNuberOfWeights, epoches, trainTimes;
private double learningRate, bias;
private Scanner scanner = new Scanner(System.in);
public PerceptronMiddle(int numInputsNuberOfWeights, double learningRate, double bias, int epoches, int trainTimes) {
this.numInputsNuberOfWeights = numInputsNuberOfWeights;
this.learningRate = learningRate;
this.bias = bias;
this.epoches = epoches;
this.trainTimes = trainTimes;
}
public void run() {
PercepAMFuncmachine percepAMFuncmachine = new PercepAMFuncmachine(numInputsNuberOfWeights, learningRate, bias, epoches);
percepAMFuncmachine.initialWeightsData();
Random random = new Random();
List<double[]> trainingData = new ArrayList<>();
double[] labels = new double[trainTimes];
for (int i = 0; i < trainTimes; i++) {
double[] dataPoint = new double[numInputsNuberOfWeights];
for (int j = 0; j < numInputsNuberOfWeights; j++) {
dataPoint[j] = random.nextDouble();
}
trainingData.add(dataPoint);
double label = 2 * (dataPoint[0] + dataPoint[1]) - 1;
labels[i] = label;
}
percepAMFuncmachine.train(trainingData, labels, epoches);
double[] testData = new double[numInputsNuberOfWeights];
System.out.println("Enter test data:");
for (int i = 0; i < numInputsNuberOfWeights; i++) {
testData[i] = scanner.nextDouble();
}
double prediction = percepAMFuncmachine.predict(testData);
System.out.println("Prediction: " + prediction);
}
}
中间类,用于更严实的封装,在外部直接填入数字就能用。
import java.util.Scanner;
public class PercUi {
public int numInputsNuberOfWeights, epoches, trainTimes;
public double learningRate, bias;
public double[] weightss;
public void u() {
Scanner scanner = new Scanner(System.in);
System.out.println("the numInputsNuberOfWeights");
numInputsNuberOfWeights = scanner.nextInt();
System.out.println("the learningrate");
learningRate = scanner.nextDouble();
System.out.println("the bias");
bias = scanner.nextDouble();
System.out.println("the epoches");
epoches = scanner.nextInt();
System.out.println("the training times");
trainTimes = scanner.nextInt();
System.out.println("weights");
weightss = new double[numInputsNuberOfWeights];
for (int i = 0; i < numInputsNuberOfWeights; i++) {
weightss[i] = scanner.nextDouble();
}
}
public static void main(String[] args) {
PercUi percUi = new PercUi();
percUi.u();
PerceptronMiddle trainer = new PerceptronMiddle(percUi.numInputsNuberOfWeights, percUi.learningRate, percUi.bias, percUi.epoches, percUi.trainTimes);
trainer.run();
}
}
ui界面(其实只有控制台)。
这些代码实现了严格封装的感知机,用于进行简易预测,也可搭建神经网络。