Minor changes in names and extracted method, etc.
This commit is contained in:
@ -34,7 +34,7 @@ public class Vector
|
||||
|
||||
public Vector add(Vector b)
|
||||
{
|
||||
if (this.dimension() != b.dimension()) throw new IllegalArgumentException("Dimensions must be equals.");
|
||||
checkEqualDimensions(b);
|
||||
return new Vector(IntStream.range(0,
|
||||
this.dimension())
|
||||
.mapToObj(i -> this.get(i) + b.get(i))
|
||||
@ -45,7 +45,7 @@ public class Vector
|
||||
|
||||
public Vector subtract(Vector b)
|
||||
{
|
||||
if (this.dimension() != b.dimension()) throw new IllegalArgumentException("Dimensions must be equals.");
|
||||
checkEqualDimensions(b);
|
||||
return new Vector(IntStream.range(0,
|
||||
this.dimension())
|
||||
.mapToObj(i -> this.get(i) - b.get(i))
|
||||
@ -55,7 +55,7 @@ public class Vector
|
||||
|
||||
public double scalar(Vector b)
|
||||
{
|
||||
if (this.dimension() != b.dimension()) throw new IllegalArgumentException("Dimensions must be equals.");
|
||||
checkEqualDimensions(b);
|
||||
return IntStream.range(0,
|
||||
this.dimension())
|
||||
.mapToDouble(i -> this.get(i) * b.get(i))
|
||||
@ -80,14 +80,9 @@ public class Vector
|
||||
|
||||
public Vector divide(double div)
|
||||
{
|
||||
var divided = new ArrayList<Double>();
|
||||
|
||||
for (int i = 0; i < this.dimension(); i++)
|
||||
{
|
||||
divided.add(this.values.get(i) / div);
|
||||
}
|
||||
|
||||
return new Vector(divided);
|
||||
return new Vector(IntStream.range(0, this.dimension())
|
||||
.mapToObj(i -> this.values.get(i) / div)
|
||||
.collect(Collectors.toCollection(ArrayList::new)));
|
||||
}
|
||||
|
||||
public double get(int index)
|
||||
@ -133,4 +128,10 @@ public class Vector
|
||||
{
|
||||
return this.values.toString();
|
||||
}
|
||||
|
||||
private void checkEqualDimensions(Vector b)
|
||||
{
|
||||
if (this.dimension() != b.dimension())
|
||||
throw new IllegalArgumentException("Dimensions must be equal");
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ package machine_learning.nearest_neighbour;
|
||||
|
||||
import machine_learning.Vector;
|
||||
|
||||
public interface Distance
|
||||
public interface DistanceFunction
|
||||
{
|
||||
double distance(Vector a, Vector b);
|
||||
}
|
@ -15,18 +15,18 @@ public class KNearestNeighbour implements MachineLearning
|
||||
private List<Vector> positives;
|
||||
private List<Vector> negatives;
|
||||
|
||||
private final Distance distance;
|
||||
private final DistanceFunction distanceFunction;
|
||||
|
||||
private final int k;
|
||||
|
||||
public KNearestNeighbour(Distance distance)
|
||||
public KNearestNeighbour(DistanceFunction distanceFunction)
|
||||
{
|
||||
this(distance, 1);
|
||||
this(distanceFunction, 1);
|
||||
}
|
||||
|
||||
public KNearestNeighbour(Distance distance, int k)
|
||||
public KNearestNeighbour(DistanceFunction distanceFunction, int k)
|
||||
{
|
||||
this.distance = distance;
|
||||
this.distanceFunction = distanceFunction;
|
||||
this.k = k;
|
||||
}
|
||||
|
||||
@ -67,7 +67,7 @@ public class KNearestNeighbour implements MachineLearning
|
||||
private List<Vector> nearestNeighbours(List<Vector> vectors, Vector vector)
|
||||
{
|
||||
return vectors.parallelStream()
|
||||
.map(v -> Map.entry(this.distance.distance(v, vector), v))
|
||||
.map(v -> Map.entry(this.distanceFunction.distance(v, vector), v))
|
||||
.sorted((e1, e2) -> e1.getKey() >= e2.getKey() ? (e1.getKey().equals(e2.getKey()) ? 0 : 1) : -1)
|
||||
.map(Map.Entry::getValue)
|
||||
.collect(Collectors.toList())
|
||||
|
Reference in New Issue
Block a user