Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 136 additions & 217 deletions lzero/mcts/buffer/game_buffer.py

Large diffs are not rendered by default.

189 changes: 105 additions & 84 deletions lzero/mcts/buffer/game_buffer_unizero.py

Large diffs are not rendered by default.

54 changes: 29 additions & 25 deletions lzero/mcts/buffer/game_segment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import List, Tuple, Optional
from typing import List, Tuple

import numpy as np
from easydict import EasyDict
Expand Down Expand Up @@ -31,48 +31,33 @@ class GameSegment:
- store_search_stats
"""

def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None, task_id: Optional[int] = None) -> None:
def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None) -> None:
"""
Overview:
Init the ``GameSegment`` according to the provided arguments.
Arguments:
- action_space (:obj:`int`): action space
action_space (:obj:`int`): action space
- game_segment_length (:obj:`int`): the transition number of one ``GameSegment`` block
- task_id (:obj:`Optional[int]`): The identifier for the task, used to select the correct obs and act space in multi-task settings. Defaults to None.

"""
self.action_space = action_space
self.game_segment_length = game_segment_length
self.num_unroll_steps = config.num_unroll_steps
self.td_steps = config.td_steps
self.frame_stack_num = config.model.frame_stack_num
self.discount_factor = config.discount_factor
if not hasattr(config.model, "action_space_size_list"):
# for single-task setting or fixed action space in multi-task setting
self.action_space_size = config.model.action_space_size
self.action_space_size = config.model.action_space_size
self.gray_scale = config.gray_scale
self.transform2string = config.transform2string
self.sampled_algo = config.sampled_algo
self.gumbel_algo = config.gumbel_algo
self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder

if task_id is None:
if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1:
# for vector obs input, e.g. classical control and box2d environments
self.zero_obs_shape = config.model.observation_shape
elif len(config.model.observation_shape) == 3:
# image obs input, e.g. atari environments
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1])
else:
if hasattr(config.model, "observation_shape_list"):
if isinstance(config.model.observation_shape_list[task_id], int) or len(config.model.observation_shape_list[task_id]) == 1:
# for vector obs input, e.g. classical control and box2d environments
self.zero_obs_shape = config.model.observation_shape_list[task_id]
elif len(config.model.observation_shape_list[task_id]) == 3:
# image obs input, e.g. atari environments
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1])
else:
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1])
if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1:
# for vector obs input, e.g. classical control and box2d environments
self.zero_obs_shape = config.model.observation_shape
elif len(config.model.observation_shape) == 3:
# image obs input, e.g. atari environments
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1])

self.obs_segment = []
self.action_segment = []
Expand All @@ -96,6 +81,12 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = []

# PPO related fields
self.episode_id = None # 标记该 segment 属于哪个 episode
self.advantage_segment = [] # 用于存储 GAE advantages
self.old_log_prob_segment = [] # 用于存储收集时的 log_prob (PPO 需要)
self.return_segment = [] # 用于存储 return (PPO value training 需要)

self.reanalyze_time = 0

def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray:
Expand Down Expand Up @@ -320,6 +311,14 @@ def game_segment_to_array(self) -> None:
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = np.array(self.chance_segment)

# Convert PPO related fields to numpy array
if len(self.advantage_segment) > 0:
self.advantage_segment = np.array(self.advantage_segment)
if len(self.old_log_prob_segment) > 0:
self.old_log_prob_segment = np.array(self.old_log_prob_segment)
if len(self.return_segment) > 0:
self.return_segment = np.array(self.return_segment)

def reset(self, init_observations: np.ndarray) -> None:
"""
Overview:
Expand All @@ -342,6 +341,11 @@ def reset(self, init_observations: np.ndarray) -> None:
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = []

# Reset PPO related fields
self.advantage_segment = []
self.old_log_prob_segment = []
self.return_segment = []

assert len(init_observations) == self.frame_stack_num

for observation in init_observations:
Expand Down
Loading