forked from aimacode/aima-java
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from aimacode/AIMA4e
Aima4e
- Loading branch information
Showing
20 changed files
with
1,445 additions
and
662 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
75 changes: 75 additions & 0 deletions
75
core/src/main/java/aima/core/search/basic/SearchUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package aima.core.search.basic; | ||
|
||
import aima.core.search.api.Node; | ||
import aima.core.search.api.NodeFactory; | ||
import aima.core.search.api.Problem; | ||
import aima.core.search.basic.support.BasicNodeFactory; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Collections; | ||
import java.util.List; | ||
|
||
/** | ||
* Some utility functions for the search module | ||
*/ | ||
public class SearchUtils { | ||
|
||
/** | ||
* Calculates the successors of a given node for a given problem. | ||
* | ||
* @param problem | ||
* @param parent | ||
* @param <A> | ||
* @param <S> | ||
* @return | ||
*/ | ||
public static <A, S> List<Node<A, S>> successors(Problem<A, S> problem, Node<A, S> parent) { | ||
S s = parent.state(); | ||
List<Node<A, S>> nodes = new ArrayList<>(); | ||
|
||
NodeFactory<A, S> nodeFactory = new BasicNodeFactory<>(); | ||
for (A action : | ||
problem.actions(s)) { | ||
S sPrime = problem.result(s, action); | ||
double cost = parent.pathCost() + problem.stepCost(s, action, sPrime); | ||
Node<A, S> node = nodeFactory.newChildNode(problem, parent, action); | ||
nodes.add(node); | ||
} | ||
return nodes; | ||
} | ||
|
||
/** | ||
* Calculates the depth of a node in a particular tree. | ||
* | ||
* @param node | ||
* @return | ||
*/ | ||
public static int depth(Node node) { | ||
Node temp = node; | ||
int count = 0; | ||
while (temp != null) { | ||
count++; | ||
temp = temp.parent(); | ||
} | ||
return count; | ||
} | ||
|
||
/** | ||
* Extracts the list of actions from a solution state. | ||
* | ||
* @param solution | ||
* @param <A> | ||
* @param <S> | ||
* @return | ||
*/ | ||
public static <A, S> List<A> generateActions(Node<A, S> solution) { | ||
Node<A, S> parent = solution; | ||
List<A> actions = new ArrayList<>(); | ||
while (parent.parent() != null) { | ||
actions.add(parent.action()); | ||
parent = parent.parent(); | ||
} | ||
Collections.reverse(actions); | ||
return actions; | ||
} | ||
} |
211 changes: 96 additions & 115 deletions
211
core/src/main/java/aima/core/search/basic/informed/AStarSearch.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,137 +1,118 @@ | ||
package aima.core.search.basic.informed; | ||
|
||
import java.util.Comparator; | ||
import java.util.HashSet; | ||
import java.util.List; | ||
import java.util.PriorityQueue; | ||
import java.util.Queue; | ||
import java.util.Set; | ||
import java.util.function.ToDoubleFunction; | ||
|
||
import aima.core.search.api.Node; | ||
import aima.core.search.api.NodeFactory; | ||
import aima.core.search.api.Problem; | ||
import aima.core.search.api.SearchController; | ||
import aima.core.search.api.SearchForActionsFunction; | ||
import aima.core.search.basic.SearchUtils; | ||
import aima.core.search.basic.support.BasicNodeFactory; | ||
import aima.core.search.basic.support.BasicSearchController; | ||
import aima.core.search.basic.uninformedsearch.GenericSearchInterface; | ||
|
||
import java.util.*; | ||
import java.util.function.ToDoubleFunction; | ||
|
||
/** | ||
* <pre> | ||
* function A*-SEARCH(problem) returns a solution, or failure | ||
* node ← a node with STATE = problem.INITIAL-STATE, PATH-COST=0 | ||
* frontier ← a priority queue ordered by PATH-COST + h(NODE), with node as the only element | ||
* explored ← an empty set | ||
* loop do | ||
* if EMPTY?(frontier) then return failure | ||
* node <- POP(frontier) // chooses the lowest-cost node in frontier | ||
* if problem.GOAL-TEST(node.STATE) then return SOLUTION(node) | ||
* add node.STATE to explored | ||
* for each action in problem.ACTIONS(node.STATE) do | ||
* child ← CHILD-NODE(problem, node, action) | ||
* if child.STATE is not in explored or frontier then | ||
* frontier ← INSERT(child, frontier) | ||
* else if child.STATE is in frontier with higher COST then | ||
* replace that frontier node with child | ||
* if problem's initial state is a goal then return empty path to initial state | ||
* frontier ← a priority queue ordered by f(n) = h(n) + g(n), with a node for the initial state | ||
* reached ← a table of {state: the best path that reached state}; initially empty | ||
* solution ← failure | ||
* while frontier is not empty and top(frontier) is cheaper than solution do | ||
* parent ← pop(frontier) | ||
* for child in successors(parent) do | ||
* s ← child.state | ||
* if s is not in reached or child is a cheaper path than reached[s] then | ||
* reached[s] ← child | ||
* add child to the frontier | ||
* if child is a goal and is cheaper than solution then | ||
* solution = child | ||
* return solution | ||
* </pre> | ||
* | ||
* | ||
* @author Ciaran O'Reilly | ||
* @author samagra | ||
*/ | ||
public class AStarSearch<A, S> implements SearchForActionsFunction<A, S> { | ||
// function A*-SEARCH((problem) returns a solution, or failure | ||
@Override | ||
public List<A> apply(Problem<A, S> problem) { | ||
// node <- a node with STATE = problem.INITIAL-STATE, PATH-COST=0 | ||
Node<A, S> node = newRootNode(problem.initialState(), 0); | ||
// frontier <- a priority queue ordered by PATH-COST + h(NODE), with | ||
// node as the | ||
// only element | ||
Queue<Node<A, S>> frontier = newPriorityQueueOrderedByPathCostPlusH(node); | ||
// explored <- an empty set | ||
Set<S> explored = newExploredSet(); | ||
// loop do | ||
while (true) { | ||
// if EMPTY?(frontier) then return failure | ||
if (frontier.isEmpty()) { | ||
return failure(); | ||
} | ||
// node <- POP(frontier) // chooses the lowest-cost node in frontier | ||
node = frontier.remove(); | ||
// if problem.GOAL-TEST(node.STATE) then return SOLUTION(node) | ||
if (isGoalState(node, problem)) { | ||
return solution(node); | ||
} | ||
// add node.STATE to explored | ||
explored.add(node.state()); | ||
// for each action in problem.ACTIONS(node.STATE) do | ||
for (A action : problem.actions(node.state())) { | ||
// child <- CHILD-NODE(problem, node, action) | ||
Node<A, S> child = newChildNode(problem, node, action); | ||
// if child.STATE is not in explored or frontier then | ||
if (!(explored.contains(child.state()) || containsState(frontier, child.state()))) { | ||
// frontier <- INSERT(child, frontier) | ||
frontier.add(child); | ||
} // else if child.STATE is in frontier with higher COST then | ||
else if (removedNodeFromFrontierWithSameStateAndHigherCost(child, frontier)) { | ||
// replace that frontier node with child | ||
frontier.add(child); | ||
} | ||
} | ||
} | ||
} | ||
|
||
// | ||
// Supporting Code | ||
protected ToDoubleFunction<Node<A, S>> h; | ||
protected NodeFactory<A, S> nodeFactory = new BasicNodeFactory<>(); | ||
protected SearchController<A, S> searchController = new BasicSearchController<A, S>(); | ||
|
||
public AStarSearch(ToDoubleFunction<Node<A, S>> h) { | ||
this.h = h; | ||
} | ||
|
||
public ToDoubleFunction<Node<A, S>> getHeuristicFunctionH() { | ||
return h; | ||
} | ||
|
||
public Node<A, S> newRootNode(S initialState, double pathCost) { | ||
return nodeFactory.newRootNode(initialState, pathCost); | ||
} | ||
public class AStarSearch<A, S> implements GenericSearchInterface<A, S>, SearchForActionsFunction<A, S> { | ||
|
||
public Node<A, S> newChildNode(Problem<A, S> problem, Node<A, S> node, A action) { | ||
return nodeFactory.newChildNode(problem, node, action); | ||
} | ||
// The heuristic function | ||
protected ToDoubleFunction<Node<A, S>> h; | ||
// A helper class to generate new nodes. | ||
protected NodeFactory<A, S> nodeFactory = new BasicNodeFactory<>(); | ||
|
||
public Queue<Node<A, S>> newPriorityQueueOrderedByPathCostPlusH(Node<A, S> initialNode) { | ||
Queue<Node<A, S>> frontier = new PriorityQueue<>( | ||
Comparator.comparingDouble(n -> n.pathCost() + h.applyAsDouble(n))); | ||
frontier.add(initialNode); | ||
return frontier; | ||
} | ||
// frontier ← a priority queue ordered by f(n) = h(n)+g(n), with a node for the initial state | ||
PriorityQueue<Node<A, S>> frontier = new PriorityQueue<>(new Comparator<Node<A, S>>() { | ||
@Override | ||
public int compare(Node<A, S> o1, Node<A, S> o2) { | ||
return (int) (getCostValue(o1) - getCostValue(o2)); | ||
} | ||
}); | ||
|
||
public Set<S> newExploredSet() { | ||
return new HashSet<>(); | ||
} | ||
|
||
public List<A> failure() { | ||
return searchController.failure(); | ||
} | ||
/** | ||
* The constructor that takes in the heuristics function. | ||
* | ||
* @param h | ||
*/ | ||
public AStarSearch(ToDoubleFunction<Node<A, S>> h) { | ||
this.h = h; | ||
} | ||
|
||
public List<A> solution(Node<A, S> node) { | ||
return searchController.solution(node); | ||
} | ||
@Override | ||
public Node<A, S> search(Problem<A, S> problem) { | ||
if (problem.isGoalState(problem.initialState())) { | ||
return nodeFactory.newRootNode(problem.initialState()); | ||
} | ||
frontier.clear(); | ||
frontier.add(nodeFactory.newRootNode(problem.initialState())); | ||
// reached ← a table of {state: the best path that reached state}; initially empty | ||
HashMap<S, Node<A, S>> reached = new HashMap<>(); | ||
Node<A, S> solution = null; | ||
// while frontier is not empty and top(frontier) is cheaper than solution do | ||
while (!frontier.isEmpty() && | ||
(solution == null || getCostValue(frontier.peek()) < getCostValue(solution))) { | ||
Node<A, S> parent = frontier.poll(); | ||
for (Node<A, S> child : | ||
SearchUtils.successors(problem, parent)) { | ||
S s = child.state(); | ||
// if s is not in reached or child is a cheaper path than reached[s] then | ||
if (!reached.containsKey(s) || | ||
getCostValue(child) < getCostValue(reached.get(s))) { | ||
reached.put(s, child); | ||
frontier.add(child); | ||
// if child is a goal and is cheaper than solution | ||
if (problem.isGoalState(s) && | ||
(solution == null || getCostValue(child) < getCostValue(solution))) { | ||
solution = child; | ||
} | ||
} | ||
} | ||
} | ||
return solution; | ||
} | ||
|
||
public boolean isGoalState(Node<A, S> node, Problem<A, S> problem) { | ||
return searchController.isGoalState(node, problem); | ||
} | ||
|
||
public boolean containsState(Queue<Node<A, S>> frontier, S state) { | ||
// NOTE: Not very efficient (i.e. linear in the size of the frontier) | ||
return frontier.stream().anyMatch(frontierNode -> frontierNode.state().equals(state)); | ||
} | ||
/** | ||
* Returns the list of actions that need to be taken in order to achieve the goal. | ||
* | ||
* @param problem The search problem | ||
* @return the list of actions | ||
*/ | ||
@Override | ||
public List<A> apply(Problem<A, S> problem) { | ||
Node<A, S> solution = this.search(problem); | ||
if (solution == null) | ||
return new ArrayList<>(); | ||
else | ||
return SearchUtils.generateActions(solution); | ||
} | ||
|
||
public boolean removedNodeFromFrontierWithSameStateAndHigherCost(Node<A, S> child, Queue<Node<A, S>> frontier) { | ||
// NOTE: Not very efficient (i.e. linear in the size of the frontier) | ||
return frontier.removeIf(n -> n.state().equals(child.state()) && n.pathCost() > child.pathCost()); | ||
} | ||
/** | ||
* Finds the value of f(n) = g(n)+h(n) for a node n. | ||
* | ||
* @param node The node n | ||
* @return f(n) | ||
*/ | ||
private double getCostValue(Node<A, S> node) { | ||
return node.pathCost() + h.applyAsDouble(node); | ||
} | ||
} |
Oops, something went wrong.