2022-06-23 22:29:19 +02:00

182 lines
5.6 KiB
Java
Executable File

package machine_learning.nearest_neighbour;
import machine_learning.DataClass;
import machine_learning.Vector;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.opentest4j.AssertionFailedError;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.*;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class KNearestNeighbourTest
{
private List<Vector> positives;
private List<Vector> negatives;
private DistanceFunction distanceFunction;
@BeforeAll
void initLearnData()
{
this.positives = new ArrayList<>(List.of(
new Vector(8d, 4d),
new Vector(8d, 6d),
new Vector(9d, 2d),
new Vector(9d, 5d))
);
this.negatives = new ArrayList<>(List.of(
new Vector(6d, 1d),
new Vector(7d, 3d),
new Vector(8d, 2d),
new Vector(9d, 0d))
);
this.distanceFunction = (a, b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1));
}
@Test
public void shouldReturnCorrectClassForVectorWithKEquals3()
{
var kNearestNeighbour = new KNearestNeighbour(this.distanceFunction, 3);
kNearestNeighbour.learn(this.positives, this.negatives);
var vector = new Vector(8, 3.5);
var actualClass = kNearestNeighbour.classify(vector);
var expectedClass = DataClass.NEGATIVE;
assertEquals(expectedClass, actualClass);
}
@Test
public void shouldReturnCorrectClassForVectorWithKEquals5()
{
var kNearestNeighbour = new KNearestNeighbour(this.distanceFunction, 5);
kNearestNeighbour.learn(this.positives, this.negatives);
var vector = new Vector(8, 3.5);
var actualClass = kNearestNeighbour.classify(vector);
var expectedClass = DataClass.POSITIVE;
assertEquals(expectedClass, actualClass);
}
@Test
public void shouldReturnCorrectClassesForAppendicitisData()
{
var trainDataFile = "./resources/app1.data";
var testDataFile = "./resources/app1.test";
var trainDataVectors = readFromFile(trainDataFile);
var dataClasses = splitIntoClasses(trainDataVectors);
var negatives = dataClasses.get(DataClass.NEGATIVE);
var positives = dataClasses.get(DataClass.POSITIVE);
var kNearestNeighbour = new KNearestNeighbour(Vector::distance);
kNearestNeighbour.learn(positives, negatives);
var testDataVectors = readFromFile(testDataFile);
var failCount = 0;
for (var vector : testDataVectors)
{
var expectedClass = DataClass.valueOf(Double.valueOf(vector.get(vector.dimension() - 1)).intValue());
var testVector = vector.decreasedDimension();
var actualClass = kNearestNeighbour.classify(testVector.normalized());
try
{
assertEquals(expectedClass, actualClass);
}
catch (AssertionFailedError e)
{
failCount++;
}
}
System.out.println(failCount + " of " + testDataVectors.size() + " are not correct classified.");
System.out.println("Fail rate of " + Math.round(100d * failCount / testDataVectors.size()) + " %");
}
@Disabled("Takes long")
@Test
public void shouldReturnOptimum()
{
var trainDataFile = "./resources/app1.data";
var testDataFile = "./resources/app1.test";
var trainDataVectors = readFromFile(trainDataFile);
var testDataVectors = readFromFile(testDataFile);
var data = Stream.concat(trainDataVectors.stream(), testDataVectors.stream())
.collect(Collectors.toList());
var crossValidation = new CrossValidation(1, 100);
var kNearestNeighbour = crossValidation.validate(data, data.size());
}
private List<Vector> readFromFile(String file)
{
List<Vector> vectorList = new ArrayList<>();
try (var reader = new BufferedReader(new FileReader(file)))
{
String line;
while ((line = reader.readLine()) != null)
{
vectorList.add(new Vector(
Arrays.stream(line.split(","))
.map(Double::valueOf)
.collect(Collectors.toList())
));
}
}
catch (IOException e)
{
e.printStackTrace();
}
return vectorList;
}
private Map<DataClass, List<Vector>> splitIntoClasses(List<Vector> data)
{
var positives = data.stream()
.filter(v -> v.get(v.dimension()-1) == 1)
.collect(Collectors.toList());
var negatives = data.stream()
.filter(v -> v.get(v.dimension()-1) == 0)
.collect(Collectors.toList());
positives = positives.stream()
.map(Vector::decreasedDimension)
.map(Vector::normalized)
.collect(Collectors.toList());
negatives = negatives.stream()
.map(Vector::decreasedDimension)
.map(Vector::normalized)
.collect(Collectors.toList());
return Map.ofEntries(Map.entry(DataClass.NEGATIVE, negatives), Map.entry(DataClass.POSITIVE, positives));
}
}