This commit is contained in:
Niklas Birk 2019-06-24 14:23:27 +02:00
parent c44e53707d
commit 248593af6c
3 changed files with 13 additions and 41 deletions

View File

@ -15,7 +15,7 @@ public class Perceptron
{ {
if (weight.scalar(x) <= 0) if (weight.scalar(x) <= 0)
{ {
weight = weight.add(x.divide(x.euclid())); weight = weight.add(x);
} }
} }
@ -29,11 +29,11 @@ public class Perceptron
System.out.println(weight); System.out.println(weight);
} }
while (elementsAreCorrectClassified(positives, weight, 1) && elementsAreCorrectClassified(negatives, weight, 0)); while (!elementsAreCorrectClassified(positives, negatives, weight));
System.out.println("-----------------------------------"); System.out.println("----------------------------------------------");
System.out.println("-- All are classified correctly. --"); System.out.println("-- All datapoints are classified correctly. --");
System.out.println("-----------------------------------"); System.out.println("----------------------------------------------");
} }
private Vector getInitializationVector(List<Vector> positives, List<Vector> negatives) private Vector getInitializationVector(List<Vector> positives, List<Vector> negatives)
@ -53,28 +53,18 @@ public class Perceptron
return a.subtract(b); return a.subtract(b);
} }
private boolean elementsAreCorrectClassified(List<Vector> vectors, Vector weight, int expectedClass) private boolean elementsAreCorrectClassified(List<Vector> positives, List<Vector> negatives, Vector weight)
{ {
int actualClass; for (var x : positives)
for (var x : vectors)
{ {
actualClass = weight.scalar(x) > 0 ? 1 : 0; if (weight.scalar(x) <= 0) return false;
}
if (actualClass != expectedClass) for (var x : negatives)
{ {
return false; if (weight.scalar(x) > 0) return false;
}
} }
return true; return true;
} }
private double scalarForThresholdPerceptron(Vector a, Vector b)
{
return IntStream.range(0,
a.dimension()-1)
.mapToDouble(i -> a.get(i) * b.get(i))
.sum();
}
} }

View File

@ -1,8 +1,6 @@
package machine_learning.perceptron; package machine_learning.perceptron;
import java.util.ArrayList; import java.util.*;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;

View File

@ -37,20 +37,4 @@ class PerceptronTest
{ {
new Perceptron().learn(this.positives, this.negatives); new Perceptron().learn(this.positives, this.negatives);
} }
@Test
void shouldClassifyCorrect2()
{
var positives = new ArrayList<>(List.of(
new Vector(List.of(0d, 1.8d)),
new Vector(List.of(2d, 0.6d)))
);
var negatives = new ArrayList<>(List.of(
new Vector(List.of(-1.2d, 1.4d)),
new Vector(List.of(0.4d, -1d)))
);
new Perceptron().learn(positives, negatives);
}
} }