diff --git a/src/search/EightPuzzleNode.java b/src/search/EightPuzzleNode.java index eb8d235..3581c04 100644 --- a/src/search/EightPuzzleNode.java +++ b/src/search/EightPuzzleNode.java @@ -17,6 +17,11 @@ public class EightPuzzleNode extends Node super(value, parent); } + private EightPuzzleNode(final int[][] value, final Node parent, final int heuristicCosts) + { + super(value, parent, heuristicCosts); + } + @Override public boolean isTargetReached(final Node target) { @@ -44,7 +49,8 @@ public class EightPuzzleNode extends Node case LEFT -> new IntPair(x-1, y); }; - final var successor = this.swapStateField(newState, emptyPosition, posToSwap); + this.swapStateField(newState, emptyPosition, posToSwap); + final var successor = new EightPuzzleNode(newState, this, super.heuristicCosts+1); if (!successor.valueEquals(this) && !successor.valueEquals(super.getParent())) { @@ -121,35 +127,11 @@ public class EightPuzzleNode extends Node return copy; } - private EightPuzzleNode swapStateField(final int[][] newState, final IntPair emptyPos, final IntPair posToSwap) + private void swapStateField(final int[][] newState, final IntPair emptyPos, final IntPair posToSwap) { final var tmp = newState[posToSwap.getY()][posToSwap.getX()]; newState[posToSwap.getY()][posToSwap.getX()] = newState[emptyPos.getY()][emptyPos.getX()]; newState[emptyPos.getY()][emptyPos.getX()] = tmp; - - return new EightPuzzleNode(newState, this); - } - - private class IntPair - { - private final int x; - private final int y; - - public IntPair(final int x, final int y) - { - this.x = x; - this.y = y; - } - - public int getX() - { - return x; - } - public int getY() - { - return y; - } - } private enum Direction diff --git a/src/search/IntPair.java b/src/search/IntPair.java new file mode 100644 index 0000000..1317336 --- /dev/null +++ b/src/search/IntPair.java @@ -0,0 +1,23 @@ +package search; + +public class IntPair +{ + private final int x; + private final int y; + + public IntPair(final int x, final int y) + { + this.x = x; + this.y = y; + } + + public int getX() + { + return x; + } + public int getY() + { + return y; + } + +} diff --git a/src/search/Node.java b/src/search/Node.java index 295fff1..a017c88 100644 --- a/src/search/Node.java +++ b/src/search/Node.java @@ -1,24 +1,30 @@ package search; -import search.heuristic.Heuristic; - import java.util.List; import java.util.Objects; public abstract class Node { protected final T value; + protected int heuristicCosts; private final Node parent; + private int heuristicEstimation; protected Node(final T value) { - this(value, null); + this(value, null, 0); } protected Node(final T value, final Node parent) + { + this(value, parent, 0); + } + + protected Node(final T value, final Node parent, final int heuristicCosts) { this.value = Objects.requireNonNull(value); this.parent = parent; + this.heuristicCosts = heuristicCosts; } public T getValue() @@ -31,6 +37,16 @@ public abstract class Node return this.parent; } + public int getHeuristic() + { + return heuristicCosts + heuristicEstimation; + } + + public void setHeuristicEstimation(final int heuristicEstimation) + { + this.heuristicEstimation = heuristicEstimation; + } + public abstract boolean isTargetReached(final Node target); public abstract List> generateSuccessors(); } diff --git a/src/search/heuristic/AStar.java b/src/search/heuristic/AStar.java index ec2cc4a..c39f9ad 100644 --- a/src/search/heuristic/AStar.java +++ b/src/search/heuristic/AStar.java @@ -7,16 +7,17 @@ import java.util.PriorityQueue; public class AStar { - private final Heuristic heuristicFunction; + private final HeuristicEstimationFunction heuristicFunction; - public AStar(final Heuristic heuristicFunction) + public AStar(final HeuristicEstimationFunction heuristicFunction) { this.heuristicFunction = heuristicFunction; } public Node heuristicSearch(final Node start, final Node target) { - final var nodes = new PriorityQueue>(Comparator.comparingInt(node -> heuristicFunction.heuristic(node, target))); + final var nodes = new PriorityQueue>(Comparator.comparingInt(Node::getHeuristic)); + start.setHeuristicEstimation(this.heuristicFunction.heuristicEstimation(start, target)); nodes.add(start); while (true) @@ -33,7 +34,14 @@ public class AStar return node; } - nodes.addAll(node.generateSuccessors()); + final var successors = node.generateSuccessors(); + + for (final var successor : successors) + { + successor.setHeuristicEstimation(this.heuristicFunction.heuristicEstimation(successor, target)); + } + + nodes.addAll(successors); } } } diff --git a/src/search/heuristic/Heuristic.java b/src/search/heuristic/Heuristic.java deleted file mode 100644 index 8d040e1..0000000 --- a/src/search/heuristic/Heuristic.java +++ /dev/null @@ -1,8 +0,0 @@ -package search.heuristic; - -import search.Node; - -public interface Heuristic -{ - int heuristic(Node node, Node target); -} diff --git a/src/search/heuristic/HeuristicEstimationFunction.java b/src/search/heuristic/HeuristicEstimationFunction.java new file mode 100644 index 0000000..8a1c667 --- /dev/null +++ b/src/search/heuristic/HeuristicEstimationFunction.java @@ -0,0 +1,8 @@ +package search.heuristic; + +import search.Node; + +public interface HeuristicEstimationFunction +{ + int heuristicEstimation(Node node, Node target); +} diff --git a/test/search/heuristic/AStarTest.java b/test/search/heuristic/AStarTest.java index 70e38af..1e96bb9 100644 --- a/test/search/heuristic/AStarTest.java +++ b/test/search/heuristic/AStarTest.java @@ -2,20 +2,20 @@ package search.heuristic; import org.junit.jupiter.api.Test; import search.EightPuzzleNode; +import search.IntPair; import search.Node; -import static org.junit.jupiter.api.Assertions.*; import static search.SearchTestUtils.printSolution; class AStarTest { @Test - void shouldReturnCorrectTargetCubekNodeHeuristik1() + void shouldReturnCorrectTargetCubekNodeHeuristic1() { final int[][] state = { - {3, 5, 0}, - {1, 2, 6}, - {4, 7, 8} + {2, 0, 4}, + {6, 7, 1}, + {8, 5, 3} }; final var root = new EightPuzzleNode(state); @@ -63,7 +63,7 @@ class AStarTest { for (var col = 0; col < value[row].length; col++) { - if (value[row][col] != targetValue[row][col]) + if (value[row][col] != 0 && value[row][col] != targetValue[row][col]) { counter++; } @@ -75,6 +75,40 @@ class AStarTest private int h2(final Node node, final Node target) { - return 0; + final var value = node.getValue(); + final var targetValue = target.getValue(); + var manhattanDistance = 0; + + for (int i = 1; i <= 8; i++) + { + final var targetPos = detectPositionOf(i, targetValue); + final var actualPos = detectPositionOf(i, value); + + if (targetPos != null && actualPos != null) + { + final var xDistance = Math.abs(targetPos.getX() - actualPos.getX()); + final var yDistance = Math.abs(targetPos.getY() - actualPos.getY()); + + manhattanDistance += xDistance + yDistance; + } + } + + return manhattanDistance; + } + + private IntPair detectPositionOf(final int i, final int[][] value) + { + for (var row = 0; row < value.length; row++) + { + for (var col = 0; col < value[row].length; col++) + { + if (value[row][col] == i) + { + return new IntPair(col, row); + } + } + } + + return null; } } \ No newline at end of file