SGD
using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; namespace ConsoleApp4 { class Program { static void Main(string[] args) { List<float[]> inputs_x = new List<float[]>(); inputs_x.Add( new float[] { 0.9f, 0.6f}); inputs_x.Add(new float[] { 2f, 2.5f } ); inputs_x.Add(new float[] { 2.6f, 2.3f }); inputs_x.Add(new float[] { 2.7f, 1.9f }); List<float> inputs_y = new List<float>(); inputs_y.Add( 2.5f); inputs_y.Add( 2.5f); inputs_y.Add( 3.5f); inputs_y.Add( 4.2f); float[] weights = new float[3]; for (var i= 0;i < weights.Length;i++) weights[i] = (float)new Random().NextDouble(); int epoch = 30000; float epsilon =0.00001f; float lr = 0.01f; float lastCost=0; for (var epoch_i = 0; epoch_i <= epoch; epoch_i++) { //随机获取input var batch = GetRandomBatch(inputs_x, inputs_y, 2); float[] weights_in_poch = new float[weights.Length]; foreach (var x_y in batch) { var x1 = x_y.Item1.First(); var x2 = x_y.Item1.Skip(1).Take(1).First(); var target_y = x_y.Item2; float diffWithTargetY = target_y - fun(x1, x2, weights[1], weights[2], weights[0]); weights_in_poch[0] += diffWithTargetY * dy_b(x1, x2); weights_in_poch[1] += diffWithTargetY * dy_theta1(x1, x2); weights_in_poch[2] += diffWithTargetY * dy_theta2(x1, x2); } for(var i=0;i<weights.Length;i++) weights[i] += lr * weights_in_poch[i]; float totalErrorCost = 0f; foreach (var x_y in batch) { var x1 = x_y.Item1.First(); var x2 = x_y.Item1.Skip(1).Take(1).First(); var target_y = x_y.Item2; float diffWithTargetY = target_y - fun(x1, x2, weights[1], weights[2], weights[0]); totalErrorCost += (float)System.Math.Pow(diffWithTargetY, 2)/2; } float cost = totalErrorCost / batch.Count; if (System.Math.Abs(cost - lastCost) <= epsilon) { Console.WriteLine(string.Format("EPOCH {0}", epoch_i)); Console.WriteLine(string.Format("LAST MSE {0}", lastCost)); Console.WriteLine(string.Format("MSE {0}", cost)); break; } lastCost = cost; if (epoch_i % 100 == 0|| epoch_i==epoch) { Console.WriteLine(string.Format("MSE {0}", cost)); } } print(weights[1], weights[2], weights[0]); Console.ReadLine(); } private static List<Tuple<float[], float>> GetRandomBatch(List<float[]> inputs_x, List<float> inputs_y, int maxCount) { List<Tuple<float[], float>> lst = new List<Tuple<float[], float>>(); System.Random rnd = new Random((int)DateTime.Now.Ticks); int count = 0; while (count<maxCount) { int rndIndex = rnd.Next(inputs_x.Count); var item=Tuple.Create<float[], float>(inputs_x[rndIndex], inputs_y[rndIndex]); lst.Add(item); count++; } return lst; } private static void print(float theta1, float theta2, float b) { Console.WriteLine(string.Format("y={0}*x1+{1}*x2+{2}", theta1, theta2, b)); } private static float fun(float x1, float x2, float theta1, float theta2, float b) { return theta1 * x1 + theta2 * x2 + b; } private static float dy_theta1(float x1, float x2) { return x1; } private static float dy_theta2(float x1, float x2) { return x2; } private static float dy_b(float x1, float x2) { return 1; } } }
自省推动进步,视野决定未来。
心怀远大理想。
为了家庭幸福而努力。
商业合作请看此处:https://www.magicube.ai
心怀远大理想。
为了家庭幸福而努力。
商业合作请看此处:https://www.magicube.ai