From 54cfe2dece4da81e6864343460e84aeecd32c2ca Mon Sep 17 00:00:00 2001 From: Niklas Birk Date: Thu, 27 Jun 2019 00:01:49 +0200 Subject: [PATCH] Added kNearestNeighbour and refactored Vector to a package level above and added a constructor --- .../{perceptron => }/Vector.java | 24 +++++-- .../nearest_neighbour/DataClass.java | 7 ++ .../nearest_neighbour/Distance.java | 8 +++ .../nearest_neighbour/KNearestNeighbour.java | 67 +++++++++++++++++++ .../perceptron/Perceptron.java | 2 + .../{perceptron => }/VectorTest.java | 40 ++++++----- .../KNearestNeighbourTest.java | 62 +++++++++++++++++ 7 files changed, 189 insertions(+), 21 deletions(-) rename src/machine_learning/{perceptron => }/Vector.java (82%) create mode 100644 src/machine_learning/nearest_neighbour/DataClass.java create mode 100644 src/machine_learning/nearest_neighbour/Distance.java create mode 100644 src/machine_learning/nearest_neighbour/KNearestNeighbour.java rename test/machine_learning/{perceptron => }/VectorTest.java (59%) create mode 100644 test/machine_learning/nearest_neighbour/KNearestNeighbourTest.java diff --git a/src/machine_learning/perceptron/Vector.java b/src/machine_learning/Vector.java similarity index 82% rename from src/machine_learning/perceptron/Vector.java rename to src/machine_learning/Vector.java index 3677af8..5a8fc3f 100644 --- a/src/machine_learning/perceptron/Vector.java +++ b/src/machine_learning/Vector.java @@ -1,4 +1,4 @@ -package machine_learning.perceptron; +package machine_learning; import java.util.*; import java.util.stream.Collectors; @@ -10,12 +10,16 @@ public class Vector public Vector(int dim) { - this.values = new ArrayList<>(); + this(IntStream.range(0, dim) + .mapToDouble(i -> 0d) + .toArray()); + } - for (int i = 0; i < dim; i++) - { - this.values.add(0d); - } + public Vector(double... value) + { + this(Arrays.stream(value) + .boxed() + .collect(Collectors.toList())); } public Vector(List values) @@ -66,6 +70,14 @@ public class Vector .sum()); } + public double distance(Vector b) + { + return Math.sqrt(IntStream.range(0, + this.dimension()) + .mapToDouble(i -> (this.get(i) - b.get(i)) * (this.get(i) - b.get(i))) + .sum()); + } + public Vector divide(double div) { var divided = new ArrayList(); diff --git a/src/machine_learning/nearest_neighbour/DataClass.java b/src/machine_learning/nearest_neighbour/DataClass.java new file mode 100644 index 0000000..c095351 --- /dev/null +++ b/src/machine_learning/nearest_neighbour/DataClass.java @@ -0,0 +1,7 @@ +package machine_learning.nearest_neighbour; + +public enum DataClass +{ + POSITIVE, + NEGATIVE +} diff --git a/src/machine_learning/nearest_neighbour/Distance.java b/src/machine_learning/nearest_neighbour/Distance.java new file mode 100644 index 0000000..42092ef --- /dev/null +++ b/src/machine_learning/nearest_neighbour/Distance.java @@ -0,0 +1,8 @@ +package machine_learning.nearest_neighbour; + +import machine_learning.Vector; + +public interface Distance +{ + double distance(Vector a, Vector b); +} diff --git a/src/machine_learning/nearest_neighbour/KNearestNeighbour.java b/src/machine_learning/nearest_neighbour/KNearestNeighbour.java new file mode 100644 index 0000000..e63faac --- /dev/null +++ b/src/machine_learning/nearest_neighbour/KNearestNeighbour.java @@ -0,0 +1,67 @@ +package machine_learning.nearest_neighbour; + +import machine_learning.Vector; + +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class KNearestNeighbour +{ + private Distance distance; + + private int k; + + public KNearestNeighbour(Distance distance) + { + this(distance, 1); + } + + public KNearestNeighbour(Distance distance, int k) + { + this.distance = distance; + this.k = k; + } + + public DataClass kNearestNeighbour(List positives, List negatives, Vector toClassify) + { + var nearestNeighbours = this.nearestNeighbours( + Stream.concat(positives.stream(), negatives.stream()) + .collect(Collectors.toList()), + toClassify + ); + + var positivesWithNearestNeighboursAmount = nearestNeighbours.stream() + .filter(positives::contains) + .count(); + + var negativesWithNearestNeighboursAmount = nearestNeighbours.stream() + .filter(negatives::contains) + .count(); + + if (positivesWithNearestNeighboursAmount > negativesWithNearestNeighboursAmount) + { + return DataClass.POSITIVE; + } + else if (positivesWithNearestNeighboursAmount < negativesWithNearestNeighboursAmount) + { + return DataClass.NEGATIVE; + } + + return new Random().nextBoolean() ? DataClass.POSITIVE : DataClass.NEGATIVE; + } + + private List nearestNeighbours(List vectors, Vector vector) + { + var nearestNeighbours = vectors.stream() + .map(v -> Map.entry(this.distance.distance(v, vector), v)) + .sorted((e1, e2) -> e1.getKey() >= e2.getKey() ? (e1.getKey().equals(e2.getKey()) ? 0 : 1) : -1) + .map(Map.Entry::getValue) + .collect(Collectors.toList()); + + return nearestNeighbours.subList(0, this.k); + } +} + diff --git a/src/machine_learning/perceptron/Perceptron.java b/src/machine_learning/perceptron/Perceptron.java index 72ca91f..3d84713 100644 --- a/src/machine_learning/perceptron/Perceptron.java +++ b/src/machine_learning/perceptron/Perceptron.java @@ -1,5 +1,7 @@ package machine_learning.perceptron; +import machine_learning.Vector; + import java.util.List; public class Perceptron diff --git a/test/machine_learning/perceptron/VectorTest.java b/test/machine_learning/VectorTest.java similarity index 59% rename from test/machine_learning/perceptron/VectorTest.java rename to test/machine_learning/VectorTest.java index bf43109..e838cab 100644 --- a/test/machine_learning/perceptron/VectorTest.java +++ b/test/machine_learning/VectorTest.java @@ -1,9 +1,7 @@ -package machine_learning.perceptron; +package machine_learning; import org.junit.jupiter.api.Test; -import java.util.List; - import static org.junit.jupiter.api.Assertions.*; class VectorTest @@ -14,7 +12,7 @@ class VectorTest { var v = new Vector(3); - var expected = new Vector(List.of(0d, 0d, 0d)); + var expected = new Vector(0d, 0d, 0d); assertEquals(3, v.dimension()); assertEquals(expected, v); @@ -23,11 +21,11 @@ class VectorTest @Test void shouldReturnCorrectVectorWhenAdding() { - var v1 = new Vector(List.of(1d, 2d)); - var v2 = new Vector(List.of(3d, 4d)); + var v1 = new Vector(1d, 2d); + var v2 = new Vector(3d, 4d); var result = v1.add(v2); - var expected = new Vector(List.of(4d, 6d)); + var expected = new Vector(4d, 6d); assertEquals(expected, result); } @@ -35,11 +33,11 @@ class VectorTest @Test void shouldReturnCorrectVectorWhenSubtracting() { - var v1 = new Vector(List.of(1d, 2d)); - var v2 = new Vector(List.of(3d, 4d)); + var v1 = new Vector(1d, 2d); + var v2 = new Vector(3d, 4d); var result = v1.subtract(v2); - var expected = new Vector(List.of(-2d, -2d)); + var expected = new Vector(-2d, -2d); assertEquals(expected, result); } @@ -47,8 +45,8 @@ class VectorTest @Test void shouldReturnCorrectVectorWhenScalarMultiplying() { - var v1 = new Vector(List.of(1d, 2d)); - var v2 = new Vector(List.of(3d, 4d)); + var v1 = new Vector(1d, 2d); + var v2 = new Vector(3d, 4d); var result = v1.scalar(v2); var expected = 11d; @@ -59,7 +57,7 @@ class VectorTest @Test void shouldReturnCorrectVectorWhenEuclid() { - var v1 = new Vector(List.of(1d, 2d)); + var v1 = new Vector(1d, 2d); var result = v1.euclid(); var expected = Math.sqrt(5); @@ -67,14 +65,26 @@ class VectorTest assertEquals(expected, result); } + @Test + void shouldReturnCorrectDistance() + { + var v1 = new Vector(1d, 2d); + var v2 = new Vector(3d, 4d); + + var result = v1.distance(v2); + var expected = Math.sqrt(8); + + assertEquals(expected, result); + } + @Test void shouldReturnCorrectVectorWhenDividing() { - var v1 = new Vector(List.of(1d, 2d)); + var v1 = new Vector(1d, 2d); var div = 2d; var result = v1.divide(div); - var expected = new Vector(List.of(0.5d, 1d)); + var expected = new Vector(0.5d, 1d); assertEquals(expected, result); } diff --git a/test/machine_learning/nearest_neighbour/KNearestNeighbourTest.java b/test/machine_learning/nearest_neighbour/KNearestNeighbourTest.java new file mode 100644 index 0000000..3b4046c --- /dev/null +++ b/test/machine_learning/nearest_neighbour/KNearestNeighbourTest.java @@ -0,0 +1,62 @@ +package machine_learning.nearest_neighbour; + +import machine_learning.Vector; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.*; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class KNearestNeighbourTest +{ + List positives; + List negatives; + + @BeforeAll + void initLearnData() + { + this.positives = new ArrayList<>(List.of( + new Vector(8d, 4d), + new Vector(8d, 6d), + new Vector(9d, 2d), + new Vector(9d, 5d)) + ); + + this.negatives = new ArrayList<>(List.of( + new Vector(6d, 1d), + new Vector(7d, 3d), + new Vector(8d, 2d), + new Vector(9d, 0d)) + ); + } + + @Test + public void shouldReturnCorrectClassForVectorWithKEquals3() + { + var kNearestNeighbour = new KNearestNeighbour((a ,b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1)), 3); + var vector = new Vector(8, 3.5); + + var actualClass = kNearestNeighbour.kNearestNeighbour(this.positives, this.negatives, vector); + var expectedClass = DataClass.NEGATIVE; + + assertEquals(expectedClass, actualClass); + } + + @Test + public void shouldReturnCorrectClassForVectorWithKEquals5() + { + var kNearestNeighbour = new KNearestNeighbour((a ,b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1)), 5); + var vector = new Vector(8, 3.5); + + var actualClass = kNearestNeighbour.kNearestNeighbour(this.positives, this.negatives, vector); + var expectedClass = DataClass.POSITIVE; + + assertEquals(expectedClass, actualClass); + } +} \ No newline at end of file