using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace ConsoleApp4
{
class Program
{
static void Main(string[] args)
{
List<float[]> inputs_x = new List<float[]>();
inputs_x.Add( new float[] { 0.9f, 0.6f});
inputs_x.Add(new float[] { 2f, 2.5f } );
inputs_x.Add(new float[] { 2.6f, 2.3f });
inputs_x.Add(new float[] { 2.7f, 1.9f });
List<float> inputs_y = new List<float>();
inputs_y.Add( 2.5f);
inputs_y.Add( 2.5f);
inputs_y.Add( 3.5f);
inputs_y.Add( 4.2f);
float[] weights = new float[3];
for (var i= 0;i < weights.Length;i++)
weights[i] = (float)new Random().NextDouble();
int epoch = 30000;
float epsilon =0.00001f;
float lr = 0.01f;
float lastCost=0;
for (var epoch_i = 0; epoch_i <= epoch; epoch_i++)
{
//随机获取input
var batch = GetRandomBatch(inputs_x, inputs_y, 2);
float[] weights_in_poch = new float[weights.Length];
foreach (var x_y in batch)
{
var x1 = x_y.Item1.First();
var x2 = x_y.Item1.Skip(1).Take(1).First();
var target_y = x_y.Item2;
float diffWithTargetY = target_y - fun(x1, x2, weights[1], weights[2], weights[0]);
weights_in_poch[0] += diffWithTargetY * dy_b(x1, x2);
weights_in_poch[1] += diffWithTargetY * dy_theta1(x1, x2);
weights_in_poch[2] += diffWithTargetY * dy_theta2(x1, x2);
}
for(var i=0;i<weights.Length;i++)
weights[i] += lr * weights_in_poch[i];
float totalErrorCost = 0f;
foreach (var x_y in batch)
{
var x1 = x_y.Item1.First();
var x2 = x_y.Item1.Skip(1).Take(1).First();
var target_y = x_y.Item2;
float diffWithTargetY = target_y - fun(x1, x2, weights[1], weights[2], weights[0]);
totalErrorCost += (float)System.Math.Pow(diffWithTargetY, 2)/2;
}
float cost = totalErrorCost / batch.Count;
if (System.Math.Abs(cost - lastCost) <= epsilon)
{
Console.WriteLine(string.Format("EPOCH {0}", epoch_i));
Console.WriteLine(string.Format("LAST MSE {0}", lastCost));
Console.WriteLine(string.Format("MSE {0}", cost));
break;
}
lastCost = cost;
if (epoch_i % 100 == 0|| epoch_i==epoch)
{
Console.WriteLine(string.Format("MSE {0}", cost));
}
}
print(weights[1], weights[2], weights[0]);
Console.ReadLine();
}
private static List<Tuple<float[], float>> GetRandomBatch(List<float[]> inputs_x, List<float> inputs_y, int maxCount)
{
List<Tuple<float[], float>> lst = new List<Tuple<float[], float>>();
System.Random rnd = new Random((int)DateTime.Now.Ticks);
int count = 0;
while (count<maxCount)
{
int rndIndex = rnd.Next(inputs_x.Count);
var item=Tuple.Create<float[], float>(inputs_x[rndIndex], inputs_y[rndIndex]);
lst.Add(item);
count++;
}
return lst;
}
private static void print(float theta1, float theta2, float b)
{
Console.WriteLine(string.Format("y={0}*x1+{1}*x2+{2}", theta1, theta2, b));
}
private static float fun(float x1, float x2, float theta1, float theta2, float b)
{
return theta1 * x1 + theta2 * x2 + b;
}
private static float dy_theta1(float x1, float x2)
{
return x1;
}
private static float dy_theta2(float x1, float x2)
{
return x2;
}
private static float dy_b(float x1, float x2)
{
return 1;
}
}
}