Crossvalidation added
This commit is contained in:
12
src/machine_learning/DataClass.java
Normal file
12
src/machine_learning/DataClass.java
Normal file
@ -0,0 +1,12 @@
|
||||
package machine_learning;
|
||||
|
||||
public enum DataClass
|
||||
{
|
||||
NEGATIVE,
|
||||
POSITIVE;
|
||||
|
||||
public static DataClass valueOf(int i)
|
||||
{
|
||||
return i == 0 ? NEGATIVE : POSITIVE;
|
||||
}
|
||||
}
|
9
src/machine_learning/MachineLearning.java
Normal file
9
src/machine_learning/MachineLearning.java
Normal file
@ -0,0 +1,9 @@
|
||||
package machine_learning;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface MachineLearning
|
||||
{
|
||||
void learn(List<Vector> positives, List<Vector> negatives);
|
||||
DataClass classify(Vector toClassify);
|
||||
}
|
@ -95,6 +95,16 @@ public class Vector
|
||||
return this.values.get(index);
|
||||
}
|
||||
|
||||
public Vector decreasedDimension()
|
||||
{
|
||||
return new Vector(this.values.subList(0, this.dimension()-1));
|
||||
}
|
||||
|
||||
public Vector normalized()
|
||||
{
|
||||
return this.divide(this.euclid());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o)
|
||||
{
|
||||
|
106
src/machine_learning/nearest_neighbour/CrossValidation.java
Normal file
106
src/machine_learning/nearest_neighbour/CrossValidation.java
Normal file
@ -0,0 +1,106 @@
|
||||
package machine_learning.nearest_neighbour;
|
||||
|
||||
import machine_learning.DataClass;
|
||||
import machine_learning.MachineLearning;
|
||||
import machine_learning.Vector;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
public class CrossValidation
|
||||
{
|
||||
private int paramMin;
|
||||
private int paramMax;
|
||||
|
||||
public CrossValidation(int paramMin, int paramMax)
|
||||
{
|
||||
this.paramMin = paramMin;
|
||||
this.paramMax = paramMax;
|
||||
}
|
||||
|
||||
public KNearestNeighbour validate(List<Vector> data, int chunkSize)
|
||||
{
|
||||
Collections.shuffle(data);
|
||||
var counter = new AtomicInteger(0);
|
||||
var chunks = data.stream()
|
||||
.collect(Collectors.groupingBy(v -> counter.getAndIncrement() / (data.size() / chunkSize)))
|
||||
.values();
|
||||
|
||||
var averageFailRates = new HashMap<Double, Integer>();
|
||||
|
||||
IntStream.range(paramMin, paramMax).forEach(i -> {
|
||||
|
||||
var failRate = new AtomicReference<>(0d);
|
||||
|
||||
chunks.forEach(chunk -> {
|
||||
var dataWithoutChunk = data.parallelStream()
|
||||
.filter(v -> !chunk.contains(v))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
var mapOfClasses = splitIntoClasses(dataWithoutChunk);
|
||||
var negatives = mapOfClasses.get(DataClass.NEGATIVE);
|
||||
var positives = mapOfClasses.get(DataClass.POSITIVE);
|
||||
|
||||
var kNearestNeighbour = new KNearestNeighbour(Vector::distance, i);
|
||||
kNearestNeighbour.learn(positives, negatives);
|
||||
|
||||
var failCount = 0;
|
||||
|
||||
for (var vector : chunk)
|
||||
{
|
||||
var expectedClass = DataClass.valueOf(Double.valueOf(vector.get(vector.dimension() - 1)).intValue());
|
||||
|
||||
var testVector = vector.decreasedDimension().normalized();
|
||||
var actualClass = kNearestNeighbour.classify(testVector);
|
||||
|
||||
if (expectedClass != actualClass)
|
||||
{
|
||||
failCount++;
|
||||
}
|
||||
}
|
||||
|
||||
failRate.set(failRate.get() + failCount * 1d / chunk.size());
|
||||
});
|
||||
|
||||
averageFailRates.put(failRate.get() / chunkSize, i);
|
||||
});
|
||||
|
||||
var optimalParam = averageFailRates.get(averageFailRates.keySet().stream().min(Double::compareTo).get());
|
||||
var finalKNearestNeighbour = new KNearestNeighbour(Vector::distance, optimalParam);
|
||||
|
||||
System.out.println("Optimaler Parameter k = " + optimalParam + " mit Fehlerrate " + averageFailRates.keySet().stream().min(Double::compareTo).get()*100 + " %");
|
||||
|
||||
var classes = splitIntoClasses(data);
|
||||
var negatives = classes.get(DataClass.NEGATIVE);
|
||||
var positives = classes.get(DataClass.POSITIVE);
|
||||
finalKNearestNeighbour.learn(positives, negatives);
|
||||
|
||||
return finalKNearestNeighbour;
|
||||
}
|
||||
|
||||
private Map<DataClass, List<Vector>> splitIntoClasses(List<Vector> data)
|
||||
{
|
||||
var positives = data.parallelStream()
|
||||
.filter(v -> v.get(v.dimension()-1) == 1)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
var negatives = data.parallelStream()
|
||||
.filter(v -> v.get(v.dimension()-1) == 0)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
positives = positives.parallelStream()
|
||||
.map(Vector::decreasedDimension)
|
||||
.map(Vector::normalized)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
negatives = negatives.parallelStream()
|
||||
.map(Vector::decreasedDimension)
|
||||
.map(Vector::normalized)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
return Map.ofEntries(Map.entry(DataClass.NEGATIVE, negatives), Map.entry(DataClass.POSITIVE, positives));
|
||||
}
|
||||
}
|
@ -1,7 +0,0 @@
|
||||
package machine_learning.nearest_neighbour;
|
||||
|
||||
public enum DataClass
|
||||
{
|
||||
POSITIVE,
|
||||
NEGATIVE
|
||||
}
|
@ -1,5 +1,7 @@
|
||||
package machine_learning.nearest_neighbour;
|
||||
|
||||
import machine_learning.DataClass;
|
||||
import machine_learning.MachineLearning;
|
||||
import machine_learning.Vector;
|
||||
|
||||
import java.util.List;
|
||||
@ -8,8 +10,11 @@ import java.util.Random;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class KNearestNeighbour
|
||||
public class KNearestNeighbour implements MachineLearning
|
||||
{
|
||||
private List<Vector> positives;
|
||||
private List<Vector> negatives;
|
||||
|
||||
private Distance distance;
|
||||
|
||||
private int k;
|
||||
@ -25,20 +30,26 @@ public class KNearestNeighbour
|
||||
this.k = k;
|
||||
}
|
||||
|
||||
public DataClass kNearestNeighbour(List<Vector> positives, List<Vector> negatives, Vector toClassify)
|
||||
public void learn(List<Vector> positives, List<Vector> negatives)
|
||||
{
|
||||
this.positives = positives;
|
||||
this.negatives = negatives;
|
||||
}
|
||||
|
||||
public DataClass classify(Vector toClassify)
|
||||
{
|
||||
var nearestNeighbours = this.nearestNeighbours(
|
||||
Stream.concat(positives.stream(), negatives.stream())
|
||||
Stream.concat(this.positives.stream(), this.negatives.stream())
|
||||
.collect(Collectors.toList()),
|
||||
toClassify
|
||||
);
|
||||
|
||||
var positivesWithNearestNeighboursAmount = nearestNeighbours.stream()
|
||||
.filter(positives::contains)
|
||||
.filter(this.positives::contains)
|
||||
.count();
|
||||
|
||||
var negativesWithNearestNeighboursAmount = nearestNeighbours.stream()
|
||||
.filter(negatives::contains)
|
||||
.filter(this.negatives::contains)
|
||||
.count();
|
||||
|
||||
if (positivesWithNearestNeighboursAmount > negativesWithNearestNeighboursAmount)
|
||||
@ -55,13 +66,12 @@ public class KNearestNeighbour
|
||||
|
||||
private List<Vector> nearestNeighbours(List<Vector> vectors, Vector vector)
|
||||
{
|
||||
var nearestNeighbours = vectors.stream()
|
||||
return vectors.parallelStream()
|
||||
.map(v -> Map.entry(this.distance.distance(v, vector), v))
|
||||
.sorted((e1, e2) -> e1.getKey() >= e2.getKey() ? (e1.getKey().equals(e2.getKey()) ? 0 : 1) : -1)
|
||||
.map(Map.Entry::getValue)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
return nearestNeighbours.subList(0, this.k);
|
||||
.collect(Collectors.toList())
|
||||
.subList(0, this.k);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,11 +1,12 @@
|
||||
package machine_learning.perceptron;
|
||||
|
||||
import machine_learning.MachineLearning;
|
||||
import machine_learning.Vector;
|
||||
import machine_learning.nearest_neighbour.DataClass;
|
||||
import machine_learning.DataClass;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class Perceptron
|
||||
public class Perceptron implements MachineLearning
|
||||
{
|
||||
private Vector weight;
|
||||
|
||||
|
Reference in New Issue
Block a user