Crossvalidation added

This commit is contained in:
Niklas Birk
2019-06-30 23:59:49 +02:00
parent 757eb4dd2b
commit 6937e28cdd
12 changed files with 787 additions and 34 deletions

View File

@ -0,0 +1,12 @@
package machine_learning;
public enum DataClass
{
NEGATIVE,
POSITIVE;
public static DataClass valueOf(int i)
{
return i == 0 ? NEGATIVE : POSITIVE;
}
}

View File

@ -0,0 +1,9 @@
package machine_learning;
import java.util.List;
public interface MachineLearning
{
void learn(List<Vector> positives, List<Vector> negatives);
DataClass classify(Vector toClassify);
}

View File

@ -95,6 +95,16 @@ public class Vector
return this.values.get(index);
}
public Vector decreasedDimension()
{
return new Vector(this.values.subList(0, this.dimension()-1));
}
public Vector normalized()
{
return this.divide(this.euclid());
}
@Override
public boolean equals(Object o)
{

View File

@ -0,0 +1,106 @@
package machine_learning.nearest_neighbour;
import machine_learning.DataClass;
import machine_learning.MachineLearning;
import machine_learning.Vector;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class CrossValidation
{
private int paramMin;
private int paramMax;
public CrossValidation(int paramMin, int paramMax)
{
this.paramMin = paramMin;
this.paramMax = paramMax;
}
public KNearestNeighbour validate(List<Vector> data, int chunkSize)
{
Collections.shuffle(data);
var counter = new AtomicInteger(0);
var chunks = data.stream()
.collect(Collectors.groupingBy(v -> counter.getAndIncrement() / (data.size() / chunkSize)))
.values();
var averageFailRates = new HashMap<Double, Integer>();
IntStream.range(paramMin, paramMax).forEach(i -> {
var failRate = new AtomicReference<>(0d);
chunks.forEach(chunk -> {
var dataWithoutChunk = data.parallelStream()
.filter(v -> !chunk.contains(v))
.collect(Collectors.toList());
var mapOfClasses = splitIntoClasses(dataWithoutChunk);
var negatives = mapOfClasses.get(DataClass.NEGATIVE);
var positives = mapOfClasses.get(DataClass.POSITIVE);
var kNearestNeighbour = new KNearestNeighbour(Vector::distance, i);
kNearestNeighbour.learn(positives, negatives);
var failCount = 0;
for (var vector : chunk)
{
var expectedClass = DataClass.valueOf(Double.valueOf(vector.get(vector.dimension() - 1)).intValue());
var testVector = vector.decreasedDimension().normalized();
var actualClass = kNearestNeighbour.classify(testVector);
if (expectedClass != actualClass)
{
failCount++;
}
}
failRate.set(failRate.get() + failCount * 1d / chunk.size());
});
averageFailRates.put(failRate.get() / chunkSize, i);
});
var optimalParam = averageFailRates.get(averageFailRates.keySet().stream().min(Double::compareTo).get());
var finalKNearestNeighbour = new KNearestNeighbour(Vector::distance, optimalParam);
System.out.println("Optimaler Parameter k = " + optimalParam + " mit Fehlerrate " + averageFailRates.keySet().stream().min(Double::compareTo).get()*100 + " %");
var classes = splitIntoClasses(data);
var negatives = classes.get(DataClass.NEGATIVE);
var positives = classes.get(DataClass.POSITIVE);
finalKNearestNeighbour.learn(positives, negatives);
return finalKNearestNeighbour;
}
private Map<DataClass, List<Vector>> splitIntoClasses(List<Vector> data)
{
var positives = data.parallelStream()
.filter(v -> v.get(v.dimension()-1) == 1)
.collect(Collectors.toList());
var negatives = data.parallelStream()
.filter(v -> v.get(v.dimension()-1) == 0)
.collect(Collectors.toList());
positives = positives.parallelStream()
.map(Vector::decreasedDimension)
.map(Vector::normalized)
.collect(Collectors.toList());
negatives = negatives.parallelStream()
.map(Vector::decreasedDimension)
.map(Vector::normalized)
.collect(Collectors.toList());
return Map.ofEntries(Map.entry(DataClass.NEGATIVE, negatives), Map.entry(DataClass.POSITIVE, positives));
}
}

View File

@ -1,7 +0,0 @@
package machine_learning.nearest_neighbour;
public enum DataClass
{
POSITIVE,
NEGATIVE
}

View File

@ -1,5 +1,7 @@
package machine_learning.nearest_neighbour;
import machine_learning.DataClass;
import machine_learning.MachineLearning;
import machine_learning.Vector;
import java.util.List;
@ -8,8 +10,11 @@ import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class KNearestNeighbour
public class KNearestNeighbour implements MachineLearning
{
private List<Vector> positives;
private List<Vector> negatives;
private Distance distance;
private int k;
@ -25,20 +30,26 @@ public class KNearestNeighbour
this.k = k;
}
public DataClass kNearestNeighbour(List<Vector> positives, List<Vector> negatives, Vector toClassify)
public void learn(List<Vector> positives, List<Vector> negatives)
{
this.positives = positives;
this.negatives = negatives;
}
public DataClass classify(Vector toClassify)
{
var nearestNeighbours = this.nearestNeighbours(
Stream.concat(positives.stream(), negatives.stream())
Stream.concat(this.positives.stream(), this.negatives.stream())
.collect(Collectors.toList()),
toClassify
);
var positivesWithNearestNeighboursAmount = nearestNeighbours.stream()
.filter(positives::contains)
.filter(this.positives::contains)
.count();
var negativesWithNearestNeighboursAmount = nearestNeighbours.stream()
.filter(negatives::contains)
.filter(this.negatives::contains)
.count();
if (positivesWithNearestNeighboursAmount > negativesWithNearestNeighboursAmount)
@ -55,13 +66,12 @@ public class KNearestNeighbour
private List<Vector> nearestNeighbours(List<Vector> vectors, Vector vector)
{
var nearestNeighbours = vectors.stream()
return vectors.parallelStream()
.map(v -> Map.entry(this.distance.distance(v, vector), v))
.sorted((e1, e2) -> e1.getKey() >= e2.getKey() ? (e1.getKey().equals(e2.getKey()) ? 0 : 1) : -1)
.map(Map.Entry::getValue)
.collect(Collectors.toList());
return nearestNeighbours.subList(0, this.k);
.collect(Collectors.toList())
.subList(0, this.k);
}
}

View File

@ -1,11 +1,12 @@
package machine_learning.perceptron;
import machine_learning.MachineLearning;
import machine_learning.Vector;
import machine_learning.nearest_neighbour.DataClass;
import machine_learning.DataClass;
import java.util.List;
public class Perceptron
public class Perceptron implements MachineLearning
{
private Vector weight;