Crossvalidation added

This commit is contained in:
Niklas Birk
2019-06-30 23:59:49 +02:00
parent 757eb4dd2b
commit 6937e28cdd
12 changed files with 787 additions and 34 deletions

View File

@ -24,10 +24,10 @@ class VectorTest
var v1 = new Vector(1d, 2d);
var v2 = new Vector(3d, 4d);
var result = v1.add(v2);
var actual = v1.add(v2);
var expected = new Vector(4d, 6d);
assertEquals(expected, result);
assertEquals(expected, actual);
}
@Test
@ -36,10 +36,10 @@ class VectorTest
var v1 = new Vector(1d, 2d);
var v2 = new Vector(3d, 4d);
var result = v1.subtract(v2);
var actual = v1.subtract(v2);
var expected = new Vector(-2d, -2d);
assertEquals(expected, result);
assertEquals(expected, actual);
}
@Test
@ -48,10 +48,10 @@ class VectorTest
var v1 = new Vector(1d, 2d);
var v2 = new Vector(3d, 4d);
var result = v1.scalar(v2);
var actual = v1.scalar(v2);
var expected = 11d;
assertEquals(expected, result);
assertEquals(expected, actual);
}
@Test
@ -59,10 +59,10 @@ class VectorTest
{
var v1 = new Vector(1d, 2d);
var result = v1.euclid();
var actual = v1.euclid();
var expected = Math.sqrt(5);
assertEquals(expected, result);
assertEquals(expected, actual);
}
@Test
@ -71,10 +71,10 @@ class VectorTest
var v1 = new Vector(1d, 2d);
var v2 = new Vector(3d, 4d);
var result = v1.distance(v2);
var actual = v1.distance(v2);
var expected = Math.sqrt(8);
assertEquals(expected, result);
assertEquals(expected, actual);
}
@Test
@ -83,9 +83,34 @@ class VectorTest
var v1 = new Vector(1d, 2d);
var div = 2d;
var result = v1.divide(div);
var actual = v1.divide(div);
var expected = new Vector(0.5d, 1d);
assertEquals(expected, result);
assertEquals(expected, actual);
}
@Test
void shouldDecreaseDimensionCorrect()
{
var v = new Vector(1d, 2d, 3d, 4d);
var decreasedDimensionVector = v.decreasedDimension();
var actual = decreasedDimensionVector.dimension();
var expected = 3;
assertEquals(expected, actual);
}
@Test
void shouldNormalizeCorrect()
{
var v = new Vector(4d, 4d, 4d, 4d);
var actual = v.normalized();
var expected = new Vector(0.5d, 0.5d, 0.5d, 0.5d);
assertEquals(expected, actual);
}
}

View File

@ -1,13 +1,20 @@
package machine_learning.nearest_neighbour;
import machine_learning.DataClass;
import machine_learning.Vector;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.opentest4j.AssertionFailedError;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.*;
@ -40,9 +47,10 @@ class KNearestNeighbourTest
public void shouldReturnCorrectClassForVectorWithKEquals3()
{
var kNearestNeighbour = new KNearestNeighbour((a ,b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1)), 3);
kNearestNeighbour.learn(this.positives, this.negatives);
var vector = new Vector(8, 3.5);
var actualClass = kNearestNeighbour.kNearestNeighbour(this.positives, this.negatives, vector);
var actualClass = kNearestNeighbour.classify(vector);
var expectedClass = DataClass.NEGATIVE;
assertEquals(expectedClass, actualClass);
@ -52,11 +60,118 @@ class KNearestNeighbourTest
public void shouldReturnCorrectClassForVectorWithKEquals5()
{
var kNearestNeighbour = new KNearestNeighbour((a ,b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1)), 5);
kNearestNeighbour.learn(this.positives, this.negatives);
var vector = new Vector(8, 3.5);
var actualClass = kNearestNeighbour.kNearestNeighbour(this.positives, this.negatives, vector);
var actualClass = kNearestNeighbour.classify(vector);
var expectedClass = DataClass.POSITIVE;
assertEquals(expectedClass, actualClass);
}
@Test
public void shouldReturnCorrectClassesForAppendicitisData()
{
var trainDataFile = "./resources/app1.data";
var testDataFile = "./resources/app1.test";
var trainDataVectors = readFromFile(trainDataFile);
var dataClasses = splitIntoClasses(trainDataVectors);
var negatives = dataClasses.get(DataClass.NEGATIVE);
var positives = dataClasses.get(DataClass.POSITIVE);
var kNearestNeighbour = new KNearestNeighbour(Vector::distance);
kNearestNeighbour.learn(positives, negatives);
var testDataVectors = readFromFile(testDataFile);
var failCount = 0;
for (var vector : testDataVectors)
{
var expectedClass = DataClass.valueOf(Double.valueOf(vector.get(vector.dimension() - 1)).intValue());
var testVector = vector.decreasedDimension();
var actualClass = kNearestNeighbour.classify(testVector.normalized());
try
{
assertEquals(expectedClass, actualClass);
}
catch (AssertionFailedError e)
{
failCount++;
}
}
System.out.println(failCount + " of " + testDataVectors.size() + " are not correct classified.");
System.out.println("Fail rate of " + Math.round(100d * failCount / testDataVectors.size()) + " %");
}
@Test
public void shouldReturnOptimum()
{
var trainDataFile = "./resources/app1.data";
var testDataFile = "./resources/app1.test";
var trainDataVectors = readFromFile(trainDataFile);
var testDataVectors = readFromFile(testDataFile);
var data = Stream.concat(trainDataVectors.stream(), testDataVectors.stream())
.collect(Collectors.toList());
var crossValidation = new CrossValidation(1, 100);
var kNearestNeighbour = crossValidation.validate(data, data.size());
}
private List<Vector> readFromFile(String file)
{
List<Vector> vectorList = new ArrayList<>();
try (var reader = new BufferedReader(new FileReader(file)))
{
String line;
while ((line = reader.readLine()) != null)
{
vectorList.add(new Vector(
Arrays.stream(line.split(","))
.map(Double::valueOf)
.collect(Collectors.toList())
));
}
}
catch (IOException e)
{
e.printStackTrace();
}
return vectorList;
}
private Map<DataClass, List<Vector>> splitIntoClasses(List<Vector> data)
{
var positives = data.stream()
.filter(v -> v.get(v.dimension()-1) == 1)
.collect(Collectors.toList());
var negatives = data.stream()
.filter(v -> v.get(v.dimension()-1) == 0)
.collect(Collectors.toList());
positives = positives.stream()
.map(Vector::decreasedDimension)
.map(Vector::normalized)
.collect(Collectors.toList());
negatives = negatives.stream()
.map(Vector::decreasedDimension)
.map(Vector::normalized)
.collect(Collectors.toList());
return Map.ofEntries(Map.entry(DataClass.NEGATIVE, negatives), Map.entry(DataClass.POSITIVE, positives));
}
}

View File

@ -1,8 +1,7 @@
package machine_learning.perceptron;
import machine_learning.Vector;
import machine_learning.nearest_neighbour.DataClass;
import org.junit.jupiter.api.Assertions;
import machine_learning.DataClass;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;