From 84a1b61477918f7c7d3d27247075415db4724be8 Mon Sep 17 00:00:00 2001 From: Niklas Birk Date: Thu, 20 Jun 2019 13:17:23 +0200 Subject: [PATCH] Perceptron (unfinished) --- .../perceptron/Perceptron.java | 73 +++++++++++++++ src/machine_learning/perceptron/Vector.java | 90 +++++++++++++++++++ .../perceptron/PerceptronTest.java | 39 ++++++++ .../perceptron/VectorTest.java | 58 ++++++++++++ 4 files changed, 260 insertions(+) create mode 100644 src/machine_learning/perceptron/Perceptron.java create mode 100644 src/machine_learning/perceptron/Vector.java create mode 100644 test/machine_learning/perceptron/PerceptronTest.java create mode 100644 test/machine_learning/perceptron/VectorTest.java diff --git a/src/machine_learning/perceptron/Perceptron.java b/src/machine_learning/perceptron/Perceptron.java new file mode 100644 index 0000000..089d2cf --- /dev/null +++ b/src/machine_learning/perceptron/Perceptron.java @@ -0,0 +1,73 @@ +package machine_learning.perceptron; + +import java.util.List; + +public class Perceptron +{ + public void learn(List positives, List negatives) + { + var weight = this.getInitializationVector(positives, negatives); + + do + { + for (var x : positives) + { + if (weight.scalar(x) <= 0) + { + weight = weight.add(x); + } + } + + for (var x : negatives) + { + if (weight.scalar(x) > 0) + { + weight = weight.subtract(x); + } + } + + System.out.println(weight); + } + while (elementsAreCorrectClassified(positives, weight, 1) && elementsAreCorrectClassified(negatives, weight, 0)); + + System.out.println("-----------------------------------"); + System.out.println("-- All are classified correctly. --"); + System.out.println("-----------------------------------"); + } + + private Vector getInitializationVector(List positives, List negatives) + {/* + var a = new Vector(positives.get(0).dimension()); + for (var x : positives) + { + a = a.add(x); + } + + var b = new Vector(positives.get(0).dimension()); + for (var x : negatives) + { + b = b.add(x); + } + + return a.subtract(b);*/ + + return new Vector(positives.get(0).dimension()); + } + + private boolean elementsAreCorrectClassified(List vectors, Vector weight, int expectedClass) + { + int actualClass; + + for (var x : vectors) + { + actualClass = weight.scalar(x) > 0 ? 1 : 0; + + if (actualClass != expectedClass) + { + return false; + } + } + + return true; + } +} diff --git a/src/machine_learning/perceptron/Vector.java b/src/machine_learning/perceptron/Vector.java new file mode 100644 index 0000000..27b416a --- /dev/null +++ b/src/machine_learning/perceptron/Vector.java @@ -0,0 +1,90 @@ +package machine_learning.perceptron; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class Vector +{ + private List values; + + public Vector(int dim) + { + this.values = new ArrayList<>(); + + for (int i = 0; i < dim; i++) + { + this.values.add(0d); + } + } + + public Vector(List values) + { + this.values = values; + } + + public int dimension() + { + return this.values.size(); + } + + public Vector add(Vector b) + { + return new Vector(IntStream.range(0, + this.values.size()) + .mapToObj(i -> this.values.get(i) + b.values.get(i)) + .collect(Collectors.toCollection(ArrayList::new)) + ); + + } + + public Vector subtract(Vector b) + { + return new Vector(IntStream.range(0, + this.values.size()) + .mapToObj(i -> this.values.get(i) - b.values.get(i)) + .collect(Collectors.toCollection(ArrayList::new)) + ); + } + + public double scalar(Vector b) + { + return IntStream.range(0, + this.values.size()) + .mapToDouble(i -> this.values.get(i) * b.values.get(i)) + .sum(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) + { + return true; + } + if (o == null || getClass() != o.getClass()) + { + return false; + } + + Vector vector = (Vector) o; + + return Objects.equals(values, vector.values); + } + + @Override + public int hashCode() + { + return values != null ? values.hashCode() : 0; + } + + @Override + public String toString() + { + return values.toString() + .replace("[", "(") + .replace("]", ")"); + } +} diff --git a/test/machine_learning/perceptron/PerceptronTest.java b/test/machine_learning/perceptron/PerceptronTest.java new file mode 100644 index 0000000..2c98b93 --- /dev/null +++ b/test/machine_learning/perceptron/PerceptronTest.java @@ -0,0 +1,39 @@ +package machine_learning.perceptron; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +import java.util.ArrayList; +import java.util.List; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class PerceptronTest +{ + List positives; + List negatives; + + @BeforeAll + void initLearnData() + { + this.positives = new ArrayList<>(List.of( + new Vector(List.of(8d, 4d)), + new Vector(List.of(8d, 6d)), + new Vector(List.of(9d, 2d)), + new Vector(List.of(9d, 5d))) + ); + + this.negatives = new ArrayList<>(List.of( + new Vector(List.of(6d, 1d)), + new Vector(List.of(7d, 3d)), + new Vector(List.of(8d, 2d)), + new Vector(List.of(9d, 0d))) + ); + } + + @Test + void shouldClassifyCorrect() + { + new Perceptron().learn(this.positives, this.negatives); + } +} \ No newline at end of file diff --git a/test/machine_learning/perceptron/VectorTest.java b/test/machine_learning/perceptron/VectorTest.java new file mode 100644 index 0000000..e2c7514 --- /dev/null +++ b/test/machine_learning/perceptron/VectorTest.java @@ -0,0 +1,58 @@ +package machine_learning.perceptron; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +class VectorTest +{ + + @Test + void shouldInitializeZeroVector() + { + var v = new Vector(3); + + var expected = new Vector(List.of(0d, 0d, 0d)); + + assertEquals(3, v.dimension()); + assertEquals(expected, v); + } + + @Test + void shouldReturnCorrectVectorWhenAdding() + { + var v1 = new Vector(List.of(1d, 2d)); + var v2 = new Vector(List.of(3d, 4d)); + + var result = v1.add(v2); + var expected = new Vector(List.of(4d, 6d)); + + assertEquals(expected, result); + } + + @Test + void shouldReturnCorrectVectorWhenSubtracting() + { + var v1 = new Vector(List.of(1d, 2d)); + var v2 = new Vector(List.of(3d, 4d)); + + var result = v1.subtract(v2); + var expected = new Vector(List.of(-2d, -2d)); + + assertEquals(expected, result); + } + + @Test + void shouldReturnCorrectVectorWhenScalar() + { + var v1 = new Vector(List.of(1d, 2d)); + var v2 = new Vector(List.of(3d, 4d)); + + var result = v1.scalar(v2); + var expected = 11d; + + assertEquals(expected, result); + } +} \ No newline at end of file