Perceptron (unfinished)
This commit is contained in:
parent
5f9776c4ef
commit
c44e53707d
@ -15,7 +15,7 @@ public class Perceptron
|
|||||||
{
|
{
|
||||||
if (weight.scalar(x) <= 0)
|
if (weight.scalar(x) <= 0)
|
||||||
{
|
{
|
||||||
weight = weight.add(x);
|
weight = weight.add(x.divide(x.euclid()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,7 +59,7 @@ public class Perceptron
|
|||||||
|
|
||||||
for (var x : vectors)
|
for (var x : vectors)
|
||||||
{
|
{
|
||||||
actualClass = scalarForThreshholdPerceptron(weight, x) > weight.get(weight.dimension()-1) ? 1 : 0;
|
actualClass = weight.scalar(x) > 0 ? 1 : 0;
|
||||||
|
|
||||||
if (actualClass != expectedClass)
|
if (actualClass != expectedClass)
|
||||||
{
|
{
|
||||||
@ -70,7 +70,7 @@ public class Perceptron
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
private double scalarForThreshholdPerceptron(Vector a, Vector b)
|
private double scalarForThresholdPerceptron(Vector a, Vector b)
|
||||||
{
|
{
|
||||||
return IntStream.range(0,
|
return IntStream.range(0,
|
||||||
a.dimension()-1)
|
a.dimension()-1)
|
||||||
|
@ -34,8 +34,8 @@ public class Vector
|
|||||||
{
|
{
|
||||||
return new Vector(IntStream.range(0,
|
return new Vector(IntStream.range(0,
|
||||||
this.dimension())
|
this.dimension())
|
||||||
.mapToObj(i -> this.values.get(i) + b.values.get(i))
|
.mapToObj(i -> this.get(i) + b.get(i))
|
||||||
.collect(Collectors.toCollection(ArrayList::new))
|
.collect(Collectors.toList())
|
||||||
);
|
);
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -44,8 +44,8 @@ public class Vector
|
|||||||
{
|
{
|
||||||
return new Vector(IntStream.range(0,
|
return new Vector(IntStream.range(0,
|
||||||
this.dimension())
|
this.dimension())
|
||||||
.mapToObj(i -> this.values.get(i) - b.values.get(i))
|
.mapToObj(i -> this.get(i) - b.get(i))
|
||||||
.collect(Collectors.toCollection(ArrayList::new))
|
.collect(Collectors.toList())
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,10 +53,30 @@ public class Vector
|
|||||||
{
|
{
|
||||||
return IntStream.range(0,
|
return IntStream.range(0,
|
||||||
this.dimension())
|
this.dimension())
|
||||||
.mapToDouble(i -> this.values.get(i) * b.values.get(i))
|
.mapToDouble(i -> this.get(i) * b.get(i))
|
||||||
.sum();
|
.sum();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public double euclid()
|
||||||
|
{
|
||||||
|
return Math.sqrt(IntStream.range(0,
|
||||||
|
this.dimension())
|
||||||
|
.mapToDouble(i -> this.get(i) * this.get(i))
|
||||||
|
.sum());
|
||||||
|
}
|
||||||
|
|
||||||
|
public Vector divide(double div)
|
||||||
|
{
|
||||||
|
var divided = new ArrayList<Double>();
|
||||||
|
|
||||||
|
for (int i = 0; i < this.dimension(); i++)
|
||||||
|
{
|
||||||
|
divided.add(this.values.get(i) / div);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Vector(divided);
|
||||||
|
}
|
||||||
|
|
||||||
public double get(int index)
|
public double get(int index)
|
||||||
{
|
{
|
||||||
return this.values.get(index);
|
return this.values.get(index);
|
||||||
@ -88,8 +108,6 @@ public class Vector
|
|||||||
@Override
|
@Override
|
||||||
public String toString()
|
public String toString()
|
||||||
{
|
{
|
||||||
return values.toString()
|
return this.values.toString();
|
||||||
.replace("[", "(")
|
|
||||||
.replace("]", ")");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -37,4 +37,20 @@ class PerceptronTest
|
|||||||
{
|
{
|
||||||
new Perceptron().learn(this.positives, this.negatives);
|
new Perceptron().learn(this.positives, this.negatives);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldClassifyCorrect2()
|
||||||
|
{
|
||||||
|
var positives = new ArrayList<>(List.of(
|
||||||
|
new Vector(List.of(0d, 1.8d)),
|
||||||
|
new Vector(List.of(2d, 0.6d)))
|
||||||
|
);
|
||||||
|
|
||||||
|
var negatives = new ArrayList<>(List.of(
|
||||||
|
new Vector(List.of(-1.2d, 1.4d)),
|
||||||
|
new Vector(List.of(0.4d, -1d)))
|
||||||
|
);
|
||||||
|
|
||||||
|
new Perceptron().learn(positives, negatives);
|
||||||
|
}
|
||||||
}
|
}
|
@ -45,7 +45,7 @@ class VectorTest
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void shouldReturnCorrectVectorWhenScalar()
|
void shouldReturnCorrectVectorWhenScalarMultiplying()
|
||||||
{
|
{
|
||||||
var v1 = new Vector(List.of(1d, 2d));
|
var v1 = new Vector(List.of(1d, 2d));
|
||||||
var v2 = new Vector(List.of(3d, 4d));
|
var v2 = new Vector(List.of(3d, 4d));
|
||||||
@ -55,4 +55,27 @@ class VectorTest
|
|||||||
|
|
||||||
assertEquals(expected, result);
|
assertEquals(expected, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldReturnCorrectVectorWhenEuclid()
|
||||||
|
{
|
||||||
|
var v1 = new Vector(List.of(1d, 2d));
|
||||||
|
|
||||||
|
var result = v1.euclid();
|
||||||
|
var expected = Math.sqrt(5);
|
||||||
|
|
||||||
|
assertEquals(expected, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldReturnCorrectVectorWhenDividing()
|
||||||
|
{
|
||||||
|
var v1 = new Vector(List.of(1d, 2d));
|
||||||
|
var div = 2d;
|
||||||
|
|
||||||
|
var result = v1.divide(div);
|
||||||
|
var expected = new Vector(List.of(0.5d, 1d));
|
||||||
|
|
||||||
|
assertEquals(expected, result);
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user