diff --git a/src/machine_learning/perceptron/Perceptron.java b/src/machine_learning/perceptron/Perceptron.java index 089d2cf..b2e29ba 100644 --- a/src/machine_learning/perceptron/Perceptron.java +++ b/src/machine_learning/perceptron/Perceptron.java @@ -1,6 +1,7 @@ package machine_learning.perceptron; import java.util.List; +import java.util.stream.IntStream; public class Perceptron { @@ -36,7 +37,7 @@ public class Perceptron } private Vector getInitializationVector(List positives, List negatives) - {/* + { var a = new Vector(positives.get(0).dimension()); for (var x : positives) { @@ -49,9 +50,7 @@ public class Perceptron b = b.add(x); } - return a.subtract(b);*/ - - return new Vector(positives.get(0).dimension()); + return a.subtract(b); } private boolean elementsAreCorrectClassified(List vectors, Vector weight, int expectedClass) @@ -60,7 +59,7 @@ public class Perceptron for (var x : vectors) { - actualClass = weight.scalar(x) > 0 ? 1 : 0; + actualClass = scalarForThreshholdPerceptron(weight, x) > weight.get(weight.dimension()-1) ? 1 : 0; if (actualClass != expectedClass) { @@ -70,4 +69,12 @@ public class Perceptron return true; } + + private double scalarForThreshholdPerceptron(Vector a, Vector b) + { + return IntStream.range(0, + a.dimension()-1) + .mapToDouble(i -> a.get(i) * b.get(i)) + .sum(); + } } diff --git a/src/machine_learning/perceptron/Vector.java b/src/machine_learning/perceptron/Vector.java index 27b416a..62ecc26 100644 --- a/src/machine_learning/perceptron/Vector.java +++ b/src/machine_learning/perceptron/Vector.java @@ -33,7 +33,7 @@ public class Vector public Vector add(Vector b) { return new Vector(IntStream.range(0, - this.values.size()) + this.dimension()) .mapToObj(i -> this.values.get(i) + b.values.get(i)) .collect(Collectors.toCollection(ArrayList::new)) ); @@ -43,7 +43,7 @@ public class Vector public Vector subtract(Vector b) { return new Vector(IntStream.range(0, - this.values.size()) + this.dimension()) .mapToObj(i -> this.values.get(i) - b.values.get(i)) .collect(Collectors.toCollection(ArrayList::new)) ); @@ -52,11 +52,16 @@ public class Vector public double scalar(Vector b) { return IntStream.range(0, - this.values.size()) + this.dimension()) .mapToDouble(i -> this.values.get(i) * b.values.get(i)) .sum(); } + public double get(int index) + { + return this.values.get(index); + } + @Override public boolean equals(Object o) { diff --git a/test/machine_learning/perceptron/PerceptronTest.java b/test/machine_learning/perceptron/PerceptronTest.java index 2c98b93..5acce2e 100644 --- a/test/machine_learning/perceptron/PerceptronTest.java +++ b/test/machine_learning/perceptron/PerceptronTest.java @@ -16,18 +16,19 @@ class PerceptronTest @BeforeAll void initLearnData() { + double biasUnit = 1d; this.positives = new ArrayList<>(List.of( - new Vector(List.of(8d, 4d)), - new Vector(List.of(8d, 6d)), - new Vector(List.of(9d, 2d)), - new Vector(List.of(9d, 5d))) + new Vector(List.of(8d, 4d, biasUnit)), + new Vector(List.of(8d, 6d, biasUnit)), + new Vector(List.of(9d, 2d, biasUnit)), + new Vector(List.of(9d, 5d, biasUnit))) ); this.negatives = new ArrayList<>(List.of( - new Vector(List.of(6d, 1d)), - new Vector(List.of(7d, 3d)), - new Vector(List.of(8d, 2d)), - new Vector(List.of(9d, 0d))) + new Vector(List.of(6d, 1d, biasUnit)), + new Vector(List.of(7d, 3d, biasUnit)), + new Vector(List.of(8d, 2d, biasUnit)), + new Vector(List.of(9d, 0d, biasUnit))) ); }