ding.utils.bfs_helper¶
ding.utils.bfs_helper
¶
get_vi_sequence(env, observation)
¶
Overview
Given an instance of the maze environment and the current observation, using Broad-First-Search (BFS) algorithm to plan an optimal path and record the result.
Arguments:
- env (:obj:Env): The instance of the maze environment.
- observation (:obj:np.ndarray): The current observation.
Returns:
- output (:obj:Tuple[np.ndarray, List]): The BFS result. output[0] contains the BFS map after each iteration and output[1] contains the optimal actions before reaching the finishing point.
Full Source Code
../ding/utils/bfs_helper.py
1import numpy as np 2import torch 3from gym import Env 4from typing import Tuple, List 5 6 7def get_vi_sequence(env: Env, observation: np.ndarray) -> Tuple[np.ndarray, List]: 8 """ 9 Overview: 10 Given an instance of the maze environment and the current observation, using Broad-First-Search (BFS) \ 11 algorithm to plan an optimal path and record the result. 12 Arguments: 13 - env (:obj:`Env`): The instance of the maze environment. 14 - observation (:obj:`np.ndarray`): The current observation. 15 Returns: 16 - output (:obj:`Tuple[np.ndarray, List]`): The BFS result. ``output[0]`` contains the BFS map after each \ 17 iteration and ``output[1]`` contains the optimal actions before reaching the finishing point. 18 """ 19 xy = np.where(observation[Ellipsis, -1] == 1) 20 start_x, start_y = xy[0][0], xy[1][0] 21 target_location = env.target_location 22 nav_map = env.nav_map 23 current_points = [target_location] 24 chosen_actions = {target_location: 0} 25 visited_points = {target_location: True} 26 vi_sequence = [] 27 28 vi_map = np.full((env.size, env.size), fill_value=env.n_action, dtype=np.int32) 29 30 found_start = False 31 while current_points and not found_start: 32 next_points = [] 33 for point_x, point_y in current_points: 34 for (action, (next_point_x, next_point_y)) in [(0, (point_x - 1, point_y)), (1, (point_x, point_y - 1)), 35 (2, (point_x + 1, point_y)), (3, (point_x, point_y + 1))]: 36 37 if (next_point_x, next_point_y) in visited_points: 38 continue 39 40 if not (0 <= next_point_x < len(nav_map) and 0 <= next_point_y < len(nav_map[next_point_x])): 41 continue 42 43 if nav_map[next_point_x][next_point_y] == 'x': 44 continue 45 46 next_points.append((next_point_x, next_point_y)) 47 visited_points[(next_point_x, next_point_y)] = True 48 chosen_actions[(next_point_x, next_point_y)] = action 49 vi_map[next_point_x, next_point_y] = action 50 51 if next_point_x == start_x and next_point_y == start_y: 52 found_start = True 53 vi_sequence.append(vi_map.copy()) 54 current_points = next_points 55 track_back = [] 56 if found_start: 57 cur_x, cur_y = start_x, start_y 58 while cur_x != target_location[0] or cur_y != target_location[1]: 59 act = vi_sequence[-1][cur_x, cur_y] 60 track_back.append((torch.FloatTensor(env.process_states([cur_x, cur_y], env.get_maze_map())), act)) 61 if act == 0: 62 cur_x += 1 63 elif act == 1: 64 cur_y += 1 65 elif act == 2: 66 cur_x -= 1 67 elif act == 3: 68 cur_y -= 1 69 70 return np.array(vi_sequence), track_back