package machine_learning.nearest_neighbour; import machine_learning.DataClass; import machine_learning.Vector; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.opentest4j.AssertionFailedError; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.*; @TestInstance(TestInstance.Lifecycle.PER_CLASS) class KNearestNeighbourTest { private List positives; private List negatives; private DistanceFunction distanceFunction; @BeforeAll void initLearnData() { this.positives = new ArrayList<>(List.of( new Vector(8d, 4d), new Vector(8d, 6d), new Vector(9d, 2d), new Vector(9d, 5d)) ); this.negatives = new ArrayList<>(List.of( new Vector(6d, 1d), new Vector(7d, 3d), new Vector(8d, 2d), new Vector(9d, 0d)) ); this.distanceFunction = (a, b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1)); } @Test public void shouldReturnCorrectClassForVectorWithKEquals3() { var kNearestNeighbour = new KNearestNeighbour(this.distanceFunction, 3); kNearestNeighbour.learn(this.positives, this.negatives); var vector = new Vector(8, 3.5); var actualClass = kNearestNeighbour.classify(vector); var expectedClass = DataClass.NEGATIVE; assertEquals(expectedClass, actualClass); } @Test public void shouldReturnCorrectClassForVectorWithKEquals5() { var kNearestNeighbour = new KNearestNeighbour(this.distanceFunction, 5); kNearestNeighbour.learn(this.positives, this.negatives); var vector = new Vector(8, 3.5); var actualClass = kNearestNeighbour.classify(vector); var expectedClass = DataClass.POSITIVE; assertEquals(expectedClass, actualClass); } @Test public void shouldReturnCorrectClassesForAppendicitisData() { var trainDataFile = "./resources/app1.data"; var testDataFile = "./resources/app1.test"; var trainDataVectors = readFromFile(trainDataFile); var dataClasses = splitIntoClasses(trainDataVectors); var negatives = dataClasses.get(DataClass.NEGATIVE); var positives = dataClasses.get(DataClass.POSITIVE); var kNearestNeighbour = new KNearestNeighbour(Vector::distance); kNearestNeighbour.learn(positives, negatives); var testDataVectors = readFromFile(testDataFile); var failCount = 0; for (var vector : testDataVectors) { var expectedClass = DataClass.valueOf(Double.valueOf(vector.get(vector.dimension() - 1)).intValue()); var testVector = vector.decreasedDimension(); var actualClass = kNearestNeighbour.classify(testVector.normalized()); try { assertEquals(expectedClass, actualClass); } catch (AssertionFailedError e) { failCount++; } } System.out.println(failCount + " of " + testDataVectors.size() + " are not correct classified."); System.out.println("Fail rate of " + Math.round(100d * failCount / testDataVectors.size()) + " %"); } @Disabled("Takes long") @Test public void shouldReturnOptimum() { var trainDataFile = "./resources/app1.data"; var testDataFile = "./resources/app1.test"; var trainDataVectors = readFromFile(trainDataFile); var testDataVectors = readFromFile(testDataFile); var data = Stream.concat(trainDataVectors.stream(), testDataVectors.stream()) .collect(Collectors.toList()); var crossValidation = new CrossValidation(1, 100); var kNearestNeighbour = crossValidation.validate(data, data.size()); } private List readFromFile(String file) { List vectorList = new ArrayList<>(); try (var reader = new BufferedReader(new FileReader(file))) { String line; while ((line = reader.readLine()) != null) { vectorList.add(new Vector( Arrays.stream(line.split(",")) .map(Double::valueOf) .collect(Collectors.toList()) )); } } catch (IOException e) { e.printStackTrace(); } return vectorList; } private Map> splitIntoClasses(List data) { var positives = data.stream() .filter(v -> v.get(v.dimension()-1) == 1) .collect(Collectors.toList()); var negatives = data.stream() .filter(v -> v.get(v.dimension()-1) == 0) .collect(Collectors.toList()); positives = positives.stream() .map(Vector::decreasedDimension) .map(Vector::normalized) .collect(Collectors.toList()); negatives = negatives.stream() .map(Vector::decreasedDimension) .map(Vector::normalized) .collect(Collectors.toList()); return Map.ofEntries(Map.entry(DataClass.NEGATIVE, negatives), Map.entry(DataClass.POSITIVE, positives)); } }