Minor changes in names and extracted method, etc.
This commit is contained in:
parent
d0adbf0acb
commit
24c9e0b247
@ -34,7 +34,7 @@ public class Vector
|
|||||||
|
|
||||||
public Vector add(Vector b)
|
public Vector add(Vector b)
|
||||||
{
|
{
|
||||||
if (this.dimension() != b.dimension()) throw new IllegalArgumentException("Dimensions must be equals.");
|
checkEqualDimensions(b);
|
||||||
return new Vector(IntStream.range(0,
|
return new Vector(IntStream.range(0,
|
||||||
this.dimension())
|
this.dimension())
|
||||||
.mapToObj(i -> this.get(i) + b.get(i))
|
.mapToObj(i -> this.get(i) + b.get(i))
|
||||||
@ -45,7 +45,7 @@ public class Vector
|
|||||||
|
|
||||||
public Vector subtract(Vector b)
|
public Vector subtract(Vector b)
|
||||||
{
|
{
|
||||||
if (this.dimension() != b.dimension()) throw new IllegalArgumentException("Dimensions must be equals.");
|
checkEqualDimensions(b);
|
||||||
return new Vector(IntStream.range(0,
|
return new Vector(IntStream.range(0,
|
||||||
this.dimension())
|
this.dimension())
|
||||||
.mapToObj(i -> this.get(i) - b.get(i))
|
.mapToObj(i -> this.get(i) - b.get(i))
|
||||||
@ -55,7 +55,7 @@ public class Vector
|
|||||||
|
|
||||||
public double scalar(Vector b)
|
public double scalar(Vector b)
|
||||||
{
|
{
|
||||||
if (this.dimension() != b.dimension()) throw new IllegalArgumentException("Dimensions must be equals.");
|
checkEqualDimensions(b);
|
||||||
return IntStream.range(0,
|
return IntStream.range(0,
|
||||||
this.dimension())
|
this.dimension())
|
||||||
.mapToDouble(i -> this.get(i) * b.get(i))
|
.mapToDouble(i -> this.get(i) * b.get(i))
|
||||||
@ -80,14 +80,9 @@ public class Vector
|
|||||||
|
|
||||||
public Vector divide(double div)
|
public Vector divide(double div)
|
||||||
{
|
{
|
||||||
var divided = new ArrayList<Double>();
|
return new Vector(IntStream.range(0, this.dimension())
|
||||||
|
.mapToObj(i -> this.values.get(i) / div)
|
||||||
for (int i = 0; i < this.dimension(); i++)
|
.collect(Collectors.toCollection(ArrayList::new)));
|
||||||
{
|
|
||||||
divided.add(this.values.get(i) / div);
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Vector(divided);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public double get(int index)
|
public double get(int index)
|
||||||
@ -133,4 +128,10 @@ public class Vector
|
|||||||
{
|
{
|
||||||
return this.values.toString();
|
return this.values.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void checkEqualDimensions(Vector b)
|
||||||
|
{
|
||||||
|
if (this.dimension() != b.dimension())
|
||||||
|
throw new IllegalArgumentException("Dimensions must be equal");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ package machine_learning.nearest_neighbour;
|
|||||||
|
|
||||||
import machine_learning.Vector;
|
import machine_learning.Vector;
|
||||||
|
|
||||||
public interface Distance
|
public interface DistanceFunction
|
||||||
{
|
{
|
||||||
double distance(Vector a, Vector b);
|
double distance(Vector a, Vector b);
|
||||||
}
|
}
|
@ -15,18 +15,18 @@ public class KNearestNeighbour implements MachineLearning
|
|||||||
private List<Vector> positives;
|
private List<Vector> positives;
|
||||||
private List<Vector> negatives;
|
private List<Vector> negatives;
|
||||||
|
|
||||||
private final Distance distance;
|
private final DistanceFunction distanceFunction;
|
||||||
|
|
||||||
private final int k;
|
private final int k;
|
||||||
|
|
||||||
public KNearestNeighbour(Distance distance)
|
public KNearestNeighbour(DistanceFunction distanceFunction)
|
||||||
{
|
{
|
||||||
this(distance, 1);
|
this(distanceFunction, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
public KNearestNeighbour(Distance distance, int k)
|
public KNearestNeighbour(DistanceFunction distanceFunction, int k)
|
||||||
{
|
{
|
||||||
this.distance = distance;
|
this.distanceFunction = distanceFunction;
|
||||||
this.k = k;
|
this.k = k;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,7 +67,7 @@ public class KNearestNeighbour implements MachineLearning
|
|||||||
private List<Vector> nearestNeighbours(List<Vector> vectors, Vector vector)
|
private List<Vector> nearestNeighbours(List<Vector> vectors, Vector vector)
|
||||||
{
|
{
|
||||||
return vectors.parallelStream()
|
return vectors.parallelStream()
|
||||||
.map(v -> Map.entry(this.distance.distance(v, vector), v))
|
.map(v -> Map.entry(this.distanceFunction.distance(v, vector), v))
|
||||||
.sorted((e1, e2) -> e1.getKey() >= e2.getKey() ? (e1.getKey().equals(e2.getKey()) ? 0 : 1) : -1)
|
.sorted((e1, e2) -> e1.getKey() >= e2.getKey() ? (e1.getKey().equals(e2.getKey()) ? 0 : 1) : -1)
|
||||||
.map(Map.Entry::getValue)
|
.map(Map.Entry::getValue)
|
||||||
.collect(Collectors.toList())
|
.collect(Collectors.toList())
|
||||||
|
@ -22,8 +22,9 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||||
class KNearestNeighbourTest
|
class KNearestNeighbourTest
|
||||||
{
|
{
|
||||||
List<Vector> positives;
|
private List<Vector> positives;
|
||||||
List<Vector> negatives;
|
private List<Vector> negatives;
|
||||||
|
private DistanceFunction distanceFunction;
|
||||||
|
|
||||||
@BeforeAll
|
@BeforeAll
|
||||||
void initLearnData()
|
void initLearnData()
|
||||||
@ -41,12 +42,14 @@ class KNearestNeighbourTest
|
|||||||
new Vector(8d, 2d),
|
new Vector(8d, 2d),
|
||||||
new Vector(9d, 0d))
|
new Vector(9d, 0d))
|
||||||
);
|
);
|
||||||
|
|
||||||
|
this.distanceFunction = (a, b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void shouldReturnCorrectClassForVectorWithKEquals3()
|
public void shouldReturnCorrectClassForVectorWithKEquals3()
|
||||||
{
|
{
|
||||||
var kNearestNeighbour = new KNearestNeighbour((a ,b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1)), 3);
|
var kNearestNeighbour = new KNearestNeighbour(this.distanceFunction, 3);
|
||||||
kNearestNeighbour.learn(this.positives, this.negatives);
|
kNearestNeighbour.learn(this.positives, this.negatives);
|
||||||
var vector = new Vector(8, 3.5);
|
var vector = new Vector(8, 3.5);
|
||||||
|
|
||||||
@ -59,7 +62,7 @@ class KNearestNeighbourTest
|
|||||||
@Test
|
@Test
|
||||||
public void shouldReturnCorrectClassForVectorWithKEquals5()
|
public void shouldReturnCorrectClassForVectorWithKEquals5()
|
||||||
{
|
{
|
||||||
var kNearestNeighbour = new KNearestNeighbour((a ,b) -> Math.abs(a.get(0) - b.get(0)) + Math.abs(a.get(1) - b.get(1)), 5);
|
var kNearestNeighbour = new KNearestNeighbour(this.distanceFunction, 5);
|
||||||
kNearestNeighbour.learn(this.positives, this.negatives);
|
kNearestNeighbour.learn(this.positives, this.negatives);
|
||||||
var vector = new Vector(8, 3.5);
|
var vector = new Vector(8, 3.5);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user