From 24c9e0b247375ecdacbeeb71bf622c3c2affc6e5 Mon Sep 17 00:00:00 2001 From: Niklas Birk Date: Tue, 30 Nov 2021 22:27:45 +0100 Subject: [PATCH] Minor changes in names and extracted method, etc. --- src/machine_learning/Vector.java | 23 ++++++++++--------- .../{Distance.java => DistanceFunction.java} | 2 +- .../nearest_neighbour/KNearestNeighbour.java | 12 +++++----- .../KNearestNeighbourTest.java | 11 +++++---- 4 files changed, 26 insertions(+), 22 deletions(-) rename src/machine_learning/nearest_neighbour/{Distance.java => DistanceFunction.java} (78%) diff --git a/src/machine_learning/Vector.java b/src/machine_learning/Vector.java index e061574..aa3bb41 100644 --- a/src/machine_learning/Vector.java +++ b/src/machine_learning/Vector.java @@ -34,7 +34,7 @@ public class Vector public Vector add(Vector b) { - if (this.dimension() != b.dimension()) throw new IllegalArgumentException("Dimensions must be equals."); + checkEqualDimensions(b); return new Vector(IntStream.range(0, this.dimension()) .mapToObj(i -> this.get(i) + b.get(i)) @@ -45,7 +45,7 @@ public class Vector public Vector subtract(Vector b) { - if (this.dimension() != b.dimension()) throw new IllegalArgumentException("Dimensions must be equals."); + checkEqualDimensions(b); return new Vector(IntStream.range(0, this.dimension()) .mapToObj(i -> this.get(i) - b.get(i)) @@ -55,7 +55,7 @@ public class Vector public double scalar(Vector b) { - if (this.dimension() != b.dimension()) throw new IllegalArgumentException("Dimensions must be equals."); + checkEqualDimensions(b); return IntStream.range(0, this.dimension()) .mapToDouble(i -> this.get(i) * b.get(i)) @@ -80,14 +80,9 @@ public class Vector public Vector divide(double div) { - var divided = new ArrayList(); - - for (int i = 0; i < this.dimension(); i++) - { - divided.add(this.values.get(i) / div); - } - - return new Vector(divided); + return new Vector(IntStream.range(0, this.dimension()) + .mapToObj(i -> this.values.get(i) / div) + .collect(Collectors.toCollection(ArrayList::new))); } public double get(int index) @@ -133,4 +128,10 @@ public class Vector { return this.values.toString(); } + + private void checkEqualDimensions(Vector b) + { + if (this.dimension() != b.dimension()) + throw new IllegalArgumentException("Dimensions must be equal"); + } } diff --git a/src/machine_learning/nearest_neighbour/Distance.java b/src/machine_learning/nearest_neighbour/DistanceFunction.java similarity index 78% rename from src/machine_learning/nearest_neighbour/Distance.java rename to src/machine_learning/nearest_neighbour/DistanceFunction.java index 42092ef..79e48b5 100644 --- a/src/machine_learning/nearest_neighbour/Distance.java +++ b/src/machine_learning/nearest_neighbour/DistanceFunction.java @@ -2,7 +2,7 @@ package machine_learning.nearest_neighbour; import machine_learning.Vector; -public interface Distance +public interface DistanceFunction { double distance(Vector a, Vector b); } diff --git a/src/machine_learning/nearest_neighbour/KNearestNeighbour.java b/src/machine_learning/nearest_neighbour/KNearestNeighbour.java index 6e63038..a11361d 100644 --- a/src/machine_learning/nearest_neighbour/KNearestNeighbour.java +++ b/src/machine_learning/nearest_neighbour/KNearestNeighbour.java @@ -15,18 +15,18 @@ public class KNearestNeighbour implements MachineLearning private List positives; private List negatives; - private final Distance distance; + private final DistanceFunction distanceFunction; private final int k; - public KNearestNeighbour(Distance distance) + public KNearestNeighbour(DistanceFunction distanceFunction) { - this(distance, 1); + this(distanceFunction, 1); } - public KNearestNeighbour(Distance distance, int k) + public KNearestNeighbour(DistanceFunction distanceFunction, int k) { - this.distance = distance; + this.distanceFunction = distanceFunction; this.k = k; } @@ -67,7 +67,7 @@ public class KNearestNeighbour implements MachineLearning private List nearestNeighbours(List vectors, Vector vector) { return vectors.parallelStream() - .map(v -> Map.entry(this.distance.distance(v, vector), v)) + .map(v -> Map.entry(this.distanceFunction.distance(v, vector), v)) .sorted((e1, e2) -> e1.getKey() >= e2.getKey() ? (e1.getKey().equals(e2.getKey()) ? 0 : 1) : -1) .map(Map.Entry::getValue) .collect(Collectors.toList()) diff --git a/test/machine_learning/nearest_neighbour/KNearestNeighbourTest.java b/test/machine_learning/nearest_neighbour/KNearestNeighbourTest.java index 6f5b629..bbf25b1 100644 --- a/test/machine_learning/nearest_neighbour/KNearestNeighbourTest.java +++ b/test/machine_learning/nearest_neighbour/KNearestNeighbourTest.java @@ -22,8 +22,9 @@ import static org.junit.jupiter.api.Assertions.*; @TestInstance(TestInstance.Lifecycle.PER_CLASS) class KNearestNeighbourTest { - List positives; - List negatives; + private List positives; + private List negatives; + private DistanceFunction distanceFunction; @BeforeAll void initLearnData() @@ -41,12 +42,14 @@ class KNearestNeighbourTest new Vector(8d, 2d), new Vector(9d, 0d)) ); + + this.distanceFunction = (a, b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1)); } @Test public void shouldReturnCorrectClassForVectorWithKEquals3() { - var kNearestNeighbour = new KNearestNeighbour((a ,b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1)), 3); + var kNearestNeighbour = new KNearestNeighbour(this.distanceFunction, 3); kNearestNeighbour.learn(this.positives, this.negatives); var vector = new Vector(8, 3.5); @@ -59,7 +62,7 @@ class KNearestNeighbourTest @Test public void shouldReturnCorrectClassForVectorWithKEquals5() { - var kNearestNeighbour = new KNearestNeighbour((a ,b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1)), 5); + var kNearestNeighbour = new KNearestNeighbour(this.distanceFunction, 5); kNearestNeighbour.learn(this.positives, this.negatives); var vector = new Vector(8, 3.5);