Crossvalidation added
This commit is contained in:
@ -24,10 +24,10 @@ class VectorTest
|
||||
var v1 = new Vector(1d, 2d);
|
||||
var v2 = new Vector(3d, 4d);
|
||||
|
||||
var result = v1.add(v2);
|
||||
var actual = v1.add(v2);
|
||||
var expected = new Vector(4d, 6d);
|
||||
|
||||
assertEquals(expected, result);
|
||||
assertEquals(expected, actual);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -36,10 +36,10 @@ class VectorTest
|
||||
var v1 = new Vector(1d, 2d);
|
||||
var v2 = new Vector(3d, 4d);
|
||||
|
||||
var result = v1.subtract(v2);
|
||||
var actual = v1.subtract(v2);
|
||||
var expected = new Vector(-2d, -2d);
|
||||
|
||||
assertEquals(expected, result);
|
||||
assertEquals(expected, actual);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -48,10 +48,10 @@ class VectorTest
|
||||
var v1 = new Vector(1d, 2d);
|
||||
var v2 = new Vector(3d, 4d);
|
||||
|
||||
var result = v1.scalar(v2);
|
||||
var actual = v1.scalar(v2);
|
||||
var expected = 11d;
|
||||
|
||||
assertEquals(expected, result);
|
||||
assertEquals(expected, actual);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -59,10 +59,10 @@ class VectorTest
|
||||
{
|
||||
var v1 = new Vector(1d, 2d);
|
||||
|
||||
var result = v1.euclid();
|
||||
var actual = v1.euclid();
|
||||
var expected = Math.sqrt(5);
|
||||
|
||||
assertEquals(expected, result);
|
||||
assertEquals(expected, actual);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -71,10 +71,10 @@ class VectorTest
|
||||
var v1 = new Vector(1d, 2d);
|
||||
var v2 = new Vector(3d, 4d);
|
||||
|
||||
var result = v1.distance(v2);
|
||||
var actual = v1.distance(v2);
|
||||
var expected = Math.sqrt(8);
|
||||
|
||||
assertEquals(expected, result);
|
||||
assertEquals(expected, actual);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -83,9 +83,34 @@ class VectorTest
|
||||
var v1 = new Vector(1d, 2d);
|
||||
var div = 2d;
|
||||
|
||||
var result = v1.divide(div);
|
||||
var actual = v1.divide(div);
|
||||
var expected = new Vector(0.5d, 1d);
|
||||
|
||||
assertEquals(expected, result);
|
||||
assertEquals(expected, actual);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldDecreaseDimensionCorrect()
|
||||
{
|
||||
var v = new Vector(1d, 2d, 3d, 4d);
|
||||
|
||||
var decreasedDimensionVector = v.decreasedDimension();
|
||||
|
||||
var actual = decreasedDimensionVector.dimension();
|
||||
var expected = 3;
|
||||
|
||||
assertEquals(expected, actual);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldNormalizeCorrect()
|
||||
{
|
||||
var v = new Vector(4d, 4d, 4d, 4d);
|
||||
|
||||
var actual = v.normalized();
|
||||
|
||||
var expected = new Vector(0.5d, 0.5d, 0.5d, 0.5d);
|
||||
|
||||
assertEquals(expected, actual);
|
||||
}
|
||||
}
|
@ -1,13 +1,20 @@
|
||||
package machine_learning.nearest_neighbour;
|
||||
|
||||
import machine_learning.DataClass;
|
||||
import machine_learning.Vector;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
import org.opentest4j.AssertionFailedError;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
@ -40,9 +47,10 @@ class KNearestNeighbourTest
|
||||
public void shouldReturnCorrectClassForVectorWithKEquals3()
|
||||
{
|
||||
var kNearestNeighbour = new KNearestNeighbour((a ,b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1)), 3);
|
||||
kNearestNeighbour.learn(this.positives, this.negatives);
|
||||
var vector = new Vector(8, 3.5);
|
||||
|
||||
var actualClass = kNearestNeighbour.kNearestNeighbour(this.positives, this.negatives, vector);
|
||||
var actualClass = kNearestNeighbour.classify(vector);
|
||||
var expectedClass = DataClass.NEGATIVE;
|
||||
|
||||
assertEquals(expectedClass, actualClass);
|
||||
@ -52,11 +60,118 @@ class KNearestNeighbourTest
|
||||
public void shouldReturnCorrectClassForVectorWithKEquals5()
|
||||
{
|
||||
var kNearestNeighbour = new KNearestNeighbour((a ,b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1)), 5);
|
||||
kNearestNeighbour.learn(this.positives, this.negatives);
|
||||
var vector = new Vector(8, 3.5);
|
||||
|
||||
var actualClass = kNearestNeighbour.kNearestNeighbour(this.positives, this.negatives, vector);
|
||||
var actualClass = kNearestNeighbour.classify(vector);
|
||||
var expectedClass = DataClass.POSITIVE;
|
||||
|
||||
assertEquals(expectedClass, actualClass);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldReturnCorrectClassesForAppendicitisData()
|
||||
{
|
||||
var trainDataFile = "./resources/app1.data";
|
||||
var testDataFile = "./resources/app1.test";
|
||||
|
||||
var trainDataVectors = readFromFile(trainDataFile);
|
||||
|
||||
var dataClasses = splitIntoClasses(trainDataVectors);
|
||||
var negatives = dataClasses.get(DataClass.NEGATIVE);
|
||||
var positives = dataClasses.get(DataClass.POSITIVE);
|
||||
|
||||
var kNearestNeighbour = new KNearestNeighbour(Vector::distance);
|
||||
kNearestNeighbour.learn(positives, negatives);
|
||||
|
||||
var testDataVectors = readFromFile(testDataFile);
|
||||
var failCount = 0;
|
||||
|
||||
for (var vector : testDataVectors)
|
||||
{
|
||||
var expectedClass = DataClass.valueOf(Double.valueOf(vector.get(vector.dimension() - 1)).intValue());
|
||||
|
||||
var testVector = vector.decreasedDimension();
|
||||
|
||||
var actualClass = kNearestNeighbour.classify(testVector.normalized());
|
||||
|
||||
try
|
||||
{
|
||||
assertEquals(expectedClass, actualClass);
|
||||
}
|
||||
catch (AssertionFailedError e)
|
||||
{
|
||||
failCount++;
|
||||
}
|
||||
}
|
||||
|
||||
System.out.println(failCount + " of " + testDataVectors.size() + " are not correct classified.");
|
||||
System.out.println("Fail rate of " + Math.round(100d * failCount / testDataVectors.size()) + " %");
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldReturnOptimum()
|
||||
{
|
||||
var trainDataFile = "./resources/app1.data";
|
||||
var testDataFile = "./resources/app1.test";
|
||||
|
||||
var trainDataVectors = readFromFile(trainDataFile);
|
||||
var testDataVectors = readFromFile(testDataFile);
|
||||
var data = Stream.concat(trainDataVectors.stream(), testDataVectors.stream())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
var crossValidation = new CrossValidation(1, 100);
|
||||
|
||||
var kNearestNeighbour = crossValidation.validate(data, data.size());
|
||||
}
|
||||
|
||||
private List<Vector> readFromFile(String file)
|
||||
{
|
||||
List<Vector> vectorList = new ArrayList<>();
|
||||
|
||||
try (var reader = new BufferedReader(new FileReader(file)))
|
||||
{
|
||||
String line;
|
||||
|
||||
while ((line = reader.readLine()) != null)
|
||||
{
|
||||
vectorList.add(new Vector(
|
||||
Arrays.stream(line.split(","))
|
||||
.map(Double::valueOf)
|
||||
.collect(Collectors.toList())
|
||||
));
|
||||
}
|
||||
|
||||
}
|
||||
catch (IOException e)
|
||||
{
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
return vectorList;
|
||||
}
|
||||
|
||||
private Map<DataClass, List<Vector>> splitIntoClasses(List<Vector> data)
|
||||
{
|
||||
var positives = data.stream()
|
||||
.filter(v -> v.get(v.dimension()-1) == 1)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
var negatives = data.stream()
|
||||
.filter(v -> v.get(v.dimension()-1) == 0)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
positives = positives.stream()
|
||||
.map(Vector::decreasedDimension)
|
||||
.map(Vector::normalized)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
negatives = negatives.stream()
|
||||
.map(Vector::decreasedDimension)
|
||||
.map(Vector::normalized)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
return Map.ofEntries(Map.entry(DataClass.NEGATIVE, negatives), Map.entry(DataClass.POSITIVE, positives));
|
||||
}
|
||||
}
|
@ -1,8 +1,7 @@
|
||||
package machine_learning.perceptron;
|
||||
|
||||
import machine_learning.Vector;
|
||||
import machine_learning.nearest_neighbour.DataClass;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import machine_learning.DataClass;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
|
Reference in New Issue
Block a user