This commit is contained in:
Niklas Birk 2019-06-24 14:23:27 +02:00
parent c44e53707d
commit 248593af6c
3 changed files with 13 additions and 41 deletions

View File

@ -15,7 +15,7 @@ public class Perceptron
{
if (weight.scalar(x) <= 0)
{
weight = weight.add(x.divide(x.euclid()));
weight = weight.add(x);
}
}
@ -29,11 +29,11 @@ public class Perceptron
System.out.println(weight);
}
while (elementsAreCorrectClassified(positives, weight, 1) && elementsAreCorrectClassified(negatives, weight, 0));
while (!elementsAreCorrectClassified(positives, negatives, weight));
System.out.println("-----------------------------------");
System.out.println("-- All are classified correctly. --");
System.out.println("-----------------------------------");
System.out.println("----------------------------------------------");
System.out.println("-- All datapoints are classified correctly. --");
System.out.println("----------------------------------------------");
}
private Vector getInitializationVector(List<Vector> positives, List<Vector> negatives)
@ -53,28 +53,18 @@ public class Perceptron
return a.subtract(b);
}
private boolean elementsAreCorrectClassified(List<Vector> vectors, Vector weight, int expectedClass)
private boolean elementsAreCorrectClassified(List<Vector> positives, List<Vector> negatives, Vector weight)
{
int actualClass;
for (var x : vectors)
for (var x : positives)
{
actualClass = weight.scalar(x) > 0 ? 1 : 0;
if (weight.scalar(x) <= 0) return false;
}
if (actualClass != expectedClass)
{
return false;
}
for (var x : negatives)
{
if (weight.scalar(x) > 0) return false;
}
return true;
}
private double scalarForThresholdPerceptron(Vector a, Vector b)
{
return IntStream.range(0,
a.dimension()-1)
.mapToDouble(i -> a.get(i) * b.get(i))
.sum();
}
}

View File

@ -1,8 +1,6 @@
package machine_learning.perceptron;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

View File

@ -37,20 +37,4 @@ class PerceptronTest
{
new Perceptron().learn(this.positives, this.negatives);
}
@Test
void shouldClassifyCorrect2()
{
var positives = new ArrayList<>(List.of(
new Vector(List.of(0d, 1.8d)),
new Vector(List.of(2d, 0.6d)))
);
var negatives = new ArrayList<>(List.of(
new Vector(List.of(-1.2d, 1.4d)),
new Vector(List.of(0.4d, -1d)))
);
new Perceptron().learn(positives, negatives);
}
}