diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 7cd4308b2..6a4458a03 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -102,23 +102,22 @@ def _make_batch(self, orig_data: Any, reanalyze_ratio: float) -> Tuple[Any]: """ pass - def _sample_orig_data(self, batch_size: int, print_priority_logs: bool = False) -> Tuple: + def _sample_orig_data(self, batch_size: int) -> Tuple: """ Overview: - Sample original data which includes: - - game_segment_list: A list of game segments. - - pos_in_game_segment_list: Transition index in the game (relative index). - - batch_index_list: The index of the start transition of the sampled mini-batch in the replay buffer. - - weights_list: The weight concerning the priority. - - make_time: The time the batch is made (for correctly updating the replay buffer when data is deleted). + sample orig_data that contains: + game_segment_list: a list of game segments + pos_in_game_segment_list: transition index in game (relative index) + batch_index_list: the index of start transition of sampled minibatch in replay buffer + weights_list: the weight concerning the priority + make_time: the time the batch is made (for correctly updating replay buffer when data is deleted) Arguments: - - batch_size (:obj:`int`): The size of the batch. - - print_priority_logs (:obj:`bool`): Whether to print logs related to priority statistics, defaults to False. + - batch_size (:obj:`int`): batch size + - beta: float the parameter in PER for calculating the priority """ - assert self._beta > 0, "Beta should be greater than 0" + assert self._beta > 0 num_of_transitions = self.get_num_of_transitions() - if not self._cfg.use_priority: - # If priority is not used, set all priorities to 1 + if self._cfg.use_priority is False: self.game_pos_priorities = np.ones_like(self.game_pos_priorities) # +1e-6 for numerical stability @@ -127,21 +126,20 @@ def _sample_orig_data(self, batch_size: int, print_priority_logs: bool = False) # sample according to transition index batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False) - - if self._cfg.reanalyze_outdated: - # Sort the batch indices if reanalyze is enabled + + if self._cfg.reanalyze_outdated is True: + # NOTE: used in reanalyze part batch_index_list.sort() - - # Calculate weights for the sampled transitions + weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self._beta) - weights_list /= weights_list.max() # Normalize weights + weights_list /= weights_list.max() game_segment_list = [] pos_in_game_segment_list = [] for idx in batch_index_list: game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx] - game_segment_idx -= self.base_idx # Adjust index based on base index + game_segment_idx -= self.base_idx game_segment = self.game_segment_buffer[game_segment_idx] game_segment_list.append(game_segment) @@ -155,50 +153,22 @@ def _sample_orig_data(self, batch_size: int, print_priority_logs: bool = False) # [0, game_segment_length - num_unroll_steps] to avoid padded data. if self._cfg.action_type == 'varied_action_space': - # For varied action space environments (e.g., board games with short game length like TicTacToe) - # We need to handle cases where game_segment_length might be smaller than num_unroll_steps + td_steps - # Strategy: progressively relax sampling constraints to accommodate short games - - # Step 1: Calculate ideal sampling upper bound - # Ideally, reserve space for both num_unroll_steps and td_steps to ensure complete trajectories - ideal_bound = self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps - - # Step 2: Handle different game length scenarios with graceful degradation - if ideal_bound > 0: - # Case A: Normal/long games - enough space for full unroll + td steps - # This is the standard case for most Atari games - sampling_upper_bound = ideal_bound - else: - # Case B: Short games - need to relax constraints - # Try to at least reserve space for unroll steps (most critical for training) - fallback_bound = self._cfg.game_segment_length - self._cfg.num_unroll_steps - - if fallback_bound > 0: - # Can still accommodate unroll steps, though td_steps might need padding - sampling_upper_bound = fallback_bound - else: - # Case C: Very short games (e.g., TicTacToe with 5-9 moves) - # Allow sampling from entire segment length, padding will be applied during unrolling - # This allows sampling from position 0 (beginning of game) when necessary - sampling_upper_bound = self._cfg.game_segment_length - - # Ensure at least 1 to avoid np.random.choice errors - if sampling_upper_bound <= 0: - sampling_upper_bound = 1 - - # Step 3: Resample position if it exceeds calculated bound - if pos_in_game_segment >= sampling_upper_bound: - pos_in_game_segment = np.random.choice(sampling_upper_bound, 1).item() - - # Step 4: Further adjust based on actual segment length (runtime check) + # For some environments (e.g., Jericho), the action space size may be different. + # To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length), + # we avoid sampling from the last `num_unroll_steps` steps of the game segment. + if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps: + pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps, 1).item() + segment_len = len(game_segment.action_segment) if pos_in_game_segment >= segment_len - 1: - # Position exceeds actual segment, resample within valid range + # If the segment is very short (length 0 or 1), we can't randomly sample a position + # before the last one. The only safe position is 0. if segment_len > 1: - # Sample from [0, segment_len-1] to allow at least 1 step forward + # If the segment has at least 2 actions, we can safely sample from [0, len-2]. + # The upper bound for np.random.choice is exclusive, so (segment_len - 1) is correct. pos_in_game_segment = np.random.choice(segment_len - 1, 1).item() else: - # Segment has 0 or 1 actions, can only use position 0 + # If segment length is 0 or 1, the only valid/safe position is 0. pos_in_game_segment = 0 else: @@ -206,14 +176,8 @@ def _sample_orig_data(self, batch_size: int, print_priority_logs: bool = False) # we can safely sample from the entire game segment range. if pos_in_game_segment >= self._cfg.game_segment_length: pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item() - - # Compatibility handling for both GameSegment objects and list data (for unittests) - try: - segment_len = len(game_segment.action_segment) - except (AttributeError, TypeError): - # For unittest compatibility: when game_segment is a list instead of GameSegment object - segment_len = len(game_segment) - + + segment_len = len(game_segment.action_segment) if pos_in_game_segment >= segment_len - 1: # If the segment is very short (length 0 or 1), we can't randomly sample a position # before the last one. The only safe position is 0. @@ -228,152 +192,115 @@ def _sample_orig_data(self, batch_size: int, print_priority_logs: bool = False) pos_in_game_segment_list.append(pos_in_game_segment) - # make_time = [time.time() for _ in range(len(batch_index_list))] + make_time = [time.time() for _ in range(len(batch_index_list))] + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) + return orig_data + + def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple: + """ + Overview: + This function samples a batch of game segments for reanalysis from the replay buffer. + It uses priority sampling based on the `reanalyze_time` of each game segment, with segments + that have been reanalyzed more frequently receiving lower priority. + + The function returns a tuple containing information about the sampled game segments, + including their positions within each segment and the time the batch was created. + Arguments: + - batch_size (:obj:`int`): + The number of samples to draw in this batch. + + Returns: + - Tuple: + A tuple containing the following elements: + - game_segment_list: A list of the sampled game segments. + - pos_in_game_segment_list: A list of indices representing the position of each transition + within its corresponding game segment. + - batch_index_list: The indices of the sampled game segments in the replay buffer. + - make_time: A list of timestamps (set to `0` in this implementation) indicating when + the batch was created. + + Key Details: + 1. **Priority Sampling**: + Game segments are sampled based on a probability distribution calculated using + the `reanalyze_time` of each segment. Segments that have been reanalyzed more frequently + are less likely to be selected. + 2. **Segment Slicing**: + Each selected game segment is sampled at regular intervals determined by the + `num_unroll_steps` parameter. Up to `samples_per_segment` transitions are sampled + from each selected segment. + 3. **Handling Extra Samples**: + If the `batch_size` is not perfectly divisible by the number of samples per segment, + additional segments are sampled to make up the difference. + 4. **Reanalyze Time Update**: + The `reanalyze_time` attribute of each sampled game segment is incremented to reflect + that it has been selected for reanalysis again. + Raises: + - ValueError: + If the `game_segment_length` is too small to accommodate the `num_unroll_steps`. + """ + train_sample_num = len(self.game_segment_buffer) + assert self._cfg.reanalyze_partition <= 0.75, "The reanalyze partition should be less than 0.75." + valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition) + + # Calculate the number of samples per segment + samples_per_segment = self._cfg.game_segment_length // self._cfg.num_unroll_steps + + # Make sure that the batch size can be divided by the number of samples per segment + if samples_per_segment == 0: + raise ValueError("The game segment length is too small for num_unroll_steps.") + + # Calculate the number of samples per segment + batch_size_per_segment = batch_size // samples_per_segment + + # If the batch size cannot be divided, process the remainder part + extra_samples = batch_size % samples_per_segment + + # We use the reanalyze_time in the game_segment_buffer to generate weights + reanalyze_times = np.array([segment.reanalyze_time for segment in self.game_segment_buffer[:valid_sample_num]]) + + # Calculate weights: the larger the reanalyze_time, the smaller the weight (use exp(-reanalyze_time)) + base_decay_rate = 100 + decay_rate = base_decay_rate / valid_sample_num + weights = np.exp(-decay_rate * reanalyze_times) + + # Normalize the weights to a probability distribution + probabilities = weights / np.sum(weights) + + # Sample game segments according to the probabilities + selected_game_segments = np.random.choice(valid_sample_num, batch_size_per_segment, replace=False, + p=probabilities) + + # If there are extra samples to be allocated, randomly select some game segments and sample again + if extra_samples > 0: + extra_game_segments = np.random.choice(valid_sample_num, extra_samples, replace=False, p=probabilities) + selected_game_segments = np.concatenate((selected_game_segments, extra_game_segments)) + + game_segment_list = [] + pos_in_game_segment_list = [] + batch_index_list = [] + + for game_segment_idx in selected_game_segments: + game_segment_idx -= self.base_idx + game_segment = self.game_segment_buffer[game_segment_idx] + + # Update reanalyze_time only once + game_segment.reanalyze_time += 1 + + # The sampling position should be 0, 0 + num_unroll_steps, ... (integer multiples of num_unroll_steps) + for i in range(samples_per_segment): + game_segment_list.append(game_segment) + pos_in_game_segment = i * self._cfg.num_unroll_steps + if pos_in_game_segment >= len(game_segment): + pos_in_game_segment = np.random.choice(len(game_segment), 1).item() + pos_in_game_segment_list.append(pos_in_game_segment) + batch_index_list.append(game_segment_idx) # Set the make_time for each sample (set to 0 for now, but can be the actual time if needed). make_time = [0. for _ in range(len(batch_index_list))] - orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) - - if print_priority_logs: - print(f"Sampled batch indices: {batch_index_list}") - print(f"Sampled priorities: {self.game_pos_priorities[batch_index_list]}") - print(f"Sampled weights: {weights_list}") - + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time) return orig_data - def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple: - """ - Overview: - This function samples a batch of game segments for reanalysis from the replay buffer. - It uses priority sampling based on the `reanalyze_time` of each game segment, with segments - that have been reanalyzed more frequently receiving lower priority. - - The function returns a tuple containing information about the sampled game segments, - including their positions within each segment and the time the batch was created. - Arguments: - - batch_size (:obj:`int`): - The number of samples to draw in this batch. - - Returns: - - Tuple: - A tuple containing the following elements: - - game_segment_list: A list of the sampled game segments. - - pos_in_game_segment_list: A list of indices representing the position of each transition - within its corresponding game segment. - - batch_index_list: The indices of the sampled game segments in the replay buffer. - - make_time: A list of timestamps (set to `0` in this implementation) indicating when - the batch was created. - - Key Details: - 1. **Priority Sampling**: - Game segments are sampled based on a probability distribution calculated using - the `reanalyze_time` of each segment. Segments that have been reanalyzed more frequently - are less likely to be selected. - 2. **Segment Slicing**: - Each selected game segment is sampled at regular intervals determined by the - `num_unroll_steps` parameter. Up to `samples_per_segment` transitions are sampled - from each selected segment. - 3. **Handling Extra Samples**: - If the `batch_size` is not perfectly divisible by the number of samples per segment, - additional segments are sampled to make up the difference. - 4. **Reanalyze Time Update**: - The `reanalyze_time` attribute of each sampled game segment is incremented to reflect - that it has been selected for reanalysis again. - Raises: - - ValueError: - If the `game_segment_length` is too small to accommodate the `num_unroll_steps`. - """ - train_sample_num = len(self.game_segment_buffer) - assert self._cfg.reanalyze_partition <= 0.75, "The reanalyze partition should be less than 0.75." - valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition) - - # Calculate the number of samples per segment - samples_per_segment = self._cfg.game_segment_length // self._cfg.num_unroll_steps - - # Make sure that the batch size can be divided by the number of samples per segment - if samples_per_segment == 0: - raise ValueError("The game segment length is too small for num_unroll_steps.") - - # Calculate the number of samples per segment - batch_size_per_segment = batch_size // samples_per_segment - - # If the batch size cannot be divided, process the remainder part - extra_samples = batch_size % samples_per_segment - - # We use the reanalyze_time in the game_segment_buffer to generate weights - reanalyze_times = np.array([segment.reanalyze_time for segment in self.game_segment_buffer[:valid_sample_num]]) - - # Calculate weights: the larger the reanalyze_time, the smaller the weight (use exp(-reanalyze_time)) - base_decay_rate = 100 - # Add a small epsilon to avoid division by zero if valid_sample_num is 0 - decay_rate = base_decay_rate / (valid_sample_num + 1e-6) - weights = np.exp(-decay_rate * reanalyze_times) - - # Normalize the weights to a probability distribution, handle case where sum is zero - sum_weights = np.sum(weights) - if sum_weights > 0: - probabilities = weights / sum_weights - else: - # If all weights are zero, use a uniform distribution - probabilities = np.ones(valid_sample_num) / valid_sample_num - - # Sample game segments according to the probabilities - # Ensure valid_sample_num is not zero before sampling - if valid_sample_num == 0: - return ([], [], [], [], []) - - selected_game_segments = np.random.choice(valid_sample_num, batch_size_per_segment, replace=False, - p=probabilities) - - # If there are extra samples to be allocated, randomly select some game segments and sample again - if extra_samples > 0: - # We need to handle the case where we might sample the same segment again. - # A simple way is to allow replacement for extra samples or sample from remaining ones. - # For simplicity, let's stick to the original logic but ensure it's safe. - remaining_segments = np.setdiff1d(np.arange(valid_sample_num), selected_game_segments) - if len(remaining_segments) < extra_samples: - # If not enough unique segments left, sample with replacement from all valid segments - extra_game_segments = np.random.choice(valid_sample_num, extra_samples, replace=True, p=probabilities) - else: - # Sample from the remaining unique segments - remaining_probs = probabilities[remaining_segments] - remaining_probs /= np.sum(remaining_probs) - extra_game_segments = np.random.choice(remaining_segments, extra_samples, replace=False, p=remaining_probs) - - selected_game_segments = np.concatenate((selected_game_segments, extra_game_segments)) - - game_segment_list = [] - pos_in_game_segment_list = [] - batch_index_list = [] - print(f"selected_game_segments:{selected_game_segments}") - for game_segment_idx in selected_game_segments: - # ========================================================================= - # FIX: The line below is the source of the error and has been removed. - # `game_segment_idx` is already a valid physical index for `game_segment_buffer`. - # game_segment_idx -= self.base_idx - # ========================================================================= - game_segment = self.game_segment_buffer[game_segment_idx] - - # Update reanalyze_time only once - game_segment.reanalyze_time += 1 - - # The sampling position should be 0, 0 + num_unroll_steps, ... (integer multiples of num_unroll_steps) - for i in range(samples_per_segment): - game_segment_list.append(game_segment) - pos_in_game_segment = i * self._cfg.num_unroll_steps - if pos_in_game_segment >= len(game_segment): - pos_in_game_segment = np.random.choice(len(game_segment), 1).item() - pos_in_game_segment_list.append(pos_in_game_segment) - # NOTE: We should append the physical index here, as it corresponds to the sampled segment. - batch_index_list.append(game_segment_idx) - - # Set the make_time for each sample (set to 0 for now, but can be the actual time if needed). - make_time = [0. for _ in range(len(batch_index_list))] - - orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time) - return orig_data - def _sample_orig_reanalyze_data(self, batch_size: int) -> Tuple: """ Overview: @@ -690,8 +617,7 @@ def remove_oldest_data_to_fit(self) -> None: Overview: remove some oldest data if the replay buffer is full. """ - if isinstance(self._cfg.batch_size, int): - assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" + assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" nums_of_game_segments = self.get_num_of_game_segments() total_transition = self.get_num_of_transitions() if total_transition > self.replay_buffer_size: @@ -703,15 +629,8 @@ def remove_oldest_data_to_fit(self) -> None: # find the max game_segment index to keep in the buffer index = i break - if isinstance(self._cfg.batch_size, int): - if total_transition >= self._cfg.batch_size: - self._remove(index + 1) - else: - try: - if total_transition >= self._cfg.batch_size[0]: - self._remove(index + 1) - except Exception as e: - print(e) + if total_transition >= self._cfg.batch_size: + self._remove(index + 1) def _remove(self, excess_game_segment_index: List[int]) -> None: """ diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index f4652e1cf..3ac7bb715 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -11,7 +11,6 @@ if TYPE_CHECKING: from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy -from line_profiler import line_profiler @BUFFER_REGISTRY.register('game_buffer_unizero') @@ -49,22 +48,9 @@ def __init__(self, cfg: dict): self.game_segment_game_pos_look_up = [] self.sample_type = self._cfg.sample_type # 'transition' or 'episode' - if hasattr(self._cfg, 'task_id'): - self.task_id = self._cfg.task_id - print(f"Task ID is set to {self.task_id}.") - try: - self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] - except Exception as e: - self.action_space_size = self._cfg.model.action_space_size - else: - self.task_id = None - print("No task_id found in configuration. Task ID is set to None.") - self.action_space_size = self._cfg.model.action_space_size - self.value_support = DiscreteSupport(*self._cfg.model.value_support_range) self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range) - #@profile def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -95,7 +81,7 @@ def sample( # target policy batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model, current_batch[1], current_batch[-1]) # current_batch[1] is batch_action batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( - policy_non_re_context, self.action_space_size + policy_non_re_context, self._cfg.model.action_space_size ) # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies @@ -112,7 +98,6 @@ def sample( train_data = [current_batch, target_batch] return train_data - #@profile def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: """ Overview: @@ -138,6 +123,9 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: obs_list, action_list, mask_list = [], [], [] timestep_list = [] bootstrap_action_list = [] + advantage_list = [] # PPO: for storing GAE advantages + old_log_prob_list = [] # PPO: for storing old log probabilities + return_list = [] # PPO: for storing returns # prepare the inputs of a batch for i in range(batch_size): @@ -148,6 +136,9 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: self._cfg.num_unroll_steps].tolist() timestep_tmp = game.timestep_segment[pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps].tolist() + # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid + # mask_tmp = [1. for i in range(len(actions_tmp))] + # mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] # TODO: the child_visits after position in the segment (with padded part) may not be updated # So the corresponding position should not be used in the training @@ -186,12 +177,44 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: ] bootstrap_action_list.append(bootstrap_action_tmp) + # import pudb;pudb.set_trace() + # PPO: extract GAE advantages if available + if hasattr(game, 'advantage_segment') and len(game.advantage_segment) > 0: + # Extract advantages for the sampled positions + advantage_tmp = game.advantage_segment[pos_in_game_segment:pos_in_game_segment + + self._cfg.num_unroll_steps].tolist() + # Pad with zeros if not enough advantages (shouldn't happen if GAE is computed correctly) + advantage_tmp += [0.0 for _ in range(self._cfg.num_unroll_steps - len(advantage_tmp))] + else: + # If no advantage computed, fill with zeros + advantage_tmp = [0.0 for _ in range(self._cfg.num_unroll_steps)] + advantage_list.append(advantage_tmp) + + # PPO: extract old_log_prob if available + if hasattr(game, 'old_log_prob_segment') and len(game.old_log_prob_segment) > 0: + log_prob_tmp = game.old_log_prob_segment[pos_in_game_segment:pos_in_game_segment + + self._cfg.num_unroll_steps].tolist() + log_prob_tmp += [0.0 for _ in range(self._cfg.num_unroll_steps - len(log_prob_tmp))] + else: + log_prob_tmp = [0.0 for _ in range(self._cfg.num_unroll_steps)] + old_log_prob_list.append(log_prob_tmp) + + # PPO: extract return if available + if hasattr(game, 'return_segment') and len(game.return_segment) > 0: + return_tmp = game.return_segment[pos_in_game_segment:pos_in_game_segment + + self._cfg.num_unroll_steps].tolist() + return_tmp += [0.0 for _ in range(self._cfg.num_unroll_steps - len(return_tmp))] + else: + return_tmp = [0.0 for _ in range(self._cfg.num_unroll_steps)] + return_list.append(return_tmp) + # formalize the input observations obs_list = prepare_observation(obs_list, self._cfg.model.model_type) # formalize the inputs of a batch - current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list] + # PPO: added advantage_list (9th), old_log_prob_list (10th), return_list (11th) + current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list, advantage_list, old_log_prob_list, return_list] for i in range(len(current_batch)): current_batch[i] = np.asarray(current_batch[i]) @@ -276,6 +299,9 @@ def _make_batch_for_reanalyze(self, batch_size: int) -> Tuple[Any]: obs_list, action_list, mask_list = [], [], [] bootstrap_action_list = [] timestep_list = [] + advantage_list = [] # PPO: for storing GAE advantages + old_log_prob_list = [] # PPO: for storing old log probabilities + return_list = [] # PPO: for storing returns # prepare the inputs of a batch for i in range(batch_size): @@ -290,6 +316,9 @@ def _make_batch_for_reanalyze(self, batch_size: int) -> Tuple[Any]: mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] timestep_tmp = game.timestep_segment[pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps].tolist() + # TODO: original buffer mask + # mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))] + # mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] # pad random action actions_tmp += [ @@ -324,11 +353,39 @@ def _make_batch_for_reanalyze(self, batch_size: int) -> Tuple[Any]: ] bootstrap_action_list.append(bootstrap_action_tmp) + # PPO: extract GAE advantages if available + if hasattr(game, 'advantage_segment') and len(game.advantage_segment) > 0: + advantage_tmp = game.advantage_segment[pos_in_game_segment:pos_in_game_segment + + self._cfg.num_unroll_steps].tolist() + advantage_tmp += [0.0 for _ in range(self._cfg.num_unroll_steps - len(advantage_tmp))] + else: + advantage_tmp = [0.0 for _ in range(self._cfg.num_unroll_steps)] + advantage_list.append(advantage_tmp) + + # PPO: extract old_log_prob if available + if hasattr(game, 'old_log_prob_segment') and len(game.old_log_prob_segment) > 0: + log_prob_tmp = game.old_log_prob_segment[pos_in_game_segment:pos_in_game_segment + + self._cfg.num_unroll_steps].tolist() + log_prob_tmp += [0.0 for _ in range(self._cfg.num_unroll_steps - len(log_prob_tmp))] + else: + log_prob_tmp = [0.0 for _ in range(self._cfg.num_unroll_steps)] + old_log_prob_list.append(log_prob_tmp) + + # PPO: extract return if available + if hasattr(game, 'return_segment') and len(game.return_segment) > 0: + return_tmp = game.return_segment[pos_in_game_segment:pos_in_game_segment + + self._cfg.num_unroll_steps].tolist() + return_tmp += [0.0 for _ in range(self._cfg.num_unroll_steps - len(return_tmp))] + else: + return_tmp = [0.0 for _ in range(self._cfg.num_unroll_steps)] + return_list.append(return_tmp) + # formalize the input observations obs_list = prepare_observation(obs_list, self._cfg.model.model_type) # formalize the inputs of a batch - current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list] + # PPO: added advantage_list (9th), old_log_prob_list (10th), return_list (11th) + current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list, advantage_list, old_log_prob_list, return_list] for i in range(len(current_batch)): current_batch[i] = np.asarray(current_batch[i]) @@ -424,11 +481,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if self._cfg.model.continuous_action_space is True: # when the action space of the environment is continuous, action_mask[:] is None. action_mask = [ - list(np.ones(self.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ - [-1 for _ in range(self.action_space_size)] for _ in range(transition_batch_size) + [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) ] else: legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] @@ -444,25 +501,18 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # =============== NOTE: The key difference with MuZero ================= # To obtain the target policy from MCTS guided by the recent target model # TODO: batch_obs (policy_obs_list) is at timestep t, batch_action is at timestep t - - if self.task_id is not None: - # TODO: support RoPE - # m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], start_pos=batch_timestep[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num - m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num - - else: - m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], start_pos=batch_timestep[:self.reanalyze_num]) # NOTE: :self.reanalyze_num - + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], start_pos=batch_timestep[:self.reanalyze_num]) # NOTE: :self.reanalyze_num # ======================================================================= - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self.value_support), - m_output.policy_logits - ] - ) + if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self.value_support), + m_output.policy_logits + ] + ) network_output.append(m_output) @@ -470,7 +520,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: reward_pool = reward_pool.squeeze().tolist() policy_logits_pool = policy_logits_pool.tolist() noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self.action_space_size + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size ).astype(np.float32).tolist() for _ in range(transition_batch_size) ] if self._cfg.mcts_ctree: @@ -478,21 +528,13 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: roots = MCTSCtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - if self.task_id is not None: - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) - # TODO: adapt unizero multitask to timestep in rope - # MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num], task_id=self.task_id) - else: - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) else: # python mcts_tree roots = MCTSPtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - if self.task_id is not None: - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num], task_id=self.task_id) - else: - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() @@ -503,7 +545,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: distributions = roots_distributions[policy_index] if policy_mask[policy_index] == 0: # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 - target_policies.append([0 for _ in range(self.action_space_size)]) + target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) else: # NOTE: It is very important to use the latest MCTS visit count distribution. sum_visits = sum(distributions) @@ -512,7 +554,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if distributions is None: # if at some obs, the legal_action is None, add the fake target_policy target_policies.append( - list(np.ones(self.action_space_size) / self.action_space_size) + list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) ) else: if self._cfg.env_type == 'not_board_games': @@ -522,7 +564,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: target_policies.append(policy) else: # for board games that have two players and legal_actions is dy - policy_tmp = [0 for _ in range(self.action_space_size)] + policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] # to make sure target_policies have the same dimension sum_visits = sum(distributions) policy = [visit_count / sum_visits for visit_count in distributions] @@ -567,13 +609,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # =============== NOTE: The key difference with MuZero ================= # calculate the bootstrapped value and target value # NOTE: batch_obs(value_obs_list) is at t+td_steps, batch_action is at timestep t+td_steps - if self.task_id is not None: - # m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep, task_id=self.task_id) - m_output = model.initial_inference(batch_obs, batch_action, task_id=self.task_id) - - else: - m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) - + m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) # ====================================================================== # if not in training, obtain the scalars of the value/reward @@ -660,32 +696,17 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A batch_target_values = np.asarray(batch_target_values) return batch_rewards, batch_target_values - - def update_priority(self, train_data: List[np.ndarray], batch_priorities: np.ndarray) -> None: + + def clear(self) -> None: """ Overview: - Update the priority of training data. - Arguments: - - train_data (:obj:`List[np.ndarray]`): training data to be updated priority. - - batch_priorities (:obj:`np.ndarray`): priorities to update to. - NOTE: - train_data = [current_batch, target_batch] - current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list] + Clear all data in the replay buffer for online learning. + This method resets the buffer to its initial empty state. """ - # TODO: NOTE: -4 is batch_index_list - indices = train_data[0][-4] - metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities} - # only update the priorities for data still in replay buffer - for i in range(len(indices)): - - # Handle ValueError by using the first timestamp of the segment for comparison. - first_transition_time = metas['make_time'][i][0] - - if first_transition_time > self.clear_time: - # Handle IndexError by converting the float index to an integer before use. - idx = int(indices[i]) - prio = metas['batch_priorities'][i] - - # Now, idx is a valid integer index. - self.game_pos_priorities[idx] = prio - + self.game_segment_buffer.clear() + # game_pos_priorities might be a list or numpy array, reset to empty list + self.game_pos_priorities = [] + self.game_segment_game_pos_look_up.clear() + self.num_of_collected_episodes = 0 + self.base_idx = 0 + self.clear_time += 1 diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index 6c2cd1999..becca5d72 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -1,5 +1,5 @@ import copy -from typing import List, Tuple, Optional +from typing import List, Tuple import numpy as np from easydict import EasyDict @@ -31,15 +31,13 @@ class GameSegment: - store_search_stats """ - def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None, task_id: Optional[int] = None) -> None: + def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None) -> None: """ Overview: Init the ``GameSegment`` according to the provided arguments. Arguments: - - action_space (:obj:`int`): action space + action_space (:obj:`int`): action space - game_segment_length (:obj:`int`): the transition number of one ``GameSegment`` block - - task_id (:obj:`Optional[int]`): The identifier for the task, used to select the correct obs and act space in multi-task settings. Defaults to None. - """ self.action_space = action_space self.game_segment_length = game_segment_length @@ -47,32 +45,19 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea self.td_steps = config.td_steps self.frame_stack_num = config.model.frame_stack_num self.discount_factor = config.discount_factor - if not hasattr(config.model, "action_space_size_list"): - # for single-task setting or fixed action space in multi-task setting - self.action_space_size = config.model.action_space_size + self.action_space_size = config.model.action_space_size self.gray_scale = config.gray_scale self.transform2string = config.transform2string self.sampled_algo = config.sampled_algo self.gumbel_algo = config.gumbel_algo self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder - if task_id is None: - if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: - # for vector obs input, e.g. classical control and box2d environments - self.zero_obs_shape = config.model.observation_shape - elif len(config.model.observation_shape) == 3: - # image obs input, e.g. atari environments - self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) - else: - if hasattr(config.model, "observation_shape_list"): - if isinstance(config.model.observation_shape_list[task_id], int) or len(config.model.observation_shape_list[task_id]) == 1: - # for vector obs input, e.g. classical control and box2d environments - self.zero_obs_shape = config.model.observation_shape_list[task_id] - elif len(config.model.observation_shape_list[task_id]) == 3: - # image obs input, e.g. atari environments - self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1]) - else: - self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) + if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: + # for vector obs input, e.g. classical control and box2d environments + self.zero_obs_shape = config.model.observation_shape + elif len(config.model.observation_shape) == 3: + # image obs input, e.g. atari environments + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) self.obs_segment = [] self.action_segment = [] @@ -96,6 +81,12 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea if self.use_ture_chance_label_in_chance_encoder: self.chance_segment = [] + # PPO related fields + self.episode_id = None # 标记该 segment 属于哪个 episode + self.advantage_segment = [] # 用于存储 GAE advantages + self.old_log_prob_segment = [] # 用于存储收集时的 log_prob (PPO 需要) + self.return_segment = [] # 用于存储 return (PPO value training 需要) + self.reanalyze_time = 0 def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray: @@ -320,6 +311,14 @@ def game_segment_to_array(self) -> None: if self.use_ture_chance_label_in_chance_encoder: self.chance_segment = np.array(self.chance_segment) + # Convert PPO related fields to numpy array + if len(self.advantage_segment) > 0: + self.advantage_segment = np.array(self.advantage_segment) + if len(self.old_log_prob_segment) > 0: + self.old_log_prob_segment = np.array(self.old_log_prob_segment) + if len(self.return_segment) > 0: + self.return_segment = np.array(self.return_segment) + def reset(self, init_observations: np.ndarray) -> None: """ Overview: @@ -342,6 +341,11 @@ def reset(self, init_observations: np.ndarray) -> None: if self.use_ture_chance_label_in_chance_encoder: self.chance_segment = [] + # Reset PPO related fields + self.advantage_segment = [] + self.old_log_prob_segment = [] + self.return_segment = [] + assert len(init_observations) == self.frame_stack_num for observation in init_observations: diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index b680a6e2d..9d57b3c5f 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -1,20 +1,17 @@ from typing import Optional + import torch import torch.nn as nn -from ding.utils import (ENV_REGISTRY, MODEL_REGISTRY, SequenceType, get_rank, - get_world_size, set_pkg_seed) -from ditk import logging +from ding.utils import MODEL_REGISTRY, SequenceType from easydict import EasyDict -from .common import (FeatureAndGradientHook, HFLanguageRepresentationNetwork, - LatentDecoder, LatentDecoderForMemoryEnv, - LatentEncoderForMemoryEnv, MZNetworkOutput, QwenNetwork, - RepresentationNetworkMLP, RepresentationNetworkUniZero, - VectorDecoderForMemoryEnv) +# from transformers import T5ForConditionalGeneration, T5Tokenizer + +from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ + VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \ + HFLanguageRepresentationNetwork, QwenNetwork from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model import WorldModel -from .vit import ViT, ViTConfig - -# from transformers import T5ForConditionalGeneration, T5Tokenizer +from ding.utils import ENV_REGISTRY, set_pkg_seed, get_rank, get_world_size # use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. @@ -91,13 +88,13 @@ def __init__( # TODO: only for MemoryEnv now self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25) self.tokenizer = Tokenizer(encoder=self.representation_network, - decoder=self.decoder_network, with_lpips=False, obs_type=world_model_cfg.obs_type) + decoder_network=self.decoder_network, with_lpips=False) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) - logging.info(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') - logging.info('==' * 20) - logging.info(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') - logging.info(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') - logging.info('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) elif world_model_cfg.obs_type == 'text': if kwargs['encoder_option'] == 'legacy': self.representation_network = HFLanguageRepresentationNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder) @@ -128,72 +125,38 @@ def __init__( self.decoder_network_tokenizer = None else: raise ValueError(f"Unsupported encoder option: {kwargs['encoder_option']}") - - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, - with_lpips=False, projection=projection, encoder_option=kwargs['encoder_option']) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, + with_lpips=False, projection=projection, encoder_option=kwargs['encoder_option']) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) - - # --- Log parameter counts for analysis --- - self._log_model_parameters(obs_type) - - logging.info(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') - logging.info('==' * 20) - logging.info(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') - logging.info(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') - logging.info('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) elif world_model_cfg.obs_type == 'image': - if world_model_cfg.encoder_type == "resnet": - self.representation_network = RepresentationNetworkUniZero( - observation_shape, - num_res_blocks, - num_channels, - self.downsample, - activation=self.activation, - norm_type=norm_type, - embedding_dim=world_model_cfg.embed_dim, - group_size=world_model_cfg.group_size, - final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, - ) - elif world_model_cfg.encoder_type == "vit": - # vit base - vit_config = ViTConfig( - image_size=observation_shape[1], - patch_size=8, - num_classes=world_model_cfg.embed_dim, - dim=768, - depth=12, - heads=12, - mlp_dim=3072, - dropout=0.1, - emb_dropout=0.1, - final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, - lora_config=world_model_cfg, - ) - self.representation_network = ViT(config=vit_config) + self.representation_network = RepresentationNetworkUniZero( + observation_shape, + num_res_blocks, + num_channels, + self.downsample, + activation=self.activation, + norm_type=norm_type, + embedding_dim=world_model_cfg.embed_dim, + group_size=world_model_cfg.group_size, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder + ) # ====== for analysis ====== if world_model_cfg.analysis_sim_norm: self.encoder_hook = FeatureAndGradientHook() self.encoder_hook.setup_hooks(self.representation_network) - - if world_model_cfg.latent_recon_loss_weight == 0: - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=None, with_lpips=False, obs_type=world_model_cfg.obs_type) - else: - # TODO: customize LatentDecoder - self.decoder_network = LatentDecoder( - embedding_dim=world_model_cfg.embed_dim, - output_shape=[3, 64, 64], - num_channels = 64, - activation=self.activation, - ) - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, with_lpips=True, obs_type=world_model_cfg.obs_type) - + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False,) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) - logging.info(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') - logging.info('==' * 20) - logging.info(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') - logging.info(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') - logging.info('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) elif world_model_cfg.obs_type == 'image_memory': self.representation_network = LatentEncoderForMemoryEnv( image_shape=(3, 5, 5), @@ -218,170 +181,17 @@ def __init__( self.encoder_hook = FeatureAndGradientHook() self.encoder_hook.setup_hooks(self.representation_network) - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, obs_type=world_model_cfg.obs_type) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network') + print('==' * 20) - - logging.info(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') - logging.info(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') - - logging.info('==' * 20) - logging.info(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') - logging.info(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') - logging.info(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network') - logging.info('==' * 20) - - # --- Log parameter counts for analysis --- - self._log_model_parameters(world_model_cfg.obs_type) - - - def _log_model_parameters(self, obs_type: str) -> None: - """ - Overview: - Logs detailed parameter counts for all model components with a comprehensive breakdown. - Includes encoder, transformer, prediction heads, and other components. - Arguments: - - obs_type (:obj:`str`): The type of observation ('vector', 'image', or 'image_memory'). - """ - from ding.utils import get_rank - - # Only print from rank 0 to avoid duplicate logs in DDP - if get_rank() != 0: - return - - logging.info('=' * 80) - logging.info('MODEL PARAMETER STATISTICS'.center(80)) - logging.info('=' * 80) - - # --- Total Model Parameters --- - total_params = sum(p.numel() for p in self.parameters()) - total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) - logging.info(f'\n{"TOTAL MODEL":<40} {total_params:>15,} parameters') - logging.info(f'{" └─ Trainable":<40} {total_trainable:>15,} parameters') - logging.info(f'{" └─ Frozen":<40} {total_params - total_trainable:>15,} parameters') - - # --- World Model Components --- - logging.info(f'\n{"-" * 80}') - logging.info(f'{"WORLD MODEL BREAKDOWN":<40}') - logging.info(f'{"-" * 80}') - - wm_params = sum(p.numel() for p in self.world_model.parameters()) - wm_trainable = sum(p.numel() for p in self.world_model.parameters() if p.requires_grad) - logging.info(f'{"World Model Total":<40} {wm_params:>15,} parameters') - logging.info(f'{" └─ Trainable":<40} {wm_trainable:>15,} parameters ({100*wm_trainable/wm_params:.1f}%)') - - # --- Encoder --- - encoder_params = sum(p.numel() for p in self.tokenizer.encoder.parameters()) - encoder_trainable = sum(p.numel() for p in self.tokenizer.encoder.parameters() if p.requires_grad) - logging.info(f'\n{"1. ENCODER (Tokenizer)":<40} {encoder_params:>15,} parameters') - logging.info(f'{" └─ Trainable":<40} {encoder_trainable:>15,} parameters ({100*encoder_trainable/encoder_params:.1f}%)') - - # --- Transformer Backbone --- - transformer_params = sum(p.numel() for p in self.world_model.transformer.parameters()) - transformer_trainable = sum(p.numel() for p in self.world_model.transformer.parameters() if p.requires_grad) - logging.info(f'\n{"2. TRANSFORMER BACKBONE":<40} {transformer_params:>15,} parameters') - logging.info(f'{" └─ Trainable":<40} {transformer_trainable:>15,} parameters ({100*transformer_trainable/transformer_params:.1f}%)') - - # --- Prediction Heads (Detailed Breakdown) --- - logging.info(f'\n{"3. PREDICTION HEADS":<40}') - - # Access head_dict from world_model - if hasattr(self.world_model, 'head_dict'): - head_dict = self.world_model.head_dict - - # Calculate total heads parameters - total_heads_params = sum(p.numel() for module in head_dict.values() for p in module.parameters()) - total_heads_trainable = sum(p.numel() for module in head_dict.values() for p in module.parameters() if p.requires_grad) - logging.info(f'{" Total (All Heads)":<40} {total_heads_params:>15,} parameters') - logging.info(f'{" └─ Trainable":<40} {total_heads_trainable:>15,} parameters ({100*total_heads_trainable/total_heads_params:.1f}%)') - - # Breakdown by head type - head_names_map = { - 'head_policy_multi_task': 'Policy Head', - 'head_value_multi_task': 'Value Head', - 'head_rewards_multi_task': 'Reward Head', - 'head_observations_multi_task': 'Next Latent (Obs) Head' - } - - logging.info(f'\n{" Breakdown by Head Type:":<40}') - for head_key, head_name in head_names_map.items(): - if head_key in head_dict: - head_module = head_dict[head_key] - head_params = sum(p.numel() for p in head_module.parameters()) - head_trainable = sum(p.numel() for p in head_module.parameters() if p.requires_grad) - - # Count number of task-specific heads (for ModuleList) - if isinstance(head_module, nn.ModuleList): - num_heads = len(head_module) - params_per_head = head_params // num_heads if num_heads > 0 else 0 - logging.info(f'{" ├─ " + head_name:<38} {head_params:>15,} parameters') - logging.info(f'{" └─ " + f"{num_heads} task-specific heads":<38} {params_per_head:>15,} params/head') - else: - logging.info(f'{" ├─ " + head_name:<38} {head_params:>15,} parameters') - logging.info(f'{" └─ Shared across tasks":<38}') - - # --- Positional & Task Embeddings --- - logging.info(f'\n{"4. EMBEDDINGS":<40}') - - if hasattr(self.world_model, 'pos_emb'): - pos_emb_params = sum(p.numel() for p in self.world_model.pos_emb.parameters()) - pos_emb_trainable = sum(p.numel() for p in self.world_model.pos_emb.parameters() if p.requires_grad) - logging.info(f'{" ├─ Positional Embedding":<40} {pos_emb_params:>15,} parameters') - if pos_emb_trainable == 0: - logging.info(f'{" └─ (Frozen)":<40}') - - if hasattr(self.world_model, 'task_emb') and self.world_model.task_emb is not None: - task_emb_params = sum(p.numel() for p in self.world_model.task_emb.parameters()) - task_emb_trainable = sum(p.numel() for p in self.world_model.task_emb.parameters() if p.requires_grad) - logging.info(f'{" ├─ Task Embedding":<40} {task_emb_params:>15,} parameters') - logging.info(f'{" └─ Trainable":<40} {task_emb_trainable:>15,} parameters') - - if hasattr(self.world_model, 'act_embedding_table'): - act_emb_params = sum(p.numel() for p in self.world_model.act_embedding_table.parameters()) - act_emb_trainable = sum(p.numel() for p in self.world_model.act_embedding_table.parameters() if p.requires_grad) - logging.info(f'{" └─ Action Embedding":<40} {act_emb_params:>15,} parameters') - logging.info(f'{" └─ Trainable":<40} {act_emb_trainable:>15,} parameters') - - # --- Decoder (if applicable) --- - if obs_type in ['vector', 'image_memory'] and self.tokenizer.decoder_network is not None: - logging.info(f'\n{"5. DECODER":<40}') - decoder_params = sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - decoder_trainable = sum(p.numel() for p in self.tokenizer.decoder_network.parameters() if p.requires_grad) - logging.info(f'{" Decoder Network":<40} {decoder_params:>15,} parameters') - logging.info(f'{" └─ Trainable":<40} {decoder_trainable:>15,} parameters') - - if obs_type == 'image_memory' and hasattr(self.tokenizer, 'lpips'): - lpips_params = sum(p.numel() for p in self.tokenizer.lpips.parameters()) - logging.info(f'{" LPIPS Loss Network":<40} {lpips_params:>15,} parameters') - - # Calculate world model params excluding decoder and LPIPS - params_without_decoder = wm_params - decoder_params - lpips_params - logging.info(f'\n{" World Model (exc. Decoder & LPIPS)":<40} {params_without_decoder:>15,} parameters') - - # --- Summary Table --- - logging.info(f'\n{"=" * 80}') - logging.info(f'{"SUMMARY":<40}') - logging.info(f'{"=" * 80}') - logging.info(f'{"Component":<30} {"Total Params":>15} {"Trainable":>15} {"% of Total":>15}') - logging.info(f'{"-" * 80}') - - components = [ - ("Encoder", encoder_params, encoder_trainable), - ("Transformer", transformer_params, transformer_trainable), - ] - - if hasattr(self.world_model, 'head_dict'): - components.append(("Prediction Heads", total_heads_params, total_heads_trainable)) - - for name, total, trainable in components: - pct = 100 * total / total_params if total_params > 0 else 0 - logging.info(f'{name:<30} {total:>15,} {trainable:>15,} {pct:>14.1f}%') - - logging.info(f'{"=" * 80}') - logging.info(f'{"TOTAL":<30} {total_params:>15,} {total_trainable:>15,} {"100.0%":>15}') - logging.info(f'{"=" * 80}\n') - def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torch.Tensor] = None, current_obs_batch: Optional[torch.Tensor] = None, start_pos: int = 0) -> MZNetworkOutput: """ @@ -467,4 +277,4 @@ def recurrent_inference(self, state_action_history: torch.Tensor, simulation_ind policy_logits = logits_policy.squeeze(1) value = logits_value.squeeze(1) - return MZNetworkOutput(value=value, reward=reward, policy_logits=policy_logits, latent_state=next_latent_state) + return MZNetworkOutput(value=value, reward=reward, policy_logits=policy_logits, latent_state=next_latent_state) \ No newline at end of file diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py old mode 100755 new mode 100644 index 5d34e9fe6..7d8a2a403 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -1,30 +1,20 @@ -import datetime import logging -import os -from collections import OrderedDict, defaultdict -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, Union, Optional, List, Tuple, Any -import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from lzero.model.common import SimNorm -from lzero.model.utils import (calculate_dormant_ratio, - compute_average_weight_magnitude, - compute_effective_rank) -from matplotlib.offsetbox import AnnotationBbox, OffsetImage -from sklearn.manifold import TSNE -from torch.distributions import (Categorical, Independent, Normal, - TanhTransform, TransformedDistribution) +from torch.distributions import Categorical, Independent, Normal, TransformedDistribution, TanhTransform +from lzero.model.common import SimNorm +from lzero.model.utils import cal_dormant_ratio from .kv_caching import KeysValues from .slicer import Head, PolicyHeadCont from .tokenizer import Tokenizer from .transformer import Transformer, TransformerConfig -from .utils import (LossWithIntermediateLosses, WorldModelOutput, hash_state, - init_weights) +from .utils import LossWithIntermediateLosses, init_weights, WorldModelOutput, hash_state logging.getLogger().setLevel(logging.DEBUG) @@ -51,11 +41,8 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: super().__init__() self.tokenizer = tokenizer self.config = config - self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings - self.transformer = Transformer(self.config) - self.task_num = 1 - self.env_num = self.config.env_num + if self.config.device == 'cpu': self.device = torch.device('cpu') else: @@ -64,8 +51,6 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: logging.info(f"self.device: {self.device}") self.to(self.device) - self.task_embed_dim = config.task_embed_dim if hasattr(config, "task_embed_dim") else 96 - # Initialize configuration parameters self._initialize_config_parameters() @@ -80,11 +65,6 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.precompute_pos_emb_diff_kv() print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") - self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 - if self.task_embed_option == "concat_task_embed": - self.obs_per_embdding_dim = self.config.embed_dim - self.task_embed_dim - else: - self.obs_per_embdding_dim = self.config.embed_dim self.continuous_action_space = self.config.continuous_action_space # Initialize action embedding table @@ -102,7 +82,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: # Head modules self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) - self.head_observations = self._create_head_for_latent(self.all_but_last_latent_state_pattern, self.obs_per_embdding_dim, \ + self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, self.obs_per_embdding_dim, \ self._get_final_norm(self.final_norm_option_in_obs_head) # NOTE: using the specified normalization method for observations head ) if self.continuous_action_space: @@ -113,13 +93,6 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) - self.head_dict = {} - for name, module in self.named_children(): - if name.startswith("head_"): - self.head_dict[name] = module - if self.head_dict: - self.head_dict = nn.ModuleDict(self.head_dict) - # Build the set of modules to skip during re-initialization. # This is compatible with cases where self.tokenizer.encoder does not have 'pretrained_model', # or self.tokenizer does not have 'decoder_network'. @@ -142,6 +115,9 @@ def custom_init(module): self._initialize_last_layer() + # Cache structures + self._initialize_cache_structures() + # Projection input dimension self._initialize_projection_input_dim() @@ -154,25 +130,18 @@ def custom_init(module): self.latent_recon_loss = torch.tensor(0., device=self.device) self.perceptual_loss = torch.tensor(0., device=self.device) - # Set to game_segment_length first to keep self.shared_pool_init_infer valid - # TODO: Very important, should be changed to match segment_length - self.shared_pool_size_init = int(self.config.game_segment_length) # NOTE: Will having too many cause incorrect retrieval of the kv cache? - # TODO: check the size of the shared pool # for self.kv_cache_recurrent_infer # If needed, recurrent_infer should store the results of the one MCTS search. self.num_simulations = getattr(self.config, 'num_simulations', 50) - - - self.shared_pool_size_recur = int(self.num_simulations*self.env_num) - self.shared_pool_recur_infer = [None] * self.shared_pool_size_recur + self.shared_pool_size = int(self.num_simulations*self.env_num) + self.shared_pool_recur_infer = [None] * self.shared_pool_size self.shared_pool_index = 0 - # Cache structures - self._initialize_cache_structures() - # for self.kv_cache_init_infer # In contrast, init_infer only needs to retain the results of the most recent step. + # self.shared_pool_size_init = int(2*self.env_num) + self.shared_pool_size_init = int(2) # NOTE: Will having too many cause incorrect retrieval of the kv cache? self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] @@ -183,237 +152,6 @@ def custom_init(module): self.reanalyze_phase = False - def _initialize_cache_structures(self) -> None: - """Initialize cache structures for past keys and values.""" - from collections import defaultdict - - # ==================== Parallel KV Cache Systems ==================== - # Check if we should use the new KV cache manager - self.use_new_cache_manager = getattr(self.config, 'use_new_cache_manager', False) - - if self.use_new_cache_manager: - # Use new unified KV cache manager - from .kv_cache_manager import KVCacheManager - self.kv_cache_manager = KVCacheManager( - config=self.config, - env_num=self.env_num, - enable_stats=True, - clear_recur_log_freq=1000, # MCTS recurrent clearing log, print every 1000 times - clear_all_log_freq=100 # Episode reset clearing log, print every 100 times - ) - # Keep backward compatibility references - self.keys_values_wm_list = self.kv_cache_manager.keys_values_wm_list - self.keys_values_wm_size_list = self.kv_cache_manager.keys_values_wm_size_list - - # ==================== BUG FIX: Complete Refactoring ==================== - # DO NOT initialize old system attributes when using new cache manager. - # Any code that depends on these old attributes must be refactored to use - # kv_cache_manager instead. - # - # Old attributes that are NO LONGER available in new system: - # - self.past_kv_cache_recurrent_infer - # - self.pool_idx_to_key_map_recur_infer - # - self.past_kv_cache_init_infer_envs - # - self.pool_idx_to_key_map_init_envs - # - # Migration guide: - # - For accessing init cache: use kv_cache_manager.get_init_cache(env_id, key) - # - For accessing recur cache: use kv_cache_manager.get_recur_cache(key) - # - For hierarchical lookup: use kv_cache_manager.hierarchical_get(env_id, key) - # ====================================================================== - - logging.info("✓ Using NEW KVCacheManager for cache management") - else: - # Use old cache system (original implementation) - self.past_kv_cache_recurrent_infer = {} - self.pool_idx_to_key_map_recur_infer = [None] * self.shared_pool_size_recur - self.past_kv_cache_init_infer_envs = [{} for _ in range(self.env_num)] - # Auxiliary data structure for reverse lookup: pool_index -> key - self.pool_idx_to_key_map_init_envs = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] - - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - logging.info("Using OLD cache system (original implementation)") - # ============================================================================= - - def _inspect_and_log_head_params(self, head_name: str, head_module: nn.Module, status: str): - """ - Inspect and log parameter statistics for the specified Head module. - - Args: - head_name (str): The name of the Head to inspect (e.g., "Value Head"). - head_module (nn.Module): The actual nn.Sequential module of the Head. - status (str): A string describing the current status (e.g., "Before Re-init"). - """ - logging.info(f"--- Inspecting {head_name} parameters ({status}) ---") - with torch.no_grad(): - for param_name, param in head_module.named_parameters(): - if param.numel() > 0: - stats = { - "mean": param.mean().item(), - "std": param.std().item(), - "abs_mean": param.abs().mean().item(), - "max": param.max().item(), - "min": param.min().item(), - } - logging.info( - f" -> {param_name:<20} | " - f"Mean: {stats['mean']:.4f}, Std: {stats['std']:.4f}, " - f"AbsMean: {stats['abs_mean']:.4f}, " - f"Max: {stats['max']:.4f}, Min: {stats['min']:.4f}" - ) - logging.info("-" * (23 + len(head_name) + len(status))) - - def reinit_prediction_heads(self, heads_to_reinit: List[str] = ['value', 'reward']) -> None: - """ - Reinitialize the parameters of specified prediction heads (e.g., Value Head and Reward Head). - Parameter statistics are logged before and after reinitialization for analysis. - - Args: - heads_to_reinit (List[str]): A list containing the names of the heads to reinitialize. - Defaults to ['value', 'reward']. - """ - logging.info(f"Starting reinitialization of prediction heads: {heads_to_reinit}") - - head_map = { - 'value': self.head_value, - 'reward': self.head_rewards, - 'policy': self.head_policy, - } - - def _init_weights_for_head(module): - # TODO - init_weights(module, norm_type=self.config.norm_type, liner_weight_zero=True) - - for head_name in heads_to_reinit: - if head_name in head_map and hasattr(head_map[head_name], 'head_module'): - head_instance = head_map[head_name] - capitalized_name = head_name.capitalize() + " Head" - - # 1. Inspect parameters before reinitialization - self._inspect_and_log_head_params(capitalized_name, head_instance.head_module, "Before Re-init") - - # 2. Apply reinitialization - logging.info(f"Reinitializing {capitalized_name}...") - head_instance.head_module.apply(_init_weights_for_head) - - # 3. Inspect parameters again after reinitialization - self._inspect_and_log_head_params(capitalized_name, head_instance.head_module, "After Re-init") - - logging.info(f"{capitalized_name} parameters successfully reinitialized.") - else: - logging.warning(f"Prediction head named '{head_name}' or its 'head_module' not found. Skipping.") - - logging.info("Reinitialization of all specified prediction heads completed.") - - def _analyze_latent_representation( - self, - latent_states: torch.Tensor, - timesteps: torch.Tensor, - game_states: torch.Tensor, - predicted_values: torch.Tensor, - predicted_rewards: torch.Tensor, - step_counter: int - ): - """ - Analyze and log statistics of latent states with t-SNE visualization. - [New feature]: Display corresponding game images on t-SNE plot with predicted Value and Reward annotations. - [Modified]: If the save path already exists, append a timestamp to the filename. - - Args: - latent_states (torch.Tensor): Encoder output, shape (B*L, 1, E) - timesteps (torch.Tensor): Corresponding timesteps, shape (B, L) - game_states (torch.Tensor): Original game observations, shape (B, L, C, H, W) - predicted_values (torch.Tensor): Predicted scalar Values, shape (B*L,) - predicted_rewards (torch.Tensor): Predicted scalar Rewards, shape (B*L,) - step_counter (int): Global training step count - """ - # Ensure latent_states and game_states have shape (N, ...) - if latent_states.dim() > 2: - latent_states = latent_states.reshape(-1, latent_states.shape[-1]) - num_c, num_h, num_w = game_states.shape[-3:] - game_states = game_states.reshape(-1, num_c, num_h, num_w) - - with torch.no_grad(): - l2_norm = torch.norm(latent_states, p=2, dim=1).mean() - mean = latent_states.mean() - std = latent_states.std() - print(f"[Step {step_counter}] Latent Stats | L2 Norm: {l2_norm:.4f}, Mean: {mean:.4f}, Std: {std:.4f}") - - # t-SNE visualization with images and V/R values - if step_counter >= 0: - print(f"[Step {step_counter}] Performing t-SNE analysis with images, values, and rewards...") - - # Convert data to CPU - latents_np = latent_states.detach().cpu().numpy() - images_np = game_states.detach().cpu().numpy() - values_np = predicted_values.detach().cpu().numpy() - rewards_np = predicted_rewards.detach().cpu().numpy() - - tsne = TSNE(n_components=2, perplexity=30, n_iter=300, random_state=42) - tsne_results = tsne.fit_transform(latents_np) - - # Draw scatter plot with images and annotations - - # Reduce number of images to keep clarity - num_points_to_plot = min(len(latents_np), 70) # Reduce to 70 points - indices = np.random.choice(len(latents_np), num_points_to_plot, replace=False) - - fig, ax = plt.subplots(figsize=(20, 18)) # Increase canvas size - - # First draw all points as background scatter plot - ax.scatter(tsne_results[:, 0], tsne_results[:, 1], c=values_np, cmap='viridis', alpha=0.3, s=10) - - for i in indices: - x, y = tsne_results[i] - img = images_np[i].transpose(1, 2, 0) - img = np.clip(img, 0, 1) - - # Place image - im = OffsetImage(img, zoom=0.7) # Slightly enlarge image - ab = AnnotationBbox(im, (x, y), frameon=True, pad=0.0, bboxprops=dict(edgecolor='none')) - ax.add_artist(ab) - - # Add text annotation below image - text_label = f"V:{values_np[i]:.1f} R:{rewards_np[i]:.1f}" - ax.text(x, y - 1.0, text_label, ha='center', va='top', fontsize=8, color='red', - bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=0.5)) - - ax.update_datalim(tsne_results) - ax.autoscale() - - ax.set_title(f't-SNE of Latent States (Value as Color) at Step {step_counter}', fontsize=16) - ax.set_xlabel('t-SNE dimension 1', fontsize=12) - ax.set_ylabel('t-SNE dimension 2', fontsize=12) - - # Add colorbar to explain background point colors - norm = plt.Normalize(values_np.min(), values_np.max()) - sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm) - sm.set_array([]) - fig.colorbar(sm, ax=ax, label='Predicted Value') - - # Modified section: Check if file exists, add timestamp if it does - base_save_path = ( - f'/mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/unizero_mspacman_analyze/' - f'tsne_with_vr_{self.config.optim_type}_step_{step_counter}.png' - ) - - # Check if file exists and determine final save path - if os.path.exists(base_save_path): - # If file already exists, generate timestamp and append to filename - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - path_root, path_ext = os.path.splitext(base_save_path) - save_path = f"{path_root}_{timestamp}{path_ext}" - print(f"File '{base_save_path}' already exists. Saving to new path with timestamp.") - else: - # If file doesn't exist, use original path - save_path = base_save_path - - # Save image - plt.savefig(save_path) - plt.close(fig) # Explicitly close figure object - print(f"t-SNE plot with V/R annotations saved to {save_path}") - def _get_final_norm(self, norm_option: str) -> nn.Module: """ Return the corresponding normalization module based on the specified normalization option. @@ -471,7 +209,7 @@ def custom_copy_kv_cache_to_shared_wm(self, src_kv: KeysValues) -> int: - index (:obj:`int`): The index in the shared pool where the KeysValues object is stored. """ src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape - + if self.shared_pool_wm[self.shared_pool_index_wm] is None: self.shared_pool_wm[self.shared_pool_index_wm] = KeysValues( src_kv_shape[0], # Number of elements (n) @@ -483,7 +221,7 @@ def custom_copy_kv_cache_to_shared_wm(self, src_kv: KeysValues) -> int: ) dst_kv = self.shared_pool_wm[self.shared_pool_index_wm] - + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): # Copy the key and value caches using torch.copy_() for efficient data transfer dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) @@ -526,7 +264,7 @@ def custom_copy_kv_cache_to_shared_recur(self, src_kv: KeysValues) -> int: dst_layer._v_cache._size = src_layer._v_cache._size index = self.shared_pool_index - self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size_recur + self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size return index @@ -542,7 +280,7 @@ def _initialize_config_parameters(self) -> None: self.gamma = self.config.gamma self.context_length = self.config.context_length self.dormant_threshold = self.config.dormant_threshold - self.analysis_dormant_ratio_weight_rank = self.config.analysis_dormant_ratio_weight_rank + self.analysis_dormant_ratio = self.config.analysis_dormant_ratio self.num_observations_tokens = self.config.tokens_per_block - 1 self.latent_recon_loss_weight = self.config.latent_recon_loss_weight self.perceptual_loss_weight = self.config.perceptual_loss_weight @@ -551,52 +289,9 @@ def _initialize_config_parameters(self) -> None: self.max_cache_size = self.config.max_cache_size self.env_num = self.config.env_num self.num_layers = self.config.num_layers + self.obs_per_embdding_dim = self.config.embed_dim self.sim_norm = SimNorm(simnorm_dim=self.group_size) - # ==================== [NEW] Policy Stability Fix Options ==================== - # Load fix options from config (with defaults for backward compatibility) - self.use_policy_logits_clip = getattr(self.config, 'use_policy_logits_clip', False) - self.policy_logits_clip_method = getattr(self.config, 'policy_logits_clip_method', 'normalize_max') - self.policy_logits_clip_min = getattr(self.config, 'policy_logits_clip_min', -10.0) - self.policy_logits_clip_max = getattr(self.config, 'policy_logits_clip_max', 10.0) - self.policy_logits_soft_beta = getattr(self.config, 'policy_logits_soft_beta', 1.0) - self.policy_logits_adaptive_percentile = getattr(self.config, 'policy_logits_adaptive_percentile', 95) - - # Running statistics for adaptive clipping - if self.policy_logits_clip_method == 'adaptive': - self.register_buffer('policy_logits_running_max', torch.tensor(10.0)) - self.register_buffer('policy_logits_running_min', torch.tensor(-10.0)) - self.policy_logits_momentum = 0.99 - - # [NEW] Fix5: Temperature scaling for policy loss - self.use_policy_loss_temperature = getattr(self.config, 'use_policy_loss_temperature', False) - self.policy_loss_temperature = getattr(self.config, 'policy_loss_temperature', 1.0) - - # [NEW] Fix3: Check if target policy re-smooth is enabled (now deprecated in favor of Fix2) - use_target_policy_resmooth = getattr(self.config, 'use_target_policy_resmooth', False) - if use_target_policy_resmooth: - logging.warning( - "[DEPRECATED] use_target_policy_resmooth=True is deprecated! " - "Policy label smoothing should now be controlled by 'continuous_ls_eps' in policy config. " - "Fix3 (use_target_policy_resmooth) creates redundant smoothing with Fix2. " - "Please set use_target_policy_resmooth=False and use continuous_ls_eps instead." - ) - - # [NEW] Debug: Print configuration on initialization - if self.use_policy_logits_clip: - logging.info( - f"[Policy Logits Control] ENABLED\n" - f" Method: {self.policy_logits_clip_method}\n" - f" Range: [{self.policy_logits_clip_min}, {self.policy_logits_clip_max}]\n" - f" Soft Beta: {self.policy_logits_soft_beta if 'soft' in self.policy_logits_clip_method else 'N/A'}" - ) - else: - logging.warning(f"[Policy Logits Control] DISABLED! Logits may grow unbounded.") - - if self.use_policy_loss_temperature and self.policy_loss_temperature != 1.0: - logging.info(f"[Policy Loss Temperature] ENABLED: temperature={self.policy_loss_temperature}") - # ============================================================================= - def _initialize_patterns(self) -> None: """Initialize patterns for block masks.""" self.all_but_last_latent_state_pattern = torch.ones(self.config.tokens_per_block) @@ -606,144 +301,12 @@ def _initialize_patterns(self) -> None: self.value_policy_tokens_pattern = torch.zeros(self.config.tokens_per_block) self.value_policy_tokens_pattern[-2] = 1 - def _apply_policy_logits_control(self, logits_policy: torch.Tensor) -> torch.Tensor: - """ - Apply policy logits control using various methods to prevent explosion. - - This method implements multiple strategies to constrain policy logits: - 1. 'hard': Hard clamp (torch.clamp) - Simple but gradients die at boundaries - 2. 'soft_tanh': Soft clamp using tanh - Smooth, gradients never zero - 3. 'soft_sigmoid': Soft clamp using sigmoid - Similar to tanh but different curve - 4. 'normalize_max': Subtract max then clamp - Preserves relative order, safer - 5. 'normalize_mean': Subtract mean then clamp - Centers distribution - 6. 'adaptive': Adaptive clipping based on running statistics - 7. 'none': No clipping - - Arguments: - - logits_policy (:obj:`torch.Tensor`): Raw policy logits from head_policy - Shape: [batch_size, num_steps, action_dim] or [batch_size * num_steps, action_dim] - - Returns: - - torch.Tensor: Controlled policy logits with the same shape - - Examples: - >>> logits = torch.randn(32, 10, 6) * 20 # Large logits - >>> controlled = self._apply_policy_logits_control(logits) - >>> assert controlled.abs().max() <= self.policy_logits_clip_max - """ - if not self.use_policy_logits_clip or self.policy_logits_clip_method == 'none': - return logits_policy - - method = self.policy_logits_clip_method - clip_min = self.policy_logits_clip_min - clip_max = self.policy_logits_clip_max - - # ==================== Method 1: Hard Clamp ==================== - if method == 'hard': - # Simple hard clipping - # Pros: Simple, fast - # Cons: Gradients become zero outside [clip_min, clip_max] - return torch.clamp(logits_policy, min=clip_min, max=clip_max) - - # ==================== Method 2: Soft Tanh Clamp ==================== - elif method == 'soft_tanh': - # Soft clamp using tanh function: clip_max * tanh(x / clip_max) - # Pros: Gradients never zero, smooth transition - # Cons: Slightly more computation - # When x is small: tanh(x) ≈ x, so output ≈ x (unchanged) - # When x is large: tanh(x) → 1, so output → clip_max (smoothly saturates) - C = clip_max # Use positive bound as scale - beta = self.policy_logits_soft_beta # Smoothness parameter - return C * torch.tanh(logits_policy / (C * beta)) - - # ==================== Method 3: Soft Sigmoid Clamp ==================== - elif method == 'soft_sigmoid': - # Soft clamp using sigmoid: maps (-∞, ∞) to (clip_min, clip_max) - # Formula: clip_min + (clip_max - clip_min) * sigmoid(x / beta) - # Pros: Smooth, bounded - # Cons: Compresses entire range, may lose relative ordering - beta = self.policy_logits_soft_beta - range_size = clip_max - clip_min - return clip_min + range_size * torch.sigmoid(logits_policy / beta) - - # ==================== Method 4: Normalize Max + Hard Clamp ==================== - elif method == 'normalize_max': - # Subtract max value first (exploits softmax translation invariance) - # softmax(x) = softmax(x - c) for any constant c - # By subtracting max, we ensure the largest logit is 0, others are negative - # Then apply hard clamp (mainly affects the negative tail) - # Pros: Preserves relative ordering, safer than pure hard clamp - # Cons: Still has gradient issues for very negative values - logits_normalized = logits_policy - logits_policy.max(dim=-1, keepdim=True)[0].detach() - return torch.clamp(logits_normalized, min=clip_min, max=clip_max) - - # ==================== Method 5: Normalize Mean + Hard Clamp ==================== - elif method == 'normalize_mean': - # Subtract mean (centers the distribution) - # Pros: Centers logits around 0, prevents drift - # Cons: May change relative probabilities more than normalize_max - logits_normalized = logits_policy - logits_policy.mean(dim=-1, keepdim=True).detach() - return torch.clamp(logits_normalized, min=clip_min, max=clip_max) - - # ==================== Method 6: Adaptive Clipping ==================== - elif method == 'adaptive': - # Dynamically adjust clipping thresholds based on running statistics - # Update running stats (only during training) - if self.training: - with torch.no_grad(): - # Compute percentile-based bounds - flat_logits = logits_policy.view(-1) - percentile = self.policy_logits_adaptive_percentile - current_max = torch.quantile(flat_logits, percentile / 100.0) - current_min = torch.quantile(flat_logits, (100 - percentile) / 100.0) - - # Update running statistics with momentum - self.policy_logits_running_max = ( - self.policy_logits_momentum * self.policy_logits_running_max + - (1 - self.policy_logits_momentum) * current_max - ) - self.policy_logits_running_min = ( - self.policy_logits_momentum * self.policy_logits_running_min + - (1 - self.policy_logits_momentum) * current_min - ) - - # Use running stats for clipping - adaptive_max = torch.clamp(self.policy_logits_running_max, max=clip_max) - adaptive_min = torch.clamp(self.policy_logits_running_min, min=clip_min) - return torch.clamp(logits_policy, min=adaptive_min, max=adaptive_max) - - else: - raise ValueError( - f"Unknown policy_logits_clip_method: {method}. " - f"Valid options: 'hard', 'soft_tanh', 'soft_sigmoid', 'normalize_max', " - f"'normalize_mean', 'adaptive', 'none'" - ) - def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: """Create head modules for the transformer.""" modules = [ - nn.LayerNorm(self.config.embed_dim), # Core optimization! # TODO - nn.Linear(self.config.embed_dim, self.config.embed_dim*4), - nn.LayerNorm(self.config.embed_dim*4), # 2. New! Stabilize internal activations - nn.GELU(approximate='tanh'), - nn.Linear(self.config.embed_dim*4, output_dim) - ] - if norm_layer: - modules.append(norm_layer) - return Head( - max_blocks=self.config.max_blocks, - block_mask=block_mask, - head_module=nn.Sequential(*modules) - ) - - def _create_head_for_latent(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: - """Create head modules for the transformer.""" - modules = [ - nn.LayerNorm(self.config.embed_dim), # Core optimization! # TODO - nn.Linear(self.config.embed_dim, self.config.embed_dim*4), - nn.LayerNorm(self.config.embed_dim*4), # 2. New! Stabilize internal activations + nn.Linear(self.config.embed_dim, self.config.embed_dim), nn.GELU(approximate='tanh'), - nn.Linear(self.config.embed_dim*4, output_dim) + nn.Linear(self.config.embed_dim, output_dim) ] if norm_layer: modules.append(norm_layer) @@ -788,22 +351,21 @@ def _initialize_last_layer(self) -> None: nn.init.zeros_(layer.bias) break + def _initialize_cache_structures(self) -> None: + """Initialize cache structures for past keys and values.""" + from collections import defaultdict + self.past_kv_cache_recurrent_infer = defaultdict(dict) + self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] def _initialize_projection_input_dim(self) -> None: """Initialize the projection input dimension based on the number of observation tokens.""" if self.num_observations_tokens == 16: self.projection_input_dim = 128 elif self.num_observations_tokens == 1: - # self.projection_input_dim = self.config.embed_dim - if self.task_embed_option == "concat_task_embed": - self.projection_input_dim = self.config.embed_dim - self.task_embed_dim - elif self.task_embed_option == "register_task_embed": - self.projection_input_dim = self.config.embed_dim - elif self.task_embed_option == "add_task_embed": - self.projection_input_dim = self.config.embed_dim - else: - self.projection_input_dim = self.config.embed_dim + self.projection_input_dim = self.obs_per_embdding_dim def _initialize_statistics(self) -> None: """Initialize counters for hit count and query count statistics.""" @@ -859,7 +421,6 @@ def precompute_pos_emb_diff_kv(self): self.pos_emb_diff_k.append(layer_pos_emb_diff_k) self.pos_emb_diff_v.append(layer_pos_emb_diff_v) - #@profile def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: """ Helper function to get positional embedding for a given layer and attention type. @@ -1060,25 +621,17 @@ def forward( x = self._transformer_pass( sequences, past_keys_values, kvcache_independent, valid_context_lengths, start_pos=start_pos_adjusted ) - + # Generate logits for various components. + # import pudb;pudb.set_traces() logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) - - # ==================== [NEW] Advanced Policy Logits Control ==================== - # Apply configurable policy logits control to prevent explosion - # Multiple methods available: hard, soft_tanh, soft_sigmoid, normalize_max, etc. - if self.use_policy_logits_clip: - logits_policy = self._apply_policy_logits_control(logits_policy) - # ================================================================================ - logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) # The 'logits_ends' is intentionally set to None. return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) - #@profile def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths): """ @@ -1107,7 +660,6 @@ def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_in valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) return embeddings + position_embeddings - #@profile def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_steps): """ Process combined observation embeddings and action tokens. @@ -1147,7 +699,6 @@ def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_step return_result += self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) return return_result, num_steps - #@profile def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): """ Process combined observation embeddings and action tokens. @@ -1200,7 +751,6 @@ def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, va else: return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths, start_pos=start_pos) - #@profile @torch.no_grad() def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, start_pos: int = 0) -> torch.FloatTensor: """ @@ -1229,108 +779,123 @@ def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, start_pos outputs_wm = self.wm_forward_for_initial_infererence(obs_embeddings, batch_action, current_obs_embeddings, start_pos) else: - # ================ calculate the target value in Train phase or calculate the target policy in reanalyze phase ================ + # ================ calculate ‘the target value in Train phase or calculate the target policy in reanalyze phase ================ self.latent_state = obs_embeddings + # import pudb;pudb.set_trace() outputs_wm = self.wm_forward_for_initial_infererence(obs_embeddings, batch_action, None, start_pos) return outputs_wm, self.latent_state - #@profile @torch.no_grad() def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTensor, batch_action=None, current_obs_embeddings=None, start_pos: int = 0) -> torch.FloatTensor: """ - Refresh key-value pairs with the initial latent state for inference. + 在初始推理阶段刷新键值对缓存 (KV Cache)。 + + KV Cache 机制详解: + ================== + 1. **目的**: 避免重复计算 Transformer 的注意力键值对,提高推理效率 + 2. **核心思想**: 相同的潜在状态对应相同的键值对,可以直接复用 + 3. **多环境支持**: 每个环境维护独立的缓存状态,支持并行推理 + 4. **缓存层次**: + - shared_pool_init_infer: 初始推理阶段的共享缓存池 (按环境分组) + - shared_pool_recur_infer: 递归推理阶段的共享缓存池 (全局) + - past_kv_cache_init_infer_envs: 状态哈希到缓存索引的映射表 Arguments: - - last_obs_embeddings (:obj:`torch.LongTensor`): The latent state embeddings. - - batch_action (optional): Actions taken. - - current_obs_embeddings (optional): Current observation embeddings. + - last_obs_embeddings (:obj:`torch.LongTensor`): 上一步的潜在状态嵌入 + - batch_action (optional): 执行的动作 + - current_obs_embeddings (optional): 当前观察的嵌入 Returns: - - torch.FloatTensor: The outputs from the world model. + - torch.FloatTensor: 世界模型的输出 """ n, num_observations_tokens, _ = last_obs_embeddings.shape + + # import pudb;pudb.set_trace() + if n <= self.env_num and current_obs_embeddings is not None: - # ================ Collect and Evaluation Phase ================ + # ================ 收集和评估阶段 ================ if current_obs_embeddings is not None: - # Determine whether it is the first step in an episode. + # 判断是否为 episode 的第一步 + + # if -1 in batch_action: + # import pudb;pudb.set_trace() + if self.continuous_action_space: first_step_flag = not isinstance(batch_action[0], np.ndarray) else: + # import pudb;pudb.set_trace() first_step_flag = max(batch_action) == -1 if first_step_flag: - # ------------------------- First Step of an Episode ------------------------- + # ------------------------- Episode 第一步:初始化 KV Cache ------------------------- + # 为当前批次的所有环境生成空的 KV Cache + # keys_values_wm 是全局的多环境 KV Cache,存储所有环境的键值对 self.keys_values_wm = self.transformer.generate_empty_keys_values(n=current_obs_embeddings.shape[0], max_tokens=self.context_length) # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + + # 使用当前观察嵌入进行前向传播,同时更新 KV Cache outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, past_keys_values=self.keys_values_wm, is_init_infer=True, start_pos=start_pos) - # Copy and store keys_values_wm for a single environment + # 将更新后的 KV Cache 复制并存储到单环境缓存池中,用于后续的缓存查找 self.update_cache_context(current_obs_embeddings, is_init_infer=True) else: - # --------------------- Continuing an Episode (Multi-environment) --------------------- - # current_obs_embeddings is the new latent_state, containing information from ready_env_num environments + # --------------------- Episode 继续步骤:KV Cache 查找与复用 --------------------- + # current_obs_embeddings 是新的潜在状态,包含来自 ready_env_num 个环境的信息 ready_env_num = current_obs_embeddings.shape[0] - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] + self.keys_values_wm_list = [] # 存储每个环境的 KV Cache + self.keys_values_wm_size_list = [] # 存储每个环境的 KV Cache 大小 for i in range(ready_env_num): - # Retrieve latent state for a single environment + # 获取单个环境的潜在状态 # TODO: len(last_obs_embeddings) may smaller than len(current_obs_embeddings), because some environments may have done state_single_env = last_obs_embeddings[i] - # Compute hash value using latent state for a single environment + # 使用潜在状态计算哈希值作为缓存键 + # 这是 KV Cache 查找的关键:相同状态对应相同的缓存 cache_key = hash_state(state_single_env.view(-1).cpu().numpy()) # last_obs_embeddings[i] is torch.Tensor - # ==================== Storage Layer Integration ==================== - # Retrieve cached value - if self.use_new_cache_manager: - # NEW SYSTEM: Use KVCacheManager - matched_value = self.kv_cache_manager.get_init_cache(env_id=i, cache_key=cache_key) + # 从初始推理缓存池中检索缓存值 + # past_kv_cache_init_infer_envs[i] 是第 i 个环境的缓存字典:{状态哈希 -> 缓存索引} + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + # 如果找到缓存索引,从共享池中获取对应的 KV Cache + matched_value = self.shared_pool_init_infer[i][cache_index] else: - # OLD SYSTEM: Use legacy cache dictionaries - cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) - if cache_index is not None: - matched_value = self.shared_pool_init_infer[i][cache_index] - else: - matched_value = None - # ============================================================================= + matched_value = None + # 统计缓存查询次数(用于性能分析) self.root_total_query_cnt += 1 if matched_value is not None: - # If a matching value is found, add it to the list + # ========== KV Cache 命中:复用已有的键值对 ========== self.root_hit_cnt += 1 - # ==================== BUG FIX: Cache Corruption Prevention ==================== - # Perform a deep copy because the transformer's forward pass modifies matched_value in-place. - if self.use_new_cache_manager: - # NEW SYSTEM: Use KeysValues.clone() for deep copy - cached_copy = matched_value.clone() - self.keys_values_wm_list.append(cached_copy) - else: - # OLD SYSTEM: Use custom_copy_kv_cache_to_shared_wm - self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) - # ============================================================================= + # 注意:需要深拷贝,因为 forward 会就地修改 matched_value + # custom_copy_kv_cache_to_shared_wm 将缓存复制到世界模型共享池 + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) self.keys_values_wm_size_list.append(matched_value.size) else: - # Reset using zero values + # ========== KV Cache 未命中:重新计算 ========== + # 生成空的单环境 KV Cache self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length) - # If using RoPE positional encoding, then at reset, the pos_embed should use the absolute position start_pos[i]. + # 如果使用 RoPE 位置编码,重置时位置嵌入应使用绝对位置 start_pos[i] outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)}, past_keys_values=self.keys_values_wm_single_env, is_init_infer=True, start_pos=start_pos[i].item()) self.keys_values_wm_list.append(self.keys_values_wm_single_env) self.keys_values_wm_size_list.append(1) - # Input self.keys_values_wm_list, output self.keys_values_wm + # ========== KV Cache 批处理:统一大小并合并 ========== + # 将多个环境的 KV Cache 统一大小并合并为批处理格式 + # trim_and_pad_kv_cache 确保所有环境的缓存具有相同的序列长度,便于批处理 self.keys_values_wm_size_list_current = self.trim_and_pad_kv_cache(is_init_infer=True) start_pos = start_pos[:ready_env_num] # TODO: len(last_obs_embeddings) may smaller than len(current_obs_embeddings), because some environments may have done # TODO: the order may be not correct? len(batch_action) may smaller than len(current_obs_embeddings), because some environments may have done batch_action = batch_action[:ready_env_num] - + # TODO: only for debug # if ready_env_num < self.env_num: # print(f'init inference ready_env_num: {ready_env_num} < env_num: {self.env_num}') @@ -1341,30 +906,38 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens # print(f"len(batch_action): {len(batch_action)}") # print(f"len(current_obs_embeddings): {len(current_obs_embeddings)}") + if self.continuous_action_space: act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(1) else: - act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(-1) - + act_tokens = torch.tensor(batch_action, dtype=torch.long, device=last_obs_embeddings.device).unsqueeze(-1) + + # ========== 两步前向传播:动作 -> 观察 ========== + # 第一步:处理动作 token,更新 KV Cache + # past_keys_values=self.keys_values_wm 传入之前的缓存状态 outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, is_init_infer=True, start_pos=start_pos) + # 第二步:处理观察嵌入,继续更新 KV Cache + # 此时 self.keys_values_wm 已经包含了动作的键值对信息 outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, past_keys_values=self.keys_values_wm, is_init_infer=True, start_pos=start_pos) - # Copy and store keys_values_wm for a single environment + # 将最新的 KV Cache 状态保存到缓存池中,供后续查找使用 self.update_cache_context(current_obs_embeddings, is_init_infer=True) elif batch_action is not None and current_obs_embeddings is None: # ================ calculate the target value in Train phase or calculate the target policy in reanalyze phase ================ # [192, 16, 64] -> [32, 6, 16, 64] last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens, - self.config.embed_dim) # (BL, K) for unroll_step=1 + self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 last_obs_embeddings = last_obs_embeddings[:, :-1, :] batch_action = torch.from_numpy(batch_action).to(last_obs_embeddings.device) if self.continuous_action_space: act_tokens = batch_action else: + + # import pudb;pudb.set_trace() act_tokens = rearrange(batch_action, 'b l -> b l 1') # select the last timestep for each sample @@ -1386,10 +959,16 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens # outputs_wm.logits_value.shape (B, H, 101) = (B*H, 101) outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - + else: + raise ValueError( + f"Unhandled case in wm_forward_for_initial_infererence:\n" + f" n={n}, env_num={self.env_num}\n" + f" batch_action is None: {batch_action is None}\n" + f" current_obs_embeddings is None: {current_obs_embeddings is None}\n" + f" This should not happen. Please check the calling logic." + ) return outputs_wm - #@profile @torch.no_grad() def forward_initial_inference(self, obs_act_dict, start_pos: int = 0): """ @@ -1400,22 +979,15 @@ def forward_initial_inference(self, obs_act_dict, start_pos: int = 0): Returns: - tuple: A tuple containing output sequence, latent state, logits rewards, logits policy, and logits value. """ + # UniZero has context in the root node + # import pudb;pudb.set_trace() outputs_wm, latent_state = self.reset_for_initial_inference(obs_act_dict, start_pos) - - # ==================== BUG FIX: Clear Cache Using Correct API ==================== - if self.use_new_cache_manager: - # NEW SYSTEM: Clear recurrent cache using KVCacheManager - self.kv_cache_manager.clear_recur_cache() - else: - # OLD SYSTEM: Clear using legacy attribute - self.past_kv_cache_recurrent_infer.clear() - # ============================================================================= + self.past_kv_cache_recurrent_infer.clear() return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value) - #@profile @torch.no_grad() def forward_recurrent_inference(self, state_action_history, simulation_index=0, search_depth=[], start_pos: int = 0): @@ -1502,7 +1074,6 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0, return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) - #@profile def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: """ Adjusts the key-value cache for each environment to ensure they all have the same size. @@ -1555,7 +1126,6 @@ def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: return self.keys_values_wm_size_list - #@profile def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, search_depth=[], valid_context_lengths=None): """ @@ -1689,72 +1259,16 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 - if self.use_new_cache_manager: - # NEW SYSTEM: Use KVCacheManager for cache storage - # ==================== BUG FIX: Deep Copy Before Storage ==================== - # CRITICAL: Must clone before storing to prevent cache corruption. - # self.keys_values_wm_single_env is a shared object that gets modified. - # Without cloning, all cache entries would point to the same object, - # causing incorrect KV retrieval and training divergence. - kv_cache_to_store = self.keys_values_wm_single_env.clone() - # ============================================================================= - - if is_init_infer: - # Store to per-environment init cache pool - # Note: KVCacheManager automatically handles eviction logic (FIFO/LRU) - self.kv_cache_manager.set_init_cache( - env_id=i, - cache_key=cache_key, - kv_cache=kv_cache_to_store # Store cloned copy, not reference - ) - else: - # Store to global recurrent cache pool - self.kv_cache_manager.set_recur_cache( - cache_key=cache_key, - kv_cache=kv_cache_to_store # Store cloned copy, not reference - ) + if is_init_infer: + # Store the latest key-value cache for initial inference + cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index else: - # OLD SYSTEM: Use legacy cache with manual eviction - if is_init_infer: - # ==================== Active Eviction Fix Logic ==================== - # 1. Get the physical index that will be overwritten - index_to_write = self.shared_pool_index_init_envs[i] - # 2. Use auxiliary list to find the old key stored at this index - old_key_to_evict = self.pool_idx_to_key_map_init_envs[i][index_to_write] - # 3. If old key exists, delete it from the main cache map - if old_key_to_evict is not None: - # Ensure the key to be deleted actually exists to avoid unexpected errors - if old_key_to_evict in self.past_kv_cache_init_infer_envs[i]: - del self.past_kv_cache_init_infer_envs[i][old_key_to_evict] - - # Now it's safe to write new data - cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) - - # 4. Update both the main cache map and auxiliary list with new mapping - self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index - self.pool_idx_to_key_map_init_envs[i][index_to_write] = cache_key - else: - # ==================== RECURRENT INFER FIX ==================== - # 1. Get the physical index that will be overwritten - index_to_write = self.shared_pool_index - # 2. Use auxiliary list to find the old key stored at this index - old_key_to_evict = self.pool_idx_to_key_map_recur_infer[index_to_write] - # 3. If old key exists, delete it from the main cache map - if old_key_to_evict is not None: - if old_key_to_evict in self.past_kv_cache_recurrent_infer: - del self.past_kv_cache_recurrent_infer[old_key_to_evict] - - # 4. Now it's safe to write new data - cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + # Store the latest key-value cache for recurrent inference + cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + self.past_kv_cache_recurrent_infer[cache_key] = cache_index - # 5. Update both the main cache map and auxiliary list with new mapping - self.past_kv_cache_recurrent_infer[cache_key] = cache_index - self.pool_idx_to_key_map_recur_infer[index_to_write] = cache_key - # ============================================================================= - - - #@profile def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, simulation_index: int = 0, start_pos: int = 0) -> list: """ @@ -1780,47 +1294,22 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, # TODO: check if this is correct matched_value = None else: - if self.use_new_cache_manager: - # NEW SYSTEM: Use KVCacheManager's hierarchical_get for unified lookup - matched_value = self.kv_cache_manager.hierarchical_get(env_id=index, cache_key=cache_key) - - # Log cache miss (statistics are automatically handled by KVCacheManager) - if matched_value is None: - logging.debug(f"[NEW CACHE MISS] Not found for key={cache_key} in both init and recurrent cache.") + # Try to retrieve the cached value from past_kv_cache_init_infer_envs + cache_index = self.past_kv_cache_init_infer_envs[index].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[index][cache_index] else: - # OLD SYSTEM: Use legacy cache dictionaries and pools - # Try to retrieve the cached value from past_kv_cache_init_infer_envs - cache_index = self.past_kv_cache_init_infer_envs[index].get(cache_key) - if cache_index is not None: - matched_value = self.shared_pool_init_infer[index][cache_index] - else: - matched_value = None + matched_value = None - # Only try to find from recurrent_infer cache if not found in init_infer - if matched_value is None: - # Safely get the index from dictionary, it may return None - recur_cache_index = self.past_kv_cache_recurrent_infer.get(cache_key) - # Only use it to retrieve value from physical pool if the index is valid (not None) - if recur_cache_index is not None: - matched_value = self.shared_pool_recur_infer[recur_cache_index] - - if recur_cache_index is None: - logging.debug(f"[OLD CACHE MISS] Not found for key={cache_key} in recurrent infer. Generating new cache.") - # ============================================================================= + # If not found, try to retrieve from past_kv_cache_recurrent_infer + if matched_value is None: + matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] if matched_value is not None: # If a matching cache is found, add it to the lists self.hit_count += 1 - # Perform a deep copy because the transformer's forward pass modifies matched_value in-place. - # Without cloning, the original cache in init_pool or recur_pool would be polluted, - # causing incorrect predictions in subsequent queries. - if self.use_new_cache_manager: - # NEW SYSTEM: Use KeysValues.clone() for deep copy - cached_copy = matched_value.clone() - self.keys_values_wm_list.append(cached_copy) - else: - # OLD SYSTEM: Use custom_copy_kv_cache_to_shared_wm - self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + # Perform a deep copy because the transformer's forward pass might modify matched_value in-place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) self.keys_values_wm_size_list.append(matched_value.size) else: # If no matching cache is found, generate a new one using zero reset @@ -1853,60 +1342,28 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar **kwargs: Any) -> LossWithIntermediateLosses: start_pos = batch['timestep'] # Encode observations into latent state representations - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations']) - - # ======================== Logging for Analysis ======================== - # This block calculates various metrics for model analysis if the corresponding config flag is enabled. - # These metrics help in debugging and understanding model behavior during training. - if self.analysis_dormant_ratio_weight_rank: - # --- Dormant Ratio Calculation --- - # Calculate the dormant ratio of the encoder to monitor neuron activity. - shape = batch['observations'].shape # Original shape, e.g., (B, T, C, H, W) - # Reshape observations to create a single large batch for the encoder. - # E.g., (32, 5, 3, 64, 64) -> (160, 3, 64, 64) - inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) - - dormant_ratio_encoder_dict = calculate_dormant_ratio( - self.tokenizer.encoder, inputs.detach(), dormant_threshold=self.dormant_threshold - ) - dormant_ratio_encoder = dormant_ratio_encoder_dict['global'] - - # --- Average Weight Magnitude Calculation --- - # Calculate the global average absolute weight magnitude for different model components. - # This is a useful metric for monitoring training stability. - avg_weight_mag_encoder = compute_average_weight_magnitude(self.tokenizer.encoder) - avg_weight_mag_transformer = compute_average_weight_magnitude(self.transformer) - avg_weight_mag_head = compute_average_weight_magnitude(self.head_dict) - - # --- Effective Rank Calculation --- - # Calculate the effective rank of representations from specific layers in the encoder. - # This metric helps analyze the dimensionality and information content of the learned features. - # The 'representation_layer_name' argument specifies the target layer within the model's named modules. - - # Effective rank for the final linear layer of the encoder. - e_rank_last_linear = compute_effective_rank( - self.tokenizer.encoder, inputs, representation_layer_name="last_linear" - ) - # Effective rank for the SimNorm layer of the encoder. - e_rank_sim_norm = compute_effective_rank( - self.tokenizer.encoder, inputs, representation_layer_name="sim_norm" - ) - - # ==================== Clear Cache Using Correct API ==================== - if self.use_new_cache_manager: - self.kv_cache_manager.clear_recur_cache() - else: - self.past_kv_cache_recurrent_infer.clear() - # ============================================================================= + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations']) # torch.Size([256, 5, 4]) + + # ========= for visual analysis ========= + # Uncomment the lines below for visual analysis in Pong + # self.plot_latent_tsne_each_and_all_for_pong(obs_embeddings, suffix='pong_H10_H4_tsne') + # self.save_as_image_with_timestep(batch['observations'], suffix='pong_H10_H4_tsne') + # Uncomment the lines below for visual analysis in visual match + # self.plot_latent_tsne_each_and_all(obs_embeddings, suffix='visual_match_memlen1-60-15_tsne') + # self.save_as_image_with_timestep(batch['observations'], suffix='visual_match_memlen1-60-15_tsne') + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio: + # Calculate dormant ratio of the encoder + shape = batch['observations'].shape # (..., C, H, W) + inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) + dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.representation_network, inputs.detach(), + percentage=self.dormant_threshold) + self.past_kv_cache_recurrent_infer.clear() self.keys_values_wm_list.clear() torch.cuda.empty_cache() else: dormant_ratio_encoder = torch.tensor(0.) - avg_weight_mag_encoder = torch.tensor(0.) - avg_weight_mag_transformer = torch.tensor(0.) - avg_weight_mag_head = torch.tensor(0.) - e_rank_last_linear = torch.tensor(0.) - e_rank_sim_norm = torch.tensor(0.) # Calculate the L2 norm of the latent state roots latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() @@ -1915,74 +1372,42 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar if self.continuous_action_space: act_tokens = batch['actions'] else: - act_tokens = rearrange(batch['actions'], 'b l -> b l 1') + act_tokens = rearrange(batch['actions'], 'b l -> b l 1') # torch.Size([256, 5]) # Forward pass to obtain predictions for observations, rewards, and policies outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, start_pos=start_pos) - # Get intermediate tensor x from model output and detach computation graph - intermediate_tensor_x = outputs.output_sequence.detach() - - global_step = kwargs.get('global_step', 0) - if global_step > 0 and global_step % 100000000000 == 0: # TODO - - with torch.no_grad(): - # Convert logits to scalar values - # Note: outputs shape is (B, L, E), we need to reshape - batch_size, seq_len = batch['actions'].shape[0], batch['actions'].shape[1] - - pred_val_logits = outputs.logits_value.view(batch_size * seq_len, -1) - pred_rew_logits = outputs.logits_rewards.view(batch_size * seq_len, -1) - - scalar_values = inverse_scalar_transform_handle(pred_val_logits).squeeze(-1) - scalar_rewards = inverse_scalar_transform_handle(pred_rew_logits).squeeze(-1) - - self._analyze_latent_representation( - latent_states=obs_embeddings, - timesteps=batch['timestep'], - game_states=batch['observations'], - predicted_values=scalar_values, - predicted_rewards=scalar_rewards, - step_counter=global_step - ) - - if self.config.use_priority: - # Calculate value_priority, similar to MuZero. - with torch.no_grad(): - # 1. Get the predicted value logits for the first step of the sequence (t=0). - # The shape is (B, support_size). - predicted_value_logits_step0 = outputs.logits_value[:, 0, :] - - # 2. Convert the categorical prediction to a scalar value. - # The shape becomes (B, 1). - predicted_scalar_value_step0 = inverse_scalar_transform_handle(predicted_value_logits_step0) - - # 3. Get the target scalar value for the first step from the batch. - # The shape is (B, num_unroll_steps), so we take the first column. - target_scalar_value_step0 = batch['scalar_target_value'][:, 0] - - # 4. Calculate the L1 loss (absolute difference) between prediction and target. - # This is the priority. We use reduction='none' to get per-sample priorities. - value_priority = F.l1_loss(predicted_scalar_value_step0.squeeze(-1), target_scalar_value_step0, reduction='none') - else: - value_priority = torch.tensor(0.) - if self.obs_type == 'image': - if self.config.latent_recon_loss_weight > 0: - # Reconstruct observations from latent state representations - reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) - - # Calculate reconstruction loss and perceptual loss - latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 - perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 - else: - # TODO: - latent_recon_loss = self.latent_recon_loss - perceptual_loss = self.perceptual_loss + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # original_images, reconstructed_images = batch['observations'], reconstructed_images + # target_policy = batch['target_policy'] + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # ========== Calculate reconstruction loss and perceptual loss ============ + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + # perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + + latent_recon_loss = self.latent_recon_loss + perceptual_loss = self.perceptual_loss elif self.obs_type == 'vector': perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) + + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings.reshape(-1, self.embed_dim)) + + # # Calculate reconstruction loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 25), + # reconstructed_images) latent_recon_loss = self.latent_recon_loss elif self.obs_type == 'text': @@ -2013,29 +1438,49 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar latent_recon_loss = self.latent_recon_loss elif self.obs_type == 'image_memory': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + # original_images, reconstructed_images = batch['observations'], reconstructed_images + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # target_policy = batch['target_policy'] + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 5, 5), + # reconstructed_images) latent_recon_loss = self.latent_recon_loss perceptual_loss = self.perceptual_loss - # ========= Logging for analysis ========= - if self.analysis_dormant_ratio_weight_rank: + # ========= logging for analysis ========= + if self.analysis_dormant_ratio: # Calculate dormant ratio of the world model - dormant_ratio_world_model = calculate_dormant_ratio(self, { + dormant_ratio_world_model = cal_dormant_ratio(self, { 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, - dormant_threshold=self.dormant_threshold) - dormant_ratio_transformer = dormant_ratio_world_model['transformer'] - dormant_ratio_head = dormant_ratio_world_model['head'] - - # ==================== Clear Cache Using Correct API ==================== - if self.use_new_cache_manager: - self.kv_cache_manager.clear_recur_cache() - else: - self.past_kv_cache_recurrent_infer.clear() - # ============================================================================= + percentage=self.dormant_threshold) + self.past_kv_cache_recurrent_infer.clear() self.keys_values_wm_list.clear() torch.cuda.empty_cache() else: - dormant_ratio_transformer = torch.tensor(0.) - dormant_ratio_head = torch.tensor(0.) + dormant_ratio_world_model = torch.tensor(0.) + + # ========== for visualization ========== + # Uncomment the lines below for visualization + # predict_policy = outputs.logits_policy + # predict_policy = F.softmax(outputs.logits_policy, dim=-1) + # predict_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # predict_rewards = inverse_scalar_transform_handle(outputs.logits_rewards.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # import pdb; pdb.set_trace() + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=[], suffix='pong_H10_H4_0613') + + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_success_episode') + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_fail_episode') + # ========== for visualization ========== # For training stability, use target_tokenizer to compute the true next latent state representations with torch.no_grad(): @@ -2071,29 +1516,15 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # assert not torch.isinf(loss_obs).any(), "loss_obs contains Inf values" # for name, param in self.tokenizer.encoder.named_parameters(): # print('name, param.mean(), param.std():', name, param.mean(), param.std()) - elif self.predict_latent_loss_type == 'cos_sim': - # Cosine Similarity Loss - # print("predict_latent_loss_type == 'cos_sim'") - cosine_sim_loss = 1 - F.cosine_similarity(logits_observations, labels_observations, dim=-1) - loss_obs = cosine_sim_loss # Apply mask to loss_obs mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) loss_obs = (loss_obs * mask_padding_expanded) - # ==================== [NEW] Fix3: Load re-smooth options from config ==================== - use_target_policy_resmooth = getattr(self.config, 'use_target_policy_resmooth', False) - target_policy_resmooth_eps = getattr(self.config, 'target_policy_resmooth_eps', 0.05) - # ====================================================================================== - - # Compute labels for policy and value (with optional re-smoothing) - labels_policy, labels_value = self.compute_labels_world_model_value_policy( - batch['target_value'], - batch['target_policy'], - batch['mask_padding'], - use_target_policy_resmooth=use_target_policy_resmooth, - target_policy_resmooth_eps=target_policy_resmooth_eps - ) + # Compute labels for policy and value + labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], + batch['target_policy'], + batch['mask_padding']) # Compute losses for rewards, policy, and value loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') @@ -2114,6 +1545,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') + # ==== TODO: calculate the new priorities for each transition. ==== + # value_priority = L1Loss(reduction='none')(labels_value.squeeze(-1), outputs['logits_value'][:, 0]) + # value_priority = value_priority.data.cpu().numpy() + 1e-6 + # Compute timesteps timesteps = torch.arange(batch['actions'].shape[1], device=batch['actions'].device) # Compute discount coefficients for each timestep @@ -2165,10 +1600,6 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() - # Add encoder output to return dictionary for external training loop access - # Using .detach() because this tensor is only used for subsequent clip operations and should not affect gradient computation - detached_obs_embeddings = obs_embeddings.detach() - if self.continuous_action_space: return LossWithIntermediateLosses( latent_recon_loss_weight=self.latent_recon_loss_weight, @@ -2186,24 +1617,11 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar middle_step_losses=middle_step_losses, last_step_losses=last_step_losses, dormant_ratio_encoder=dormant_ratio_encoder, - dormant_ratio_transformer=dormant_ratio_transformer, - dormant_ratio_head=dormant_ratio_head, - avg_weight_mag_encoder = avg_weight_mag_encoder, - avg_weight_mag_transformer = avg_weight_mag_transformer, - avg_weight_mag_head = avg_weight_mag_head, - e_rank_last_linear = e_rank_last_linear, - e_rank_sim_norm = e_rank_sim_norm, + dormant_ratio_world_model=dormant_ratio_world_model, latent_state_l2_norms=latent_state_l2_norms, policy_mu=mu, policy_sigma=sigma, target_sampled_actions=target_sampled_actions, - - value_priority=value_priority, - intermediate_tensor_x=intermediate_tensor_x, - obs_embeddings=detached_obs_embeddings, - logits_value=outputs.logits_value.detach(), - logits_reward=outputs.logits_rewards.detach(), - logits_policy=outputs.logits_policy.detach(), ) else: return LossWithIntermediateLosses( @@ -2222,23 +1640,253 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar middle_step_losses=middle_step_losses, last_step_losses=last_step_losses, dormant_ratio_encoder=dormant_ratio_encoder, - dormant_ratio_transformer=dormant_ratio_transformer, - dormant_ratio_head=dormant_ratio_head, - avg_weight_mag_encoder = avg_weight_mag_encoder, - avg_weight_mag_transformer = avg_weight_mag_transformer, - avg_weight_mag_head = avg_weight_mag_head, - e_rank_last_linear = e_rank_last_linear, - e_rank_sim_norm = e_rank_sim_norm, + dormant_ratio_world_model=dormant_ratio_world_model, latent_state_l2_norms=latent_state_l2_norms, - value_priority=value_priority, - intermediate_tensor_x=intermediate_tensor_x, - obs_embeddings=detached_obs_embeddings, - logits_value=outputs.logits_value.detach(), - logits_reward=outputs.logits_rewards.detach(), - logits_policy=outputs.logits_policy.detach(), ) + def compute_loss_ppo( + self, + batch: Dict[str, torch.Tensor], + target_tokenizer: Tokenizer = None, + inverse_scalar_transform_handle=None, + clip_ratio: float = 0.2, + value_coef: float = 0.5, + entropy_coef: float = 0.01, + **kwargs: Any + ) -> LossWithIntermediateLosses: + """ + Compute PPO losses combined with UniZero's observation and reward losses. + + Args: + batch: Dictionary containing batch data including PPO-specific fields: + - 'advantages': GAE advantages [B, T] + - 'old_log_prob': Old policy log probabilities [B, T] + - 'returns': Target returns for value function [B, T] + target_tokenizer: Target tokenizer for computing labels + inverse_scalar_transform_handle: Function to convert categorical values to scalars + clip_ratio: PPO clipping ratio (default: 0.2) + value_coef: Coefficient for value loss (default: 0.5) + entropy_coef: Coefficient for entropy loss (default: 0.01) + """ + start_pos = batch['timestep'] + # ========== 1. Observation encoding and forward pass (same as compute_loss) ========== + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations']) + + # Action tokens + if self.continuous_action_space: + act_tokens = batch['actions'] + else: + act_tokens = rearrange(batch['actions'], 'b l -> b l 1') + + # Forward pass + outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, start_pos=start_pos) + + # ========== 2. Observation and reward losses (same as compute_loss) ========== + # Handle different observation types + if self.obs_type == 'vector': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + latent_recon_loss = self.latent_recon_loss + elif self.obs_type == 'image': + latent_recon_loss = self.latent_recon_loss + perceptual_loss = self.perceptual_loss + elif self.obs_type == 'text': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=torch.float32) + decode_loss_mode = self.config.decode_loss_mode + if decode_loss_mode == "after_backbone": + next_latent_state = outputs.logits_observations[:, :-1, :] + next_target_ids = batch['observations'][:, 1:, :] + latent_recon_loss = self.tokenizer.decode_to_reconstruction_outputs( + embeddings=next_latent_state, + target_ids=next_target_ids, + ).loss + elif decode_loss_mode == "before_backbone": + latent_recon_loss = self.tokenizer.decode_to_reconstruction_outputs( + embeddings=obs_embeddings, + target_ids=batch['observations'], + ).loss + else: + latent_recon_loss = self.latent_recon_loss + else: + latent_recon_loss = self.latent_recon_loss + perceptual_loss = self.perceptual_loss + + # Compute labels for observations and rewards + with torch.no_grad(): + target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations']) + + labels_observations, labels_rewards, _ = self.compute_labels_world_model( + target_obs_embeddings, batch['rewards'], batch['ends'], batch['mask_padding'] + ) + + # Observation loss + logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') + labels_observations = labels_observations.reshape(-1, self.projection_input_dim) + + if self.predict_latent_loss_type == 'mse': + loss_obs = F.mse_loss(logits_observations, labels_observations, reduction='none').mean(-1) + elif self.predict_latent_loss_type == 'group_kl': + batch_size, num_features = logits_observations.shape + epsilon = 1e-6 + logits_reshaped = logits_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + labels_reshaped = labels_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + loss_obs = F.kl_div(logits_reshaped.log(), labels_reshaped, reduction='none').sum(dim=-1).mean(dim=-1) + else: + loss_obs = torch.tensor(0.0, device=logits_observations.device) + + mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) + loss_obs = (loss_obs * mask_padding_expanded) + + # Reward loss + loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') + + # ========== 3. PPO Policy Loss ========== + # Get PPO data from batch + advantages = batch['advantages'].float() # [B, T] + old_log_prob = batch['old_log_prob'].float() # [B, T] + actions = batch['actions'].long() # [B, T] for discrete + + # Get policy logits and create distribution + policy_logits = outputs.logits_policy # [B, T, A] + + if not self.continuous_action_space: + # Discrete action space + # Apply action mask if available + if 'action_mask' in batch: + action_mask = batch['action_mask'].bool() + masked_logits = policy_logits.masked_fill(~action_mask, -1e9) + else: + masked_logits = policy_logits + + # Create categorical distribution + dist = Categorical(logits=masked_logits) + log_prob = dist.log_prob(actions) # [B, T] + entropy = dist.entropy() # [B, T] + else: + # Continuous action space - extract mu and sigma + action_space_size = self.config.action_space_size + mu = policy_logits[:, :, :action_space_size] + sigma = policy_logits[:, :, action_space_size:] + dist = Independent(Normal(mu, sigma), 1) + log_prob = dist.log_prob(actions) # [B, T] + entropy = dist.entropy() # [B, T] + + # Calculate importance sampling ratio + ratio = torch.exp(log_prob - old_log_prob) # [B, T] + + # Clipped surrogate loss + surrogate1 = ratio * advantages + surrogate2 = torch.clamp(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) * advantages + clipped_surrogate = torch.min(surrogate1, surrogate2) # [B, T] + + # Apply mask and compute policy loss + mask_padding = batch['mask_padding'][:, :policy_logits.shape[1]] # [B, T] + policy_loss = -(clipped_surrogate * mask_padding).sum() / (mask_padding.sum() + 1e-8) + + # Policy entropy (for logging) + policy_entropy = (entropy * mask_padding).sum() / (mask_padding.sum() + 1e-8) + + # ========== 4. PPO Value Loss (使用交叉熵,与 compute_loss 一致) ========== + returns_categorical = batch['returns'] # [B, T, support_size] - 已经是分类分布 + + # 使用 compute_cross_entropy_loss 计算损失(与 compute_loss 一致) + # 准备 labels_value 格式 + labels_returns = returns_categorical.reshape(-1, self.support_size) # [B*T, support_size] + + # 使用现有的 compute_cross_entropy_loss 函数 + value_loss = self.compute_cross_entropy_loss(outputs, returns_categorical, batch, element='value') + # value_loss 已经是 masked 的,需要取平均 + value_loss = value_loss.sum() / (batch['mask_padding'].sum() + 1e-8) + + # ========== 5. Entropy Loss ========== + entropy_loss = -policy_entropy # Negative entropy to encourage exploration + + # ========== 6. Total Loss ========== + # Discount coefficients + timesteps = torch.arange(batch['actions'].shape[1], device=batch['actions'].device) + discounts = self.gamma ** timesteps + + # Discounted losses + discounted_loss_obs = (loss_obs.view(-1, batch['actions'].shape[1] - 1) * discounts[1:]).sum() / (batch['mask_padding'][:, 1:].sum() + 1e-8) + discounted_loss_rewards = (loss_rewards.view(-1, batch['actions'].shape[1]) * discounts).sum() / (batch['mask_padding'].sum() + 1e-8) + + # Total loss + loss_total = ( + discounted_loss_obs * self.latent_recon_loss_weight + + discounted_loss_rewards + + policy_loss + + value_coef * value_loss + + entropy_coef * entropy_loss + ) + + # ========== 7. Return LossWithIntermediateLosses ========== + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=self.continuous_action_space, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=value_loss, + loss_policy=policy_loss, + latent_recon_loss=discounted_loss_obs, # Using obs loss as latent recon loss + perceptual_loss=perceptual_loss, + orig_policy_loss=policy_loss, + policy_entropy=policy_entropy, + first_step_losses={}, + middle_step_losses={}, + last_step_losses={}, + dormant_ratio_encoder=torch.tensor(0.0), + dormant_ratio_world_model=torch.tensor(0.0), + latent_state_l2_norms=torch.tensor(0.0), + loss_total=loss_total, + ) + # def compute_loss_ppo( + # self, + # batch: Dict[str, torch.Tensor], + # inverse_scalar_transform_handle, + # clip_ratio: float, + # value_coef: float, + # entropy_coef: float, + # ) -> Dict[str, torch.Tensor]: + # """Compute PPO losses given policy logits and associated targets.""" + # policy_logits = batch['policy_logits'] + # action_mask = batch['action_mask'].bool() + # actions = batch['actions'].long() + # old_log_prob = batch['old_log_prob'].float() + # advantages = batch['advantages'].float() + # returns = batch['returns'].float() + + # # import pudb;pudb.set_trace() + + # pred_values = inverse_scalar_transform_handle(batch['values']).squeeze(-1) + + # masked_logits = policy_logits.masked_fill(~action_mask, -1e9) + # dist = Categorical(logits=masked_logits) + # log_prob = dist.log_prob(actions) + # entropy = dist.entropy() + + # ratio = torch.exp(log_prob - old_log_prob) + # surrogate1 = ratio * advantages + # surrogate2 = torch.clamp(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) * advantages + # policy_loss = -torch.min(surrogate1, surrogate2).mean() + # value_loss = F.mse_loss(pred_values, returns) + # entropy_mean = entropy.mean() + # entropy_loss = -entropy_mean + + # loss_total = policy_loss + value_coef * value_loss + entropy_coef * entropy_loss + + # return { + # 'loss_total': loss_total, + # 'loss_policy': policy_loss, + # 'loss_value': value_loss, + # 'loss_entropy': entropy_loss, + # 'entropy_mean': entropy_mean, + # 'ratio_mean': ratio.mean(), + # 'advantage_mean': advantages.mean(), + # 'return_mean': returns.mean(), + # } # TODO: test correctness def _calculate_policy_loss_cont_simple(self, outputs, batch: dict): """ @@ -2301,7 +1949,7 @@ def _calculate_policy_loss_cont_simple(self, outputs, batch: dict): return policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma - def _calculate_policy_loss_cont(self, outputs, batch: dict, task_id=None) -> Tuple[torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor]: + def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor]: """ Calculate the policy loss for continuous actions. @@ -2316,12 +1964,9 @@ def _calculate_policy_loss_cont(self, outputs, batch: dict, task_id=None) -> Tup - mu (:obj:`torch.Tensor`): The mean of the normal distribution. - sigma (:obj:`torch.Tensor`): The standard deviation of the normal distribution. """ - if task_id is None: - batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ + batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ 0], self.config.num_unroll_steps, self.config.action_space_size - else: - batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ - 0], self.config.num_unroll_steps, self.config.action_space_size_list[task_id] + policy_logits_all = outputs.logits_policy mask_batch = batch['mask_padding'] child_sampled_actions_batch = batch['child_sampled_actions'] @@ -2363,8 +2008,6 @@ def _calculate_policy_loss_cont(self, outputs, batch: dict, task_id=None) -> Tup # KL as projector target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) - - # KL as projector policy_loss = -torch.sum( torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 ) * mask_batch @@ -2385,15 +2028,9 @@ def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): logits = getattr(outputs, f'logits_{element}') - # ==================== TODO: Temperature Scaling for Policy ==================== - if element == 'policy' and self.use_policy_loss_temperature and self.policy_loss_temperature != 1.0: - # Apply temperature scaling to soften the distribution - logits = logits / self.policy_loss_temperature - # =================================================================================== - if torch.isnan(logits).any(): raise ValueError(f"NaN detected in outputs for batch {batch} and element '{element}'") - + if torch.isnan(labels).any(): raise ValueError(f"NaN detected in labels_value for batch {batch} and element '{element}'") @@ -2420,7 +2057,6 @@ def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): return loss - #@profile def compute_policy_entropy_loss(self, logits, mask): # Compute entropy of the policy probs = torch.softmax(logits, dim=1) @@ -2430,7 +2066,6 @@ def compute_policy_entropy_loss(self, logits, mask): entropy_loss = (entropy * mask) return entropy_loss - #@profile def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag @@ -2450,23 +2085,11 @@ def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torc return labels_observations, labels_rewards.view(-1, self.support_size), None - #@profile def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, - mask_padding: torch.BoolTensor, - use_target_policy_resmooth: bool = False, - target_policy_resmooth_eps: float = 0.05) -> Tuple[torch.Tensor, torch.Tensor]: + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute labels for value and policy predictions. """ mask_fill = torch.logical_not(mask_padding) - # ==================== [NEW] Fix3: Re-smooth Target Policy ==================== - # Re-smooth target_policy to prevent extreme distributions in buffer - if use_target_policy_resmooth and target_policy_resmooth_eps > 0: - num_actions = target_policy.shape[-1] - uniform_dist = torch.ones_like(target_policy) / num_actions - target_policy = (1 - target_policy_resmooth_eps) * target_policy + \ - target_policy_resmooth_eps * uniform_dist - # ============================================================================= - # Fill the masked areas of policy mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) labels_policy = target_policy.masked_fill(mask_fill_policy, -100) @@ -2484,23 +2107,11 @@ def clear_caches(self): """ Clears the caches of the world model. """ - if self.use_new_cache_manager: - # Use new KV cache manager's clear method - self.kv_cache_manager.clear_all() - print(f'Cleared {self.__class__.__name__} KV caches (NEW system).') - - # Optionally print stats before clearing - if hasattr(self.kv_cache_manager, 'get_stats_summary'): - stats = self.kv_cache_manager.get_stats_summary() - if stats.get('stats_enabled'): - logging.debug(f'Cache stats before clear: {stats}') - else: - # Use old cache clearing logic - for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: - kv_cache_dict_env.clear() - self.past_kv_cache_recurrent_infer.clear() - self.keys_values_wm_list.clear() - print(f'Cleared {self.__class__.__name__} past_kv_cache (OLD system).') + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + print(f'Cleared {self.__class__.__name__} past_kv_cache.') def __repr__(self) -> str: return "transformer-based latent world_model of UniZero" diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py old mode 100755 new mode 100644 index 766012870..baadffaca --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -1,92 +1,23 @@ import copy -import logging from collections import defaultdict -from typing import Any, Dict, List, Tuple, Union +from typing import List, Dict, Any, Tuple, Union import numpy as np import torch -import torch.nn.functional as F import wandb from ding.model import model_wrap from ding.utils import POLICY_REGISTRY + +from lzero.entry.utils import initialize_zeros_batch, initialize_pad_batch from lzero.mcts import UniZeroMCTSCtree as MCTSCtree from lzero.model import ImageTransforms -from lzero.policy import (DiscreteSupport, InverseScalarTransform, - mz_network_output_unpack, phi_transform, prepare_obs, - prepare_obs_stack_for_unizero, scalar_transform, - select_action, to_torch_float_tensor) -from lzero.policy.head_clip_manager import (HeadClipConfig, HeadClipManager, - create_head_clip_manager_from_dict) +from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs, \ + prepare_obs_stack_for_unizero from lzero.policy.muzero import MuZeroPolicy -from lzero.policy.utils import initialize_pad_batch -from torch.nn.utils.convert_parameters import (parameters_to_vector, - vector_to_parameters) - from .utils import configure_optimizers_nanogpt -def scale_module_weights_vectorized(module: torch.nn.Module, scale_factor: float): - """ - Efficiently scale all weights of a module using vectorized operations. - """ - if not (0.0 < scale_factor < 1.0): - return # Do nothing if the scaling factor is invalid - - # 1. Flatten all parameters of the module into a single vector - params_vec = parameters_to_vector(module.parameters()) - - # 2. Perform multiplication operation on this vector - params_vec.data.mul_(scale_factor) - - # 3. Copy the scaled vector back to the individual parameters of the module - vector_to_parameters(params_vec, module.parameters()) - - -def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): - """ - Configure optimizer with differentiated learning rates and weight decay for encoder/backbone/head of UniZero model. - """ - # 1. Define parameters that need special handling - param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} - - # 2. Divide parameters into three groups: Transformer backbone, Tokenizer, and Heads - transformer_params = {pn: p for pn, p in param_dict.items() if 'transformer' in pn} - tokenizer_params = {pn: p for pn, p in param_dict.items() if 'tokenizer' in pn} - - # Head parameters are those that belong to neither transformer nor tokenizer - head_params = { - pn: p for pn, p in param_dict.items() - if 'transformer' not in pn and 'tokenizer' not in pn - } - - # 3. Set different optimizer parameters for each group (especially learning rate) - # We still use AdamW here, but with more reasonable learning rate settings - optim_groups = [ - { - 'params': list(tokenizer_params.values()), - 'lr': learning_rate, # Tokenizer uses base learning rate, e.g., 1e-4 - 'weight_decay': weight_decay - }, - { - 'params': list(transformer_params.values()), - 'lr': learning_rate, # Tokenizer uses base learning rate, e.g., 1e-4 - 'weight_decay': weight_decay - }, - { - 'params': list(head_params.values()), - 'lr': learning_rate, # Heads also use base learning rate, e.g., 1e-4 - 'weight_decay': weight_decay - - } - ] - - logging.info("--- Optimizer Groups ---") - logging.info(f"Transformer LR: {learning_rate}") - logging.info(f"Tokenizer/Heads LR: {learning_rate}") - - optimizer = torch.optim.AdamW(optim_groups, betas=betas) - return optimizer - @POLICY_REGISTRY.register('unizero') class UniZeroPolicy(MuZeroPolicy): """ @@ -134,8 +65,6 @@ class UniZeroPolicy(MuZeroPolicy): # (int) The save interval of the model. learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), world_model_cfg=dict( - # (str) The encoder type, e.g., 'resnet' or 'vit'. - encoder_type='resnet', # (bool) If True, the action space of the environment is continuous, otherwise discrete. continuous_action_space=False, # (int) The number of tokens per block. @@ -152,8 +81,8 @@ class UniZeroPolicy(MuZeroPolicy): device='cpu', # (bool) Whether to analyze simulation normalization. analysis_sim_norm=False, - # (bool) Whether to analyze dormant ratio, average_weight_magnitude of net, effective_rank of latent. - analysis_dormant_ratio_weight_rank=False, + # (bool) Whether to analyze dormant ratio. + analysis_dormant_ratio=False, # (int) The shape of the action space. action_space_size=6, # (int) The size of the group, related to simulation normalization. @@ -210,131 +139,13 @@ class UniZeroPolicy(MuZeroPolicy): rope_theta=10000, # (int) The maximum sequence length for position encoding. max_seq_len=8192, - # (int) The rank parameter for LoRA (Low-Rank Adaptation). Set to 0 to disable LoRA. - lora_r=0, - # (float) The alpha parameter for LoRA scaling. - lora_alpha=1, - # (float) The dropout probability for LoRA layers. - lora_dropout=0.0, # Controls where to compute reconstruction loss: 'after_backbone', 'before_backbone', or None. # - after_backbone: The reconstruction loss is computed after the encoded representation passes through the backbone. - # - before_backbone: The reconstruction loss is computed directly on the encoded representation, without the backbone. + # - before_backbone: The reconstruction loss is computed directly on the encoded representation, without the backbone. decode_loss_mode=None, - # (str/None) Task embedding option. Set to None to disable task-specific embeddings. Options are ['concat_task_embed', 'add_task_embed', 'register_task_embed']. - # Please note that "register_task_embed" has not yet been fully tested. - task_embed_option=None, - # (bool) Whether to use task embeddings. - use_task_embed=False, - # TODO: optimize the following configs. - # (bool) Whether to use normal head (standard prediction heads). - use_normal_head=True, - # (bool) Whether to use Soft Mixture-of-Experts (MoE) head. - use_softmoe_head=False, - # (bool) Whether to use Mixture-of-Experts (MoE) head. - use_moe_head=False, - # (int) Number of experts in the MoE head. - num_experts_in_moe_head=4, - # (bool) Whether to use MoE in the transformer layers. - moe_in_transformer=False, - # (bool) Whether to use multiplicative MoE in the transformer layers. - multiplication_moe_in_transformer=False, - # (int) Number of shared experts in MoE. - n_shared_experts=1, - # (int) Number of experts to use per token in MoE. - num_experts_per_tok=1, - # (int) Total number of experts in the transformer MoE. - num_experts_of_moe_in_transformer=8, - # ****** Priority ****** - # (bool) Whether to use priority when sampling training data from the buffer. - use_priority=False, ), ), # ****** common ****** - # (bool) Whether to enable adaptive policy entropy weight (alpha) - use_adaptive_entropy_weight=True, - # (float) Learning rate for adaptive alpha optimizer - adaptive_entropy_alpha_lr=1e-3, - # (float) Target entropy ratio at the start of training (higher = more exploration) - target_entropy_start_ratio=0.98, - # (float) Target entropy ratio at the end of training (lower = more exploitation) - target_entropy_end_ratio=0.05, - # (int) Number of training steps to decay target entropy from start to end ratio - target_entropy_decay_steps=500000, - - # ==================== START: Encoder-Clip Annealing Config ==================== - # (bool) Whether to enable annealing for encoder-clip values. - use_encoder_clip_annealing=True, - # (str) Annealing type. Options: 'linear' or 'cosine'. - encoder_clip_anneal_type='cosine', - # (float) Starting clip value for annealing (looser in early training). - encoder_clip_start_value=30.0, - # (float) Ending clip value for annealing (stricter in later training). - encoder_clip_end_value=10.0, - # (int) Training iteration steps required to complete annealing from start to end value. - encoder_clip_anneal_steps=100000, # e.g., reach final value after 100k iterations - # (float) Fixed latent norm clip threshold (used when encoder_clip_annealing is disabled) - latent_norm_clip_threshold=20.0, - # ===================== END: Encoder-Clip Annealing Config ===================== - - # ==================== START: Head-Clip Annealing Config ==================== - # NOTE: The usage and implementation of Head-Clip may need to be optimized - # (bool) Whether to enable head-clip (dynamically clip head output range) - use_head_clip=False, # Disabled by default - # Detailed Head-Clip configuration - head_clip_config=dict( - enabled=False, - # Specify heads that need clipping (optional, defaults to empty list) - enabled_heads=[], # Example: ['policy', 'value', 'rewards'] - # Detailed configuration for each head (optional) - head_configs={ - # 'policy': { - # 'use_annealing': True, - # 'anneal_type': 'cosine', # 'cosine' or 'linear' - # 'start_value': 30.0, # Loose in early phase - # 'end_value': 10.0, # Strict in later phase - # 'anneal_steps': 500000, - # }, - # 'value': { - # 'clip_threshold': 20.0, - # 'use_annealing': False, - # }, - }, - # Monitoring configuration - monitor_freq=1, # Check every iteration - log_freq=1000, # Print log every 1000 iterations - ), - # ===================== END: Head-Clip Annealing Config ===================== - - # ==================== START: Policy Label Smoothing Config ==================== - # (float) Starting epsilon value for policy label smoothing (higher = more smoothing) - policy_ls_eps_start=0.05, - # (float) Ending epsilon value for policy label smoothing (lower = less smoothing) - policy_ls_eps_end=0.01, - # (int) Number of training steps to decay label smoothing epsilon from start to end - policy_ls_eps_decay_steps=50000, - - label_smoothing_eps=0.1, # TODO: For value - - # (bool) Whether to use continuous (fixed) label smoothing throughout training - use_continuous_label_smoothing=False, - # (float) Fixed epsilon value for continuous label smoothing (only used when use_continuous_label_smoothing=True) - continuous_ls_eps=0.05, - # ===================== END: Policy Label Smoothing Config ===================== - - # ==================== START: Learning Rate Scheduler Config ==================== - # (int) Total training iterations for cosine annealing LR scheduler (only used when cos_lr_scheduler=True) - total_iterations=500000, - # (float) Final learning rate for cosine annealing LR scheduler (only used when cos_lr_scheduler=True) - final_learning_rate=4e-5, - # ===================== END: Learning Rate Scheduler Config ===================== - - # ==================== START: Monitoring Config ==================== - # (int) Frequency of monitoring model parameter and gradient norms (in training iterations). Set to 0 to disable. - monitor_norm_freq=5000, - # (bool) Whether to enable enhanced policy monitoring (logits statistics, target policy entropy, etc.) - use_enhanced_policy_monitoring=False, - # ===================== END: Monitoring Config ===================== - # (bool) whether to use rnd model. use_rnd_model=False, # (bool) Whether to use multi-gpu training. @@ -367,7 +178,7 @@ class UniZeroPolicy(MuZeroPolicy): # (bool) Whether to use the pure policy to collect data. collect_with_pure_policy=False, # (int) The evaluation frequency. - eval_freq=int(5e3), + eval_freq=int(2e3), # (str) The sample type. Options are ['episode', 'transition']. sample_type='transition', # ****** observation ****** @@ -416,12 +227,8 @@ class UniZeroPolicy(MuZeroPolicy): n_episode=8, # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. num_segments=8, - # (int) the number of simulations in MCTS for renalyze. + # (int) the number of simulations in MCTS. num_simulations=50, - # (int) The number of simulations in MCTS for the collect phase. - collect_num_simulations=25, - # (int) The number of simulations in MCTS for the eval phase. - eval_num_simulations=50, # (float) Discount factor (gamma) for returns. discount_factor=0.997, # (int) The number of steps for calculating target q_value. @@ -466,8 +273,6 @@ class UniZeroPolicy(MuZeroPolicy): priority_prob_beta=0.4, # (int) The initial Env Steps for training. train_start_after_envsteps=int(0), - # (bool) Whether to use task_exploitation_weight. - use_task_exploitation_weight=False, # ****** UCB ****** # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. @@ -508,139 +313,24 @@ def default_model(self) -> Tuple[str, List[str]]: """ return 'UniZeroModel', ['lzero.model.unizero_model'] - - # ==================== Model Norm Monitoring Function ==================== - def _monitor_model_norms(self) -> Dict[str, float]: - """ - Overview: - Calculate and return parameter matrix norms for key model components (Encoder, Transformer, Heads). - This function should be called within a torch.no_grad() context for efficiency. - Returns: - - norm_metrics (:obj:`Dict[str, float]`): Dictionary containing all norm metrics for logging. - """ - world_model = self._learn_model.world_model - norm_metrics = {} - - # Define module groups to monitor - module_groups = { - 'encoder': world_model.tokenizer.encoder, - 'transformer': world_model.transformer, - 'head_value': world_model.head_value, - 'head_reward': world_model.head_rewards, - 'head_policy': world_model.head_policy, - } - - for group_name, group_module in module_groups.items(): - total_norm_sq = 0.0 - for param_name, param in group_module.named_parameters(): - if param.requires_grad: - # Calculate L2 norm for single layer parameters - param_norm = param.data.norm(2).item() - # Replace dots to display correctly as hierarchy in TensorBoard - log_name = f'norm/{group_name}/{param_name.replace(".", "/")}' - norm_metrics[log_name] = param_norm - total_norm_sq += param_norm ** 2 - - # Calculate total norm for entire module - total_group_norm = np.sqrt(total_norm_sq) - norm_metrics[f'norm/{group_name}/_total_norm'] = total_group_norm - - return norm_metrics - - def _monitor_gradient_norms(self) -> Dict[str, float]: - """ - Overview: - Calculate and return gradient norms for key model components. - This function should be called after gradient computation and before parameter updates. - Returns: - - grad_metrics (:obj:`Dict[str, float]`): Dictionary containing all gradient norm metrics for logging. - """ - world_model = self._learn_model.world_model - grad_metrics = {} - - # Define module groups to monitor - module_groups = { - 'encoder': world_model.tokenizer.encoder, - 'transformer': world_model.transformer, - 'head_value': world_model.head_value, - 'head_reward': world_model.head_rewards, - 'head_policy': world_model.head_policy, - } - - for group_name, group_module in module_groups.items(): - total_grad_norm_sq = 0.0 - num_params_with_grad = 0 - - for param_name, param in group_module.named_parameters(): - if param.requires_grad and param.grad is not None: - # Calculate L2 norm for single layer parameter gradients - grad_norm = param.grad.data.norm(2).item() - # Replace dots to display correctly as hierarchy in TensorBoard - log_name = f'grad/{group_name}/{param_name.replace(".", "/")}' - grad_metrics[log_name] = grad_norm - total_grad_norm_sq += grad_norm ** 2 - num_params_with_grad += 1 - - # Calculate total gradient norm for entire module - if num_params_with_grad > 0: - total_group_grad_norm = np.sqrt(total_grad_norm_sq) - grad_metrics[f'grad/{group_name}/_total_norm'] = total_group_grad_norm - else: - grad_metrics[f'grad/{group_name}/_total_norm'] = 0.0 - - return grad_metrics - # ================================================================= - def _init_learn(self) -> None: """ Overview: Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. """ - if self._cfg.optim_type == 'SGD': - # Configure SGD optimizer - self._optimizer_world_model = torch.optim.SGD( - self._model.world_model.parameters(), - lr=self._cfg.learning_rate, - momentum=self._cfg.momentum, - weight_decay=self._cfg.weight_decay - ) - elif self._cfg.optim_type == 'AdamW': - # NOTE: nanoGPT optimizer - self._optimizer_world_model = configure_optimizers_nanogpt( - model=self._model.world_model, - learning_rate=self._cfg.learning_rate, - weight_decay=self._cfg.weight_decay, - device_type=self._cfg.device, - betas=(0.9, 0.95), - ) - elif self._cfg.optim_type == 'AdamW_mix_lr_wdecay': - self._optimizer_world_model = configure_optimizer_unizero( - model=self._model.world_model, - learning_rate=self._cfg.learning_rate, - weight_decay=self._cfg.weight_decay, - device_type=self._cfg.device, - betas=(0.9, 0.95), - ) + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) if self._cfg.cos_lr_scheduler: from torch.optim.lr_scheduler import CosineAnnealingLR - total_iters = self._cfg.total_iterations - final_lr = self._cfg.final_learning_rate - - self.lr_scheduler = CosineAnnealingLR( - self._optimizer_world_model, - T_max=total_iters, - eta_min=final_lr - ) - logging.info(f"CosineAnnealingLR enabled: T_max={total_iters}, eta_min={final_lr}") - - - if self._cfg.piecewise_decay_lr_scheduler: - from torch.optim.lr_scheduler import LambdaLR - max_step = self._cfg.threshold_training_steps_for_final_lr - # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. - lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa - self.lr_scheduler = LambdaLR(self._optimizer_world_model, lr_lambda=lr_lambda) + # TODO: check the total training steps + self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) # use model_wrapper for specialized demands of different modes self._target_model = copy.deepcopy(self._model) @@ -667,120 +357,28 @@ def _init_learn(self) -> None: self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + # PPO: Initialize PPO hyperparameters from config + # Note: self._cfg is already the policy config, so use self._cfg.ppo directly + self.ppo_clip_ratio = getattr(self._cfg.ppo, 'clip_ratio', 0.2) + self.ppo_value_coef = getattr(self._cfg.ppo, 'value_coef', 0.5) + self.ppo_entropy_coef = getattr(self._cfg.ppo, 'entropy_coef', 0.01) + self.intermediate_losses = defaultdict(float) self.l2_norm_before = 0. self.l2_norm_after = 0. self.grad_norm_before = 0. self.grad_norm_after = 0. - if self._cfg.model.model_type == 'conv': - # for image-input env - self.pad_token_id = -1 - else: - # for text-input env and vector-input env - # Retrieve the tokenizer from the encoder module if it exists - encoder_tokenizer = getattr(self._model.tokenizer.encoder, 'tokenizer', None) - - # Extract the padding token ID from the tokenizer if available, otherwise use 0 as default. Used in _reset_collect() - # The pad_token_id is used to identify padding tokens in sequences, which is essential for: - # 1. Masking padded positions during attention computation to prevent them from affecting the output - # 2. Properly handling variable-length sequences in batch processing - # 3. Distinguishing between actual tokens and padding in loss calculation - # Default value 0 is a common convention when no specific padding token is defined - self.pad_token_id = encoder_tokenizer.pad_token_id if encoder_tokenizer is not None else 0 - + encoder_tokenizer = getattr(self._model.tokenizer.encoder, 'tokenizer', None) + self.pad_token_id = encoder_tokenizer.pad_token_id if encoder_tokenizer is not None else 0 + if self._cfg.use_wandb: # TODO: add the model to wandb wandb.watch(self._learn_model.representation_network, log="all") self.accumulation_steps = self._cfg.accumulation_steps - # ==================== START: Target Entropy Regularization Initialization ==================== - # Read whether to enable adaptive alpha from config, and provide a default value - self.use_adaptive_entropy_weight = self._cfg.use_adaptive_entropy_weight - - # Add configuration in _init_learn - self.target_entropy_start_ratio = self._cfg.target_entropy_start_ratio - self.target_entropy_end_ratio = self._cfg.target_entropy_end_ratio - self.target_entropy_decay_steps = self._cfg.target_entropy_decay_steps # e.g., complete annealing within 200k steps (2M envsteps) - - if self.use_adaptive_entropy_weight: - # 1. Set target entropy. For discrete action spaces, a common heuristic is the negative logarithm - # of action space dimension multiplied by a coefficient. - # This coefficient (e.g., 0.98) can be used as a hyperparameter. - action_space_size = self._cfg.model.action_space_size - self.target_entropy = -np.log(1.0 / action_space_size) * 0.98 - - # 2. Initialize a learnable log_alpha parameter. - # Initialized to 0, meaning initial alpha = exp(0) = 1.0. - self.log_alpha = torch.nn.Parameter(torch.zeros(1, device=self._cfg.device), requires_grad=True) - - # 3. Create a dedicated optimizer for log_alpha. - # Using a smaller learning rate (e.g., 1e-4) different from the main optimizer is usually more stable. - alpha_lr = self._cfg.adaptive_entropy_alpha_lr - self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr) - - logging.info("="*20) - logging.info(">>> Target Entropy Regularization (Adaptive Alpha) Enabled <<<") - logging.info(f" Target Entropy: {self.target_entropy:.4f}") - logging.info(f" Alpha Optimizer Learning Rate: {alpha_lr:.2e}") - logging.info("="*20) - # ===================== END: Target Entropy Regularization Initialization ===================== - - # ==================== START: Initialize Encoder-Clip Annealing Parameters ==================== - self.use_encoder_clip_annealing = self._cfg.use_encoder_clip_annealing - self.latent_norm_clip_threshold = self._cfg.latent_norm_clip_threshold # TODO - if self.use_encoder_clip_annealing: - self.encoder_clip_anneal_type = self._cfg.encoder_clip_anneal_type - self.encoder_clip_start = self._cfg.encoder_clip_start_value - self.encoder_clip_end = self._cfg.encoder_clip_end_value - self.encoder_clip_anneal_steps = self._cfg.encoder_clip_anneal_steps - - logging.info("="*20) - logging.info(">>> Encoder-Clip Annealing Enabled <<<") - logging.info(f" Type: {self.encoder_clip_anneal_type}") - logging.info(f" Range: {self.encoder_clip_start} -> {self.encoder_clip_end}") - logging.info(f" Steps: {self.encoder_clip_anneal_steps}") - logging.info("="*20) - else: - # If annealing is not enabled, use a fixed clip threshold - self.latent_norm_clip_threshold = self._cfg.latent_norm_clip_threshold - # ===================== END: Initialize Encoder-Clip Annealing Parameters ===================== - - # ==================== START: Initialize Head-Clip Manager ==================== - self.use_head_clip = self._cfg.use_head_clip - - if self.use_head_clip: - head_clip_config_dict = self._cfg.head_clip_config - # Ensure enabled is consistent with top-level configuration - head_clip_config_dict['enabled'] = self.use_head_clip - - # Create HeadClipManager - self.head_clip_manager = create_head_clip_manager_from_dict(head_clip_config_dict) - - logging.info("=" * 60) - logging.info(">>> Head-Clip Manager Initialized <<<") - logging.info(f" Enabled heads: {self.head_clip_manager.enabled_heads}") - for head_name in self.head_clip_manager.enabled_heads: - config = self.head_clip_manager.get_head_config(head_name) - if config.use_annealing: - logging.info( - f" {head_name}: annealing {config.start_value:.1f} → {config.end_value:.1f} " - f"over {config.anneal_steps} steps ({config.anneal_type})" - ) - else: - logging.info(f" {head_name}: fixed threshold = {config.clip_threshold:.1f}") - logging.info("=" * 60) - else: - self.head_clip_manager = None - # ===================== END: Initialize Head-Clip Manager ===================== - - # Policy Label Smoothing Parameters - self.policy_ls_eps_start = self._cfg.policy_ls_eps_start - self.policy_ls_eps_end = self._cfg.policy_ls_eps_end - self.policy_ls_eps_decay_steps = self._cfg.policy_ls_eps_decay_steps - logging.info(f"self.policy_ls_eps_start: {self.policy_ls_eps_start}") - + # @profile def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: """ Overview: @@ -796,32 +394,17 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in """ self._learn_model.train() self._target_model.train() - current_batch, target_batch, train_iter = data - obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch + # PPO: current_batch now contains 11 elements: obs, action, bootstrap_action, mask, indices, weights, make_time, timestep, advantage, old_log_prob, return + obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch, advantage_batch, old_log_prob_batch, return_batch = current_batch target_reward, target_value, target_policy = target_batch - - # Calculate current epsilon for policy label smoothing - # ==================== Continuous Label Smoothing ==================== - use_continuous_label_smoothing = self._cfg.use_continuous_label_smoothing - if use_continuous_label_smoothing: - # Use fixed high epsilon throughout training - current_policy_label_eps = self._cfg.continuous_ls_eps - else: - # Use original decay schedule - if self.policy_ls_eps_start > 0: - progress = min(1.0, train_iter / self.policy_ls_eps_decay_steps) - current_policy_label_eps = self.policy_ls_eps_start * (1 - progress) + self.policy_ls_eps_end * progress - else: - current_policy_label_eps = 0.0 - # ================================================================================ - + # Prepare observations based on frame stack number if self._cfg.model.frame_stack_num > 1: obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) else: - obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) - + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) # TODO: optimize + # Apply augmentations if needed if self._cfg.use_augmentation: obs_batch = self.image_transforms.transform(obs_batch) @@ -844,8 +427,16 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in transformed_target_value = scalar_transform(target_value) # Convert to categorical distributions - target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward, label_smoothing_eps= self._cfg.label_smoothing_eps) - target_value_categorical = phi_transform(self.value_support, transformed_target_value, label_smoothing_eps=self._cfg.label_smoothing_eps) + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # PPO: Transform returns to categorical distribution (same as target_value) + # Convert return_batch to torch tensor and reshape + return_batch_tensor = torch.from_numpy(return_batch).to(self._cfg.device).float() + return_batch_reshaped = return_batch_tensor.view(self._cfg.batch_size, -1) # [B, num_unroll_steps] + # Apply scalar_transform and phi_transform + transformed_returns = scalar_transform(return_batch_reshaped) + returns_categorical = phi_transform(self.value_support, transformed_returns) # [B, num_unroll_steps, support_size] # Prepare batch for GPT model batch_for_gpt = {} @@ -866,135 +457,42 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, device=self._cfg.device) batch_for_gpt['target_value'] = target_value_categorical[:, :-1] - - # ==================== Apply Policy Label Smoothing ==================== - # This was previously computed but never applied. Now we actually smooth the target_policy. - smoothed_target_policy = target_policy[:, :-1] - if current_policy_label_eps > 0: - num_actions = smoothed_target_policy.shape[-1] - uniform_dist = torch.ones_like(smoothed_target_policy) / num_actions - smoothed_target_policy = (1.0 - current_policy_label_eps) * smoothed_target_policy + \ - current_policy_label_eps * uniform_dist - batch_for_gpt['target_policy'] = smoothed_target_policy - # =================================================================================== - - batch_for_gpt['scalar_target_value'] = target_value + batch_for_gpt['target_policy'] = target_policy[:, :-1] + + # PPO: Add PPO-specific data to batch_for_gpt + # Convert numpy arrays to torch tensors and align shapes + advantage_batch_tensor = torch.from_numpy(advantage_batch).to(self._cfg.device).float() + old_log_prob_batch_tensor = torch.from_numpy(old_log_prob_batch).to(self._cfg.device).float() + + # Align shapes: [B, num_unroll_steps] -> [B, T] where T matches target_value_categorical + # target_value_categorical is [B, num_unroll_steps+1, support_size], we take [:, :-1] to get [B, num_unroll_steps, support_size] + # returns_categorical is [B, num_unroll_steps, support_size], we need to align with target_value_categorical[:, :-1] + target_seq_len = batch_for_gpt['target_value'].shape[1] # This is num_unroll_steps (after [:, :-1]) + batch_for_gpt['advantages'] = advantage_batch_tensor[:, :target_seq_len] + batch_for_gpt['old_log_prob'] = old_log_prob_batch_tensor[:, :target_seq_len] + # Use categorical distribution version of returns (already transformed above) + # returns_categorical is [B, num_unroll_steps, support_size], align with target_seq_len + batch_for_gpt['returns'] = returns_categorical[:, :target_seq_len] # [B, T, support_size] # Extract valid target policy data and compute entropy valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) average_target_policy_entropy = target_policy_entropy.mean() - # Update world model - losses = self._learn_model.world_model.compute_loss( - batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, global_step=train_iter, current_policy_label_eps=current_policy_label_eps, + # Update world model with PPO loss + losses = self._learn_model.world_model.compute_loss_ppo( + batch_for_gpt, + self._target_model.world_model.tokenizer, + self.value_inverse_scalar_transform_handle, + clip_ratio=self.ppo_clip_ratio, + value_coef=self.ppo_value_coef, + entropy_coef=self.ppo_entropy_coef, ) - # ==================== Integrate norm monitoring logic ==================== - norm_log_dict = {} - # Check if monitoring frequency is reached - if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): - with torch.no_grad(): - # 1. Monitor model parameter norms - param_norm_metrics = self._monitor_model_norms() - norm_log_dict.update(param_norm_metrics) - - # 2. Monitor intermediate tensor x (Transformer output) - intermediate_x = losses.intermediate_losses.get('intermediate_tensor_x') - if intermediate_x is not None: - # x shape is (B, T, E) - # Calculate L2 norm for each token - token_norms = intermediate_x.norm(p=2, dim=-1) - - # Record statistics of these norms - norm_log_dict['norm/x_token/mean'] = token_norms.mean().item() - norm_log_dict['norm/x_token/std'] = token_norms.std().item() - norm_log_dict['norm/x_token/max'] = token_norms.max().item() - norm_log_dict['norm/x_token/min'] = token_norms.min().item() - - # 3. Monitor detailed statistics of logits (Value, Policy, Reward) - logits_value = losses.intermediate_losses.get('logits_value') - if logits_value is not None: - norm_log_dict['logits/value/mean'] = logits_value.mean().item() - norm_log_dict['logits/value/std'] = logits_value.std().item() - norm_log_dict['logits/value/max'] = logits_value.max().item() - norm_log_dict['logits/value/min'] = logits_value.min().item() - norm_log_dict['logits/value/abs_max'] = logits_value.abs().max().item() - - logits_policy = losses.intermediate_losses.get('logits_policy') - if logits_policy is not None: - norm_log_dict['logits/policy/mean'] = logits_policy.mean().item() - norm_log_dict['logits/policy/std'] = logits_policy.std().item() - norm_log_dict['logits/policy/max'] = logits_policy.max().item() - norm_log_dict['logits/policy/min'] = logits_policy.min().item() - norm_log_dict['logits/policy/abs_max'] = logits_policy.abs().max().item() - - logits_reward = losses.intermediate_losses.get('logits_reward') - if logits_reward is not None: - norm_log_dict['logits/reward/mean'] = logits_reward.mean().item() - norm_log_dict['logits/reward/std'] = logits_reward.std().item() - norm_log_dict['logits/reward/max'] = logits_reward.max().item() - norm_log_dict['logits/reward/min'] = logits_reward.min().item() - norm_log_dict['logits/reward/abs_max'] = logits_reward.abs().max().item() - - # 4. Monitor obs_embeddings (Encoder output) statistics - obs_embeddings = losses.intermediate_losses.get('obs_embeddings') - if obs_embeddings is not None: - # Calculate L2 norm for each embedding - emb_norms = obs_embeddings.norm(p=2, dim=-1) - norm_log_dict['embeddings/obs/norm_mean'] = emb_norms.mean().item() - norm_log_dict['embeddings/obs/norm_std'] = emb_norms.std().item() - norm_log_dict['embeddings/obs/norm_max'] = emb_norms.max().item() - norm_log_dict['embeddings/obs/norm_min'] = emb_norms.min().item() - - # ==================== Early Warning System ==================== - # Detect potential training instability and issue warnings - warnings_issued = [] - - # Check 1: Policy logits explosion (should be caught by clip, but warn anyway) - if 'logits/policy/abs_max' in norm_log_dict: - policy_abs_max = norm_log_dict['logits/policy/abs_max'] - if policy_abs_max > 8.0: - warnings_issued.append(f"⚠️ CRITICAL: Policy logits explosion detected! abs_max={policy_abs_max:.2f} (threshold: 8.0)") - elif policy_abs_max > 5.0: - warnings_issued.append(f"⚠️ WARNING: Policy logits getting large! abs_max={policy_abs_max:.2f} (threshold: 5.0)") - - # Check 2: Embedding norm explosion - if 'embeddings/obs/norm_std' in norm_log_dict: - emb_norm_std = norm_log_dict['embeddings/obs/norm_std'] - if emb_norm_std > 10.0: - warnings_issued.append(f"⚠️ CRITICAL: Embedding norm std explosion! std={emb_norm_std:.2f} (threshold: 10.0)") - elif emb_norm_std > 5.0: - warnings_issued.append(f"⚠️ WARNING: Embedding norm std increasing! std={emb_norm_std:.2f} (threshold: 5.0)") - - # Check 3: X token norm collapse - if 'norm/x_token/std' in norm_log_dict: - x_token_std = norm_log_dict['norm/x_token/std'] - if x_token_std < 0.1: - warnings_issued.append(f"⚠️ CRITICAL: X token norm collapse! std={x_token_std:.4f} (threshold: 0.1)") - elif x_token_std < 0.5: - warnings_issued.append(f"⚠️ WARNING: X token norm decreasing! std={x_token_std:.4f} (threshold: 0.5)") - - # Log warnings if any - if warnings_issued: - logging.warning(f"\n{'='*80}\n[TRAINING STABILITY] Iteration {train_iter}:\n" + "\n".join(warnings_issued) + f"\n{'='*80}") - norm_log_dict['stability/warning_count'] = float(len(warnings_issued)) - else: - norm_log_dict['stability/warning_count'] = 0.0 - # ==================================================================== - # ================================================================= - - # Extract the calculated value_priority from the returned losses. - value_priority_tensor = losses.intermediate_losses['value_priority'] - # Convert to numpy array for the replay buffer, adding a small epsilon. - value_priority_np = value_priority_tensor.detach().cpu().numpy() + 1e-6 - - weighted_total_loss = (weights * losses.loss_total).mean() - + weighted_total_loss = losses.loss_total for loss_name, loss_value in losses.intermediate_losses.items(): self.intermediate_losses[f"{loss_name}"] = loss_value - # Extract losses from intermediate_losses dictionary obs_loss = self.intermediate_losses['loss_obs'] reward_loss = self.intermediate_losses['loss_rewards'] policy_loss = self.intermediate_losses['loss_policy'] @@ -1003,23 +501,12 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in perceptual_loss = self.intermediate_losses['perceptual_loss'] orig_policy_loss = self.intermediate_losses['orig_policy_loss'] policy_entropy = self.intermediate_losses['policy_entropy'] - first_step_losses = self.intermediate_losses['first_step_losses'] - middle_step_losses = self.intermediate_losses['middle_step_losses'] - last_step_losses = self.intermediate_losses['last_step_losses'] + # first_step_losses = self.intermediate_losses['first_step_losses'] + # middle_step_losses = self.intermediate_losses['middle_step_losses'] + # last_step_losses = self.intermediate_losses['last_step_losses'] dormant_ratio_encoder = self.intermediate_losses['dormant_ratio_encoder'] - dormant_ratio_transformer = self.intermediate_losses['dormant_ratio_transformer'] - dormant_ratio_head = self.intermediate_losses['dormant_ratio_head'] - avg_weight_mag_encoder = self.intermediate_losses['avg_weight_mag_encoder'] - avg_weight_mag_transformer = self.intermediate_losses['avg_weight_mag_transformer'] - avg_weight_mag_head = self.intermediate_losses['avg_weight_mag_head'] - e_rank_last_linear = self.intermediate_losses['e_rank_last_linear'] - e_rank_sim_norm = self.intermediate_losses['e_rank_sim_norm'] + dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model'] latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms'] - latent_action_l2_norms = self.intermediate_losses['latent_action_l2_norms'] - - temperature_value=self.intermediate_losses['temperature_value'] - temperature_reward=self.intermediate_losses['temperature_reward'] - temperature_policy=self.intermediate_losses['temperature_policy'] assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" @@ -1029,129 +516,19 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in if (train_iter % self.accumulation_steps) == 0: self._optimizer_world_model.zero_grad() - - # ==================== START: Target Entropy Regularization Update Logic ==================== - alpha_loss = None - current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight # Default to fixed value - if self.use_adaptive_entropy_weight: - # Dynamically calculate target entropy (this logic is correct and preserved) - progress = min(1.0, train_iter / self.target_entropy_decay_steps) - current_ratio = self.target_entropy_start_ratio * (1 - progress) + self.target_entropy_end_ratio * progress - action_space_size = self._cfg.model.action_space_size - # Note: We define target_entropy as a positive number, which is more intuitive - current_target_entropy = -np.log(1.0 / action_space_size) * current_ratio - - # Calculate alpha_loss (corrected sign) - # This is the core correction: removed the negative sign at the front - # detach() is still critical to ensure alpha_loss gradient only flows to log_alpha - alpha_loss = (self.log_alpha * (policy_entropy.detach() - current_target_entropy)).mean() - - # Update log_alpha - self.alpha_optimizer.zero_grad() - alpha_loss.backward() - self.alpha_optimizer.step() - # [Optimization suggestion] Add log_alpha clipping as a safety measure - with torch.no_grad(): - # Limit alpha to a range, e.g., [1e-4, 10.0] - self.log_alpha.clamp_(np.log(5e-2), np.log(10.0)) - - # Use current updated alpha (with gradient flow truncated) - current_alpha = self.log_alpha.exp().detach() - - # Recalculate weighted policy loss and total loss - # Note: policy_entropy here is already an average value of a batch - weighted_policy_loss = orig_policy_loss - current_alpha * policy_entropy - # Rebuild total loss (not using losses.loss_total) - # Ensure the weights here are consistent with the calculation in LossWithIntermediateLosses class - self.obs_loss_weight = 2 - self.value_loss_weight = 0.5 - self.reward_loss_weight = 1. - self.policy_loss_weight = 1. - self.ends_loss_weight = 0. - - self.latent_recon_loss_weight = self._cfg.model.world_model_cfg.latent_recon_loss_weight - self.perceptual_loss_weight = self._cfg.model.world_model_cfg.perceptual_loss_weight - - if self.latent_recon_loss_weight>0: - total_loss = ( - self.reward_loss_weight * reward_loss + - self.value_loss_weight * value_loss + - self.policy_loss_weight * weighted_policy_loss + - self.obs_loss_weight * obs_loss + - self.latent_recon_loss_weight * latent_recon_loss+ - self.perceptual_loss_weight*perceptual_loss - ) - else: - - total_loss = ( - self.reward_loss_weight * reward_loss + - self.value_loss_weight * value_loss + - self.policy_loss_weight * weighted_policy_loss + - self.obs_loss_weight * obs_loss - - ) - weighted_total_loss = (weights * total_loss).mean() - # ===================== END: Target Entropy Regularization Update Logic ===================== - # Scale the loss by the number of accumulation steps weighted_total_loss = weighted_total_loss / self.accumulation_steps weighted_total_loss.backward() - # Still executed within torch.no_grad() context - # ================================================================= - with torch.no_grad(): - # 1. Encoder-Clip - # ==================== START: Dynamically calculate current Clip threshold ==================== - current_clip_value = self.latent_norm_clip_threshold # Default to fixed value - if self.use_encoder_clip_annealing: - progress = min(1.0, train_iter / self.encoder_clip_anneal_steps) - - if self.encoder_clip_anneal_type == 'cosine': - # Cosine schedule: smoothly transition from 1 to 0 - cosine_progress = 0.5 * (1.0 + np.cos(np.pi * progress)) - current_clip_value = self.encoder_clip_end + \ - (self.encoder_clip_start - self.encoder_clip_end) * cosine_progress - else: # Default to linear schedule - current_clip_value = self.encoder_clip_start * (1 - progress) + \ - self.encoder_clip_end * progress - # ===================== END: Dynamically calculate current Clip threshold ===================== - - # 1. Encoder-Clip (using dynamically calculated current_clip_value) - if current_clip_value > 0 and 'obs_embeddings' in losses.intermediate_losses: - obs_embeddings = losses.intermediate_losses['obs_embeddings'] - if obs_embeddings is not None: - max_latent_norm = obs_embeddings.norm(p=2, dim=-1).max() - if max_latent_norm > current_clip_value: - scale_factor = current_clip_value / max_latent_norm.item() - # No longer print frequently, or can be changed to print every N steps - if train_iter % 1000 == 0: - logging.info(f"[Encoder-Clip Annealing] Iter {train_iter}: Max latent norm {max_latent_norm.item():.2f} > {current_clip_value:.2f}. Scaling by {scale_factor:.4f}.") - scale_module_weights_vectorized(self._model.world_model.tokenizer.encoder, scale_factor) - - if self.use_head_clip and self.head_clip_manager is not None: - head_clip_results = self.head_clip_manager.apply_head_clip( - self._learn_model.world_model, - losses, - train_iter - ) - - # Check if the current iteration completes an accumulation cycle if (train_iter + 1) % self.accumulation_steps == 0: - # ==================== [NEW] Monitor gradient norms ==================== - # Monitor gradient norms before gradient clipping to diagnose gradient explosion/vanishing issues - if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): - grad_norm_metrics = self._monitor_gradient_norms() - norm_log_dict.update(grad_norm_metrics) - # ================================================================= - # Analyze gradient norms if simulation normalization analysis is enabled if self._cfg.analysis_sim_norm: # Clear previous analysis results to prevent memory overflow del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() self._target_model.encoder_hook.clear_data() - + # Clip gradients to prevent exploding gradients total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( self._learn_model.world_model.parameters(), self._cfg.grad_clip_value @@ -1188,20 +565,19 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in max_memory_allocated_gb = 0. return_log_dict = { - 'analysis/first_step_loss_value': first_step_losses['loss_value'].item(), - 'analysis/first_step_loss_policy': first_step_losses['loss_policy'].item(), - 'analysis/first_step_loss_rewards': first_step_losses['loss_rewards'].item(), - 'analysis/first_step_loss_obs': first_step_losses['loss_obs'].item(), - - 'analysis/middle_step_loss_value': middle_step_losses['loss_value'].item(), - 'analysis/middle_step_loss_policy': middle_step_losses['loss_policy'].item(), - 'analysis/middle_step_loss_rewards': middle_step_losses['loss_rewards'].item(), - 'analysis/middle_step_loss_obs': middle_step_losses['loss_obs'].item(), - - 'analysis/last_step_loss_value': last_step_losses['loss_value'].item(), - 'analysis/last_step_loss_policy': last_step_losses['loss_policy'].item(), - 'analysis/last_step_loss_rewards': last_step_losses['loss_rewards'].item(), - 'analysis/last_step_loss_obs': last_step_losses['loss_obs'].item(), + # Step losses statistics removed + # 'analysis/first_step_loss_value': first_step_losses['loss_value'].item(), + # 'analysis/first_step_loss_policy': first_step_losses['loss_policy'].item(), + # 'analysis/first_step_loss_rewards': first_step_losses['loss_rewards'].item(), + # 'analysis/first_step_loss_obs': first_step_losses['loss_obs'].item(), + # 'analysis/middle_step_loss_value': middle_step_losses['loss_value'].item(), + # 'analysis/middle_step_loss_policy': middle_step_losses['loss_policy'].item(), + # 'analysis/middle_step_loss_rewards': middle_step_losses['loss_rewards'].item(), + # 'analysis/middle_step_loss_obs': middle_step_losses['loss_obs'].item(), + # 'analysis/last_step_loss_value': last_step_losses['loss_value'].item(), + # 'analysis/last_step_loss_policy': last_step_losses['loss_policy'].item(), + # 'analysis/last_step_loss_rewards': last_step_losses['loss_rewards'].item(), + # 'analysis/last_step_loss_obs': last_step_losses['loss_obs'].item(), 'Current_GPU': current_memory_allocated_gb, 'Max_GPU': max_memory_allocated_gb, @@ -1218,91 +594,21 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'target_policy_entropy': average_target_policy_entropy.item(), 'reward_loss': reward_loss.item(), 'value_loss': value_loss.item(), - # Add value_priority to the log dictionary. - 'value_priority': value_priority_np.mean().item(), - 'value_priority_orig': value_priority_np, + # 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO 'target_reward': target_reward.mean().item(), 'target_value': target_value.mean().item(), 'transformed_target_reward': transformed_target_reward.mean().item(), 'transformed_target_value': transformed_target_value.mean().item(), 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), - 'analysis/dormant_ratio_encoder': dormant_ratio_encoder, - 'analysis/dormant_ratio_transformer': dormant_ratio_transformer, - 'analysis/dormant_ratio_head': dormant_ratio_head, - - 'analysis/avg_weight_mag_encoder': avg_weight_mag_encoder, - 'analysis/avg_weight_mag_transformer': avg_weight_mag_transformer, - 'analysis/avg_weight_mag_head': avg_weight_mag_head, - 'analysis/e_rank_last_linear': e_rank_last_linear, - 'analysis/e_rank_sim_norm': e_rank_sim_norm, - + 'analysis/dormant_ratio_encoder': dormant_ratio_encoder.item(), + 'analysis/dormant_ratio_world_model': dormant_ratio_world_model.item(), 'analysis/latent_state_l2_norms': latent_state_l2_norms.item(), - 'analysis/latent_action_l2_norms': latent_action_l2_norms, 'analysis/l2_norm_before': self.l2_norm_before, 'analysis/l2_norm_after': self.l2_norm_after, 'analysis/grad_norm_before': self.grad_norm_before, 'analysis/grad_norm_after': self.grad_norm_after, - - "temperature_value":temperature_value, - "temperature_reward":temperature_reward, - "temperature_policy":temperature_policy, - - "current_policy_label_eps":current_policy_label_eps, } - - if norm_log_dict: - return_log_dict.update(norm_log_dict) - - use_enhanced_policy_monitoring = self._cfg.use_enhanced_policy_monitoring - if use_enhanced_policy_monitoring: - # Monitor policy logits statistics - with torch.no_grad(): - logits_policy = losses.intermediate_losses.get('logits_policy') - if logits_policy is not None: - return_log_dict['policy_logits/norm'] = logits_policy.norm(dim=-1).mean().item() - return_log_dict['policy_logits/max'] = logits_policy.max().item() - return_log_dict['policy_logits/min'] = logits_policy.min().item() - return_log_dict['policy_logits/std'] = logits_policy.std().item() - - # [NEW] Also monitor Value and Reward logits - logits_value = losses.intermediate_losses.get('logits_value') - if logits_value is not None: - return_log_dict['value_logits/abs_max'] = logits_value.abs().max().item() - return_log_dict['value_logits/norm'] = logits_value.norm(dim=-1).mean().item() - - logits_reward = losses.intermediate_losses.get('logits_reward') - if logits_reward is not None: - return_log_dict['reward_logits/abs_max'] = logits_reward.abs().max().item() - return_log_dict['reward_logits/norm'] = logits_reward.norm(dim=-1).mean().item() - - # Monitor target_policy entropy statistics (minimum entropy indicates extreme distributions) - valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] - target_policy_entropies = -torch.sum( - valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1 - ) - return_log_dict['target_policy_entropy/mean'] = target_policy_entropies.mean().item() - return_log_dict['target_policy_entropy/min'] = target_policy_entropies.min().item() - return_log_dict['target_policy_entropy/max'] = target_policy_entropies.max().item() - return_log_dict['target_policy_entropy/std'] = target_policy_entropies.std().item() - # ================================================================================ - - if self.use_adaptive_entropy_weight: - return_log_dict['adaptive_alpha'] = current_alpha.item() - return_log_dict['adaptive_target_entropy_ratio'] = current_ratio - return_log_dict['alpha_loss'] = alpha_loss.item() - - if self.use_encoder_clip_annealing: - return_log_dict['current_encoder_clip_value'] = current_clip_value - - if self.use_head_clip and self.head_clip_manager is not None: - # Add head clip results to log (if any) - if head_clip_results: - for head_name, info in head_clip_results.items(): - return_log_dict[f'head_clip/{head_name}/max_logits'] = info['max_logits'] - return_log_dict[f'head_clip/{head_name}/threshold'] = info['threshold'] - if info['scaled']: - return_log_dict[f'head_clip/{head_name}/scale_factor'] = info['scale_factor'] - + if self._cfg.use_wandb: wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step) wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) @@ -1312,7 +618,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in def monitor_weights_and_grads(self, model): for name, param in model.named_parameters(): if param.requires_grad: - logging.info(f"Layer: {name} | " + print(f"Layer: {name} | " f"Weight mean: {param.data.mean():.4f} | " f"Weight std: {param.data.std():.4f} | " f"Grad mean: {param.grad.mean():.4f} | " @@ -1324,25 +630,24 @@ def _init_collect(self) -> None: Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. """ self._collect_model = self._model - # Create a configuration copy for collect MCTS and set specific simulation count - mcts_collect_cfg = copy.deepcopy(self._cfg) - mcts_collect_cfg.num_simulations = self._cfg.collect_num_simulations + if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(mcts_collect_cfg) + self._mcts_collect = MCTSCtree(self._cfg) else: - self._mcts_collect = MCTSPtree(mcts_collect_cfg) + self._mcts_collect = MCTSPtree(self._cfg) self._collect_mcts_temperature = 1. self._collect_epsilon = 0.0 self.collector_env_num = self._cfg.collector_env_num if self._cfg.model.model_type == 'conv': self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) - self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)] + self.last_batch_action = [-1 for i in range(self.collector_env_num)] elif self._cfg.model.model_type == 'mlp': self.last_batch_obs = torch.full( [self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id, ).to(self._cfg.device) - self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)] + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + # @profile def _forward_collect( self, data: torch.Tensor, @@ -1350,9 +655,8 @@ def _forward_collect( temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id: np.array = None, - timestep: List = [0], - task_id: int = None, + ready_env_id: np.ndarray = None, + timestep: List = [0] ) -> Dict: """ Overview: @@ -1365,7 +669,6 @@ def _forward_collect( - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - timestep (:obj:`list`): The step index of the env in one episode. - - task_id (:obj:`int`): The task id. Default is None, which means UniZero is in the single-task mode. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ @@ -1390,7 +693,7 @@ def _forward_collect( output = {i: None for i in ready_env_id} with torch.no_grad(): - network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action_collect, data, timestep) + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() @@ -1398,91 +701,138 @@ def _forward_collect( policy_logits = policy_logits.detach().cpu().numpy().tolist() legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] - # the only difference between collect and eval is the dirichlet noise - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(active_collect_env_num) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + + if self._cfg.collect_with_pure_policy: + # 纯策略模式:直接使用 policy_logits,跳过 MCTS + batch_action = [] + for i, env_id in enumerate(ready_env_id): + # 1. 将 policy_logits 转换为 numpy array + logits = np.array(policy_logits[i]) + + # 2. 应用 action_mask + masked_logits = logits.copy() + masked_logits[action_mask[i] == 0] = -1e9 + + # 3. 应用 softmax + temperature + exp_logits = np.exp((masked_logits - np.max(masked_logits)) / self._collect_mcts_temperature) + probs = exp_logits / (np.sum(exp_logits) + 1e-8) + + # 4. 采样动作(或 argmax,根据 eps_greedy 配置) + if self._cfg.eps.eps_greedy_exploration_in_collect: + action = np.argmax(probs) + if np.random.rand() < self._collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # 采样 + action = np.random.choice(len(probs), p=probs) + + # 5. 计算熵 + visit_count_distribution_entropy = -np.sum(probs * np.log(probs + 1e-8)) + + # 6. 设置返回值 + distributions = probs.tolist() + value = pred_values[i] # 使用 predicted_value + + # 7. 处理 predicted_next_text(如果需要,可以通过 recurrent_inference 获取,这里先设为 None) + # 注意:如果需要 predicted_next_text,可以在这里添加 recurrent_inference 调用 + predicted_next = None + + output[env_id] = { + 'action': int(action), + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + 'timestep': timestep[i], + 'predicted_next_text': predicted_next, + } + batch_action.append(int(action)) + + self.last_batch_obs = data + self.last_batch_action = batch_action else: - # python mcts_tree - roots = MCTSPtree.roots(active_collect_env_num, legal_actions) - - roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) - - next_latent_state_with_env = self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep) - - # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` - roots_visit_count_distributions = roots.get_distributions() - roots_values = roots.get_values() # shape: {list: batch_size} - - - batch_action = [] - for i, env_id in enumerate(ready_env_id): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - - if self._cfg.eps.eps_greedy_exploration_in_collect: - # eps greedy collect - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self._collect_mcts_temperature, deterministic=True - ) - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - if np.random.rand() < self._collect_epsilon: - action = np.random.choice(legal_actions[i]) - else: - # normal collect - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self._collect_mcts_temperature, deterministic=False - ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - - next_latent_state = next_latent_state_with_env[i][action] - - if self._cfg.model.world_model_cfg.obs_type == 'text' and self._cfg.model.world_model_cfg.decode_loss_mode is not None and self._cfg.model.world_model_cfg.decode_loss_mode.lower() != 'none': - # Output the plain text content decoded by the decoder from the next latent state - predicted_next = self._collect_model.tokenizer.decode_to_plain_text(embeddings=next_latent_state, max_length=256) + # 原有 MCTS 逻辑 + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) else: - predicted_next = None - - output[env_id] = { - 'action': action, - 'visit_count_distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'searched_value': value, - 'predicted_value': pred_values[i], - 'predicted_policy_logits': policy_logits[i], - 'timestep': timestep[i], - 'predicted_next_text': predicted_next, - } - batch_action.append(action) - - self.last_batch_obs = data - self.last_batch_action_collect = batch_action - - # This logic is a temporary workaround specific to the muzero_segment_collector. + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + + next_latent_state_with_env = self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self._collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + next_latent_state = next_latent_state_with_env[i][action] + + if self._cfg.model.world_model_cfg.obs_type == 'text' and self._cfg.model.world_model_cfg.decode_loss_mode is not None and self._cfg.model.world_model_cfg.decode_loss_mode.lower() != 'none': + # Output the plain text content decoded by the decoder from the next latent state + predicted_next = self._collect_model.tokenizer.decode_to_plain_text(embeddings=next_latent_state, max_length=256) + else: + predicted_next = None + + # ============== TODO: only for visualize ============== + # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + # distributions, temperature=self._collect_mcts_temperature, deterministic=True + # ) + # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # ============== TODO: only for visualize ============== + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + 'timestep': timestep[i], + 'predicted_next_text': predicted_next, + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + # ========= TODO: for muzero_segment_collector now ========= if active_collect_env_num < self.collector_env_num: - # When an environment finishes an episode ('done'), the length of `self.last_batch_obs` passed back - # becomes smaller than the total number of collector environments. - # Handling this dynamic batch size is complex, as the transformer's KV cache retrieval - # requires a stable environment ID for correct indexing. A mismatch would cause retrieval errors. - # - # Therefore, as a simpler solution, we reset the collection state for ALL environments. - # By resetting `self.last_batch_action` to -1 for all `self.collector_env_num` environments, - # we force the transformer to start its context from scratch, avoiding incorrect cache lookups. - logging.info('========== collect_forward ============') - logging.info(f'An environment has finished. Active envs: {active_collect_env_num} < Total envs: {self.collector_env_num}. Resetting all.') - + print('==========collect_forward============') + print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') self._reset_collect(reset_init_data=True) - - # If the sampling type is 'episode', it's unexpected for the number of active environments to drop, - # as this suggests an inconsistent state or a potential issue in the collection logic. if getattr(self._cfg, 'sample_type', '') == 'episode': - logging.warning('Inconsistent state detected. `sample_type` is "episode", but the number of active environments has changed.') + print('BUG: sample_type is episode, but len(self.last_batch_obs) < self.collector_env_num') return output @@ -1492,29 +842,23 @@ def _init_eval(self) -> None: Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. """ self._eval_model = self._model - - # Create a configuration copy for eval MCTS and set specific simulation count - mcts_eval_cfg = copy.deepcopy(self._cfg) - mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations - if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(mcts_eval_cfg) + self._mcts_eval = MCTSCtree(self._cfg) else: - self._mcts_eval = MCTSPtree(mcts_eval_cfg) - + self._mcts_eval = MCTSPtree(self._cfg) self.evaluator_env_num = self._cfg.evaluator_env_num if self._cfg.model.model_type == 'conv': self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) - self.last_batch_action_eval = [-1 for i in range(self.collector_env_num)] + self.last_batch_action = [-1 for i in range(self.collector_env_num)] elif self._cfg.model.model_type == 'mlp': self.last_batch_obs = torch.full( [self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id, ).to(self._cfg.device) - self.last_batch_action_eval = [-1 for i in range(self.collector_env_num)] + self.last_batch_action = [-1 for i in range(self.collector_env_num)] - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, - ready_env_id: np.array = None, timestep: List = [0], task_id: int = None,) -> Dict: + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [-1], + ready_env_id: np.array = None, timestep: List = [0]) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -1525,7 +869,6 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to eval. - timestep (:obj:`list`): The step index of the env in one episode. - - task_id (:obj:`int`): The task id. Default is None, which means UniZero is in the single-task mode. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of eval_env, C is the number of channels, \ @@ -1546,7 +889,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 ready_env_id = np.arange(active_eval_env_num) output = {i: None for i in ready_env_id} with torch.no_grad(): - network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action_eval, data, timestep) + network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) # if not in training, obtain the scalars of the value/reward @@ -1555,62 +898,111 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + + # 检查是否使用纯策略模式(复用 collect_with_pure_policy 配置,或使用单独的 eval_with_pure_policy) + use_pure_policy = getattr(self._cfg, 'eval_with_pure_policy', False) or getattr(self._cfg, 'collect_with_pure_policy', False) + + if use_pure_policy: + # 纯策略模式:直接使用 policy_logits,跳过 MCTS + batch_action = [] + for i, env_id in enumerate(ready_env_id): + # 1. 将 policy_logits 转换为 numpy array + logits = np.array(policy_logits[i]) + + # 2. 应用 action_mask + masked_logits = logits.copy() + masked_logits[action_mask[i] == 0] = -1e9 + + # 3. 应用 softmax(评估模式使用 temperature=1,确定性选择) + exp_logits = np.exp(masked_logits - np.max(masked_logits)) + probs = exp_logits / (np.sum(exp_logits) + 1e-8) + + # 4. 选择动作(评估模式使用 argmax,确定性) + action = np.argmax(probs) + + # 5. 计算熵 + visit_count_distribution_entropy = -np.sum(probs * np.log(probs + 1e-8)) + + # 6. 设置返回值 + distributions = probs.tolist() + value = pred_values[i] # 使用 predicted_value + + # 7. 处理 predicted_next_text(如果需要,可以通过 recurrent_inference 获取,这里先设为 None) + predicted_next = None + + output[env_id] = { + 'action': int(action), + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + 'timestep': timestep[i], + 'predicted_next_text': predicted_next, + } + batch_action.append(int(action)) + + self.last_batch_obs = data + self.last_batch_action = batch_action else: - # python mcts_tree - roots = MCTSPtree.roots(active_eval_env_num, legal_actions) - roots.prepare_no_noise(reward_roots, policy_logits, to_play) - next_latent_state_with_env = self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep) - - # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` - roots_visit_count_distributions = roots.get_distributions() - roots_values = roots.get_values() # shape: {list: batch_size} - - batch_action = [] - - for i, env_id in enumerate(ready_env_id): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than - # sampling during the evaluation phase. - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=1, deterministic=True - ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the - # entire action set. - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - - # Predict the next latent state based on the selected action and policy - next_latent_state = next_latent_state_with_env[i][action] - - if self._cfg.model.world_model_cfg.obs_type == 'text' and self._cfg.model.world_model_cfg.decode_loss_mode is not None and self._cfg.model.world_model_cfg.decode_loss_mode.lower() != 'none': - # Output the plain text content decoded by the decoder from the next latent state - predicted_next = self._eval_model.tokenizer.decode_to_plain_text(embeddings=next_latent_state, max_length=256) + # 原有 MCTS 逻辑 + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) else: - predicted_next = None + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + next_latent_state_with_env = self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # print("roots_visit_count_distributions:", distributions, "root_value:", value) + + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[env_id] = { - 'action': action, - 'visit_count_distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'searched_value': value, - 'predicted_value': pred_values[i], - 'predicted_policy_logits': policy_logits[i], - 'timestep': timestep[i], - 'predicted_next_text': predicted_next, - } - batch_action.append(action) - - self.last_batch_obs_eval = data - self.last_batch_action_eval = batch_action + # Predict the next latent state based on the selected action and policy + next_latent_state = next_latent_state_with_env[i][action] + + if self._cfg.model.world_model_cfg.obs_type == 'text' and self._cfg.model.world_model_cfg.decode_loss_mode is not None and self._cfg.model.world_model_cfg.decode_loss_mode.lower() != 'none': + # Output the plain text content decoded by the decoder from the next latent state + predicted_next = self._eval_model.tokenizer.decode_to_plain_text(embeddings=next_latent_state, max_length=256) + else: + predicted_next = None + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + 'timestep': timestep[i], + 'predicted_next_text': predicted_next, + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action return output - def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True, task_id: int = None) -> None: + def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: """ Overview: This method resets the collection process for a specific environment. It clears caches and memory @@ -1631,52 +1023,31 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in ) self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] - - # We must handle both single int and list of ints for env_id. - if env_id is not None: - if isinstance(env_id, int): - env_ids_to_reset = [env_id] - else: # Assumes it's a list - env_ids_to_reset = env_id - - # The key condition: `current_steps` is None only on the end-of-episode reset call from the collector. - if current_steps is None: - world_model = self._collect_model.world_model - for eid in env_ids_to_reset: - # ==================== BUG FIX: Refactored Cache Clearing ==================== - # Clear the specific environment's initial inference cache. - if hasattr(world_model, 'use_new_cache_manager') and world_model.use_new_cache_manager: - # NEW SYSTEM: Use KVCacheManager to clear per-environment cache - if eid < world_model.env_num: - world_model.kv_cache_manager.init_pools[eid].clear() - logging.info(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end (NEW system).') - else: - # OLD SYSTEM: Use legacy cache dictionary - if eid < len(world_model.past_kv_cache_init_infer_envs): - world_model.past_kv_cache_init_infer_envs[eid].clear() - logging.info(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end (OLD system).') - # ============================================================================= + # Return immediately if env_id is None or a list + if env_id is None or isinstance(env_id, list): + return # Determine the clear interval based on the environment's sample type - clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 # Clear caches if the current steps are a multiple of the clear interval - if current_steps is not None and current_steps % clear_interval == 0: - logging.info(f'clear_interval: {clear_interval}') + if current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') # Clear various caches in the collect model's world model world_model = self._collect_model.world_model - # ==================== Phase 1.5: Use unified clear_caches() method ==================== - # This automatically handles both old and new cache systems - world_model.clear_caches() - # ====================================================================================== + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() # Free up GPU memory torch.cuda.empty_cache() - logging.info(f'eps_steps_lst[{env_id}]: {current_steps}, collector: collect_model clear()') + print('collector: collect_model clear()') + print(f'eps_steps_lst[{env_id}]: {current_steps}') - def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True, task_id: int = None) -> None: + def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: """ Overview: This method resets the evaluation process for a specific environment. It clears caches and memory @@ -1689,80 +1060,37 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. """ if reset_init_data: - if task_id is not None: - self.last_batch_obs_eval = initialize_pad_batch( - self._cfg.model.observation_shape_list[task_id], - self._cfg.evaluator_env_num, - self._cfg.device, - pad_token_id=self.pad_token_id - ) - logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) - - else: - self.last_batch_obs_eval = initialize_pad_batch( - self._cfg.model.observation_shape, - self._cfg.evaluator_env_num, - self._cfg.device, - pad_token_id=self.pad_token_id - ) - logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) - + self.last_batch_obs = initialize_pad_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device, + pad_token_id=self.pad_token_id + ) self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] - # This logic handles the crucial end-of-episode cache clearing for evaluation. - # The evaluator calls `_policy.reset([env_id])` when an episode is done. - if env_id is not None: - if isinstance(env_id, int): - env_ids_to_reset = [env_id] - else: # Assumes it's a list - env_ids_to_reset = env_id - - # The key condition: `current_steps` is None only on the end-of-episode reset call from the evaluator. - if current_steps is None: - world_model = self._eval_model.world_model - for eid in env_ids_to_reset: - # ==================== BUG FIX: Refactored Cache Clearing ==================== - # Clear the specific environment's initial inference cache. - if hasattr(world_model, 'use_new_cache_manager') and world_model.use_new_cache_manager: - # NEW SYSTEM: Use KVCacheManager to clear per-environment cache - if eid < world_model.env_num: - world_model.kv_cache_manager.init_pools[eid].clear() - logging.info(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end (NEW system).') - else: - # OLD SYSTEM: Use legacy cache dictionary - if eid < len(world_model.past_kv_cache_init_infer_envs): - world_model.past_kv_cache_init_infer_envs[eid].clear() - logging.info(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end (OLD system).') - # ============================================================================= - - # The recurrent cache is global. - # ==================== Phase 1.5: Use unified clear_caches() method ==================== - # This automatically handles both old and new cache systems - world_model.clear_caches() - # ====================================================================================== - - world_model.keys_values_wm_list.clear() - torch.cuda.empty_cache() - return + # Return immediately if env_id is None or a list + if env_id is None or isinstance(env_id, list): + return # Determine the clear interval based on the environment's sample type - clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + # Clear caches if the current steps are a multiple of the clear interval - if current_steps is not None and current_steps % clear_interval == 0: - logging.info(f'clear_interval: {clear_interval}') + if current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') # Clear various caches in the eval model's world model world_model = self._eval_model.world_model - # ==================== Phase 1.5: Use unified clear_caches() method ==================== - # This automatically handles both old and new cache systems - world_model.clear_caches() - # ====================================================================================== + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() # Free up GPU memory torch.cuda.empty_cache() - logging.info('evaluator: eval_model clear()') - logging.info(f'eps_steps_lst[{env_id}]: {current_steps}') + print('evaluator: eval_model clear()') + print(f'eps_steps_lst[{env_id}]: {current_steps}') def _monitor_vars_learn(self) -> List[str]: """ @@ -1770,158 +1098,57 @@ def _monitor_vars_learn(self) -> List[str]: Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value ``_forward_learn``. """ - base_vars = [ - # ==================== Analysis Metrics ==================== + return [ 'analysis/dormant_ratio_encoder', - 'analysis/dormant_ratio_transformer', - 'analysis/dormant_ratio_head', - 'analysis/avg_weight_mag_encoder', - 'analysis/avg_weight_mag_transformer', - 'analysis/avg_weight_mag_head', - 'analysis/e_rank_last_linear', - 'analysis/e_rank_sim_norm', + 'analysis/dormant_ratio_world_model', 'analysis/latent_state_l2_norms', - 'analysis/latent_action_l2_norms', 'analysis/l2_norm_before', 'analysis/l2_norm_after', 'analysis/grad_norm_before', 'analysis/grad_norm_after', - # ==================== Step-wise Loss Analysis ==================== 'analysis/first_step_loss_value', 'analysis/first_step_loss_policy', 'analysis/first_step_loss_rewards', 'analysis/first_step_loss_obs', + 'analysis/middle_step_loss_value', 'analysis/middle_step_loss_policy', 'analysis/middle_step_loss_rewards', 'analysis/middle_step_loss_obs', + 'analysis/last_step_loss_value', 'analysis/last_step_loss_policy', 'analysis/last_step_loss_rewards', 'analysis/last_step_loss_obs', - # ==================== System Metrics ==================== 'Current_GPU', 'Max_GPU', 'collect_epsilon', 'collect_mcts_temperature', 'cur_lr_world_model', + 'cur_lr_tokenizer', - # ==================== Core Losses ==================== 'weighted_total_loss', 'obs_loss', 'policy_loss', 'orig_policy_loss', 'policy_entropy', 'latent_recon_loss', - 'perceptual_loss', 'target_policy_entropy', 'reward_loss', 'value_loss', + 'consistency_loss', 'value_priority', 'target_reward', 'target_value', - 'transformed_target_reward', - 'transformed_target_value', - - # ==================== Gradient Norms ==================== 'total_grad_norm_before_clip_wm', - - # ==================== Temperature Parameters ==================== - 'temperature_value', - 'temperature_reward', - 'temperature_policy', - - # ==================== Training Configuration ==================== - 'current_policy_label_eps', - 'adaptive_alpha', - 'adaptive_target_entropy_ratio', - 'alpha_loss', - 'current_encoder_clip_value', - ] - - # ==================== [NEW] Norm and Intermediate Tensor Monitoring Variables ==================== - norm_vars = [ - # Module total norms (parameter norms) - 'norm/encoder/_total_norm', - 'norm/transformer/_total_norm', - 'norm/head_value/_total_norm', - 'norm/head_reward/_total_norm', - 'norm/head_policy/_total_norm', - - # Module total norms (gradient norms) - 'grad/encoder/_total_norm', - 'grad/transformer/_total_norm', - 'grad/head_value/_total_norm', - 'grad/head_reward/_total_norm', - 'grad/head_policy/_total_norm', - - # Intermediate tensor x (Transformer output) statistics - 'norm/x_token/mean', - 'norm/x_token/std', - 'norm/x_token/max', - 'norm/x_token/min', - - # Detailed logits statistics (Value) - 'logits/value/mean', - 'logits/value/std', - 'logits/value/max', - 'logits/value/min', - 'logits/value/abs_max', - - # Detailed logits statistics (Policy) - 'logits/policy/mean', - 'logits/policy/std', - 'logits/policy/max', - 'logits/policy/min', - 'logits/policy/abs_max', - - # Detailed logits statistics (Reward) - 'logits/reward/mean', - 'logits/reward/std', - 'logits/reward/max', - 'logits/reward/min', - 'logits/reward/abs_max', - - # Embeddings statistics - 'embeddings/obs/norm_mean', - 'embeddings/obs/norm_std', - 'embeddings/obs/norm_max', - 'embeddings/obs/norm_min', - - ] - - head_clip_vars = [] - # Check if head_clip is enabled and manager exists - if getattr(self, 'use_head_clip', False) and getattr(self, 'head_clip_manager', None) is not None: - # Iterate through all enabled heads and generate corresponding monitoring keys - for head_name in self.head_clip_manager.enabled_heads: - head_clip_vars.append(f'head_clip/{head_name}/max_logits') - head_clip_vars.append(f'head_clip/{head_name}/threshold') - head_clip_vars.append(f'head_clip/{head_name}/scale_factor') - - - enhanced_policy_vars = [ - # Policy logits statistics - 'policy_logits/norm', - 'policy_logits/max', - 'policy_logits/min', - 'policy_logits/std', - # Target policy entropy statistics - 'target_policy_entropy/mean', - 'target_policy_entropy/min', - 'target_policy_entropy/max', - 'target_policy_entropy/std', - ] - - stability_vars = [ - 'stability/warning_count', # Number of warnings issued in current check + # tokenizer + 'commitment_loss', + 'reconstruction_loss', + 'perceptual_loss', ] - return base_vars + norm_vars+ head_clip_vars + enhanced_policy_vars + stability_vars - - def _state_dict_learn(self) -> Dict[str, Any]: """ Overview: @@ -1929,16 +1156,11 @@ def _state_dict_learn(self) -> Dict[str, Any]: Returns: - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. """ - state_dict = { + return { 'model': self._learn_model.state_dict(), 'target_model': self._target_model.state_dict(), 'optimizer_world_model': self._optimizer_world_model.state_dict(), } - # ==================== START: Save Alpha Optimizer State ==================== - if self.use_adaptive_entropy_weight: - state_dict['alpha_optimizer'] = self.alpha_optimizer.state_dict() - # ===================== END: Save Alpha Optimizer State ===================== - return state_dict def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ @@ -1949,6 +1171,7 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ self._learn_model.load_state_dict(state_dict['model']) self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) def recompute_pos_emb_diff_and_clear_cache(self) -> None: """ @@ -1960,4 +1183,4 @@ def recompute_pos_emb_diff_and_clear_cache(self) -> None: # If rotary_emb is False, nn.Embedding is used for absolute position encoding. model.world_model.precompute_pos_emb_diff_kv() model.world_model.clear_caches() - torch.cuda.empty_cache() + torch.cuda.empty_cache() \ No newline at end of file diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 1e0a65845..800121cee 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -1,6 +1,7 @@ +import os import time from collections import deque, namedtuple -from typing import Optional, Any, List, Dict, Set +from typing import Optional, Any, List import numpy as np import torch @@ -9,99 +10,98 @@ from ding.torch_utils import to_ndarray from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, get_rank, get_world_size, \ allreduce_data +from ding.rl_utils import gae_data, gae from ding.worker.collector.base_serial_collector import ISerialCollector from torch.nn import L1Loss import torch.distributed as dist from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation +from lzero.policy.utils import compute_bleu @SERIAL_COLLECTOR_REGISTRY.register('episode_muzero') class MuZeroCollector(ISerialCollector): """ Overview: - The episode-based collector for MCTS-based reinforcement learning algorithms, - including MuZero, EfficientZero, Sampled EfficientZero, and Gumbel MuZero. - It orchestrates the data collection process in a serial manner, managing interactions - between the policy and the environment to generate game segments for training. + The Episode Collector for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, Gumbel MuZero. + It manages the data collection process for training these algorithms using a serial mechanism. Interfaces: - ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``collect``, - ``_compute_priorities``, ``pad_and_save_last_trajectory``, ``_output_log``, ``close``, ``__del__``. + ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``envstep``, ``__del__``, ``_compute_priorities``, + ``pad_and_save_last_trajectory``, ``collect``, ``_output_log``, ``close`` Properties: - ``envstep``. + ``envstep`` """ - # Default configuration for the collector. To be compatible with ISerialCollector. + # TO be compatible with ISerialCollector config = dict() def __init__( self, collect_print_freq: int = 100, - env: Optional[BaseEnvManager] = None, - policy: Optional[namedtuple] = None, + env: BaseEnvManager = None, + policy: namedtuple = None, tb_logger: 'SummaryWriter' = None, # noqa - exp_name: str = 'default_experiment', - instance_name: str = 'collector', + exp_name: Optional[str] = 'default_experiment', + instance_name: Optional[str] = 'collector', policy_config: 'policy_config' = None, # noqa - task_id: Optional[int] = None, ) -> None: """ Overview: - Initializes the MuZeroCollector with the given configuration. + Initialize the MuZeroCollector with the given parameters. Arguments: - - collect_print_freq (:obj:`int`): The frequency (in training iterations) at which to print collection statistics. - - env (:obj:`Optional[BaseEnvManager]`): An instance of a vectorized environment manager. - - policy (:obj:`Optional[namedtuple]`): A namedtuple containing the policy's forward pass and other methods. - - tb_logger (:obj:`Optional[SummaryWriter]`): A TensorBoard logger instance for logging metrics. - - exp_name (:obj:`str`): The name of the experiment, used for organizing logs. - - instance_name (:obj:`str`): A unique name for this collector instance. - - policy_config (:obj:`'policy_config'`): The configuration object for the policy. - - task_id (:obj:`Optional[int]`): The identifier for the current task in a multi-task setting. If None, operates in single-task mode. + - collect_print_freq (:obj:`int`): Frequency (in training steps) at which to print collection information. + - env (:obj:`Optional[BaseEnvManager]`): Instance of the subclass of vectorized environment manager. + - policy (:obj:`Optional[namedtuple]`): namedtuple of the collection mode policy API. + - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger instance. + - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. + - instance_name (:obj:`str`): Unique identifier for this collector instance. + - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. """ - self.task_id = task_id self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = collect_print_freq self._timer = EasyTimer() self._end_flag = False - # Get distributed training info self._rank = get_rank() self._world_size = get_world_size() - - # Logger setup: only rank 0 creates the main logger and TensorBoard logger. if self._rank == 0: if tb_logger is not None: self._logger, _ = build_logger( - path=f'./{self._exp_name}/log/{self._instance_name}', + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False ) self._tb_logger = tb_logger else: self._logger, self._tb_logger = build_logger( - path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name ) else: self._logger, _ = build_logger( - path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False ) self._tb_logger = None self.policy_config = policy_config self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy + # PPO configuration (required) + self.ppo_gamma = policy_config.ppo.gamma + self.ppo_gae_lambda = policy_config.ppo.gae_lambda + self.reset(policy, env) def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Resets or replaces the environment managed by the collector. - If `_env` is None, it resets the existing environment. Otherwise, it replaces the old - environment with the new one and launches it. + Reset or replace the environment managed by this collector. + If _env is None, reset the old environment. + If _env is not None, replace the old environment in the collector with the new passed \ + in environment and launch. Arguments: - - _env (:obj:`Optional[BaseEnvManager]`): The new environment to be used. If None, resets the current environment. + - env (:obj:`Optional[BaseEnvManager]`): New environment to manage, if provided. """ if _env is not None: self._env = _env @@ -113,39 +113,42 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: - Resets or replaces the policy used by the collector. - If `_policy` is None, it resets the existing policy. Otherwise, it replaces the old - policy with the new one. + Reset or replace the policy used by this collector. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the collector with the new passed in policy. Arguments: - - _policy (:obj:`Optional[namedtuple]`): The new policy to be used. + - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy """ - assert hasattr(self, '_env'), "Please set env first before resetting policy." + assert hasattr(self, '_env'), "please set env first" if _policy is not None: self._policy = _policy self._default_n_episode = _policy.get_attribute('cfg').get('n_episode', None) self._logger.debug( - f"Set default n_episode mode(n_episode({self._default_n_episode}), env_num({self._env_num}))" + 'Set default n_episode mode(n_episode({}), env_num({}))'.format(self._default_n_episode, self._env_num) ) self._policy.reset() def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Resets the collector, including the environment and policy. Also re-initializes - internal state variables for tracking collection progress. + Reset the collector with the given policy and/or environment. + If _env is None, reset the old environment. + If _env is not None, replace the old environment in the collector with the new passed \ + in environment and launch. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the collector with the new passed in policy. Arguments: - - _policy (:obj:`Optional[namedtuple]`): The new policy to use. - - _env (:obj:`Optional[BaseEnvManager]`): The new environment to use. + - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ + env_manager(BaseEnvManager) """ if _env is not None: self.reset_env(_env) if _policy is not None: self.reset_policy(_policy) - # Initialize per-environment tracking info - self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} + self._env_info = {env_id: {'time': 0., 'step': 0, 'text_bleu': 0.} for env_id in range(self._env_num)} - # Reset overall statistics self._episode_info = [] self._total_envstep_count = 0 self._total_episode_count = 0 @@ -153,35 +156,39 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self._last_train_iter = 0 self._end_flag = False - # A pool to store completed game segments, implemented using a deque. + # A game_segment_pool implementation based on the deque structure. self.game_segment_pool = deque(maxlen=int(1e6)) self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps + # Global episode_id counter for tracking segments belonging to the same episode + self._global_episode_id = 0 + def _reset_stat(self, env_id: int) -> None: """ Overview: - Resets the statistics for a specific environment, identified by `env_id`. - This is typically called when an episode in that environment ends. + Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool \ + and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ + to get more messages. Arguments: - - env_id (:obj:`int`): The ID of the environment to reset statistics for. + - env_id (:obj:`int`): the id where we need to reset the collector's state """ - self._env_info[env_id] = {'time': 0., 'step': 0} + self._env_info[env_id] = {'time': 0., 'step': 0, 'text_bleu': 0.} @property def envstep(self) -> int: """ Overview: - Returns the total number of environment steps collected since the last reset. + Get the total number of environment steps collected. Returns: - - envstep (:obj:`int`): The total environment step count. + - envstep (:obj:`int`): Total number of environment steps collected. """ return self._total_envstep_count def close(self) -> None: """ Overview: - Closes the collector, including the environment and any loggers. - Ensures that all resources are properly released. + Close the collector. If end_flag is False, close the environment, flush the tb_logger \ + and close the tb_logger. """ if self._end_flag: return @@ -194,456 +201,665 @@ def close(self) -> None: def __del__(self) -> None: """ Overview: - Destructor for the collector instance, ensuring that `close` is called - to clean up resources. + Execute the close command and close the collector. __del__ is automatically called to \ + destroy the collector instance when the collector finishes its work """ self.close() # ============================================================== - # MCTS+RL Core Collection Logic + # MCTS+RL related core code # ============================================================== - def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> Optional[np.ndarray]: + def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> np.ndarray: """ Overview: - Computes priorities for experience replay based on the discrepancy between - predicted values and MCTS search values. + Compute the priorities for transitions based on prediction and search value discrepancies. Arguments: - - i (:obj:`int`): The index of the environment's data in the lists. - - pred_values_lst (:obj:`List[float]`): A list containing lists of predicted values for each environment. - - search_values_lst (:obj:`List[float]`): A list containing lists of search values from MCTS for each environment. + - i (:obj:`int`): Index of the values in the list to compute the priority for. + - pred_values_lst (:obj:`List[float]`): List of predicted values. + - search_values_lst (:obj:`List[float]`): List of search values obtained from MCTS. Returns: - - priorities (:obj:`Optional[np.ndarray]`): An array of priorities for the transitions. Returns None if priority is not used. + - priorities (:obj:`np.ndarray`): Array of computed priorities. """ if self.policy_config.use_priority: - # Calculate priorities as the L1 loss between predicted values and search values. - # 'reduction=none' ensures the loss is calculated for each element individually. + # Calculate priorities. The priorities are the L1 losses between the predicted + # values and the search values. We use 'none' as the reduction parameter, which + # means the loss is calculated for each element individually, instead of being summed or averaged. + # A small constant (1e-6) is added to the results to avoid zero priorities. This + # is done because zero priorities could potentially cause issues in some scenarios. pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1) - search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device).float().view(-1) - - # A small epsilon is added to avoid zero priorities. - priorities = L1Loss(reduction='none')(pred_values, search_values).detach().cpu().numpy() + 1e-6 + search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device + ).float().view(-1) + priorities = L1Loss(reduction='none' + )(pred_values, + search_values).detach().cpu().numpy() + 1e-6 else: - # If priority is not used, return None. The replay buffer will use max priority for new data. + # priorities is None -> use the max priority for all newly collected data priorities = None return priorities - def pad_and_save_last_trajectory( - self, i: int, last_game_segments: List[Optional[GameSegment]], - last_game_priorities: List[Optional[np.ndarray]], - game_segments: List[GameSegment], done: np.ndarray - ) -> None: + def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegment], + last_game_priorities: List[np.ndarray], + game_segments: List[GameSegment], done: np.ndarray) -> None: """ Overview: - Pads the end of the `last_game_segment` with data from the start of the current `game_segment`. - This is necessary to compute target values for the final transitions of a segment. After padding, - the completed segment is stored in the `game_segment_pool`. + Save the game segment to the pool if the current game is finished, padding it if necessary. Arguments: - - i (:obj:`int`): The index of the environment being processed. - - last_game_segments (:obj:`List[Optional[GameSegment]]`): List of game segments from the previous collection chunk. - - last_game_priorities (:obj:`List[Optional[np.ndarray]]`): List of priorities corresponding to the last game segments. - - game_segments (:obj:`List[GameSegment]`): List of game segments from the current collection chunk. - - done (:obj:`np.ndarray`): Array indicating if the episode has terminated for each environment. + - i (:obj:`int`): Index of the current game segment. + - last_game_segments (:obj:`List[GameSegment]`): List of the last game segments to be padded and saved. + - last_game_priorities (:obj:`List[np.ndarray]`): List of priorities of the last game segments. + - game_segments (:obj:`List[GameSegment]`): List of the current game segments. + - done (:obj:`np.ndarray`): Array indicating whether each game is done. Note: - An implicit assumption is that the start of the new segment's observation history overlaps with the - end of the last segment's, e.g., `(last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all()` is True. + (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True """ - # --- Prepare padding data from the current game segment --- - # Observations for padding are taken from the start of the new segment. - beg_index_obs = self.policy_config.model.frame_stack_num - end_index_obs = beg_index_obs + self.policy_config.num_unroll_steps + self.policy_config.td_steps - pad_obs_lst = game_segments[i].obs_segment[beg_index_obs:end_index_obs] - - # Actions for padding. - beg_index_ac = 0 - end_index_ac = beg_index_ac + self.policy_config.num_unroll_steps + self.policy_config.td_steps - pad_action_lst = game_segments[i].action_segment[beg_index_ac:end_index_ac] - - # Child visits for padding. - pad_child_visits_lst = game_segments[i].child_visit_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] - - # Rewards for padding. - beg_index_rew = 0 - end_index_rew = beg_index_rew + self.unroll_plus_td_steps - 1 - pad_reward_lst = game_segments[i].reward_segment[beg_index_rew:end_index_rew] - - # Root values for padding. - beg_index_val = 0 - end_index_val = beg_index_val + self.unroll_plus_td_steps - pad_root_values_lst = game_segments[i].root_value_segment[beg_index_val:end_index_val] + # pad over last segment trajectory + beg_index = self.policy_config.model.frame_stack_num + end_index = beg_index + self.policy_config.num_unroll_steps + self.policy_config.td_steps + + # the start obs is init zero obs, so we take the + # [ : +] obs as the pad obs + # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs + pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index] + + # NOTE: for unizero + beg_index = 0 + end_index = beg_index + self.policy_config.num_unroll_steps + self.policy_config.td_steps + pad_action_lst = game_segments[i].action_segment[beg_index:end_index] + + # NOTE: for unizero + pad_child_visits_lst = game_segments[i].child_visit_segment[ + :self.policy_config.num_unroll_steps + self.policy_config.td_steps] + + # EfficientZero original repo bug: + # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] + beg_index = 0 + end_index = beg_index + self.unroll_plus_td_steps - 1 + + pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_lst = game_segments[i].chance_segment[beg_index_rew:end_index_rew] - + chance_lst = game_segments[i].chance_segment[beg_index:end_index] + + beg_index = 0 + end_index = beg_index + self.unroll_plus_td_steps + + pad_root_values_lst = game_segments[i].root_value_segment[beg_index:end_index] + if self.policy_config.gumbel_algo: - pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index_val:end_index_val] + pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index:end_index] - # --- Pad the last game segment and save it --- + # pad over and save if self.policy_config.gumbel_algo: - last_game_segments[i].pad_over( - pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, - pad_child_visits_lst, next_segment_improved_policy=pad_improved_policy_prob - ) + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, + next_segment_improved_policy=pad_improved_policy_prob) else: if self.policy_config.use_ture_chance_label_in_chance_encoder: - last_game_segments[i].pad_over( - pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, - pad_child_visits_lst, next_chances=chance_lst - ) + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, + next_chances=chance_lst) else: - last_game_segments[i].pad_over( - pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst - ) - - # Convert the segment's lists to NumPy arrays for efficient storage. + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst) + """ + Note: + game_segment element shape: + obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 + rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 + action: game_segment_length -> 20 + root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 + child_visits: game_segment_length + num_unroll_steps -> 20 +5 + to_play: game_segment_length -> 20 + action_mask: game_segment_length -> 20 + """ + last_game_segments[i].game_segment_to_array() - # Add the completed game segment and its associated data to the pool. + # put the game segment into the pool self.game_segment_pool.append((last_game_segments[i], last_game_priorities[i], done[i])) - # Reset the placeholder for the last game segment. + # reset last game_segments last_game_segments[i] = None last_game_priorities[i] = None - def collect( - self, - n_episode: Optional[int] = None, - train_iter: int = 0, - policy_kwargs: Optional[Dict] = None, - collect_with_pure_policy: bool = False - ) -> List[Any]: + return None + + def collect(self, + n_episode: Optional[int] = None, + train_iter: int = 0, + policy_kwargs: Optional[dict] = None, + collect_with_pure_policy: bool = True) -> List[Any]: """ Overview: - Collects `n_episode` episodes of data. It manages the entire lifecycle of an episode, - from getting actions from the policy, stepping the environment, storing transitions, - and saving completed game segments. + Collect `n_episode` episodes of data with policy_kwargs, trained for `train_iter` iterations. Arguments: - - n_episode (:obj:`Optional[int]`): The number of episodes to collect. If None, uses the default from the policy config. - - train_iter (:obj:`int`): The current training iteration, used for logging. - - policy_kwargs (:obj:`Optional[Dict]`): Additional keyword arguments to pass to the policy's forward method, like temperature for exploration. - - collect_with_pure_policy (:obj:`bool`): If True, collects data using a pure policy (e.g., greedy action) without MCTS. + - n_episode (:obj:`Optional[int]`): Number of episodes to collect. + - train_iter (:obj:`int`): Number of training iterations completed so far. + - policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy. + - collect_with_pure_policy (:obj:`bool`): Whether to collect data using pure policy without MCTS. Returns: - - return_data (:obj:`List[Any]`): A list containing the collected game segments and metadata. + - return_data (:obj:`List[Any]`): Collected data in the form of a list. """ - # TODO(author): Consider implementing `collect_with_pure_policy` as a separate, more streamlined collector for clarity and modularity. + # TODO: collect_with_pure_policy as a separate collector if n_episode is None: if self._default_n_episode is None: - raise RuntimeError("Please specify `n_episode` for collection.") + raise RuntimeError("Please specify collect n_episode") else: n_episode = self._default_n_episode - assert n_episode >= self._env_num, f"Please ensure n_episode ({n_episode}) >= env_num ({self._env_num})." - + assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num) if policy_kwargs is None: policy_kwargs = {} - temperature = policy_kwargs.get('temperature', 1.0) - epsilon = policy_kwargs.get('epsilon', 0.0) + temperature = policy_kwargs['temperature'] + epsilon = policy_kwargs['epsilon'] - # --- Initializations --- collected_episode = 0 collected_step = 0 env_nums = self._env_num retry_waiting_time = 0.05 - # Wait for all environments to be ready and get initial observations. + # initializations init_obs = self._env.ready_obs while len(init_obs.keys()) != self._env_num: - self._logger.warning(f"Waiting for all environments to reset. Ready envs: {list(init_obs.keys())}") + # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) time.sleep(retry_waiting_time) + self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) + ) init_obs = self._env.ready_obs - # Prepare initial state dictionaries from observations. action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} - timestep_dict = {i: to_ndarray(init_obs[i].get('timestep', -1)) for i in range(env_nums)} + + timestep_dict = {} + for i in range(env_nums): + if 'timestep' not in init_obs[i]: + if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: + print(f"Warning: 'timestep' key is missing in init_obs[{i}]. Assigning value -1. Please note that the unizero algorithm may require the 'timestep' key in init_obs.") + timestep_dict[i] = to_ndarray(init_obs[i].get('timestep', -1)) + if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} - # Initialize game segments and observation stacks for each environment. - game_segments = [GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) for _ in range(env_nums)] - observation_window_stack = [deque(maxlen=self.policy_config.model.frame_stack_num) for _ in range(env_nums)] + game_segments = [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(env_nums) + ] + # stacked observation windows in reset stage for init game_segments + observation_window_stack = [[] for _ in range(env_nums)] for env_id in range(env_nums): - for _ in range(self.policy_config.model.frame_stack_num): - observation_window_stack[env_id].append(to_ndarray(init_obs[env_id]['observation'])) + observation_window_stack[env_id] = deque( + [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) game_segments[env_id].reset(observation_window_stack[env_id]) + + # Set initial episode_id for each game segment + game_segments[env_id].episode_id = self._global_episode_id + self._global_episode_id += 1 - # State tracking variables for the collection loop. dones = np.array([False for _ in range(env_nums)]) - last_game_segments: List[Optional[GameSegment]] = [None for _ in range(env_nums)] - last_game_priorities: List[Optional[np.ndarray]] = [None for _ in range(env_nums)] - - # Buffers for priority calculation. + last_game_segments = [None for _ in range(env_nums)] + last_game_priorities = [None for _ in range(env_nums)] + # for priorities in self-play search_values_lst = [[] for _ in range(env_nums)] pred_values_lst = [[] for _ in range(env_nums)] if self.policy_config.gumbel_algo: improved_policy_lst = [[] for _ in range(env_nums)] - # Logging variables. - eps_steps_lst = np.zeros(env_nums) - visit_entropies_lst = np.zeros(env_nums) + # some logs + eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) if self.policy_config.gumbel_algo: completed_value_lst = np.zeros(env_nums) + self_play_moves = 0. + self_play_episodes = 0. + self_play_moves_max = 0 + self_play_visit_entropy = [] + total_transitions = 0 - ready_env_id: Set[int] = set() + ready_env_id = set() remain_episode = n_episode if collect_with_pure_policy: - # Dummy visit counts for pure policy collection. - temp_visit_list = [0.0 for _ in range(self._env.action_space.n)] + temp_visit_list = [0.0 for i in range(self._env.action_space.n)] - # --- Main Collection Loop --- while True: with self._timer: - # Get observations from ready environments. + # Get current ready env obs. obs = self._env.ready_obs + new_available_env_id = set(obs.keys()).difference(ready_env_id) - ready_env_id.update(list(new_available_env_id)[:remain_episode]) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) + + # NOTE: If waiting for N environments to synchronize, it may result in some environments not being completed (done) by the time of return. + # However, the current muzero_collector does not properly maintain the global self.last_game_segments, leading to some data not being collected. + + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs = list(stack_obs.values()) + + action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} + to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} + timestep_dict = {env_id: timestep_dict[env_id] for env_id in ready_env_id} - # Prepare policy inputs. - stack_obs_list = [game_segments[env_id].get_obs() for env_id in ready_env_id] action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] to_play = [to_play_dict[env_id] for env_id in ready_env_id] timestep = [timestep_dict[env_id] for env_id in ready_env_id] - stack_obs_array = to_ndarray(stack_obs_list) - stack_obs_tensor = prepare_observation(stack_obs_array, self.policy_config.model.model_type) - stack_obs_tensor = torch.from_numpy(stack_obs_tensor).to(self.policy_config.device) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} + + stack_obs = to_ndarray(stack_obs) + # return stack_obs shape: [B, S*C, W, H] e.g. [8, 4*1, 96, 96] + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device) # ============================================================== - # Policy Forward Pass + # Key policy forward step # ============================================================== - policy_input = { - 'data': stack_obs_tensor, - 'action_mask': action_mask, - 'temperature': temperature, - 'to_play': to_play, - 'epsilon': epsilon, - 'ready_env_id': ready_env_id, - 'timestep': timestep - } - if self.task_id is not None: - policy_input['task_id'] = self.task_id + # print(f'ready_env_id:{ready_env_id}') + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep) - policy_output = self._policy.forward(**policy_input) + pred_next_text_with_env_id = {k: v['predicted_next_text'] if 'predicted_next_text' in v else -1 for k, v in policy_output.items()} + + # Extract relevant policy outputs + actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} + value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} + pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} + timestep_dict_with_env_id = { + k: v['timestep'] if 'timestep' in v else -1 for k, v in policy_output.items() + } + # PPO: calculate log_prob from policy_logits and action + # Use predicted_policy_logits to compute log probability of the selected action + log_prob_dict_with_env_id = {} + # import pudb;pudb.set_trace() - # --- Unpack policy outputs --- - actions, value_dict, pred_value_dict = {}, {}, {} - distributions_dict, visit_entropy_dict = {}, {} + for k, v in policy_output.items(): + if 'predicted_policy_logits' in v: + # Compute log_prob from policy_logits: log(softmax(logits)[action]) + policy_logits = np.array(v['predicted_policy_logits']) + action = v['action'] + # Apply softmax to get probabilities (with numerical stability) + exp_logits = np.exp(policy_logits - np.max(policy_logits)) + probs = exp_logits / (np.sum(exp_logits) + 1e-8) + # Get log probability of the selected action + log_prob = np.log(probs[action] + 1e-8) + log_prob_dict_with_env_id[k] = float(log_prob) + else: + # Fallback: if no policy_logits available, set to 0.0 + log_prob_dict_with_env_id[k] = 0.0 + if self.policy_config.sampled_algo: - root_sampled_actions_dict = {} - if self.policy_config.gumbel_algo: - improved_policy_dict, completed_value_dict = {}, {} + root_sampled_actions_dict_with_env_id = { + k: v['root_sampled_actions'] for k, v in policy_output.items() + } + + if not collect_with_pure_policy: + distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in + policy_output.items()} + visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in + policy_output.items()} + + if self.policy_config.gumbel_algo: + improved_policy_dict_with_env_id = {k: v['improved_policy_probs'] for k, v in + policy_output.items()} + completed_value_with_env_id = {k: v['roots_completed_value'] for k, v in policy_output.items()} + + # Initialize dictionaries to store results + actions = {} + value_dict = {} + pred_value_dict = {} + timestep_dict = {} + pred_next_text = {} + log_prob_dict = {} # PPO: log_prob dictionary + + if not collect_with_pure_policy: + distributions_dict = {} + visit_entropy_dict = {} + + if self.policy_config.sampled_algo: + root_sampled_actions_dict = {} + + if self.policy_config.gumbel_algo: + improved_policy_dict = {} + completed_value_dict = {} + # Populate the result dictionaries for env_id in ready_env_id: - output = policy_output[env_id] - actions[env_id] = output['action'] - value_dict[env_id] = output['searched_value'] - pred_value_dict[env_id] = output['predicted_value'] - + actions[env_id] = actions_with_env_id.pop(env_id) + value_dict[env_id] = value_dict_with_env_id.pop(env_id) + pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id) + timestep_dict[env_id] = timestep_dict_with_env_id.pop(env_id) + pred_next_text[env_id] = pred_next_text_with_env_id.pop(env_id) + log_prob_dict[env_id] = log_prob_dict_with_env_id.pop(env_id) # PPO: populate log_prob + if not collect_with_pure_policy: - distributions_dict[env_id] = output['visit_count_distributions'] - visit_entropy_dict[env_id] = output['visit_count_distribution_entropy'] + distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id) + if self.policy_config.sampled_algo: - root_sampled_actions_dict[env_id] = output['root_sampled_actions'] - if self.policy_config.gumbel_algo: - improved_policy_dict[env_id] = output['improved_policy_probs'] - completed_value_dict[env_id] = output['roots_completed_value'] + root_sampled_actions_dict[env_id] = root_sampled_actions_dict_with_env_id.pop(env_id) + + visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) + if self.policy_config.gumbel_algo: + improved_policy_dict[env_id] = improved_policy_dict_with_env_id.pop(env_id) + completed_value_dict[env_id] = completed_value_with_env_id.pop(env_id) + # ============================================================== - # Environment Interaction + # Interact with the environment # ============================================================== timesteps = self._env.step(actions) - interaction_duration = self._timer.value / len(timesteps) if timesteps else 0 - + interaction_duration = self._timer.value / len(timesteps) + + groundtrut_next_text = {} for env_id, episode_timestep in timesteps.items(): with self._timer: - # Handle abnormal timesteps by resetting the environment and policy state. if episode_timestep.info.get('abnormal', False): + # If there is an abnormal episode_timestep, reset all the related variables(including this env). + # suppose there is no reset param, reset this env self._env.reset({env_id: None}) self._policy.reset([env_id]) self._reset_stat(env_id) - self._logger.info(f"Environment {env_id} returned an abnormal step, info: {episode_timestep.info}") + self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, episode_timestep.info)) continue - obs, reward, done, info = episode_timestep.obs, episode_timestep.reward, episode_timestep.done, episode_timestep.info + - # Store MCTS search statistics. + if "world_model_cfg" in self.policy_config.model and self.policy_config.model.world_model_cfg.obs_type == 'text': + obs_input_ids = torch.tensor(obs['observation'], dtype=torch.long) # shape: [L] + obs_attn_mask = torch.tensor(obs['obs_attn_mask'][0], dtype=torch.long) + valid_input_ids = obs_input_ids[obs_attn_mask == 1].tolist() + + groundtrut_next_text[env_id] = self._env._envs[env_id].tokenizer.decode(valid_input_ids, skip_special_tokens=True) + text_bleu = compute_bleu(reference=groundtrut_next_text[env_id], prediction=pred_next_text[env_id]) + # Whether to output text comparisons with high BLEU scores to evaluate the effectiveness of decoding the next latent. + if text_bleu > 0.85: + os.makedirs("./log", exist_ok=True) + with open("./log/bleu_match.txt", "a", encoding="utf-8") as f: + f.write(f"pred_text={pred_next_text[env_id]}\ngroundtruth_text={groundtrut_next_text[env_id]}\ntext_bleu={text_bleu:.4f}\n\n") + if collect_with_pure_policy: game_segments[env_id].store_search_stats(temp_visit_list, 0) else: if self.policy_config.sampled_algo: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id]) + game_segments[env_id].store_search_stats( + distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] + ) elif self.policy_config.gumbel_algo: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], improved_policy=improved_policy_dict[env_id]) + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], + improved_policy=improved_policy_dict[env_id]) else: game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) + + # PPO: store log_prob for PPO training + game_segments[env_id].old_log_prob_segment.append(log_prob_dict[env_id]) - # Append the current transition to the game segment. - append_args = (actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], to_play_dict[env_id]) + # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} + # in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment`` if self.policy_config.use_ture_chance_label_in_chance_encoder: - append_args += (chance_dict[env_id],) - append_args += (timestep_dict[env_id],) - game_segments[env_id].append(*append_args) + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id], timestep_dict[env_id], chance_dict[env_id] + ) + else: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id], timestep_dict[env_id] + ) - # Update state dictionaries for the next step. + # NOTE: the position of code snippet is very important. + # the obs['action_mask'] and obs['to_play'] are corresponding to the next action action_mask_dict[env_id] = to_ndarray(obs['action_mask']) to_play_dict[env_id] = to_ndarray(obs['to_play']) timestep_dict[env_id] = to_ndarray(obs.get('timestep', -1)) if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict[env_id] = to_ndarray(obs['chance']) - dones[env_id] = done if not self.policy_config.ignore_done else False - - # Update logging and priority data. + if self.policy_config.ignore_done: + dones[env_id] = False + else: + dones[env_id] = done + if not collect_with_pure_policy: visit_entropies_lst[env_id] += visit_entropy_dict[env_id] if self.policy_config.gumbel_algo: completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) - + eps_steps_lst[env_id] += 1 + if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: + # only for UniZero now + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + + total_transitions += 1 + if self.policy_config.use_priority: pred_values_lst[env_id].append(pred_value_dict[env_id]) search_values_lst[env_id].append(value_dict[env_id]) + if self.policy_config.gumbel_algo and not collect_with_pure_policy: + improved_policy_lst[env_id].append(improved_policy_dict[env_id]) - # Update the observation window with the new observation. + # append the newest obs observation_window_stack[env_id].append(to_ndarray(obs['observation'])) # ============================================================== - # Game Segment Saving Logic + # we will save a game segment if it is the end of the game or the next game segment is finished. # ============================================================== - # If a segment is full, pad and save the previous segment. + + # if game segment is full, we will save the last game segment if game_segments[env_id].is_full(): + # pad over last segment trajectory if last_game_segments[env_id] is not None: - self.pad_and_save_last_trajectory(env_id, last_game_segments, last_game_priorities, game_segments, dones) + # TODO(pu): return the one game segment + self.pad_and_save_last_trajectory( + env_id, last_game_segments, last_game_priorities, game_segments, dones + ) - # Calculate priorities for the now-completed `last_game_segment`. + # calculate priority priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - pred_values_lst[env_id], search_values_lst[env_id] = [], [] + pred_values_lst[env_id] = [] + search_values_lst[env_id] = [] + if self.policy_config.gumbel_algo and not collect_with_pure_policy: + improved_policy_lst[env_id] = [] - # The current segment becomes the `last_game_segment`. + # the current game_segments become last_game_segment last_game_segments[env_id] = game_segments[env_id] last_game_priorities[env_id] = priorities - # Start a new game segment. - game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) + # create new GameSegment + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) game_segments[env_id].reset(observation_window_stack[env_id]) + + # Inherit episode_id from the previous segment (same episode) + game_segments[env_id].episode_id = last_game_segments[env_id].episode_id self._env_info[env_id]['step'] += 1 + if "world_model_cfg" in self.policy_config.model and self.policy_config.model.world_model_cfg.obs_type == 'text': + self._env_info[env_id]['text_bleu'] += text_bleu + collected_step += 1 self._env_info[env_id]['time'] += self._timer.value + interaction_duration - - # --- Episode Termination Handling --- - if done: - collected_episode += 1 - reward = info['eval_episode_return'] - log_info = {'reward': reward, 'time': self._env_info[env_id]['time'], 'step': self._env_info[env_id]['step']} + if episode_timestep.done: + reward = episode_timestep.info['eval_episode_return'] + info = { + 'reward': reward, + 'time': self._env_info[env_id]['time'], + 'step': self._env_info[env_id]['step'], + } + if "world_model_cfg" in self.policy_config.model and self.policy_config.model.world_model_cfg.obs_type == 'text': + info.update({'text_bleu':self._env_info[env_id]['text_bleu'] / self._env_info[env_id]['step']}) + if not collect_with_pure_policy: - log_info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] if eps_steps_lst[env_id] > 0 else 0 + info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] if self.policy_config.gumbel_algo: - log_info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] if eps_steps_lst[env_id] > 0 else 0 - self._episode_info.append(log_info) + info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] + + collected_episode += 1 + self._episode_info.append(info) - # Pad and save the segment before the final one. + # ============================================================== + # if it is the end of the game, we will save the game segment + # ============================================================== + + # NOTE: put the penultimate game segment in one episode into the trajectory_pool + # pad over 2th last game_segment using the last game_segment if last_game_segments[env_id] is not None: - self.pad_and_save_last_trajectory(env_id, last_game_segments, last_game_priorities, game_segments, dones) - - # Process and save the final segment of the episode. + self.pad_and_save_last_trajectory( + env_id, last_game_segments, last_game_priorities, game_segments, dones + ) + + # store current segment trajectory priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) + + # NOTE: put the last game segment in one episode into the trajectory_pool game_segments[env_id].game_segment_to_array() - if len(game_segments[env_id].reward_segment) > 0: + + # assert len(game_segments[env_id]) == len(priorities) + # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null + if len(game_segments[env_id].reward_segment) != 0: self.game_segment_pool.append((game_segments[env_id], priorities, dones[env_id])) - # Reset environment-specific states for a new episode. + # print(game_segments[env_id].reward_segment) + # reset the finished env and init game_segments if n_episode > self._env_num: - # Re-initialize the state for this env_id. + # Get current ready env obs. init_obs = self._env.ready_obs - while env_id not in init_obs: - self._logger.warning(f"Waiting for env {env_id} to reset...") + retry_waiting_time = 0.001 + while len(init_obs.keys()) != self._env_num: + # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) time.sleep(retry_waiting_time) + self._logger.info( + '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 + ) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format( + retry_waiting_time, self._env._env_states + ) + ) init_obs = self._env.ready_obs - + + new_available_env_id = set(init_obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) timestep_dict[env_id] = to_ndarray(init_obs[env_id].get('timestep', -1)) - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) - # Reset game segment and observation stack. - game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) - observation_window_stack[env_id].clear() - for _ in range(self.policy_config.model.frame_stack_num): - observation_window_stack[env_id].append(init_obs[env_id]['observation']) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) + + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + observation_window_stack[env_id] = deque( + [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) game_segments[env_id].reset(observation_window_stack[env_id]) last_game_segments[env_id] = None last_game_priorities[env_id] = None + + # New episode starts, assign new episode_id + game_segments[env_id].episode_id = self._global_episode_id + self._global_episode_id += 1 - # Reset tracking and logging variables. - pred_values_lst[env_id], search_values_lst[env_id] = [], [] - eps_steps_lst[env_id], visit_entropies_lst[env_id] = 0, 0 - if self.policy_config.gumbel_algo: - completed_value_lst[env_id] = 0 + # log + self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) + if not collect_with_pure_policy: + self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) + self_play_moves += eps_steps_lst[env_id] + self_play_episodes += 1 - # Reset policy and collector stats for the finished environment. - self._policy.reset([env_id]) + pred_values_lst[env_id] = [] + search_values_lst[env_id] = [] + eps_steps_lst[env_id] = 0 + visit_entropies_lst[env_id] = 0 + + # Env reset is done by env_manager automatically + self._policy.reset([env_id]) # NOTE: reset the policy for the env_id. Default reset_init_data=True. self._reset_stat(env_id) ready_env_id.remove(env_id) - # --- Check for Collection Completion --- if collected_episode >= n_episode: - # Prepare data for returning. - return_data = [ - [item[0] for item in self.game_segment_pool], - [{ - 'priorities': item[1], - 'done': item[2], + # Batch compute GAE for all episodes in the pool + self._batch_compute_gae_for_pool() + + # [data, meta_data] + return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [ + { + 'priorities': self.game_segment_pool[i][1], + 'done': self.game_segment_pool[i][2], 'unroll_plus_td_steps': self.unroll_plus_td_steps - } for item in self.game_segment_pool] + } for i in range(len(self.game_segment_pool)) ] self.game_segment_pool.clear() break - - # --- Finalize and Log --- + collected_duration = sum([d['time'] for d in self._episode_info]) - # NOTE: Only for usual DDP not for unizero_multitask pipeline. - # In DDP, aggregate statistics across all processes. - # if self._world_size > 1: - # collected_step = allreduce_data(collected_step, 'sum') - # collected_episode = allreduce_data(collected_episode, 'sum') - # collected_duration = allreduce_data(collected_duration, 'sum') + # reduce data when enables DDP + if self._world_size > 1: + # Before allreduce + self._logger.info(f"Rank {self._rank} before allreduce: collected_step={collected_step}, collected_episode={collected_episode}") + collected_step = allreduce_data(collected_step, 'sum') + collected_episode = allreduce_data(collected_episode, 'sum') + collected_duration = allreduce_data(collected_duration, 'sum') + # After allreduce + self._logger.info(f"Rank {self._rank} after allreduce: collected_step={collected_step}, collected_episode={collected_episode}") self._total_envstep_count += collected_step self._total_episode_count += collected_episode self._total_duration += collected_duration + # log self._output_log(train_iter) return return_data def _output_log(self, train_iter: int) -> None: """ Overview: - Aggregates and logs collection statistics to the console, TensorBoard, and WandB. - This method is only executed by the rank 0 process in a distributed setup. + Log the collector's data and output the log information. Arguments: - - train_iter (:obj:`int`): The current training iteration number, used as the logging step. + - train_iter (:obj:`int`): Current training iteration number for logging context. """ if self._rank != 0: return - if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: self._last_train_iter = train_iter episode_count = len(self._episode_info) envstep_count = sum([d['step'] for d in self._episode_info]) duration = sum([d['time'] for d in self._episode_info]) episode_reward = [d['reward'] for d in self._episode_info] - + if "world_model_cfg" in self.policy_config.model and self.policy_config.model.world_model_cfg.obs_type == 'text': + episode_bleu = [d['text_bleu'] for d in self._episode_info] + + if not self.collect_with_pure_policy: + visit_entropy = [d['visit_entropy'] for d in self._episode_info] + else: + visit_entropy = [0.0] + if self.policy_config.gumbel_algo: + completed_value = [d['completed_value'] for d in self._episode_info] + self._total_duration += duration info = { 'episode_count': episode_count, 'envstep_count': envstep_count, 'avg_envstep_per_episode': envstep_count / episode_count, - 'avg_envstep_per_sec': envstep_count / duration if duration > 0 else 0, - 'avg_episode_per_sec': episode_count / duration if duration > 0 else 0, + 'avg_envstep_per_sec': envstep_count / duration, + 'avg_episode_per_sec': episode_count / duration, 'collect_time': duration, 'reward_mean': np.mean(episode_reward), 'reward_std': np.std(episode_reward), @@ -652,32 +868,187 @@ def _output_log(self, train_iter: int) -> None: 'total_envstep_count': self._total_envstep_count, 'total_episode_count': self._total_episode_count, 'total_duration': self._total_duration, + 'visit_entropy': np.mean(visit_entropy), } - - if not self.collect_with_pure_policy: - visit_entropy = [d['visit_entropy'] for d in self._episode_info] - info['visit_entropy_mean'] = np.mean(visit_entropy) + if "world_model_cfg" in self.policy_config.model and self.policy_config.model.world_model_cfg.obs_type == 'text': + info.update({'text_avg_bleu':np.mean(episode_bleu)}) if self.policy_config.gumbel_algo: - completed_value = [d['completed_value'] for d in self._episode_info] - info['completed_value_mean'] = np.mean(completed_value) - + info['completed_value'] = np.mean(completed_value) self._episode_info.clear() + self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) - # Log to console - self._logger.info("Collector Training Summary:\n{}".format('\n'.join([f' {k}: {v}' for k, v in info.items()]))) - - # Log to TensorBoard and WandB for k, v in info.items(): - if self.task_id is None: - tb_prefix_iter = f'{self._instance_name}_iter/' - tb_prefix_step = f'{self._instance_name}_step/' + if k in ['each_reward']: + continue + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + if k in ['total_envstep_count']: + continue + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + + if self.policy_config.use_wandb: + wandb.log({'{}_step/'.format(self._instance_name) + k: v for k, v in info.items()}, step=self._total_envstep_count) + + def _batch_compute_gae_for_pool_bak(self) -> None: + """ + Overview: + Batch compute GAE (Generalized Advantage Estimation) for all segments in game_segment_pool + at the end of collect. Process by grouping segments by episode_id. + Original implementation using manual GAE computation. + """ + if len(self.game_segment_pool) == 0: + return + + gamma = self.ppo_gamma + gae_lambda = self.ppo_gae_lambda + + # 1. Group all segments by episode_id + episode_groups = {} # {episode_id: [(pool_idx, segment, priorities, done), ...]} + + for pool_idx in range(len(self.game_segment_pool)): + segment, priorities, done_flag = self.game_segment_pool[pool_idx] + episode_id = segment.episode_id + + if episode_id not in episode_groups: + episode_groups[episode_id] = [] + episode_groups[episode_id].append((pool_idx, segment, priorities, done_flag)) + + # 2. Compute GAE for each episode + for episode_id, segments_info in episode_groups.items(): + # Sort by pool_idx to ensure temporal order + segments_info.sort(key=lambda x: x[0]) + + # Extract values and rewards for the entire episode + all_values = [] + all_rewards = [] + segment_lengths = [] + + for pool_idx, segment, _, _ in segments_info: + seg_len = len(segment.action_segment) + segment_lengths.append(seg_len) + + # Extract values and rewards from this segment + values = segment.root_value_segment[:seg_len] + rewards = segment.reward_segment[:seg_len] + + all_values.extend(values) + all_rewards.extend(rewards) + + # Convert to numpy arrays + all_values = np.array(all_values, dtype=np.float32) + all_rewards = np.array(all_rewards, dtype=np.float32) + + # Compute GAE from back to front + advantages = np.zeros_like(all_rewards, dtype=np.float32) + returns = np.zeros_like(all_rewards, dtype=np.float32) # PPO: compute return simultaneously + gae = 0.0 + + for t in reversed(range(len(all_rewards))): + # Get next value + if t == len(all_rewards) - 1: + next_value = 0.0 # Episode end else: - tb_prefix_iter = f'{self._instance_name}_iter_task{self.task_id}/' - tb_prefix_step = f'{self._instance_name}_step_task{self.task_id}/' + next_value = all_values[t + 1] + + # TD error: δ_t = r_t + γ*V(s_{t+1}) - V(s_t) + delta = all_rewards[t] + gamma * next_value - all_values[t] - self._tb_logger.add_scalar(tb_prefix_iter + k, v, train_iter) - self._tb_logger.add_scalar(tb_prefix_step + k, v, self._total_envstep_count) + # GAE: A_t = δ_t + γ*λ*A_{t+1} + gae = delta + gamma * gae_lambda * gae + advantages[t] = gae + + # PPO: Return = Advantage + Value + returns[t] = gae + all_values[t] - if self.policy_config.use_wandb: - wandb_log_data = {tb_prefix_step + k: v for k, v in info.items()} - wandb.log(wandb_log_data, step=self._total_envstep_count) + # 3. Distribute advantages and returns back to segments + offset = 0 + for i, (pool_idx, segment, priorities, done_flag) in enumerate(segments_info): + seg_len = segment_lengths[i] + + # Assign advantages and returns + segment.advantage_segment = advantages[offset:offset + seg_len].copy() + segment.return_segment = returns[offset:offset + seg_len].copy() # PPO: assign returns + offset += seg_len + + # Update segment in pool + self.game_segment_pool[pool_idx] = (segment, priorities, done_flag) + + self._logger.info(f"Batch computed GAE for {len(episode_groups)} episodes in game_segment_pool") + + def _batch_compute_gae_for_pool(self) -> None: + """ + Overview: + Batch compute GAE (Generalized Advantage Estimation) for all segments in game_segment_pool + at the end of collect. Process by grouping segments by episode_id. + Uses ding library's GAE functions for computation. + """ + if len(self.game_segment_pool) == 0: + return + + gamma = self.ppo_gamma + gae_lambda = self.ppo_gae_lambda + + # 1. Group all segments by episode_id + episode_groups = {} # {episode_id: [(pool_idx, segment, priorities, done), ...]} + + for pool_idx in range(len(self.game_segment_pool)): + segment, priorities, done_flag = self.game_segment_pool[pool_idx] + episode_id = segment.episode_id + + if episode_id not in episode_groups: + episode_groups[episode_id] = [] + episode_groups[episode_id].append((pool_idx, segment, priorities, done_flag)) + + # 2. Compute GAE for each episode using ding library + for episode_id, segments_info in episode_groups.items(): + # Sort by pool_idx to ensure temporal order + segments_info.sort(key=lambda x: x[0]) + + # Extract values and rewards for the entire episode + all_values = [] + all_rewards = [] + segment_lengths = [] + + for pool_idx, segment, _, _ in segments_info: + seg_len = len(segment.action_segment) + segment_lengths.append(seg_len) + + # Extract values and rewards from this segment + values = segment.root_value_segment[:seg_len] + rewards = segment.reward_segment[:seg_len] + + all_values.extend(values) + all_rewards.extend(rewards) + + # Convert to torch tensors for ding library + value = torch.tensor(all_values, dtype=torch.float32) + # Create next_value: [v1, v2, ..., v_{T-1}, 0.0] for last step + next_value = torch.cat([value[1:], torch.tensor([0.0], dtype=torch.float32)]) + reward = torch.tensor(all_rewards, dtype=torch.float32) + # Create done flags: False for all steps except the last one in the episode + done = torch.zeros(len(all_rewards), dtype=torch.bool) + if len(segments_info) > 0: + # Mark the last step of the episode as done + done[-1] = True + + # Use ding library's GAE functions + compute_adv_data = gae_data(value, next_value, reward, done, None) + advantages = gae(compute_adv_data, gamma, gae_lambda) + + # Convert back to numpy and compute returns + advantages_np = advantages.cpu().numpy().astype(np.float32) + returns_np = advantages_np + np.array(all_values, dtype=np.float32) + + # 3. Distribute advantages and returns back to segments + offset = 0 + for i, (pool_idx, segment, priorities, done_flag) in enumerate(segments_info): + seg_len = segment_lengths[i] + + # Assign advantages and returns + segment.advantage_segment = advantages_np[offset:offset + seg_len].copy() + segment.return_segment = returns_np[offset:offset + seg_len].copy() + offset += seg_len + + # Update segment in pool + self.game_segment_pool[pool_idx] = (segment, priorities, done_flag) + + self._logger.info(f"Batch computed GAE for {len(episode_groups)} episodes in game_segment_pool using ding library") diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index 5fc680d97..331275692 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -1,122 +1,115 @@ import copy -import threading import time from collections import namedtuple -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Optional, Callable, Tuple, Dict, Any import numpy as np import torch import wandb from ding.envs import BaseEnvManager -from ding.torch_utils import to_item, to_ndarray, to_tensor -from ding.utils import (EasyTimer, broadcast_object_list, build_logger, - get_rank, get_world_size) -from ding.worker.collector.base_serial_evaluator import (ISerialEvaluator, - VectorEvalMonitor) -from ditk import logging +from ding.torch_utils import to_ndarray, to_item, to_tensor +from ding.utils import build_logger, EasyTimer +from ding.utils import get_world_size, get_rank, broadcast_object_list +from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor from easydict import EasyDict + from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation +from lzero.policy.utils import mz_network_output_unpack class MuZeroEvaluator(ISerialEvaluator): """ Overview: - The Evaluator for MCTS-based reinforcement learning algorithms, such as MuZero, EfficientZero, and Sampled EfficientZero. + The Evaluator class for MCTS+RL algorithms, such as MuZero, EfficientZero, and Sampled EfficientZero. Interfaces: __init__, reset, reset_policy, reset_env, close, should_eval, eval Properties: env, policy """ - # Default configuration for the MuZeroEvaluator. - config = dict( - # The frequency of evaluation, measured in training iterations. - eval_freq=5000, - ) - @classmethod def default_config(cls: type) -> EasyDict: """ Overview: - Get the default configuration of the MuZeroEvaluator. + Retrieve the default configuration for the evaluator by merging evaluator-specific defaults with other + defaults and any user-provided configuration. Returns: - - cfg (:obj:`EasyDict`): An EasyDict object representing the default configuration. + - cfg (:obj:`EasyDict`): The default configuration for the evaluator. """ cfg = EasyDict(copy.deepcopy(cls.config)) cfg.cfg_type = cls.__name__ + 'Dict' return cfg + config = dict( + # Evaluate every "eval_freq" training iterations. + eval_freq=50, + ) + def __init__( self, eval_freq: int = 1000, n_evaluator_episode: int = 3, - stop_value: float = 1e6, - env: Optional[BaseEnvManager] = None, - policy: Optional[namedtuple] = None, - tb_logger: Optional['SummaryWriter'] = None, - exp_name: str = 'default_experiment', - instance_name: str = 'evaluator', - policy_config: Optional[EasyDict] = None, - task_id: Optional[int] = None, + stop_value: int = 1e6, + env: BaseEnvManager = None, + policy: namedtuple = None, + tb_logger: 'SummaryWriter' = None, # noqa + exp_name: Optional[str] = 'default_experiment', + instance_name: Optional[str] = 'evaluator', + policy_config: 'policy_config' = None, # noqa ) -> None: """ Overview: - Initializes the MuZeroEvaluator. This evaluator is compatible with MuZero, Sampled MuZero, Gumbel MuZero, EfficientZero, UniZero, and Sampled UniZero (i.e., all algorithms except AlphaZero). + Initialize the evaluator with configuration settings for various components such as logger helper and timer. Arguments: - - eval_freq (:obj:`int`): The frequency, in training iterations, at which to run evaluation. - - n_evaluator_episode (:obj:`int`): The total number of episodes to run during each evaluation. - - stop_value (:obj:`float`): The reward threshold at which training is considered converged and will stop. - - env (:obj:`Optional[BaseEnvManager]`): An optional environment manager for evaluation. - - policy (:obj:`Optional[namedtuple]`): An optional policy for evaluation. - - tb_logger (:obj:`Optional['SummaryWriter']`): An optional TensorBoard logger. - - exp_name (:obj:`str`): The name of the experiment, used for logging. - - instance_name (:obj:`str`): The name of this evaluator instance. - - policy_config (:obj:`Optional[EasyDict]`): Configuration for the policy. - - task_id (:obj:`Optional[int]`): The unique identifier for the task. If None, the evaluator operates in single-task mode. In a multi-task setting, each task corresponds to a specific evaluator instance. + - eval_freq (:obj:`int`): Evaluation frequency in terms of training steps. + - n_evaluator_episode (:obj:`int`): Number of episodes to evaluate in total. + - stop_value (:obj:`float`): A reward threshold above which the training is considered converged. + - env (:obj:`Optional[BaseEnvManager]`): An optional instance of a subclass of BaseEnvManager. + - policy (:obj:`Optional[namedtuple]`): An optional API namedtuple defining the policy for evaluation. + - tb_logger (:obj:`Optional[SummaryWriter]`): Optional TensorBoard logger instance. + - exp_name (:obj:`str`): Name of the experiment, used to determine output directory. + - instance_name (:obj:`str`): Name of this evaluator instance. + - policy_config (:obj:`Optional[dict]`): Optional configuration for the game policy. """ - self.stop_event = threading.Event() # Event to signal a stop, e.g., due to a timeout. - self.task_id = task_id self._eval_freq = eval_freq self._exp_name = exp_name self._instance_name = instance_name - self._rank = get_rank() - # Initialize logger. Only rank 0 needs a full logger with TensorBoard. - if self._rank == 0: + # Logger (Monitor will be initialized in policy setter) + # Only rank == 0 learner needs monitor and tb_logger, others only need text_logger to display terminal output. + if get_rank() == 0: if tb_logger is not None: self._logger, _ = build_logger( - f'./{self._exp_name}/log/{self._instance_name}', self._instance_name, need_tb=False + './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False ) self._tb_logger = tb_logger else: self._logger, self._tb_logger = build_logger( - f'./{self._exp_name}/log/{self._instance_name}', self._instance_name + './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name ) else: - if tb_logger is not None: - self._logger, _ = build_logger( - f'./{self._exp_name}/log/{self._instance_name}', self._instance_name, need_tb=False - ) - self._tb_logger = tb_logger + self._logger, self._tb_logger = None, None # for close elegantly - logging.info(f'rank {self._rank}, self.task_id: {self.task_id}') self.reset(policy, env) + self._timer = EasyTimer() self._default_n_episode = n_evaluator_episode self._stop_value = stop_value # ============================================================== - # MCTS+RL related core properties + # MCTS+RL related core code # ============================================================== self.policy_config = policy_config def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset the environment. If a new environment is provided, it replaces the old one. + Reset the environment for the evaluator, optionally replacing it with a new environment. + If _env is None, reset the old environment. If _env is not None, replace the old environment + in the evaluator with the new passed in environment and launch. Arguments: - - _env (:obj:`Optional[BaseEnvManager]`): New environment manager to use. If None, resets the existing environment. + - _env (:obj:`Optional[BaseEnvManager]`): An optional new environment instance to replace the existing one. """ if _env is not None: self._env = _env @@ -128,22 +121,29 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: - Reset the policy. If a new policy is provided, it replaces the old one. + Reset the policy for the evaluator, optionally replacing it with a new policy. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the evaluator with the new passed in policy. Arguments: - - _policy (:obj:`Optional[namedtuple]`): New policy to use. If None, resets the existing policy. + - _policy (:obj:`Optional[namedtuple]`): An optional new policy namedtuple to replace the existing one. """ - assert hasattr(self, '_env'), "Please set environment first." + assert hasattr(self, '_env'), "please set env first" if _policy is not None: self._policy = _policy - self._policy.reset(task_id=self.task_id) + self._policy.reset() def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset both the policy and the environment. + Reset both the policy and environment for the evaluator, optionally replacing them. + If _env is None, reset the old environment. + If _env is not None, replace the old environment in the evaluator with the new passed in \ + environment and launch. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the evaluator with the new passed in policy. Arguments: - - _policy (:obj:`Optional[namedtuple]`): New policy to use. - - _env (:obj:`Optional[BaseEnvManager]`): New environment manager to use. + - _policy (:obj:`Optional[namedtuple]`): An optional new policy namedtuple to replace the existing one. + - _env (:obj:`Optional[BaseEnvManager]`): An optional new environment instance to replace the existing one. """ if _env is not None: self.reset_env(_env) @@ -152,36 +152,37 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self._max_episode_return = float("-inf") self._last_eval_iter = 0 self._end_flag = False + def close(self) -> None: """ Overview: - Close the evaluator, including the environment and the TensorBoard logger. + Close the evaluator, the environment, flush and close the TensorBoard logger if applicable. """ if self._end_flag: return self._end_flag = True - if hasattr(self, '_env'): - self._env.close() + self._env.close() if self._tb_logger: self._tb_logger.flush() self._tb_logger.close() - def __del__(self) -> None: + def __del__(self): """ Overview: - Destructor that ensures `close` is called to clean up resources. + Execute the close command and close the evaluator. __del__ is automatically called \ + to destroy the evaluator instance when the evaluator finishes its work """ self.close() def should_eval(self, train_iter: int) -> bool: """ Overview: - Determine whether it's time to run an evaluation based on the training iteration. + Determine whether to initiate evaluation based on the training iteration count and evaluation frequency. Arguments: - - train_iter (:obj:`int`): The current training iteration. + - train_iter (:obj:`int`): The current count of training iterations. Returns: - - (:obj:`bool`): True if evaluation should be run, otherwise False. + - (:obj:`bool`): `True` if evaluation should be initiated, otherwise `False`. """ if train_iter == self._last_eval_iter: return False @@ -192,64 +193,54 @@ def should_eval(self, train_iter: int) -> bool: def eval( self, - save_ckpt_fn: Optional[Callable] = None, + save_ckpt_fn: Callable = None, train_iter: int = -1, envstep: int = -1, n_episode: Optional[int] = None, return_trajectory: bool = False, - ) -> Tuple[bool, Dict[str, Any]]: + ) -> Tuple[bool, float]: """ Overview: - Run a full evaluation process. It will evaluate the current policy, log the results, - and save a checkpoint if a new best performance is achieved. + Evaluate the current policy, storing the best policy if it achieves the highest historical reward. Arguments: - - save_ckpt_fn (:obj:`Optional[Callable]`): A function to save a checkpoint. Called when a new best reward is achieved. - - train_iter (:obj:`int`): The current training iteration. - - envstep (:obj:`int`): The current total environment steps. - - n_episode (:obj:`Optional[int]`): The number of episodes to evaluate. Defaults to the value set in `__init__`. - - return_trajectory (:obj:`bool`): Whether to return the collected `game_segments` in the result dictionary. + - save_ckpt_fn (:obj:`Optional[Callable]`): Optional function to save a checkpoint when a new best reward is achieved. + - train_iter (:obj:`int`): The current training iteration count. + - envstep (:obj:`int`): The current environment step count. + - n_episode (:obj:`Optional[int]`): Optional number of evaluation episodes; defaults to the evaluator's setting. + - return_trajectory (:obj:`bool`): Return the evaluated trajectory `game_segments` in `episode_info` if True. Returns: - - stop_flag (:obj:`bool`): A flag indicating whether the training should stop (e.g., if the stop value is reached). - - episode_info (:obj:`Dict[str, Any]`): A dictionary containing evaluation results, such as rewards and episode lengths. + - stop_flag (:obj:`bool`): Indicates whether the training can be stopped based on the stop value. + - episode_info (:obj:`Dict[str, Any]`): A dictionary containing information about the evaluation episodes. """ - if torch.cuda.is_available() and self.task_id is not None: - # NOTE: important for unizero_multitask pipeline. - self._logger.info(f"=========in eval() Rank {get_rank()} ===========") - device = torch.cuda.current_device() - self._logger.info(f"before set device: {device}") - torch.cuda.set_device(get_rank()) - self._logger.info(f"after set device: {get_rank()}") - + # the evaluator only works on rank0 episode_info = None stop_flag = False - if self.task_id is not None and get_rank() >= 0: - # In a multi-task setting, each task corresponds to a specific evaluator instance. - eval_flag = True - elif self.task_id is None and get_rank() == 0: - # In a single-task setting, only evaluate rank 0. - eval_flag = True - else: - eval_flag = False - - if eval_flag: + if get_rank() == 0: if n_episode is None: n_episode = self._default_n_episode - assert n_episode is not None, "Please specify the number of evaluation episodes (n_episode)." + assert n_episode is not None, "please indicate eval n_episode" envstep_count = 0 eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode) env_nums = self._env.env_num self._env.reset() - self._policy.reset(task_id=self.task_id) + self._policy.reset() - # Initializations + # initializations init_obs = self._env.ready_obs - # Wait for all environments to be ready, especially in subprocess-based environment managers. retry_waiting_time = 0.001 while len(init_obs.keys()) != self._env_num: - self._logger.info(f"Waiting for all environments to reset. Current ready envs: {list(init_obs.keys())}") + # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) time.sleep(retry_waiting_time) + self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, + self._env._env_states) + ) init_obs = self._env.ready_obs action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} @@ -258,17 +249,20 @@ def eval( timestep_dict = {} for i in range(env_nums): if 'timestep' not in init_obs[i]: - self._logger.warning(f"'timestep' key is missing in init_obs[{i}], assigning value -1") + if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: + print(f"Warning: 'timestep' key is missing in init_obs[{i}]. Assigning value -1. Please note that the unizero algorithm may require the 'timestep' key in init_obs.") timestep_dict[i] = to_ndarray(init_obs[i].get('timestep', -1)) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} + dones = np.array([False for _ in range(env_nums)]) game_segments = [ GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config, - task_id=self.task_id + config=self.policy_config ) for _ in range(env_nums) ] for i in range(env_nums): @@ -279,91 +273,111 @@ def eval( ready_env_id = set() remain_episode = n_episode eps_steps_lst = np.zeros(env_nums) + with self._timer: while not eval_monitor.is_finished(): - # Check if a timeout has occurred. - if self.stop_event.is_set(): - # self.stop_event may be set in safe_eval() methd in lzero/entry/utils.py - self._logger.info("[EVALUATOR]: Evaluation aborted due to timeout.") - break - - # Get observations from ready environments. + # Get current ready env obs. obs = self._env.ready_obs new_available_env_id = set(obs.keys()).difference(ready_env_id) ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) - # Prepare stacked observations and other inputs for the policy. + # In a parallel evaluation setting, it's possible for all active environments to finish their + # episodes simultaneously. This can leave `ready_env_id` temporarily empty while the environments + # are being reset by the manager. + # To prevent processing an empty batch, which would cause an IndexError or other errors downstream, + # we check if `ready_env_id` is empty. If so, we sleep briefly to prevent a busy-wait, + # and `continue` to the next loop iteration to wait for newly reset environments to become available. + if not ready_env_id: + time.sleep(0.01) + continue + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} stack_obs = list(stack_obs.values()) + + action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} + to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} + timestep_dict = {env_id: timestep_dict[env_id] for env_id in ready_env_id} action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] to_play = [to_play_dict[env_id] for env_id in ready_env_id] timestep = [timestep_dict[env_id] for env_id in ready_env_id] + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} + stack_obs = to_ndarray(stack_obs) stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() # ============================================================== - # Policy Forward Pass + # policy forward (without MCTS - use predicted_policy_logits directly) # ============================================================== - if self.task_id is None: - # Single-task setting - policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, timestep=timestep) - else: - # Multi-task setting - policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, timestep=timestep, task_id=self.task_id) - - # Unpack policy outputs. - actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} - distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} - if self.policy_config.sampled_algo: - root_sampled_actions_dict_with_env_id = {k: v['root_sampled_actions'] for k, v in policy_output.items()} - value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} - pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} - timestep_dict_with_env_id = {k: v.get('timestep', -1) for k, v in policy_output.items()} - visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in policy_output.items()} - - # Remap outputs from policy's internal IDs to environment IDs. - actions, distributions_dict, value_dict, pred_value_dict, timestep_dict, visit_entropy_dict = {}, {}, {}, {}, {}, {} - if self.policy_config.sampled_algo: - root_sampled_actions_dict = {} - - for index, env_id in enumerate(ready_env_id): - actions[env_id] = actions_with_env_id.pop(env_id) - distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id) - if self.policy_config.sampled_algo: - root_sampled_actions_dict[env_id] = root_sampled_actions_dict_with_env_id.pop(env_id) - value_dict[env_id] = value_dict_with_env_id.pop(env_id) - pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id) - timestep_dict[env_id] = timestep_dict_with_env_id.pop(env_id) - visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) + # Call policy.forward() to get policy output (this will still run MCTS internally, + # but we'll ignore the MCTS action and use predicted_policy_logits instead) + policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, timestep=timestep) + + # Extract predicted_policy_logits from policy output + actions = {} + for idx, env_id in enumerate(ready_env_id): + if env_id in policy_output: + # Get predicted_policy_logits (from initial inference, before MCTS) + if 'predicted_policy_logits' in policy_output[env_id]: + policy_logits = np.array(policy_output[env_id]['predicted_policy_logits']) + + # Apply action mask (action_mask is a list, indexed by ready_env_id order) + masked_logits = policy_logits.copy() + masked_logits[action_mask[idx] == 0] = -1e9 + + # Select action with highest probability (argmax) - deterministic evaluation + action = np.argmax(masked_logits) + actions[env_id] = int(action) + else: + # Fallback: use MCTS action if predicted_policy_logits not available + actions[env_id] = policy_output[env_id]['action'] + else: + # If env_id not in output, use a default action (should not happen) + actions[env_id] = 0 # ============================================================== - # Environment Interaction + # Interact with env. # ============================================================== timesteps = self._env.step(actions) timesteps = to_tensor(timesteps, dtype=torch.float32) + for env_id, episode_timestep in timesteps.items(): obs, reward, done, info = episode_timestep.obs, episode_timestep.reward, episode_timestep.done, episode_timestep.info + # obs_input_ids = obs['observation'].long() + # obs_attn_mask = obs['obs_attn_mask'][0].long() + # valid_input_ids = obs_input_ids[obs_attn_mask == 1].tolist() + eps_steps_lst[env_id] += 1 - # This reset logic is specific to UniZero-like models. if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False, task_id=self.task_id) + # only for UniZero now + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], timestep_dict[env_id] - ) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id], timestep_dict[env_id], chance_dict[env_id] + ) + else: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id], timestep_dict[env_id] + ) - # IMPORTANT: The action_mask and to_play from the new observation correspond to the *next* state. + # NOTE: the position of code snippet is very important. + # the obs['action_mask'] and obs['to_play'] are corresponding to next action action_mask_dict[env_id] = to_ndarray(obs['action_mask']) to_play_dict[env_id] = to_ndarray(obs['to_play']) timestep_dict[env_id] = to_ndarray(obs.get('timestep', -1)) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_dict[env_id] = to_ndarray(obs['chance']) dones[env_id] = done if episode_timestep.done: + # Env reset is done by env_manager automatically. self._policy.reset([env_id]) reward = episode_timestep.info['eval_episode_return'] saved_info = {'eval_episode_return': episode_timestep.info['eval_episode_return']} @@ -372,105 +386,115 @@ def eval( eval_monitor.update_info(env_id, saved_info) eval_monitor.update_reward(env_id, reward) self._logger.info( - f"[EVALUATOR] env {env_id} finished episode, final reward: {eval_monitor.get_latest_reward(env_id)}, " - f"current episode count: {eval_monitor.get_current_episode()}" + "[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format( + env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() + ) ) - # If there are more episodes to run than available environments, reset and reuse this one. + # reset the finished env and init game_segments if n_episode > self._env_num: + # Get current ready env obs. init_obs = self._env.ready_obs - # Wait for the environment to be ready again. + retry_waiting_time = 0.001 while len(init_obs.keys()) != self._env_num: - self._logger.info(f"Waiting for env {env_id} to reset. Current ready envs: {list(init_obs.keys())}") + # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info( + 'Before sleeping, the _env_states is {}'.format(self._env._env_states) + ) time.sleep(retry_waiting_time) + self._logger.info( + '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 + ) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format( + retry_waiting_time, self._env._env_states + ) + ) init_obs = self._env.ready_obs new_available_env_id = set(init_obs.keys()).difference(ready_env_id) ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) - # Re-initialize state for the new episode. action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) timestep_dict[env_id] = to_ndarray(init_obs[env_id].get('timestep', -1)) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) + game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config, - task_id=self.task_id + config=self.policy_config ) + game_segments[env_id].reset( - [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)] + [ + init_obs[env_id]['observation'] + for _ in range(self.policy_config.model.frame_stack_num) + ] ) eps_steps_lst[env_id] = 0 - # NOTE: Reset the policy state for this env_id. `reset_init_data` defaults to True. - self._policy.reset([env_id]) + + # Env reset is done by env_manager automatically. + self._policy.reset([env_id]) # NOTE: reset the policy for the env_id. Default reset_init_data=True. ready_env_id.remove(env_id) envstep_count += 1 - + duration = self._timer.value episode_return = eval_monitor.get_episode_return() info = { 'train_iter': train_iter, - 'ckpt_name': f'iteration_{train_iter}.pth.tar', + 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), 'episode_count': n_episode, 'envstep_count': envstep_count, - 'avg_envstep_per_episode': envstep_count / n_episode if n_episode > 0 else 0, + 'avg_envstep_per_episode': envstep_count / n_episode, 'evaluate_time': duration, - 'avg_envstep_per_sec': envstep_count / duration if duration > 0 else 0, - 'avg_time_per_episode': n_episode / duration if duration > 0 else 0, + 'avg_envstep_per_sec': envstep_count / duration, + 'avg_time_per_episode': n_episode / duration, 'reward_mean': np.mean(episode_return), 'reward_std': np.std(episode_return), 'reward_max': np.max(episode_return), - 'reward_min': np.min(episode_return), + 'reward_min': np.min(episode_return) + # 'each_reward': episode_return, } episode_info = eval_monitor.get_episode_info() if episode_info is not None: info.update(episode_info) - - logging.info(f'rank {self._rank}, self.task_id: {self.task_id}') self._logger.info(self._logger.get_tabulate_vars_hor(info)) - - # Log to TensorBoard and WandB. for k, v in info.items(): - if k in ['train_iter', 'ckpt_name', 'each_reward'] or not np.isscalar(v): + if k in ['train_iter', 'ckpt_name', 'each_reward']: continue - if self.task_id is None: - self._tb_logger.add_scalar(f'{self._instance_name}_iter/{k}', v, train_iter) - self._tb_logger.add_scalar(f'{self._instance_name}_step/{k}', v, envstep) - else: - self._tb_logger.add_scalar(f'{self._instance_name}_iter_task{self.task_id}/{k}', v, train_iter) - self._tb_logger.add_scalar(f'{self._instance_name}_step_task{self.task_id}/{k}', v, envstep) + if not np.isscalar(v): + continue + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) if self.policy_config.use_wandb: - wandb.log({f'{self._instance_name}_step/{k}': v}, step=envstep) + wandb.log({'{}_step/'.format(self._instance_name) + k: v}, step=envstep) - # Check for new best performance and save checkpoint. - mean_episode_return = np.mean(episode_return) - if mean_episode_return > self._max_episode_return: + episode_return = np.mean(episode_return) + if episode_return > self._max_episode_return: if save_ckpt_fn: save_ckpt_fn('ckpt_best.pth.tar') - self._max_episode_return = mean_episode_return - - # Check if the stop condition is met. - stop_flag = mean_episode_return >= self._stop_value and train_iter > 0 + self._max_episode_return = episode_return + stop_flag = episode_return >= self._stop_value and train_iter > 0 if stop_flag: self._logger.info( - f"[LightZero serial pipeline] Current episode_return: {mean_episode_return} is greater than " - f"stop_value: {self._stop_value}. The agent is considered converged." + "[LightZero serial pipeline] " + + "Current episode_return: {} is greater than stop_value: {}".format(episode_return, + self._stop_value) + + ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." ) - # NOTE: Only for usual DDP not for unizero_multitask pipeline. - # Finalize DDP synchronization for evaluation results. - # if get_world_size() > 1: - # objects = [stop_flag, episode_info] - # print(f'rank {self._rank}, self.task_id: {self.task_id}') - # print('before broadcast_object_list') - # broadcast_object_list(objects, src=0) - # print('evaluator after broadcast_object_list') - # stop_flag, episode_info = objects + if get_world_size() > 1: + objects = [stop_flag, episode_info] + broadcast_object_list(objects, src=0) + stop_flag, episode_info = objects episode_info = to_item(episode_info) if return_trajectory: diff --git a/lzero/worker/test_gae_computation.py b/lzero/worker/test_gae_computation.py new file mode 100644 index 000000000..3c9f792fc --- /dev/null +++ b/lzero/worker/test_gae_computation.py @@ -0,0 +1,295 @@ +""" +Test script for GAE computation in MuZeroCollector. + +This script tests both the original (_batch_compute_gae_for_pool_bak) and new +(_batch_compute_gae_for_pool) implementations to ensure they produce the same results. +""" + +import numpy as np +import torch +from easydict import EasyDict +from unittest.mock import Mock, MagicMock + +# Import the collector class +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from lzero.worker.muzero_collector import MuZeroCollector +from lzero.mcts.buffer.game_segment import GameSegment + + +class MockGameSegment: + """Mock GameSegment for testing""" + def __init__(self, episode_id, values, rewards, actions): + self.episode_id = episode_id + self.root_value_segment = values.copy() + self.reward_segment = rewards.copy() + self.action_segment = actions.copy() + self.advantage_segment = [] + self.return_segment = [] + + def __len__(self): + return len(self.action_segment) + + +def create_test_data(): + """Create test data for GAE computation""" + + # Test case 1: Simple episode with 5 steps + # Episode 0: 3 segments + episode_0_segments = [ + MockGameSegment( + episode_id=0, + values=[1.0, 2.0, 3.0], + rewards=[0.1, 0.2, 0.3], + actions=[0, 1, 0] + ), + MockGameSegment( + episode_id=0, + values=[4.0, 5.0], + rewards=[0.4, 0.5], + actions=[1, 0] + ), + MockGameSegment( + episode_id=0, + values=[6.0], + rewards=[0.6], + actions=[0] + ), + ] + + # Test case 2: Another episode with 4 steps + # Episode 1: 2 segments + episode_1_segments = [ + MockGameSegment( + episode_id=1, + values=[10.0, 11.0], + rewards=[1.0, 1.1], + actions=[0, 1] + ), + MockGameSegment( + episode_id=1, + values=[12.0, 13.0], + rewards=[1.2, 1.3], + actions=[1, 0] + ), + ] + + return episode_0_segments, episode_1_segments + + +def create_mock_collector(): + """Create a mock MuZeroCollector instance""" + collector = Mock(spec=MuZeroCollector) + + # Set up PPO parameters + collector.ppo_gamma = 0.99 + collector.ppo_gae_lambda = 0.95 + + # Create logger mock + collector._logger = Mock() + collector._logger.info = Mock() + + return collector + + +def test_gae_computation(): + """Test GAE computation with both implementations""" + + print("=" * 80) + print("Testing GAE Computation") + print("=" * 80) + + # Create test data + episode_0_segments, episode_1_segments = create_test_data() + + # Create game_segment_pool + game_segment_pool = [] + priorities = 1.0 + done_flag = False + + # Add episode 0 segments + for seg in episode_0_segments: + game_segment_pool.append((seg, priorities, done_flag)) + + # Add episode 1 segments + for seg in episode_1_segments: + game_segment_pool.append((seg, priorities, done_flag)) + + # Make deep copies for comparison + def copy_segment(seg): + new_seg = MockGameSegment( + seg.episode_id, + seg.root_value_segment.copy() if isinstance(seg.root_value_segment, list) else seg.root_value_segment.copy(), + seg.reward_segment.copy() if isinstance(seg.reward_segment, list) else seg.reward_segment.copy(), + seg.action_segment.copy() if isinstance(seg.action_segment, list) else seg.action_segment.copy() + ) + return new_seg + + pool_bak = [(copy_segment(seg), priorities, done_flag) for seg, _, _ in game_segment_pool] + pool_new = [(copy_segment(seg), priorities, done_flag) for seg, _, _ in game_segment_pool] + + collector_bak = create_mock_collector() + collector_bak.game_segment_pool = pool_bak + + collector_new = create_mock_collector() + collector_new.game_segment_pool = pool_new + + # Test original implementation + print("\n[1] Testing original implementation (_batch_compute_gae_for_pool_bak)...") + MuZeroCollector._batch_compute_gae_for_pool_bak(collector_bak) + + # Test new implementation + print("[2] Testing new implementation (_batch_compute_gae_for_pool)...") + MuZeroCollector._batch_compute_gae_for_pool(collector_new) + + # Compare results + print("\n[3] Comparing results...") + print("-" * 80) + + all_match = True + for i, ((seg_bak, _, _), (seg_new, _, _)) in enumerate(zip(pool_bak, pool_new)): + print(f"\nSegment {i} (Episode {seg_bak.episode_id}):") + print(f" Length: {len(seg_bak.action_segment)}") + + # Compare advantages + adv_bak = np.array(seg_bak.advantage_segment) + adv_new = np.array(seg_new.advantage_segment) + + if len(adv_bak) != len(adv_new): + print(f" ❌ Advantage length mismatch: {len(adv_bak)} vs {len(adv_new)}") + all_match = False + else: + max_diff = np.max(np.abs(adv_bak - adv_new)) + print(f" Advantages - Max difference: {max_diff:.6f}") + if max_diff > 1e-5: + print(f" ❌ Advantages don't match!") + print(f" Original: {adv_bak}") + print(f" New: {adv_new}") + all_match = False + else: + print(f" ✓ Advantages match") + + # Compare returns + ret_bak = np.array(seg_bak.return_segment) + ret_new = np.array(seg_new.return_segment) + + if len(ret_bak) != len(ret_new): + print(f" ❌ Return length mismatch: {len(ret_bak)} vs {len(ret_new)}") + all_match = False + else: + max_diff = np.max(np.abs(ret_bak - ret_new)) + print(f" Returns - Max difference: {max_diff:.6f}") + if max_diff > 1e-5: + print(f" ❌ Returns don't match!") + print(f" Original: {ret_bak}") + print(f" New: {ret_new}") + all_match = False + else: + print(f" ✓ Returns match") + + # Print detailed values for first segment of each episode + if i == 0 or i == len(episode_0_segments): + print(f"\n Detailed values for Segment {i}:") + print(f" Values: {seg_bak.root_value_segment}") + print(f" Rewards: {seg_bak.reward_segment}") + print(f" Advantages (original): {adv_bak}") + print(f" Advantages (new): {adv_new}") + print(f" Returns (original): {ret_bak}") + print(f" Returns (new): {ret_new}") + + print("\n" + "=" * 80) + if all_match: + print("✓ All tests passed! Both implementations produce identical results.") + else: + print("❌ Tests failed! Implementations produce different results.") + print("=" * 80) + + return all_match + + +def test_manual_gae_verification(): + """Manually verify GAE computation for a simple case""" + print("\n" + "=" * 80) + print("Manual GAE Verification (Simple Case)") + print("=" * 80) + + # Simple case: 3 steps, gamma=0.99, lambda=0.95 + values = np.array([1.0, 2.0, 3.0], dtype=np.float32) + rewards = np.array([0.1, 0.2, 0.3], dtype=np.float32) + gamma = 0.99 + gae_lambda = 0.95 + + print(f"\nInput:") + print(f" Values: {values}") + print(f" Rewards: {rewards}") + print(f" Gamma: {gamma}") + print(f" Lambda: {gae_lambda}") + + # Manual computation + advantages = np.zeros_like(rewards) + gae_val = 0.0 + + print(f"\nManual computation (backward):") + for t in reversed(range(len(rewards))): + if t == len(rewards) - 1: + next_value = 0.0 + else: + next_value = values[t + 1] + + delta = rewards[t] + gamma * next_value - values[t] + gae_val = delta + gamma * gae_lambda * gae_val + advantages[t] = gae_val + print(f" t={t}: delta={delta:.6f}, gae={gae_val:.6f}") + + returns = advantages + values + + print(f"\nResults:") + print(f" Advantages: {advantages}") + print(f" Returns: {returns}") + + # Test with ding library + print(f"\nDing library computation:") + from ding.rl_utils import gae_data, gae + + value = torch.tensor(values, dtype=torch.float32) + next_value = torch.cat([value[1:], torch.tensor([0.0], dtype=torch.float32)]) + reward = torch.tensor(rewards, dtype=torch.float32) + done = torch.tensor([False, False, True], dtype=torch.bool) + + compute_adv_data = gae_data(value, next_value, reward, done, None) + advantages_ding = gae(compute_adv_data, gamma, gae_lambda) + returns_ding = advantages_ding + value + + print(f" Advantages: {advantages_ding.cpu().numpy()}") + print(f" Returns: {returns_ding.cpu().numpy()}") + + # Compare + max_diff_adv = np.max(np.abs(advantages - advantages_ding.cpu().numpy())) + max_diff_ret = np.max(np.abs(returns - returns_ding.cpu().numpy())) + + print(f"\nComparison:") + print(f" Advantages max diff: {max_diff_adv:.6f}") + print(f" Returns max diff: {max_diff_ret:.6f}") + + if max_diff_adv < 1e-5 and max_diff_ret < 1e-5: + print(" ✓ Manual and ding library results match!") + else: + print(" ❌ Results don't match!") + + +if __name__ == "__main__": + # Run manual verification first + test_manual_gae_verification() + + # Run full test + success = test_gae_computation() + + if success: + print("\n✅ All tests passed!") + exit(0) + else: + print("\n❌ Some tests failed!") + exit(1) + diff --git a/zoo/atari/config/atari_unizero_config.py b/zoo/atari/config/atari_unizero_config.py index 01b75eab8..2c68c80fe 100644 --- a/zoo/atari/config/atari_unizero_config.py +++ b/zoo/atari/config/atari_unizero_config.py @@ -65,7 +65,6 @@ def main(env_id='PongNoFrameskip-v4', seed=0): num_heads=8, embed_dim=768, obs_type='image', - encoder_type='resnet', env_num=max(collector_env_num, evaluator_env_num), rotary_emb=False, ), diff --git a/zoo/atari/config/atari_unizero_ppo_config.py b/zoo/atari/config/atari_unizero_ppo_config.py new file mode 100644 index 000000000..e1d427754 --- /dev/null +++ b/zoo/atari/config/atari_unizero_ppo_config.py @@ -0,0 +1,131 @@ +from easydict import EasyDict + +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map + + +def main(env_id='PongNoFrameskip-v4', seed=0): + action_space_size = atari_env_action_space_map[env_id] + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + collector_env_num = 8 + game_segment_length = 20 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(5e5) + batch_size = 64 + num_unroll_steps = 10 + infer_context_length = 4 + num_layers = 2 + replay_ratio = 0.25 + + # TODO: only for debug + # collector_env_num = 2 + # game_segment_length = 20 + # evaluator_env_num = 2 + # num_simulations = 2 + # max_env_step = int(5e5) + # batch_size = 10 + # num_unroll_steps = 5 + # infer_context_length = 2 + # num_layers = 1 + # replay_ratio = 0.1 + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + atari_unizero_ppo_config = dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # TODO: only for debug + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + world_model_cfg=dict( + policy_entropy_weight=1e-4, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=max(collector_env_num, evaluator_env_num), + rotary_emb=False, + ), + ), + model_path=None, + num_unroll_steps=num_unroll_steps, + replay_ratio=replay_ratio, + batch_size=batch_size, + learning_rate=0.0001, + num_simulations=num_simulations, + train_start_after_envsteps=2000, + # train_start_after_envsteps=0, # TODO: only for debug + game_segment_length=game_segment_length, + replay_buffer_size=int(1e6), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # Whether to use pure policy (without MCTS) for data collection + collect_with_pure_policy=True, + # Whether to use pure policy (without MCTS) for evaluation + # If not set, will use collect_with_pure_policy value + eval_with_pure_policy=True, + # Whether to use online learning (clear replay_buffer after each training iteration) + online_learning=True, + # PPO configuration for GAE computation + ppo=dict( + gamma=0.99, # Discount factor + gae_lambda=0.95, # GAE lambda parameter + clip_ratio=0.2, # PPO clipping ratio + value_coef=0.5, # Value loss coefficient + entropy_coef=0.01, # Entropy loss coefficient + ), + ), + ) + atari_unizero_ppo_config = EasyDict(atari_unizero_ppo_config) + main_config = atari_unizero_ppo_config + + atari_unizero_ppo_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + atari_unizero_ppo_create_config = EasyDict(atari_unizero_ppo_create_config) + create_config = atari_unizero_ppo_create_config + + main_config.exp_name = f'data_lz/data_unizero_ppo/{env_id[:-14]}/{env_id[:-14]}_uz_ppo_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + from lzero.entry import train_unizero + train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process some environment.') + parser.add_argument('--env', type=str, help='The environment to use', default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + main(args.env, args.seed) diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_unizero_ppo_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_unizero_ppo_config.py new file mode 100644 index 000000000..3c7ceba01 --- /dev/null +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_unizero_ppo_config.py @@ -0,0 +1,114 @@ +from easydict import EasyDict +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +reanalyze_ratio = 0. +update_per_collect = None +replay_ratio = 0.25 +max_env_step = int(5e5) +batch_size = 256 +num_unroll_steps = 10 +infer_context_length = 4 +norm_type = 'BN' + +# debug +# collector_env_num = 2 +# n_episode = 2 +# evaluator_env_num = 2 +# num_simulations = 5 +# batch_size = 2 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +lunarlander_unizero_ppo_config = dict( + exp_name=f'data_unizero_ppo/lunarlander_unizero_ppo_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_H{num_unroll_steps}-infer{infer_context_length}_bs{batch_size}_{norm_type}_seed0', + env=dict( + env_name='LunarLander-v2', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=8, + action_space_size=4, + model_type='mlp', + norm_type=norm_type, + world_model_cfg=dict( + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + action_space_size=4, + group_size=8, # NOTE: sim_norm + num_layers=4, + num_heads=4, + embed_dim=256, + env_num=max(collector_env_num, evaluator_env_num), + obs_type='vector', + norm_type=norm_type, + ), + ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + num_unroll_steps=num_unroll_steps, + cuda=True, + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='AdamW', + piecewise_decay_lr_scheduler=False, + learning_rate=0.0001, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # Whether to use pure policy (without MCTS) for data collection + collect_with_pure_policy=True, + # Whether to use pure policy (without MCTS) for evaluation + # If not set, will use collect_with_pure_policy value + eval_with_pure_policy=True, + # Whether to use online learning (clear replay_buffer after each training iteration) + online_learning=True, + # PPO configuration for GAE computation + ppo=dict( + gamma=0.99, # Discount factor + gae_lambda=0.95, # GAE lambda parameter + clip_ratio=0.2, # PPO clipping ratio + value_coef=0.5, # Value loss coefficient + entropy_coef=0.01, # Entropy loss coefficient + ), + ), +) +lunarlander_unizero_ppo_config = EasyDict(lunarlander_unizero_ppo_config) +main_config = lunarlander_unizero_ppo_config + +lunarlander_unizero_ppo_create_config = dict( + env=dict( + type='lunarlander', + import_names=['zoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), +) +lunarlander_unizero_ppo_create_config = EasyDict(lunarlander_unizero_ppo_create_config) +create_config = lunarlander_unizero_ppo_create_config + +if __name__ == "__main__": + from lzero.entry import train_unizero + train_unizero([main_config, create_config], seed=0, max_env_step=max_env_step) + diff --git a/zoo/box2d/lunarlander/envs/lunarlander_env.py b/zoo/box2d/lunarlander/envs/lunarlander_env.py index 4f5f15834..729110827 100755 --- a/zoo/box2d/lunarlander/envs/lunarlander_env.py +++ b/zoo/box2d/lunarlander/envs/lunarlander_env.py @@ -131,8 +131,13 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: Returns: - timestep (:obj:`BaseEnvTimestep`): The timestep information including observation, reward, done flag, and info. """ - if action.shape == (1,): - action = action.item() # 0-dim array + # Handle both numpy array and int/float action types + if hasattr(action, 'shape'): + if action.shape == (1,): + action = action.item() # 0-dim array + elif isinstance(action, (int, float, np.integer, np.floating)): + # Already a scalar, use as is + action = int(action) if self._act_scale: action = affine_transform(action, min_val=-1, max_val=1) if self._save_replay_gif: diff --git a/zoo/classic_control/cartpole/config/cartpole_unizero_config.py b/zoo/classic_control/cartpole/config/cartpole_unizero_config.py index 7cb8d98d4..3c7ed3acf 100644 --- a/zoo/classic_control/cartpole/config/cartpole_unizero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_unizero_config.py @@ -78,6 +78,21 @@ replay_buffer_size=int(1e6), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, + # Whether to use pure policy (without MCTS) for data collection + collect_with_pure_policy=True, + # Whether to use pure policy (without MCTS) for evaluation + # If not set, will use collect_with_pure_policy value + eval_with_pure_policy=True, + # Whether to use online learning (clear replay_buffer after each training iteration) + online_learning=True, + # PPO configuration for GAE computation + ppo=dict( + gamma=0.99, # Discount factor + gae_lambda=0.95, # GAE lambda parameter + clip_ratio=0.2, # PPO clipping ratio + value_coef=0.5, # Value loss coefficient + entropy_coef=0.01, # Entropy loss coefficient + ), ), )