Implemented A* at full state

This commit is contained in:
Niklas Birk 2019-04-04 18:59:25 +02:00
parent 9e51fc8c69
commit 3032aeeca8
7 changed files with 111 additions and 48 deletions

View File

@ -17,6 +17,11 @@ public class EightPuzzleNode extends Node<int[][]>
super(value, parent);
}
private EightPuzzleNode(final int[][] value, final Node<int[][]> parent, final int heuristicCosts)
{
super(value, parent, heuristicCosts);
}
@Override
public boolean isTargetReached(final Node<int[][]> target)
{
@ -44,7 +49,8 @@ public class EightPuzzleNode extends Node<int[][]>
case LEFT -> new IntPair(x-1, y);
};
final var successor = this.swapStateField(newState, emptyPosition, posToSwap);
this.swapStateField(newState, emptyPosition, posToSwap);
final var successor = new EightPuzzleNode(newState, this, super.heuristicCosts+1);
if (!successor.valueEquals(this) && !successor.valueEquals(super.getParent()))
{
@ -121,35 +127,11 @@ public class EightPuzzleNode extends Node<int[][]>
return copy;
}
private EightPuzzleNode swapStateField(final int[][] newState, final IntPair emptyPos, final IntPair posToSwap)
private void swapStateField(final int[][] newState, final IntPair emptyPos, final IntPair posToSwap)
{
final var tmp = newState[posToSwap.getY()][posToSwap.getX()];
newState[posToSwap.getY()][posToSwap.getX()] = newState[emptyPos.getY()][emptyPos.getX()];
newState[emptyPos.getY()][emptyPos.getX()] = tmp;
return new EightPuzzleNode(newState, this);
}
private class IntPair
{
private final int x;
private final int y;
public IntPair(final int x, final int y)
{
this.x = x;
this.y = y;
}
public int getX()
{
return x;
}
public int getY()
{
return y;
}
}
private enum Direction

23
src/search/IntPair.java Normal file
View File

@ -0,0 +1,23 @@
package search;
public class IntPair
{
private final int x;
private final int y;
public IntPair(final int x, final int y)
{
this.x = x;
this.y = y;
}
public int getX()
{
return x;
}
public int getY()
{
return y;
}
}

View File

@ -1,24 +1,30 @@
package search;
import search.heuristic.Heuristic;
import java.util.List;
import java.util.Objects;
public abstract class Node<T>
{
protected final T value;
protected int heuristicCosts;
private final Node<T> parent;
private int heuristicEstimation;
protected Node(final T value)
{
this(value, null);
this(value, null, 0);
}
protected Node(final T value, final Node<T> parent)
{
this(value, parent, 0);
}
protected Node(final T value, final Node<T> parent, final int heuristicCosts)
{
this.value = Objects.requireNonNull(value);
this.parent = parent;
this.heuristicCosts = heuristicCosts;
}
public T getValue()
@ -31,6 +37,16 @@ public abstract class Node<T>
return this.parent;
}
public int getHeuristic()
{
return heuristicCosts + heuristicEstimation;
}
public void setHeuristicEstimation(final int heuristicEstimation)
{
this.heuristicEstimation = heuristicEstimation;
}
public abstract boolean isTargetReached(final Node<T> target);
public abstract List<Node<T>> generateSuccessors();
}

View File

@ -7,16 +7,17 @@ import java.util.PriorityQueue;
public class AStar<T>
{
private final Heuristic<T> heuristicFunction;
private final HeuristicEstimationFunction<T> heuristicFunction;
public AStar(final Heuristic<T> heuristicFunction)
public AStar(final HeuristicEstimationFunction<T> heuristicFunction)
{
this.heuristicFunction = heuristicFunction;
}
public Node<T> heuristicSearch(final Node<T> start, final Node<T> target)
{
final var nodes = new PriorityQueue<Node<T>>(Comparator.comparingInt(node -> heuristicFunction.heuristic(node, target)));
final var nodes = new PriorityQueue<Node<T>>(Comparator.comparingInt(Node::getHeuristic));
start.setHeuristicEstimation(this.heuristicFunction.heuristicEstimation(start, target));
nodes.add(start);
while (true)
@ -33,7 +34,14 @@ public class AStar<T>
return node;
}
nodes.addAll(node.generateSuccessors());
final var successors = node.generateSuccessors();
for (final var successor : successors)
{
successor.setHeuristicEstimation(this.heuristicFunction.heuristicEstimation(successor, target));
}
nodes.addAll(successors);
}
}
}

View File

@ -1,8 +0,0 @@
package search.heuristic;
import search.Node;
public interface Heuristic<T>
{
int heuristic(Node<T> node, Node<T> target);
}

View File

@ -0,0 +1,8 @@
package search.heuristic;
import search.Node;
public interface HeuristicEstimationFunction<T>
{
int heuristicEstimation(Node<T> node, Node<T> target);
}

View File

@ -2,20 +2,20 @@ package search.heuristic;
import org.junit.jupiter.api.Test;
import search.EightPuzzleNode;
import search.IntPair;
import search.Node;
import static org.junit.jupiter.api.Assertions.*;
import static search.SearchTestUtils.printSolution;
class AStarTest
{
@Test
void shouldReturnCorrectTargetCubekNodeHeuristik1()
void shouldReturnCorrectTargetCubekNodeHeuristic1()
{
final int[][] state = {
{3, 5, 0},
{1, 2, 6},
{4, 7, 8}
{2, 0, 4},
{6, 7, 1},
{8, 5, 3}
};
final var root = new EightPuzzleNode(state);
@ -63,7 +63,7 @@ class AStarTest
{
for (var col = 0; col < value[row].length; col++)
{
if (value[row][col] != targetValue[row][col])
if (value[row][col] != 0 && value[row][col] != targetValue[row][col])
{
counter++;
}
@ -75,6 +75,40 @@ class AStarTest
private int h2(final Node<int[][]> node, final Node<int[][]> target)
{
return 0;
final var value = node.getValue();
final var targetValue = target.getValue();
var manhattanDistance = 0;
for (int i = 1; i <= 8; i++)
{
final var targetPos = detectPositionOf(i, targetValue);
final var actualPos = detectPositionOf(i, value);
if (targetPos != null && actualPos != null)
{
final var xDistance = Math.abs(targetPos.getX() - actualPos.getX());
final var yDistance = Math.abs(targetPos.getY() - actualPos.getY());
manhattanDistance += xDistance + yDistance;
}
}
return manhattanDistance;
}
private IntPair detectPositionOf(final int i, final int[][] value)
{
for (var row = 0; row < value.length; row++)
{
for (var col = 0; col < value[row].length; col++)
{
if (value[row][col] == i)
{
return new IntPair(col, row);
}
}
}
return null;
}
}