diff --git a/src/machine_learning/perceptron/Perceptron.java b/src/machine_learning/perceptron/Perceptron.java index b2e29ba..a69ed94 100644 --- a/src/machine_learning/perceptron/Perceptron.java +++ b/src/machine_learning/perceptron/Perceptron.java @@ -15,7 +15,7 @@ public class Perceptron { if (weight.scalar(x) <= 0) { - weight = weight.add(x); + weight = weight.add(x.divide(x.euclid())); } } @@ -59,7 +59,7 @@ public class Perceptron for (var x : vectors) { - actualClass = scalarForThreshholdPerceptron(weight, x) > weight.get(weight.dimension()-1) ? 1 : 0; + actualClass = weight.scalar(x) > 0 ? 1 : 0; if (actualClass != expectedClass) { @@ -70,7 +70,7 @@ public class Perceptron return true; } - private double scalarForThreshholdPerceptron(Vector a, Vector b) + private double scalarForThresholdPerceptron(Vector a, Vector b) { return IntStream.range(0, a.dimension()-1) diff --git a/src/machine_learning/perceptron/Vector.java b/src/machine_learning/perceptron/Vector.java index 62ecc26..9bfe96e 100644 --- a/src/machine_learning/perceptron/Vector.java +++ b/src/machine_learning/perceptron/Vector.java @@ -34,8 +34,8 @@ public class Vector { return new Vector(IntStream.range(0, this.dimension()) - .mapToObj(i -> this.values.get(i) + b.values.get(i)) - .collect(Collectors.toCollection(ArrayList::new)) + .mapToObj(i -> this.get(i) + b.get(i)) + .collect(Collectors.toList()) ); } @@ -44,8 +44,8 @@ public class Vector { return new Vector(IntStream.range(0, this.dimension()) - .mapToObj(i -> this.values.get(i) - b.values.get(i)) - .collect(Collectors.toCollection(ArrayList::new)) + .mapToObj(i -> this.get(i) - b.get(i)) + .collect(Collectors.toList()) ); } @@ -53,10 +53,30 @@ public class Vector { return IntStream.range(0, this.dimension()) - .mapToDouble(i -> this.values.get(i) * b.values.get(i)) + .mapToDouble(i -> this.get(i) * b.get(i)) .sum(); } + public double euclid() + { + return Math.sqrt(IntStream.range(0, + this.dimension()) + .mapToDouble(i -> this.get(i) * this.get(i)) + .sum()); + } + + public Vector divide(double div) + { + var divided = new ArrayList(); + + for (int i = 0; i < this.dimension(); i++) + { + divided.add(this.values.get(i) / div); + } + + return new Vector(divided); + } + public double get(int index) { return this.values.get(index); @@ -88,8 +108,6 @@ public class Vector @Override public String toString() { - return values.toString() - .replace("[", "(") - .replace("]", ")"); + return this.values.toString(); } } diff --git a/test/machine_learning/perceptron/PerceptronTest.java b/test/machine_learning/perceptron/PerceptronTest.java index 5acce2e..3c376e1 100644 --- a/test/machine_learning/perceptron/PerceptronTest.java +++ b/test/machine_learning/perceptron/PerceptronTest.java @@ -37,4 +37,20 @@ class PerceptronTest { new Perceptron().learn(this.positives, this.negatives); } + + @Test + void shouldClassifyCorrect2() + { + var positives = new ArrayList<>(List.of( + new Vector(List.of(0d, 1.8d)), + new Vector(List.of(2d, 0.6d))) + ); + + var negatives = new ArrayList<>(List.of( + new Vector(List.of(-1.2d, 1.4d)), + new Vector(List.of(0.4d, -1d))) + ); + + new Perceptron().learn(positives, negatives); + } } \ No newline at end of file diff --git a/test/machine_learning/perceptron/VectorTest.java b/test/machine_learning/perceptron/VectorTest.java index e2c7514..bf43109 100644 --- a/test/machine_learning/perceptron/VectorTest.java +++ b/test/machine_learning/perceptron/VectorTest.java @@ -45,7 +45,7 @@ class VectorTest } @Test - void shouldReturnCorrectVectorWhenScalar() + void shouldReturnCorrectVectorWhenScalarMultiplying() { var v1 = new Vector(List.of(1d, 2d)); var v2 = new Vector(List.of(3d, 4d)); @@ -55,4 +55,27 @@ class VectorTest assertEquals(expected, result); } + + @Test + void shouldReturnCorrectVectorWhenEuclid() + { + var v1 = new Vector(List.of(1d, 2d)); + + var result = v1.euclid(); + var expected = Math.sqrt(5); + + assertEquals(expected, result); + } + + @Test + void shouldReturnCorrectVectorWhenDividing() + { + var v1 = new Vector(List.of(1d, 2d)); + var div = 2d; + + var result = v1.divide(div); + var expected = new Vector(List.of(0.5d, 1d)); + + assertEquals(expected, result); + } } \ No newline at end of file