From 757eb4dd2b7142359ec5c6510a626c856944a7b2 Mon Sep 17 00:00:00 2001 From: Niklas Birk Date: Thu, 27 Jun 2019 00:11:59 +0200 Subject: [PATCH] Added classify method for perceptron --- .../perceptron/Perceptron.java | 6 ++++ .../perceptron/PerceptronTest.java | 28 +++++++++++++++++-- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/machine_learning/perceptron/Perceptron.java b/src/machine_learning/perceptron/Perceptron.java index 5ee45ef..4f55b1b 100644 --- a/src/machine_learning/perceptron/Perceptron.java +++ b/src/machine_learning/perceptron/Perceptron.java @@ -1,6 +1,7 @@ package machine_learning.perceptron; import machine_learning.Vector; +import machine_learning.nearest_neighbour.DataClass; import java.util.List; @@ -42,6 +43,11 @@ public class Perceptron System.out.println("-----------------------------------------------------------------"); } + public DataClass classify(Vector vector) + { + return this.weight.scalar(vector) > 0 ? DataClass.POSITIVE : DataClass.NEGATIVE; + } + private Vector getInitializationVector(List positives, List negatives) { var a = new Vector(positives.get(0).dimension()); diff --git a/test/machine_learning/perceptron/PerceptronTest.java b/test/machine_learning/perceptron/PerceptronTest.java index 4740fca..36cdaf3 100644 --- a/test/machine_learning/perceptron/PerceptronTest.java +++ b/test/machine_learning/perceptron/PerceptronTest.java @@ -1,6 +1,8 @@ package machine_learning.perceptron; import machine_learning.Vector; +import machine_learning.nearest_neighbour.DataClass; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -8,11 +10,14 @@ import org.junit.jupiter.api.TestInstance; import java.util.ArrayList; import java.util.List; +import static org.junit.jupiter.api.Assertions.*; + @TestInstance(TestInstance.Lifecycle.PER_CLASS) class PerceptronTest { List positives; List negatives; + Perceptron perceptron; @BeforeAll void initLearnData() @@ -31,11 +36,30 @@ class PerceptronTest new Vector(8d, 2d, biasUnit), new Vector(9d, 0d, biasUnit)) ); + + this.perceptron = new Perceptron(); + this.perceptron.learn(this.positives, this.negatives); } @Test - void shouldClassifyCorrect() + void shouldClassifyVectorCorrectAsNegative() { - new Perceptron().learn(this.positives, this.negatives); + var vector = new Vector(0d, 0d, 1d); + + var actualClass = this.perceptron.classify(vector); + var expectedClass = DataClass.NEGATIVE; + + assertEquals(expectedClass, actualClass); + } + + @Test + void shouldClassifyVectorCorrectAsPositive() + { + var vector = new Vector(9d, 3d, 1d); + + var actualClass = this.perceptron.classify(vector); + var expectedClass = DataClass.POSITIVE; + + assertEquals(expectedClass, actualClass); } } \ No newline at end of file