Perceptron (unfinished)
This commit is contained in:
parent
84a1b61477
commit
5f9776c4ef
@ -1,6 +1,7 @@
|
|||||||
package machine_learning.perceptron;
|
package machine_learning.perceptron;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
public class Perceptron
|
public class Perceptron
|
||||||
{
|
{
|
||||||
@ -36,7 +37,7 @@ public class Perceptron
|
|||||||
}
|
}
|
||||||
|
|
||||||
private Vector getInitializationVector(List<Vector> positives, List<Vector> negatives)
|
private Vector getInitializationVector(List<Vector> positives, List<Vector> negatives)
|
||||||
{/*
|
{
|
||||||
var a = new Vector(positives.get(0).dimension());
|
var a = new Vector(positives.get(0).dimension());
|
||||||
for (var x : positives)
|
for (var x : positives)
|
||||||
{
|
{
|
||||||
@ -49,9 +50,7 @@ public class Perceptron
|
|||||||
b = b.add(x);
|
b = b.add(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
return a.subtract(b);*/
|
return a.subtract(b);
|
||||||
|
|
||||||
return new Vector(positives.get(0).dimension());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean elementsAreCorrectClassified(List<Vector> vectors, Vector weight, int expectedClass)
|
private boolean elementsAreCorrectClassified(List<Vector> vectors, Vector weight, int expectedClass)
|
||||||
@ -60,7 +59,7 @@ public class Perceptron
|
|||||||
|
|
||||||
for (var x : vectors)
|
for (var x : vectors)
|
||||||
{
|
{
|
||||||
actualClass = weight.scalar(x) > 0 ? 1 : 0;
|
actualClass = scalarForThreshholdPerceptron(weight, x) > weight.get(weight.dimension()-1) ? 1 : 0;
|
||||||
|
|
||||||
if (actualClass != expectedClass)
|
if (actualClass != expectedClass)
|
||||||
{
|
{
|
||||||
@ -70,4 +69,12 @@ public class Perceptron
|
|||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private double scalarForThreshholdPerceptron(Vector a, Vector b)
|
||||||
|
{
|
||||||
|
return IntStream.range(0,
|
||||||
|
a.dimension()-1)
|
||||||
|
.mapToDouble(i -> a.get(i) * b.get(i))
|
||||||
|
.sum();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ public class Vector
|
|||||||
public Vector add(Vector b)
|
public Vector add(Vector b)
|
||||||
{
|
{
|
||||||
return new Vector(IntStream.range(0,
|
return new Vector(IntStream.range(0,
|
||||||
this.values.size())
|
this.dimension())
|
||||||
.mapToObj(i -> this.values.get(i) + b.values.get(i))
|
.mapToObj(i -> this.values.get(i) + b.values.get(i))
|
||||||
.collect(Collectors.toCollection(ArrayList::new))
|
.collect(Collectors.toCollection(ArrayList::new))
|
||||||
);
|
);
|
||||||
@ -43,7 +43,7 @@ public class Vector
|
|||||||
public Vector subtract(Vector b)
|
public Vector subtract(Vector b)
|
||||||
{
|
{
|
||||||
return new Vector(IntStream.range(0,
|
return new Vector(IntStream.range(0,
|
||||||
this.values.size())
|
this.dimension())
|
||||||
.mapToObj(i -> this.values.get(i) - b.values.get(i))
|
.mapToObj(i -> this.values.get(i) - b.values.get(i))
|
||||||
.collect(Collectors.toCollection(ArrayList::new))
|
.collect(Collectors.toCollection(ArrayList::new))
|
||||||
);
|
);
|
||||||
@ -52,11 +52,16 @@ public class Vector
|
|||||||
public double scalar(Vector b)
|
public double scalar(Vector b)
|
||||||
{
|
{
|
||||||
return IntStream.range(0,
|
return IntStream.range(0,
|
||||||
this.values.size())
|
this.dimension())
|
||||||
.mapToDouble(i -> this.values.get(i) * b.values.get(i))
|
.mapToDouble(i -> this.values.get(i) * b.values.get(i))
|
||||||
.sum();
|
.sum();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public double get(int index)
|
||||||
|
{
|
||||||
|
return this.values.get(index);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean equals(Object o)
|
public boolean equals(Object o)
|
||||||
{
|
{
|
||||||
|
@ -16,18 +16,19 @@ class PerceptronTest
|
|||||||
@BeforeAll
|
@BeforeAll
|
||||||
void initLearnData()
|
void initLearnData()
|
||||||
{
|
{
|
||||||
|
double biasUnit = 1d;
|
||||||
this.positives = new ArrayList<>(List.of(
|
this.positives = new ArrayList<>(List.of(
|
||||||
new Vector(List.of(8d, 4d)),
|
new Vector(List.of(8d, 4d, biasUnit)),
|
||||||
new Vector(List.of(8d, 6d)),
|
new Vector(List.of(8d, 6d, biasUnit)),
|
||||||
new Vector(List.of(9d, 2d)),
|
new Vector(List.of(9d, 2d, biasUnit)),
|
||||||
new Vector(List.of(9d, 5d)))
|
new Vector(List.of(9d, 5d, biasUnit)))
|
||||||
);
|
);
|
||||||
|
|
||||||
this.negatives = new ArrayList<>(List.of(
|
this.negatives = new ArrayList<>(List.of(
|
||||||
new Vector(List.of(6d, 1d)),
|
new Vector(List.of(6d, 1d, biasUnit)),
|
||||||
new Vector(List.of(7d, 3d)),
|
new Vector(List.of(7d, 3d, biasUnit)),
|
||||||
new Vector(List.of(8d, 2d)),
|
new Vector(List.of(8d, 2d, biasUnit)),
|
||||||
new Vector(List.of(9d, 0d)))
|
new Vector(List.of(9d, 0d, biasUnit)))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user