diff --git a/src/machine_learning/perceptron/Perceptron.java b/src/machine_learning/perceptron/Perceptron.java index a69ed94..ec27c11 100644 --- a/src/machine_learning/perceptron/Perceptron.java +++ b/src/machine_learning/perceptron/Perceptron.java @@ -15,7 +15,7 @@ public class Perceptron { if (weight.scalar(x) <= 0) { - weight = weight.add(x.divide(x.euclid())); + weight = weight.add(x); } } @@ -29,11 +29,11 @@ public class Perceptron System.out.println(weight); } - while (elementsAreCorrectClassified(positives, weight, 1) && elementsAreCorrectClassified(negatives, weight, 0)); + while (!elementsAreCorrectClassified(positives, negatives, weight)); - System.out.println("-----------------------------------"); - System.out.println("-- All are classified correctly. --"); - System.out.println("-----------------------------------"); + System.out.println("----------------------------------------------"); + System.out.println("-- All datapoints are classified correctly. --"); + System.out.println("----------------------------------------------"); } private Vector getInitializationVector(List positives, List negatives) @@ -53,28 +53,18 @@ public class Perceptron return a.subtract(b); } - private boolean elementsAreCorrectClassified(List vectors, Vector weight, int expectedClass) + private boolean elementsAreCorrectClassified(List positives, List negatives, Vector weight) { - int actualClass; - - for (var x : vectors) + for (var x : positives) { - actualClass = weight.scalar(x) > 0 ? 1 : 0; + if (weight.scalar(x) <= 0) return false; + } - if (actualClass != expectedClass) - { - return false; - } + for (var x : negatives) + { + if (weight.scalar(x) > 0) return false; } return true; } - - private double scalarForThresholdPerceptron(Vector a, Vector b) - { - return IntStream.range(0, - a.dimension()-1) - .mapToDouble(i -> a.get(i) * b.get(i)) - .sum(); - } } diff --git a/src/machine_learning/perceptron/Vector.java b/src/machine_learning/perceptron/Vector.java index 9bfe96e..dbbcabb 100644 --- a/src/machine_learning/perceptron/Vector.java +++ b/src/machine_learning/perceptron/Vector.java @@ -1,8 +1,6 @@ package machine_learning.perceptron; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; +import java.util.*; import java.util.stream.Collectors; import java.util.stream.IntStream; diff --git a/test/machine_learning/perceptron/PerceptronTest.java b/test/machine_learning/perceptron/PerceptronTest.java index 3c376e1..5acce2e 100644 --- a/test/machine_learning/perceptron/PerceptronTest.java +++ b/test/machine_learning/perceptron/PerceptronTest.java @@ -37,20 +37,4 @@ class PerceptronTest { new Perceptron().learn(this.positives, this.negatives); } - - @Test - void shouldClassifyCorrect2() - { - var positives = new ArrayList<>(List.of( - new Vector(List.of(0d, 1.8d)), - new Vector(List.of(2d, 0.6d))) - ); - - var negatives = new ArrayList<>(List.of( - new Vector(List.of(-1.2d, 1.4d)), - new Vector(List.of(0.4d, -1d))) - ); - - new Perceptron().learn(positives, negatives); - } } \ No newline at end of file