diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 5c2953a5c..8619a18eb 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -87,19 +87,19 @@ def __init__(self, config, vecenv, policy, logger=None): ) device = config['device'] - self.observations = torch.zeros(segments, horizon, *obs_space.shape, + self.observations = torch.zeros(segments, horizon + 1, *obs_space.shape, dtype=pufferlib.pytorch.numpy_to_torch_dtype_dict[obs_space.dtype], pin_memory=device == 'cuda' and config['cpu_offload'], device='cpu' if config['cpu_offload'] else device) - self.actions = torch.zeros(segments, horizon, *atn_space.shape, device=device, + self.actions = torch.zeros(segments, horizon + 1, *atn_space.shape, device=device, dtype=pufferlib.pytorch.numpy_to_torch_dtype_dict[atn_space.dtype]) - self.values = torch.zeros(segments, horizon, device=device) - self.logprobs = torch.zeros(segments, horizon, device=device) - self.rewards = torch.zeros(segments, horizon, device=device) - self.terminals = torch.zeros(segments, horizon, device=device) - self.truncations = torch.zeros(segments, horizon, device=device) - self.ratio = torch.ones(segments, horizon, device=device) - self.importance = torch.ones(segments, horizon, device=device) + self.values = torch.zeros(segments, horizon + 1, device=device) + self.logprobs = torch.zeros(segments, horizon + 1, device=device) + self.rewards = torch.zeros(segments, horizon + 1, device=device) + self.terminals = torch.zeros(segments, horizon + 1, device=device) + self.truncations = torch.zeros(segments, horizon + 1, device=device) + self.ratio = torch.ones(segments, horizon + 1, device=device) + self.importance = torch.ones(segments, horizon + 1, device=device) self.ep_lengths = torch.zeros(total_agents, device=device, dtype=torch.int32) self.ep_indices = torch.arange(total_agents, device=device, dtype=torch.int32) self.free_idx = total_agents @@ -282,7 +282,7 @@ def evaluate(self): # Note: We are not yet handling masks in this version self.ep_lengths[env_id] += 1 - if l+1 >= config['bptt_horizon']: + if l+1 > config['bptt_horizon']: num_full = env_id.stop - env_id.start self.ep_indices[env_id] = self.free_idx + torch.arange(num_full, device=config['device']).int() self.ep_lengths[env_id] = 0