Implemented A* at full state

This commit is contained in:
Niklas Birk 2019-04-04 18:59:25 +02:00
parent 9e51fc8c69
commit 3032aeeca8
7 changed files with 111 additions and 48 deletions

View File

@ -17,6 +17,11 @@ public class EightPuzzleNode extends Node<int[][]>
super(value, parent); super(value, parent);
} }
private EightPuzzleNode(final int[][] value, final Node<int[][]> parent, final int heuristicCosts)
{
super(value, parent, heuristicCosts);
}
@Override @Override
public boolean isTargetReached(final Node<int[][]> target) public boolean isTargetReached(final Node<int[][]> target)
{ {
@ -44,7 +49,8 @@ public class EightPuzzleNode extends Node<int[][]>
case LEFT -> new IntPair(x-1, y); case LEFT -> new IntPair(x-1, y);
}; };
final var successor = this.swapStateField(newState, emptyPosition, posToSwap); this.swapStateField(newState, emptyPosition, posToSwap);
final var successor = new EightPuzzleNode(newState, this, super.heuristicCosts+1);
if (!successor.valueEquals(this) && !successor.valueEquals(super.getParent())) if (!successor.valueEquals(this) && !successor.valueEquals(super.getParent()))
{ {
@ -121,35 +127,11 @@ public class EightPuzzleNode extends Node<int[][]>
return copy; return copy;
} }
private EightPuzzleNode swapStateField(final int[][] newState, final IntPair emptyPos, final IntPair posToSwap) private void swapStateField(final int[][] newState, final IntPair emptyPos, final IntPair posToSwap)
{ {
final var tmp = newState[posToSwap.getY()][posToSwap.getX()]; final var tmp = newState[posToSwap.getY()][posToSwap.getX()];
newState[posToSwap.getY()][posToSwap.getX()] = newState[emptyPos.getY()][emptyPos.getX()]; newState[posToSwap.getY()][posToSwap.getX()] = newState[emptyPos.getY()][emptyPos.getX()];
newState[emptyPos.getY()][emptyPos.getX()] = tmp; newState[emptyPos.getY()][emptyPos.getX()] = tmp;
return new EightPuzzleNode(newState, this);
}
private class IntPair
{
private final int x;
private final int y;
public IntPair(final int x, final int y)
{
this.x = x;
this.y = y;
}
public int getX()
{
return x;
}
public int getY()
{
return y;
}
} }
private enum Direction private enum Direction

23
src/search/IntPair.java Normal file
View File

@ -0,0 +1,23 @@
package search;
public class IntPair
{
private final int x;
private final int y;
public IntPair(final int x, final int y)
{
this.x = x;
this.y = y;
}
public int getX()
{
return x;
}
public int getY()
{
return y;
}
}

View File

@ -1,24 +1,30 @@
package search; package search;
import search.heuristic.Heuristic;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
public abstract class Node<T> public abstract class Node<T>
{ {
protected final T value; protected final T value;
protected int heuristicCosts;
private final Node<T> parent; private final Node<T> parent;
private int heuristicEstimation;
protected Node(final T value) protected Node(final T value)
{ {
this(value, null); this(value, null, 0);
} }
protected Node(final T value, final Node<T> parent) protected Node(final T value, final Node<T> parent)
{
this(value, parent, 0);
}
protected Node(final T value, final Node<T> parent, final int heuristicCosts)
{ {
this.value = Objects.requireNonNull(value); this.value = Objects.requireNonNull(value);
this.parent = parent; this.parent = parent;
this.heuristicCosts = heuristicCosts;
} }
public T getValue() public T getValue()
@ -31,6 +37,16 @@ public abstract class Node<T>
return this.parent; return this.parent;
} }
public int getHeuristic()
{
return heuristicCosts + heuristicEstimation;
}
public void setHeuristicEstimation(final int heuristicEstimation)
{
this.heuristicEstimation = heuristicEstimation;
}
public abstract boolean isTargetReached(final Node<T> target); public abstract boolean isTargetReached(final Node<T> target);
public abstract List<Node<T>> generateSuccessors(); public abstract List<Node<T>> generateSuccessors();
} }

View File

@ -7,16 +7,17 @@ import java.util.PriorityQueue;
public class AStar<T> public class AStar<T>
{ {
private final Heuristic<T> heuristicFunction; private final HeuristicEstimationFunction<T> heuristicFunction;
public AStar(final Heuristic<T> heuristicFunction) public AStar(final HeuristicEstimationFunction<T> heuristicFunction)
{ {
this.heuristicFunction = heuristicFunction; this.heuristicFunction = heuristicFunction;
} }
public Node<T> heuristicSearch(final Node<T> start, final Node<T> target) public Node<T> heuristicSearch(final Node<T> start, final Node<T> target)
{ {
final var nodes = new PriorityQueue<Node<T>>(Comparator.comparingInt(node -> heuristicFunction.heuristic(node, target))); final var nodes = new PriorityQueue<Node<T>>(Comparator.comparingInt(Node::getHeuristic));
start.setHeuristicEstimation(this.heuristicFunction.heuristicEstimation(start, target));
nodes.add(start); nodes.add(start);
while (true) while (true)
@ -33,7 +34,14 @@ public class AStar<T>
return node; return node;
} }
nodes.addAll(node.generateSuccessors()); final var successors = node.generateSuccessors();
for (final var successor : successors)
{
successor.setHeuristicEstimation(this.heuristicFunction.heuristicEstimation(successor, target));
}
nodes.addAll(successors);
} }
} }
} }

View File

@ -1,8 +0,0 @@
package search.heuristic;
import search.Node;
public interface Heuristic<T>
{
int heuristic(Node<T> node, Node<T> target);
}

View File

@ -0,0 +1,8 @@
package search.heuristic;
import search.Node;
public interface HeuristicEstimationFunction<T>
{
int heuristicEstimation(Node<T> node, Node<T> target);
}

View File

@ -2,20 +2,20 @@ package search.heuristic;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import search.EightPuzzleNode; import search.EightPuzzleNode;
import search.IntPair;
import search.Node; import search.Node;
import static org.junit.jupiter.api.Assertions.*;
import static search.SearchTestUtils.printSolution; import static search.SearchTestUtils.printSolution;
class AStarTest class AStarTest
{ {
@Test @Test
void shouldReturnCorrectTargetCubekNodeHeuristik1() void shouldReturnCorrectTargetCubekNodeHeuristic1()
{ {
final int[][] state = { final int[][] state = {
{3, 5, 0}, {2, 0, 4},
{1, 2, 6}, {6, 7, 1},
{4, 7, 8} {8, 5, 3}
}; };
final var root = new EightPuzzleNode(state); final var root = new EightPuzzleNode(state);
@ -63,7 +63,7 @@ class AStarTest
{ {
for (var col = 0; col < value[row].length; col++) for (var col = 0; col < value[row].length; col++)
{ {
if (value[row][col] != targetValue[row][col]) if (value[row][col] != 0 && value[row][col] != targetValue[row][col])
{ {
counter++; counter++;
} }
@ -75,6 +75,40 @@ class AStarTest
private int h2(final Node<int[][]> node, final Node<int[][]> target) private int h2(final Node<int[][]> node, final Node<int[][]> target)
{ {
return 0; final var value = node.getValue();
final var targetValue = target.getValue();
var manhattanDistance = 0;
for (int i = 1; i <= 8; i++)
{
final var targetPos = detectPositionOf(i, targetValue);
final var actualPos = detectPositionOf(i, value);
if (targetPos != null && actualPos != null)
{
final var xDistance = Math.abs(targetPos.getX() - actualPos.getX());
final var yDistance = Math.abs(targetPos.getY() - actualPos.getY());
manhattanDistance += xDistance + yDistance;
}
}
return manhattanDistance;
}
private IntPair detectPositionOf(final int i, final int[][] value)
{
for (var row = 0; row < value.length; row++)
{
for (var col = 0; col < value[row].length; col++)
{
if (value[row][col] == i)
{
return new IntPair(col, row);
}
}
}
return null;
} }
} }