/** * Constructs a deterministic world based on the provided map. * @param map the first index is the x index, the second the y; 1 entries indicate a wall */ public GridWorldDomain(int [][] map){ this.setMap(map); this.setDeterministicTransitionDynamics(); }
/** * Constructs an empty map with deterministic transitions * @param width width of the map * @param height height of the map */ public GridWorldDomain(int width, int height){ this.width = width; this.height = height; this.setDeterministicTransitionDynamics(); this.makeEmptyMap(); }
/** * Initializes the reward function for a grid world of size width and height and initializes the reward values everywhere to initializingReward. * The reward returned from specific agent positions may be changed with the {@link #setReward(int, int, double)} method. * @param width the width of the grid world * @param height the height of the grid world * @param initializingReward the reward to which all agent position transitions are initialized to return. */ public GridWorldRewardFunction(int width, int height, double initializingReward){ this.initialize(width, height, initializingReward); }
public IRLExample(){ this.gwd = new GridWorldDomain(5 ,5); this.gwd.setNumberOfLocationTypes(5); gwd.makeEmptyMap(); this.domain = gwd.generateDomain(); State bs = this.basicState(); this.sg = new LeftSideGen(5, bs); this.v = GridWorldVisualizer.getVisualizer(this.gwd.getMap()); }
@Before public void setup() { this.gw = new GridWorldDomain(11,11); gw.setMapToFourRooms(); gw.setProbSucceedTransitionDynamics(1.0); this.domain = gw.generateDomain(); //generate the grid world domain } public State generateState() {
/** * Will set the map of the world to the classic Four Rooms map used the original options work (Sutton, R.S. and Precup, D. and Singh, S., 1999). */ public void setMapToFourRooms(){ this.width = 11; this.height = 11; this.makeEmptyMap(); horizontalWall(0, 0, 5); horizontalWall(2, 4, 5); horizontalWall(6, 7, 4); horizontalWall(9, 10, 4); verticalWall(0, 0, 5); verticalWall(2, 7, 5); verticalWall(9, 10, 5); }
public void visualize(String outputpath){ Visualizer v = GridWorldVisualizer.getVisualizer(gwdg.getMap()); new EpisodeSequenceVisualizer(v, domain, outputpath); }
/** * Returns state render layer for a gird world domain with the provided wall map. * @param map the wall map matrix where 0s indicate it is clear of walls, 1s indicate a full cell wall in that cell, 2s indicate a 1D north wall, 3s indicate a 1D east wall, and 4s indicate a 1D north and east wall. * @return a grid world domain state render layer */ public static StateRenderLayer getRenderLayer(int [][] map){ StateRenderLayer r = new StateRenderLayer(); r.addStatePainter(new MapPainter(map)); OOStatePainter oopainter = new OOStatePainter(); oopainter.addObjectClassPainter(GridWorldDomain.CLASS_LOCATION, new LocationPainter(map)); oopainter.addObjectClassPainter(GridWorldDomain.CLASS_AGENT, new CellPainter(1, Color.gray, map)); r.addStatePainter(oopainter); return r; }
@Override public State sample(State s, Action a) { s = s.copy(); double [] directionProbs = transitionDynamics[actionInd(a.actionName())]; double roll = rand.nextDouble(); double curSum = 0.; int dir = 0; for(int i = 0; i < directionProbs.length; i++){ curSum += directionProbs[i]; if(roll < curSum){ dir = i; break; } } int [] dcomps = movementDirectionFromIndex(dir); return move(s, dcomps[0], dcomps[1]); }
public void simpleValueFunctionVis(ValueFunction valueFunction, Policy p){ List<State> allStates = StateReachability.getReachableStates(initialState, domain, hashingFactory); ValueFunctionVisualizerGUI gui = GridWorldDomain.getGridWorldValueFunctionVisualization(allStates, 11, 11, valueFunction, p); gui.initGUI(); }
public List<PropositionalFunction> generatePfs(){ List<PropositionalFunction> pfs = Arrays.asList( new AtLocationPF(PF_AT_LOCATION, new String[]{CLASS_AGENT, CLASS_LOCATION}), new WallToPF(PF_WALL_NORTH, new String[]{CLASS_AGENT}, 0), new WallToPF(PF_WALL_SOUTH, new String[]{CLASS_AGENT}, 1), new WallToPF(PF_WALL_EAST, new String[]{CLASS_AGENT}, 2), new WallToPF(PF_WALL_WEST, new String[]{CLASS_AGENT}, 3)); return pfs; }
/** * Initializes with a terminal position at the specified agent x and y locaiton. * @param x the x location of the agent * @param y the y location of the agent */ public GridWorldTerminalFunction(int x, int y){ this.terminalPositions.add(new IntPair(x, y)); }
/** * Initializes the function. * @param name the name of the function * @param parameterClasses the object class parameter types * @param direction the unit distance direction from the agent to check for a wall (0,1,2,3 corresponds to north,south,east,west). */ public WallToPF(String name, String[] parameterClasses, int direction) { super(name, parameterClasses); int [] dcomps = GridWorldDomain.this.movementDirectionFromIndex(direction); xdelta = dcomps[0]; ydelta = dcomps[1]; }
/** * Returns visualizer for a grid world domain with the provided wall map. * @param map the wall map matrix where 0s indicate it is clear of walls, 1s indicate a full cell wall in that cell, 2s indicate a 1D north wall, 3s indicate a 1D east wall, and 4s indicate a 1D north and east wall. * @return a grid world domain visualizer */ public static Visualizer getVisualizer(int [][] map){ StateRenderLayer r = getRenderLayer(map); Visualizer v = new Visualizer(r); return v; }
/** * Returns state render layer for a gird world domain with the provided wall map. This method has been deprecated because the domain object is no * longer necessary. Use the {@link #getRenderLayer(int[][])} method instead. * @param d the domain of the grid world * @param map the wall map matrix where 0s indicate it is clear of walls, 1s indicate a full cell wall in that cell, 2s indicate a 1D north wall, 3s indicate a 1D east wall, and 4s indicate a 1D north and east wall. * @return a grid world domain state render layer */ @Deprecated public static StateRenderLayer getRenderLayer(Domain d, int [][] map){ StateRenderLayer r = new StateRenderLayer(); r.addStatePainter(new MapPainter(map)); OOStatePainter oopainter = new OOStatePainter(); oopainter.addObjectClassPainter(GridWorldDomain.CLASS_LOCATION, new LocationPainter(map)); oopainter.addObjectClassPainter(GridWorldDomain.CLASS_AGENT, new CellPainter(1, Color.gray, map)); r.addStatePainter(oopainter); return r; }
public void simpleValueFunctionVis(ValueFunction valueFunction, Policy p, State initialState, Domain domain, HashableStateFactory hashingFactory, String title){ List<State> allStates = StateReachability.getReachableStates(initialState, (SADomain)domain, hashingFactory); ValueFunctionVisualizerGUI gui = GridWorldDomain.getGridWorldValueFunctionVisualization( allStates, valueFunction, p); gui.setTitle(title); gui.initGUI(); }
@Override public boolean isTerminal(State s) { int x = ((GridWorldState)s).agent.x; int y = ((GridWorldState)s).agent.y; return this.terminalPositions.contains(new IntPair(x, y)); }
/** * Marks a position as a terminal position for the agent. * @param x the x location of the agent. * @param y the y location of the agent. */ public void markAsTerminalPosition(int x, int y){ this.terminalPositions.add(new IntPair(x, y)); }
/** * Unmarks an agent position as a terminal position. * @param x the x location of the agent. * @param y the y location of the agent. */ public void unmarkTerminalPosition(int x, int y){ this.terminalPositions.remove(new IntPair(x, y)); }
/** * Returns true if a position is marked as a terminal position; false otherwise. * @param x the x location of the agent. * @param y the y location of the agent. * @return true if a position is marked as a terminal position; false otherwise. */ public boolean isTerminalPosition(int x, int y){ return this.terminalPositions.contains(new IntPair(x, y)); }