diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index 8729d3c30..c6210a8a2 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -8,12 +8,17 @@ import torch as th import torch.utils.tensorboard as thboard import tqdm -from stable_baselines3.common import base_class, on_policy_algorithm, policies, vec_env +from stable_baselines3.common import base_class +from stable_baselines3.common import buffers as sb3_buffers +from stable_baselines3.common import on_policy_algorithm, policies, type_aliases +from stable_baselines3.common import utils as sb3_utils +from stable_baselines3.common import vec_env from stable_baselines3.sac import policies as sac_policies from torch.nn import functional as F from imitation.algorithms import base from imitation.data import buffer, rollout, types, wrappers +from imitation.policies import replay_buffer_wrapper from imitation.rewards import reward_nets, reward_wrapper from imitation.util import logger, networks, util @@ -246,6 +251,38 @@ def __init__( else: self.gen_train_timesteps = gen_train_timesteps + self.is_gen_on_policy = isinstance( + self.gen_algo, on_policy_algorithm.OnPolicyAlgorithm + ) + if self.is_gen_on_policy: + rollout_buffer = self.gen_algo.rollout_buffer + self.gen_algo.rollout_buffer = ( + replay_buffer_wrapper.RolloutBufferRewardWrapper( + buffer_size=self.gen_train_timesteps // rollout_buffer.n_envs, + observation_space=rollout_buffer.observation_space, + action_space=rollout_buffer.action_space, + rollout_buffer_class=rollout_buffer.__class__, + reward_fn=self.reward_train.predict_processed, + device=rollout_buffer.device, + gae_lambda=rollout_buffer.gae_lambda, + gamma=rollout_buffer.gamma, + n_envs=rollout_buffer.n_envs, + ) + ) + else: + replay_buffer = self.gen_algo.replay_buffer + self.gen_algo.replay_buffer = ( + replay_buffer_wrapper.ReplayBufferRewardWrapper( + buffer_size=self.gen_train_timesteps, + observation_space=replay_buffer.observation_space, + action_space=replay_buffer.action_space, + replay_buffer_class=sb3_buffers.ReplayBuffer, + reward_fn=self.reward_train.predict_processed, + device=replay_buffer.device, + n_envs=replay_buffer.n_envs, + ) + ) + if gen_replay_buffer_capacity is None: gen_replay_buffer_capacity = self.gen_train_timesteps self._gen_replay_buffer = buffer.ReplayBuffer( @@ -382,41 +419,126 @@ def train_disc( return train_stats - def train_gen( + def collect_rollouts( self, total_timesteps: Optional[int] = None, + callback: type_aliases.MaybeCallback = None, learn_kwargs: Optional[Mapping] = None, - ) -> None: - """Trains the generator to maximize the discriminator loss. - - After the end of training populates the generator replay buffer (used in - discriminator training) with `self.disc_batch_size` transitions. + ): + """Collect rollouts. Args: total_timesteps: The number of transitions to sample from `self.venv_train` during training. By default, `self.gen_train_timesteps`. + callback: Callback that will be called at each step + (and at the beginning and end of the rollout) learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()` method. """ - if total_timesteps is None: - total_timesteps = self.gen_train_timesteps if learn_kwargs is None: learn_kwargs = {} - with self.logger.accumulate_means("gen"): - self.gen_algo.learn( - total_timesteps=total_timesteps, - reset_num_timesteps=False, - callback=self.gen_callback, + if total_timesteps is None: + total_timesteps = self.gen_train_timesteps + + # total timesteps should be per env + total_timesteps = total_timesteps // self.gen_algo.n_envs + # NOTE (Taufeeque): call setup_learn or not? + if "eval_env" not in learn_kwargs: + total_timesteps, callback = self.gen_algo._setup_learn( + total_timesteps, + eval_env=None, + callback=callback, **learn_kwargs, ) - self._global_step += 1 + else: + total_timesteps, callback = self.gen_algo._setup_learn( + total_timesteps, + callback=callback, + **learn_kwargs, + ) + callback.on_training_start(locals(), globals()) + if self.is_gen_on_policy: + self.gen_algo.collect_rollouts( + self.gen_algo.env, + callback, + self.gen_algo.rollout_buffer, + n_rollout_steps=total_timesteps, + ) + rollouts = None + else: + self.gen_algo.train_freq = total_timesteps + self.gen_algo._convert_train_freq() + rollouts = self.gen_algo.collect_rollouts( + self.gen_algo.env, + train_freq=self.gen_algo.train_freq, + action_noise=self.gen_algo.action_noise, + callback=callback, + learning_starts=self.gen_algo.learning_starts, + replay_buffer=self.gen_algo.replay_buffer, + ) + + if self.is_gen_on_policy: + if ( + len(self.gen_algo.ep_info_buffer) > 0 + and len(self.gen_algo.ep_info_buffer[0]) > 0 + ): + self.logger.record( + "rollout/ep_rew_mean", + sb3_utils.safe_mean( + [ep_info["r"] for ep_info in self.gen_algo.ep_info_buffer] + ), + ) + self.logger.record( + "rollout/ep_len_mean", + sb3_utils.safe_mean( + [ep_info["l"] for ep_info in self.gen_algo.ep_info_buffer] + ), + ) + self.logger.record( + "time/total_timesteps", + self.gen_algo.num_timesteps, + exclude="tensorboard", + ) + else: + self.gen_algo._dump_logs() gen_trajs, ep_lens = self.venv_buffering.pop_trajectories() self._check_fixed_horizon(ep_lens) gen_samples = rollout.flatten_trajectories_with_rew(gen_trajs) self._gen_replay_buffer.store(gen_samples) + callback.on_training_end() + return rollouts + + def train_gen( + self, + rollouts, + ) -> None: + """Trains the generator to maximize the discriminator loss. + + After the end of training populates the generator replay buffer (used in + discriminator training) with `self.disc_batch_size` transitions. + """ + with self.logger.accumulate_means("gen"): + # self.gen_algo.learn( + # total_timesteps=total_timesteps, + # reset_num_timesteps=False, + # callback=self.gen_callback, + # **learn_kwargs, + # ) + if self.is_gen_on_policy: + self.gen_algo.train() + else: + if self.gen_algo.gradient_steps >= 0: + gradient_steps = self.gen_algo.gradient_steps + else: + gradient_steps = rollouts.episode_timesteps + self.gen_algo.train( + batch_size=self.gen_algo.batch_size, + gradient_steps=gradient_steps, + ) + self._global_step += 1 def train( self, @@ -445,11 +567,14 @@ def train( f"total_timesteps={total_timesteps})!" ) for r in tqdm.tqdm(range(0, n_rounds), desc="round"): - self.train_gen(self.gen_train_timesteps) + rollouts = self.collect_rollouts( + self.gen_train_timesteps, self.gen_callback + ) for _ in range(self.n_disc_updates_per_round): with networks.training(self.reward_train): # switch to training mode (affects dropout, normalization) self.train_disc() + self.train_gen(rollouts) if callback: callback(r) self.logger.dump(self._global_step) diff --git a/src/imitation/policies/replay_buffer_wrapper.py b/src/imitation/policies/replay_buffer_wrapper.py index 6d0d70449..535aff1a7 100644 --- a/src/imitation/policies/replay_buffer_wrapper.py +++ b/src/imitation/policies/replay_buffer_wrapper.py @@ -3,15 +3,16 @@ from typing import Mapping, Type import numpy as np +import torch as th from gym import spaces -from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.buffers import BaseBuffer, ReplayBuffer, RolloutBuffer from stable_baselines3.common.type_aliases import ReplayBufferSamples from imitation.rewards.reward_function import RewardFn from imitation.util import util -def _samples_to_reward_fn_input( +def _replay_samples_to_reward_fn_input( samples: ReplayBufferSamples, ) -> Mapping[str, np.ndarray]: """Convert a sample from a replay buffer to a numpy array.""" @@ -23,6 +24,18 @@ def _samples_to_reward_fn_input( ) +def _rollout_samples_to_reward_fn_input( + buffer: RolloutBuffer, +) -> Mapping[str, np.ndarray]: + """Convert a sample from a rollout buffer to a numpy array.""" + return dict( + state=buffer.observations, + action=buffer.actions, + next_state=buffer.next_observations, + done=buffer.dones, + ) + + class ReplayBufferRewardWrapper(ReplayBuffer): """Relabel the rewards in transitions sampled from a ReplayBuffer.""" @@ -50,7 +63,9 @@ def __init__( # DictReplayBuffer because the current RewardFn only takes in NumPy array-based # inputs, and SAC is the only use case for ReplayBuffer relabeling. See: # https://github.com/HumanCompatibleAI/imitation/pull/459#issuecomment-1201997194 - assert replay_buffer_class is ReplayBuffer, "only ReplayBuffer is supported" + assert ( + replay_buffer_class is ReplayBuffer + ), f"only ReplayBuffer is supported: given {replay_buffer_class}" assert not isinstance(observation_space, spaces.Dict) self.replay_buffer = replay_buffer_class( buffer_size, @@ -80,7 +95,7 @@ def full(self, full: bool): def sample(self, *args, **kwargs): samples = self.replay_buffer.sample(*args, **kwargs) - rewards = self.reward_fn(**_samples_to_reward_fn_input(samples)) + rewards = self.reward_fn(**_replay_samples_to_reward_fn_input(samples)) shape = samples.rewards.shape device = samples.rewards.device rewards_th = util.safe_to_tensor(rewards).reshape(shape).to(device) @@ -101,3 +116,120 @@ def _get_samples(self): "_get_samples() is intentionally not implemented." "This method should not be called.", ) + + +class RolloutBufferRewardWrapper(BaseBuffer): + """Relabel the rewards in transitions sampled from a RolloutBuffer.""" + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + *, + rollout_buffer_class: Type[RolloutBuffer], + reward_fn: RewardFn, + **kwargs, + ): + """Builds RolloutBufferRewardWrapper. + + Args: + buffer_size: Max number of elements in the buffer + observation_space: Observation space + action_space: Action space + rollout_buffer_class: Class of the rollout buffer. + reward_fn: Reward function for reward relabeling. + **kwargs: keyword arguments for RolloutBuffer. + """ + # Note(yawen-d): we directly inherit RolloutBuffer and leave out the case of + # DictRolloutBuffer because the current RewardFn only takes in NumPy array-based + # inputs, and GAIL/AIRL is the only use case for RolloutBuffer relabeling. See: + # https://github.com/HumanCompatibleAI/imitation/pull/459#issuecomment-1201997194 + assert rollout_buffer_class is RolloutBuffer, "only RolloutBuffer is supported" + assert not isinstance(observation_space, spaces.Dict) + self.rollout_buffer = rollout_buffer_class( + buffer_size, + observation_space, + action_space, + **kwargs, + ) + self.reward_fn = reward_fn + _base_kwargs = {k: v for k, v in kwargs.items() if k in ["device", "n_envs"]} + super().__init__(buffer_size, observation_space, action_space, **_base_kwargs) + + @property + def pos(self) -> int: + return self.rollout_buffer.pos + + @property + def values(self): + return self.rollout_buffer.values + + @property + def observations(self): + return self.rollout_buffer.observations + + @property + def actions(self): + return self.rollout_buffer.actions + + @property + def log_probs(self): + return self.rollout_buffer.log_probs + + @property + def advantages(self): + return self.rollout_buffer.advantages + + @property + def rewards(self): + return self.rollout_buffer.rewards + + @property + def returns(self): + return self.rollout_buffer.returns + + @pos.setter + def pos(self, pos: int): + self.rollout_buffer.pos = pos + + @property + def full(self) -> bool: + return self.rollout_buffer.full + + @full.setter + def full(self, full: bool): + self.rollout_buffer.full = full + + def reset(self): + self.rollout_buffer.reset() + + def get(self, *args, **kwargs): + if not self.rollout_buffer.generator_ready: + input_dict = _rollout_samples_to_reward_fn_input(self.rollout_buffer) + rewards = np.zeros_like(self.rollout_buffer.rewards) + for i in range(self.buffer_size): + rewards[i] = self.reward_fn(**{k: v[i] for k, v in input_dict.items()}) + + self.rollout_buffer.rewards = rewards + self.rollout_buffer.compute_returns_and_advantage( + self.last_values, self.last_dones + ) + ret = self.rollout_buffer.get(*args, **kwargs) + return ret + + def add(self, *args, **kwargs): + self.rollout_buffer.add(*args, **kwargs) + + def _get_samples(self): + raise NotImplementedError( + "_get_samples() is intentionally not implemented." + "This method should not be called.", + ) + + def compute_returns_and_advantage( + self, last_values: th.Tensor, dones: np.ndarray + ) -> None: + self.last_values = last_values + self.last_dones = dones + self.rollout_buffer.compute_returns_and_advantage(last_values, dones) diff --git a/src/imitation/rewards/reward_nets.py b/src/imitation/rewards/reward_nets.py index 4e6c747e3..00a634af6 100644 --- a/src/imitation/rewards/reward_nets.py +++ b/src/imitation/rewards/reward_nets.py @@ -723,7 +723,7 @@ def forward( # series of remaining potential shapings can lead to reward shaping # that does not preserve the optimal policy if the episodes have variable # length! - new_shaping = (1 - done.float()) * new_shaping_output + new_shaping = (1 - done.float().flatten()) * new_shaping_output final_rew = ( base_reward_net_output + self.discount_factor * new_shaping