finished
This commit is contained in:
parent
c44e53707d
commit
248593af6c
@ -15,7 +15,7 @@ public class Perceptron
|
||||
{
|
||||
if (weight.scalar(x) <= 0)
|
||||
{
|
||||
weight = weight.add(x.divide(x.euclid()));
|
||||
weight = weight.add(x);
|
||||
}
|
||||
}
|
||||
|
||||
@ -29,11 +29,11 @@ public class Perceptron
|
||||
|
||||
System.out.println(weight);
|
||||
}
|
||||
while (elementsAreCorrectClassified(positives, weight, 1) && elementsAreCorrectClassified(negatives, weight, 0));
|
||||
while (!elementsAreCorrectClassified(positives, negatives, weight));
|
||||
|
||||
System.out.println("-----------------------------------");
|
||||
System.out.println("-- All are classified correctly. --");
|
||||
System.out.println("-----------------------------------");
|
||||
System.out.println("----------------------------------------------");
|
||||
System.out.println("-- All datapoints are classified correctly. --");
|
||||
System.out.println("----------------------------------------------");
|
||||
}
|
||||
|
||||
private Vector getInitializationVector(List<Vector> positives, List<Vector> negatives)
|
||||
@ -53,28 +53,18 @@ public class Perceptron
|
||||
return a.subtract(b);
|
||||
}
|
||||
|
||||
private boolean elementsAreCorrectClassified(List<Vector> vectors, Vector weight, int expectedClass)
|
||||
private boolean elementsAreCorrectClassified(List<Vector> positives, List<Vector> negatives, Vector weight)
|
||||
{
|
||||
int actualClass;
|
||||
|
||||
for (var x : vectors)
|
||||
for (var x : positives)
|
||||
{
|
||||
actualClass = weight.scalar(x) > 0 ? 1 : 0;
|
||||
|
||||
if (actualClass != expectedClass)
|
||||
{
|
||||
return false;
|
||||
if (weight.scalar(x) <= 0) return false;
|
||||
}
|
||||
|
||||
for (var x : negatives)
|
||||
{
|
||||
if (weight.scalar(x) > 0) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private double scalarForThresholdPerceptron(Vector a, Vector b)
|
||||
{
|
||||
return IntStream.range(0,
|
||||
a.dimension()-1)
|
||||
.mapToDouble(i -> a.get(i) * b.get(i))
|
||||
.sum();
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,6 @@
|
||||
package machine_learning.perceptron;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
|
@ -37,20 +37,4 @@ class PerceptronTest
|
||||
{
|
||||
new Perceptron().learn(this.positives, this.negatives);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldClassifyCorrect2()
|
||||
{
|
||||
var positives = new ArrayList<>(List.of(
|
||||
new Vector(List.of(0d, 1.8d)),
|
||||
new Vector(List.of(2d, 0.6d)))
|
||||
);
|
||||
|
||||
var negatives = new ArrayList<>(List.of(
|
||||
new Vector(List.of(-1.2d, 1.4d)),
|
||||
new Vector(List.of(0.4d, -1d)))
|
||||
);
|
||||
|
||||
new Perceptron().learn(positives, negatives);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user