Minor changes in names and extracted method, etc.

This commit is contained in:
Niklas Birk 2021-11-30 22:27:45 +01:00
parent d0adbf0acb
commit 24c9e0b247
4 changed files with 26 additions and 22 deletions

View File

@ -34,7 +34,7 @@ public class Vector
public Vector add(Vector b) public Vector add(Vector b)
{ {
if (this.dimension() != b.dimension()) throw new IllegalArgumentException("Dimensions must be equals."); checkEqualDimensions(b);
return new Vector(IntStream.range(0, return new Vector(IntStream.range(0,
this.dimension()) this.dimension())
.mapToObj(i -> this.get(i) + b.get(i)) .mapToObj(i -> this.get(i) + b.get(i))
@ -45,7 +45,7 @@ public class Vector
public Vector subtract(Vector b) public Vector subtract(Vector b)
{ {
if (this.dimension() != b.dimension()) throw new IllegalArgumentException("Dimensions must be equals."); checkEqualDimensions(b);
return new Vector(IntStream.range(0, return new Vector(IntStream.range(0,
this.dimension()) this.dimension())
.mapToObj(i -> this.get(i) - b.get(i)) .mapToObj(i -> this.get(i) - b.get(i))
@ -55,7 +55,7 @@ public class Vector
public double scalar(Vector b) public double scalar(Vector b)
{ {
if (this.dimension() != b.dimension()) throw new IllegalArgumentException("Dimensions must be equals."); checkEqualDimensions(b);
return IntStream.range(0, return IntStream.range(0,
this.dimension()) this.dimension())
.mapToDouble(i -> this.get(i) * b.get(i)) .mapToDouble(i -> this.get(i) * b.get(i))
@ -80,14 +80,9 @@ public class Vector
public Vector divide(double div) public Vector divide(double div)
{ {
var divided = new ArrayList<Double>(); return new Vector(IntStream.range(0, this.dimension())
.mapToObj(i -> this.values.get(i) / div)
for (int i = 0; i < this.dimension(); i++) .collect(Collectors.toCollection(ArrayList::new)));
{
divided.add(this.values.get(i) / div);
}
return new Vector(divided);
} }
public double get(int index) public double get(int index)
@ -133,4 +128,10 @@ public class Vector
{ {
return this.values.toString(); return this.values.toString();
} }
private void checkEqualDimensions(Vector b)
{
if (this.dimension() != b.dimension())
throw new IllegalArgumentException("Dimensions must be equal");
}
} }

View File

@ -2,7 +2,7 @@ package machine_learning.nearest_neighbour;
import machine_learning.Vector; import machine_learning.Vector;
public interface Distance public interface DistanceFunction
{ {
double distance(Vector a, Vector b); double distance(Vector a, Vector b);
} }

View File

@ -15,18 +15,18 @@ public class KNearestNeighbour implements MachineLearning
private List<Vector> positives; private List<Vector> positives;
private List<Vector> negatives; private List<Vector> negatives;
private final Distance distance; private final DistanceFunction distanceFunction;
private final int k; private final int k;
public KNearestNeighbour(Distance distance) public KNearestNeighbour(DistanceFunction distanceFunction)
{ {
this(distance, 1); this(distanceFunction, 1);
} }
public KNearestNeighbour(Distance distance, int k) public KNearestNeighbour(DistanceFunction distanceFunction, int k)
{ {
this.distance = distance; this.distanceFunction = distanceFunction;
this.k = k; this.k = k;
} }
@ -67,7 +67,7 @@ public class KNearestNeighbour implements MachineLearning
private List<Vector> nearestNeighbours(List<Vector> vectors, Vector vector) private List<Vector> nearestNeighbours(List<Vector> vectors, Vector vector)
{ {
return vectors.parallelStream() return vectors.parallelStream()
.map(v -> Map.entry(this.distance.distance(v, vector), v)) .map(v -> Map.entry(this.distanceFunction.distance(v, vector), v))
.sorted((e1, e2) -> e1.getKey() >= e2.getKey() ? (e1.getKey().equals(e2.getKey()) ? 0 : 1) : -1) .sorted((e1, e2) -> e1.getKey() >= e2.getKey() ? (e1.getKey().equals(e2.getKey()) ? 0 : 1) : -1)
.map(Map.Entry::getValue) .map(Map.Entry::getValue)
.collect(Collectors.toList()) .collect(Collectors.toList())

View File

@ -22,8 +22,9 @@ import static org.junit.jupiter.api.Assertions.*;
@TestInstance(TestInstance.Lifecycle.PER_CLASS) @TestInstance(TestInstance.Lifecycle.PER_CLASS)
class KNearestNeighbourTest class KNearestNeighbourTest
{ {
List<Vector> positives; private List<Vector> positives;
List<Vector> negatives; private List<Vector> negatives;
private DistanceFunction distanceFunction;
@BeforeAll @BeforeAll
void initLearnData() void initLearnData()
@ -41,12 +42,14 @@ class KNearestNeighbourTest
new Vector(8d, 2d), new Vector(8d, 2d),
new Vector(9d, 0d)) new Vector(9d, 0d))
); );
this.distanceFunction = (a, b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1));
} }
@Test @Test
public void shouldReturnCorrectClassForVectorWithKEquals3() public void shouldReturnCorrectClassForVectorWithKEquals3()
{ {
var kNearestNeighbour = new KNearestNeighbour((a ,b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1)), 3); var kNearestNeighbour = new KNearestNeighbour(this.distanceFunction, 3);
kNearestNeighbour.learn(this.positives, this.negatives); kNearestNeighbour.learn(this.positives, this.negatives);
var vector = new Vector(8, 3.5); var vector = new Vector(8, 3.5);
@ -59,7 +62,7 @@ class KNearestNeighbourTest
@Test @Test
public void shouldReturnCorrectClassForVectorWithKEquals5() public void shouldReturnCorrectClassForVectorWithKEquals5()
{ {
var kNearestNeighbour = new KNearestNeighbour((a ,b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1)), 5); var kNearestNeighbour = new KNearestNeighbour(this.distanceFunction, 5);
kNearestNeighbour.learn(this.positives, this.negatives); kNearestNeighbour.learn(this.positives, this.negatives);
var vector = new Vector(8, 3.5); var vector = new Vector(8, 3.5);