From f735fa931d6501bbe52e357568ff33617862b401 Mon Sep 17 00:00:00 2001 From: Eric Lin Date: Fri, 21 Oct 2022 18:02:39 -0400 Subject: [PATCH] create LinearRegression --- .../math/Regression/LinearRegression.java | 42 +++++++++++++++++++ .../stuylib/util/plot/Playground.java | 38 +++++++++++------ 2 files changed, 68 insertions(+), 12 deletions(-) create mode 100644 src/com/stuypulse/stuylib/math/Regression/LinearRegression.java diff --git a/src/com/stuypulse/stuylib/math/Regression/LinearRegression.java b/src/com/stuypulse/stuylib/math/Regression/LinearRegression.java new file mode 100644 index 00000000..b96a2c7d --- /dev/null +++ b/src/com/stuypulse/stuylib/math/Regression/LinearRegression.java @@ -0,0 +1,42 @@ +package com.stuypulse.stuylib.math.Regression; + +import com.stuypulse.stuylib.math.Vector2D; + +// sum of x +public class LinearRegression { + + public Vector2D[] points; + + public LinearRegression(Vector2D... points){ + this.points = points; + } + + private Vector2D equation(){ + double sumX = 0; + double sumY = 0; + double sumXX = 0; + double sumXY = 0; + + for(Vector2D point : points) { + sumX += point.x; + sumY += point.y; + sumXX += Math.pow(point.x, 2); + sumXY += point.x * point.y; + + } + + double intercept = (sumY * sumXX - sumX * sumXY) / (sumXX - Math.pow(sumX, 2)); + double slope = (points.length * sumXY - sumX * sumY) / (points.length * sumXX - Math.pow(sumX,2)); + + return new Vector2D(slope, intercept); + } + + public void addRefPoint(double x, double y){ + + } + + public double predictedValue(double x){ + return equation().x * x + equation().y; + } + +} diff --git a/src/com/stuypulse/stuylib/util/plot/Playground.java b/src/com/stuypulse/stuylib/util/plot/Playground.java index 7e045404..b9e1c939 100644 --- a/src/com/stuypulse/stuylib/util/plot/Playground.java +++ b/src/com/stuypulse/stuylib/util/plot/Playground.java @@ -5,6 +5,7 @@ package com.stuypulse.stuylib.util.plot; import com.stuypulse.stuylib.math.Vector2D; +import com.stuypulse.stuylib.math.Regression.LinearRegression; import com.stuypulse.stuylib.math.interpolation.*; import com.stuypulse.stuylib.streams.*; import com.stuypulse.stuylib.streams.angles.AFuser; @@ -34,11 +35,11 @@ public interface Constants { int WIDTH = 800; int HEIGHT = 600; - double MIN_X = -1.0; - double MAX_X = 1.0; + double MIN_X = 0.0; + double MAX_X = 10.0; - double MIN_Y = -1.0; - double MAX_Y = 1.0; + double MIN_Y = 0.0; + double MAX_Y = 10.0; Settings SETTINGS = new Settings() @@ -69,6 +70,11 @@ public static Series make(String id, BStream series) { } public static void main(String[] args) throws InterruptedException { + + + + + Plot plot = new Plot(Constants.SETTINGS); VStream m = VStream.create(() -> plot.getMouse().sub(new Vector2D(0.5, 0.5)).mul(2)); @@ -95,14 +101,22 @@ public static void main(String[] args) throws InterruptedException { VStream delay = VStream.create(() -> delay_angle.get().getVector().mul(1.1)); VStream afuser = VStream.create(() -> afuser_angle.get().getVector().mul(1.05)); - plot.addSeries(Constants.make("Angle", mouse)) - .addSeries(Constants.make("Jerk", jerk)) - .addSeries(Constants.make("Rate", rate)) - .addSeries(Constants.make("LPF", lpf)) - .addSeries(Constants.make("HPF", hpf)) - .addSeries(Constants.make("afuser", afuser)) - .addSeries(Constants.make("delayed", delay)) - .addSeries(Constants.make("mouse", m)); + Vector2D[] refPoints = {new Vector2D(1,2), new Vector2D(2,5), new Vector2D(6, 10)}; + LinearRegression LinearRegression = new LinearRegression(refPoints); + plot.addSeries(new FuncSeries( + new Config("linear regression", 1000), + new Domain(0, 10), + x -> LinearRegression.predictedValue(x) + )); + + // plot.addSeries(Constants.make("Angle", mouse)) + // .addSeries(Constants.make("Jerk", jerk)) + // .addSeries(Constants.make("Rate", rate)) + // .addSeries(Constants.make("LPF", lpf)) + // .addSeries(Constants.make("HPF", hpf)) + // .addSeries(Constants.make("afuser", afuser)) + // .addSeries(Constants.make("delayed", delay)) + // .addSeries(Constants.make("mouse", m)); // // .addSeries(Constants.make("y=x", x -> x)) // .addSeries(