diff --git a/README.md b/README.md index 4cadf675a..35af7ba5c 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ The `-s` flag sets up a virtual screen at 1280x720 resolution with 24-bit color ### Distributional realism -We provide a PufferDrive implementation of the [Waymo Open Sim Agents Challenge (WOSAC)](https://waymo.com/open/challenges/2025/sim-agents/) for fast, easy evaluation of how well your trained agent matches distributional properties of human behavior. See details [here](https://github.com/Emerge-Lab/PufferDrive/main/pufferlib/ocean/benchmark). +We provide a PufferDrive implementation of the [Waymo Open Sim Agents Challenge (WOSAC)](https://waymo.com/open/challenges/2025/sim-agents/) for fast, easy evaluation of how well your trained agent matches distributional properties of human behavior. See details [here](https://github.com/Emerge-Lab/PufferDrive/tree/main/pufferlib/ocean/benchmark). WOSAC evaluation with random policy: ```bash diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index 13671d3a6..e50775d51 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -20,6 +20,8 @@ hidden_size = 256 [env] num_agents = 1024 +; If True, we control non-vehicle entities as well (e.g., pedestrians, cyclists) +control_non_vehicles = False ; Options: discrete, continuous action_type = discrete ; Options: classic, jerk @@ -44,10 +46,16 @@ resample_frequency = 910 num_maps = 1000 ; Determines which step of the trajectory to initialize the agents at upon reset init_steps = 0 -; Options: "control_vehicles", "control_agents", "control_tracks_to_predict", "control_sdc_only" +; Options: "control_vehicles", "control_agents", "control_wosac", "control_sdc_only" control_mode = "control_vehicles" -; Options: "created_all_valid", "create_only_controlled" -init_mode = "create_all_valid" +; Options: "created_all_valid", "create_only_controlled", "dynamic_no_agents" +init_mode = "dynamic_no_agents" +; Only for dynamic_no_agents init_mode +num_agents_per_world = 32 +vehicle_width = 2.0 +vehicle_length = 4.5 +vehicle_height = 1.8 +goal_curriculum = 30.0 [train] total_timesteps = 2_000_000_000 @@ -87,7 +95,7 @@ show_grid = False show_lasers = False ; Display human xy logs in the background show_human_logs = True -; Options: str to path (e.g., "resources/drive/binaries/map_001.bin"), None +; Options: List[str to path], str to path (e.g., "resources/drive/binaries/map_001.bin"), None render_map = none [eval] @@ -99,7 +107,7 @@ wosac_realism_eval = False wosac_num_rollouts = 32 # Number of policy rollouts per scene wosac_init_steps = 10 # When to start the simulation wosac_num_agents = 256 # Total number of WOSAC agents to evaluate -wosac_control_mode = "control_tracks_to_predict" # Control the tracks to predict +wosac_control_mode = "control_wosac" # Control the tracks to predict wosac_init_mode = "create_all_valid" # Initialize from the tracks to predict wosac_goal_behavior = 2 # Stop when reaching the goal wosac_goal_radius = 2.0 # Can shrink goal radius for WOSAC evaluation @@ -145,3 +153,9 @@ min = 0.0 max = 1.0 mean = 0.5 scale = auto + +[controlled_exp.train.learning_rate] +values = [0.001, 0.003, 0.01] + +[controlled_exp.train.ent_coef] +values = [0.01, 0.005] diff --git a/pufferlib/ocean/benchmark/README.md b/pufferlib/ocean/benchmark/README.md index 2e08295f4..7bae42c2e 100644 --- a/pufferlib/ocean/benchmark/README.md +++ b/pufferlib/ocean/benchmark/README.md @@ -60,7 +60,6 @@ Steps [for every scene]: Linear acceleration: 0.4658 Angular speed: 0.5543 Angular acceleration: 0.6589 - Kinematics realism score: 0.5607 ``` These scores go to 1.0 if we use the time-dependent estimator, execpt for the smoothing factor that is used to avoid bins with 0 probability. diff --git a/pufferlib/ocean/benchmark/estimators.py b/pufferlib/ocean/benchmark/estimators.py index 7af3592a5..bee1860c9 100644 --- a/pufferlib/ocean/benchmark/estimators.py +++ b/pufferlib/ocean/benchmark/estimators.py @@ -13,7 +13,7 @@ def histogram_estimate( min_val: float, max_val: float, num_bins: int, - additive_smoothing: float = 0.1, + additive_smoothing: float, ) -> np.ndarray: """Computes log-likelihoods of samples based on histograms. @@ -68,7 +68,7 @@ def log_likelihood_estimate_timeseries( min_val: float, max_val: float, num_bins: int, - additive_smoothing: float = 0.1, + additive_smoothing: float, treat_timesteps_independently: bool = True, sanity_check: bool = False, plot_agent_idx: int = 0, @@ -120,6 +120,86 @@ def log_likelihood_estimate_timeseries( return log_probs +def bernoulli_estimate( + log_samples: np.ndarray, + sim_samples: np.ndarray, + additive_smoothing: float, +) -> np.ndarray: + """Computes log probabilities of samples based on Bernoulli distributions. + + Args: + log_samples: Boolean array of shape (n_agents, sample_size) + sim_samples: Boolean array of shape (n_agents, sample_size) + additive_smoothing: Pseudocount for Laplace smoothing + + Returns: + Shape (n_agents, sample_size) - log-likelihood of each log sample + """ + if log_samples.dtype != bool: + raise ValueError("log_samples must be boolean array for Bernoulli estimate") + if sim_samples.dtype != bool: + raise ValueError("sim_samples must be boolean array for Bernoulli estimate") + + return histogram_estimate( + log_samples.astype(float), + sim_samples.astype(float), + min_val=-0.5, + max_val=1.5, + num_bins=2, + additive_smoothing=additive_smoothing, + ) + + +def log_likelihood_estimate_scenario_level( + log_values: np.ndarray, + sim_values: np.ndarray, + min_val: float, + max_val: float, + num_bins: int, + additive_smoothing: float | None = None, + use_bernoulli: bool = False, +) -> np.ndarray: + """Computes log-likelihood estimates for scenario-level features (no time dimension). + + Args: + log_values: Shape (n_agents,) + sim_values: Shape (n_agents, n_rollouts) + min_val: Minimum value for histogram bins (ignored if use_bernoulli=True) + max_val: Maximum value for histogram bins (ignored if use_bernoulli=True) + num_bins: Number of histogram bins (ignored if use_bernoulli=True) + additive_smoothing: Pseudocount for Laplace smoothing + use_bernoulli: If True, use Bernoulli estimator for boolean features + + Returns: + Shape (n_agents,) - log-likelihood of each log feature + """ + if log_values.ndim != 1: + raise ValueError(f"log_values must be 1D, got shape {log_values.shape}") + if sim_values.ndim != 2: + raise ValueError(f"sim_values must be 2D, got shape {sim_values.shape}") + + log_values_2d = log_values[:, np.newaxis] + sim_values_2d = sim_values + + if use_bernoulli: + log_likelihood_2d = bernoulli_estimate( + log_values_2d.astype(bool), + sim_values_2d.astype(bool), + additive_smoothing=0.001, + ) + else: + log_likelihood_2d = histogram_estimate( + log_values_2d, + sim_values_2d, + min_val=min_val, + max_val=max_val, + num_bins=num_bins, + additive_smoothing=additive_smoothing, + ) + + return log_likelihood_2d[:, 0] + + def _plot_histogram_sanity_check( log_samples: np.ndarray, sim_samples: np.ndarray, diff --git a/pufferlib/ocean/benchmark/evaluator.py b/pufferlib/ocean/benchmark/evaluator.py index 4c50a422a..383384e62 100644 --- a/pufferlib/ocean/benchmark/evaluator.py +++ b/pufferlib/ocean/benchmark/evaluator.py @@ -1,11 +1,9 @@ """WOSAC evaluation class for PufferDrive.""" import torch -import time import numpy as np import pandas as pd -from pprint import pprint -from typing import Dict, Optional +from typing import Dict import matplotlib.pyplot as plt import configparser import os @@ -20,6 +18,11 @@ "linear_acceleration", "angular_speed", "angular_acceleration", + "distance_to_nearest_object", + "time_to_collision", + "collision_indication", + "distance_to_road_edge", + "offroad_indication", ] @@ -32,6 +35,7 @@ def __init__(self, config: Dict): self.init_steps = config.get("eval", {}).get("wosac_init_steps", 0) self.sim_steps = self.num_steps - self.init_steps self.num_rollouts = config.get("eval", {}).get("wosac_num_rollouts", 32) + self.device = config.get("train", {}).get("device", "cuda") wosac_metrics_path = os.path.join(os.path.dirname(__file__), "wosac.ini") self.metrics_config = configparser.ConfigParser() @@ -121,17 +125,17 @@ def compute_metrics( self, ground_truth_trajectories: Dict, simulated_trajectories: Dict, + agent_state: Dict, + road_edge_polylines: Dict, aggregate_results: bool = False, ) -> Dict: """Compute realism metrics comparing simulated and ground truth trajectories. Args: - ground_truth_trajectories: Dict with keys ['x', 'y', 'z', 'heading', 'id'] - Each trajectory has shape (n_agents, n_rollouts, n_steps) - simulated_trajectories: Dict with same keys plus 'scenario_id' - shape (n_agents, n_steps) for trajectories - shape (n_agents,) for id - list of length n_agents for scenario_id + ground_truth_trajectories: Dict with keys ['x', 'y', 'z', 'heading', 'id', 'scenario_id', 'valid'] + simulated_trajectories: Dict with keys ['x', 'y', 'z', 'heading', 'id'] + agent_state: Dict with length and width of agents. + road_edge_polylines: Dict with keys ['x', 'y', 'lengths', 'scenario_id'] Note: z-position currently not used. @@ -143,6 +147,8 @@ def compute_metrics( "Agent IDs don't match between simulated and ground truth trajectories" ) + eval_mask = ground_truth_trajectories["id"][:, 0] >= 0 + # Extract trajectories sim_x = simulated_trajectories["x"] sim_y = simulated_trajectories["y"] @@ -151,26 +157,84 @@ def compute_metrics( ref_y = ground_truth_trajectories["y"] ref_heading = ground_truth_trajectories["heading"] ref_valid = ground_truth_trajectories["valid"] + agent_length = agent_state["length"] + agent_width = agent_state["width"] + scenario_ids = ground_truth_trajectories["scenario_id"] + + # We evaluate the metrics only for the Tracks to Predict. + eval_sim_x = sim_x[eval_mask] + eval_sim_y = sim_y[eval_mask] + eval_sim_heading = sim_heading[eval_mask] + eval_ref_x = ref_x[eval_mask] + eval_ref_y = ref_y[eval_mask] + eval_ref_heading = ref_heading[eval_mask] + eval_ref_valid = ref_valid[eval_mask] + eval_agent_length = agent_length[eval_mask] + eval_agent_width = agent_width[eval_mask] + eval_scenario_ids = scenario_ids[eval_mask] # Compute features # Kinematics-related features sim_linear_speed, sim_linear_accel, sim_angular_speed, sim_angular_accel = metrics.compute_kinematic_features( - sim_x, sim_y, sim_heading + eval_sim_x, eval_sim_y, eval_sim_heading ) ref_linear_speed, ref_linear_accel, ref_angular_speed, ref_angular_accel = metrics.compute_kinematic_features( - ref_x, ref_y, ref_heading + eval_ref_x, eval_ref_y, eval_ref_heading ) # Get the log speed (linear and angular) validity. Since this is computed by # a delta between steps i-1 and i+1, we verify that both of these are # valid (logical and). - speed_validity, acceleration_validity = metrics.compute_kinematic_validity(ref_valid) + speed_validity, acceleration_validity = metrics.compute_kinematic_validity(ref_valid[eval_mask]) + + # Interaction-related features + sim_signed_distances, sim_collision_per_step, sim_time_to_collision = metrics.compute_interaction_features( + sim_x, sim_y, sim_heading, scenario_ids, agent_length, agent_width, eval_mask, device=self.device + ) + + ref_signed_distances, ref_collision_per_step, ref_time_to_collision = metrics.compute_interaction_features( + ref_x, + ref_y, + ref_heading, + scenario_ids, + agent_length, + agent_width, + eval_mask, + device=self.device, + valid=ref_valid, + ) + + # Map-based features + sim_distance_to_road_edge, sim_offroad_per_step = metrics.compute_map_features( + eval_sim_x, + eval_sim_y, + eval_sim_heading, + eval_scenario_ids, + eval_agent_length, + eval_agent_width, + road_edge_polylines, + device=self.device, + ) + + ref_distance_to_road_edge, ref_offroad_per_step = metrics.compute_map_features( + eval_ref_x, + eval_ref_y, + eval_ref_heading, + eval_scenario_ids, + eval_agent_length, + eval_agent_width, + road_edge_polylines, + device=self.device, + valid=eval_ref_valid, + ) # Compute realism metrics # Average Displacement Error (ADE) and minADE # Note: This metric is not included in the scoring meta-metric, as per WOSAC rules. - ade, min_ade = metrics.compute_displacement_error(sim_x, sim_y, ref_x, ref_y, ref_valid) + ade, min_ade = metrics.compute_displacement_error( + eval_sim_x, eval_sim_y, eval_ref_x, eval_ref_y, eval_ref_valid + ) # Log-likelihood metrics # Kinematic features log-likelihoods @@ -230,6 +294,49 @@ def compute_metrics( sanity_check=False, ) + min_val, max_val, num_bins, additive_smoothing, independent_timesteps = self._get_histogram_params( + "distance_to_nearest_object" + ) + distance_to_nearest_object_log_likelihood = estimators.log_likelihood_estimate_timeseries( + log_values=ref_signed_distances, + sim_values=sim_signed_distances, + treat_timesteps_independently=independent_timesteps, + min_val=min_val, + max_val=max_val, + num_bins=num_bins, + additive_smoothing=additive_smoothing, + sanity_check=False, + ) + + min_val, max_val, num_bins, additive_smoothing, independent_timesteps = self._get_histogram_params( + "time_to_collision" + ) + time_to_collision_log_likelihood = estimators.log_likelihood_estimate_timeseries( + log_values=ref_time_to_collision, + sim_values=sim_time_to_collision, + treat_timesteps_independently=independent_timesteps, + min_val=min_val, + max_val=max_val, + num_bins=num_bins, + additive_smoothing=additive_smoothing, + sanity_check=False, + ) + + # Map-based features log-likelihoods + min_val, max_val, num_bins, additive_smoothing, independent_timesteps = self._get_histogram_params( + "distance_to_road_edge" + ) + distance_to_road_edge_log_likelihood = estimators.log_likelihood_estimate_timeseries( + log_values=ref_distance_to_road_edge, + sim_values=sim_distance_to_road_edge, + treat_timesteps_independently=independent_timesteps, + min_val=min_val, + max_val=max_val, + num_bins=num_bins, + additive_smoothing=additive_smoothing, + sanity_check=False, + ) + speed_likelihood = np.exp( metrics._reduce_average_with_validity( linear_speed_log_likelihood, @@ -262,20 +369,89 @@ def compute_metrics( ) ) - # Get agent IDs and scenario IDs - agent_ids = ground_truth_trajectories["id"] - scenario_ids = ground_truth_trajectories["scenario_id"] + distance_to_nearest_object_likelihood = np.exp( + metrics._reduce_average_with_validity( + distance_to_nearest_object_log_likelihood, + eval_ref_valid[:, 0, :], + axis=1, + ) + ) + + time_to_collision_likelihood = np.exp( + metrics._reduce_average_with_validity( + time_to_collision_log_likelihood, + eval_ref_valid[:, 0, :], + axis=1, + ) + ) + + distance_to_road_edge_likelihood = np.exp( + metrics._reduce_average_with_validity( + distance_to_road_edge_log_likelihood, + eval_ref_valid[:, 0, :], + axis=1, + ) + ) + + # Collision likelihood is computed by aggregating in time. For invalid objects + # in the logged scenario, we need to filter possible collisions in simulation. + # `sim_collision_indication` shape: (n_samples, n_objects). + + sim_collision_indication = np.any(np.where(eval_ref_valid, sim_collision_per_step, False), axis=2) + ref_collision_indication = np.any(np.where(eval_ref_valid, ref_collision_per_step, False), axis=2) + + sim_num_collisions = np.sum(sim_collision_indication, axis=1) + ref_num_collisions = np.sum(ref_collision_indication, axis=1) + + collision_log_likelihood = estimators.log_likelihood_estimate_scenario_level( + log_values=ref_collision_indication[:, 0], + sim_values=sim_collision_indication, + min_val=0.0, + max_val=1.0, + num_bins=2, + use_bernoulli=True, + ) + collision_likelihood = np.exp(collision_log_likelihood) + + # Offroad likelihood (same pattern as collision) + sim_offroad_indication = np.any(np.where(eval_ref_valid, sim_offroad_per_step, False), axis=2) + ref_offroad_indication = np.any(np.where(eval_ref_valid, ref_offroad_per_step, False), axis=2) + + sim_num_offroad = np.sum(sim_offroad_indication, axis=1) + ref_num_offroad = np.sum(ref_offroad_indication, axis=1) + + offroad_log_likelihood = estimators.log_likelihood_estimate_scenario_level( + log_values=ref_offroad_indication[:, 0], + sim_values=sim_offroad_indication, + min_val=0.0, + max_val=1.0, + num_bins=2, + use_bernoulli=True, + ) + offroad_likelihood = np.exp(offroad_log_likelihood) + + # Get agent IDs + eval_agent_ids = ground_truth_trajectories["id"][eval_mask] df = pd.DataFrame( { - "agent_id": agent_ids.flatten(), - "scenario_id": scenario_ids.flatten(), + "agent_id": eval_agent_ids.flatten(), + "scenario_id": eval_scenario_ids.flatten(), + "num_collisions_sim": sim_num_collisions.flatten(), + "num_collisions_ref": ref_num_collisions.flatten(), + "num_offroad_sim": sim_num_offroad.flatten(), + "num_offroad_ref": ref_num_offroad.flatten(), "ade": ade, "min_ade": min_ade, "likelihood_linear_speed": speed_likelihood, "likelihood_linear_acceleration": accel_likelihood, "likelihood_angular_speed": angular_speed_likelihood, "likelihood_angular_acceleration": angular_accel_likelihood, + "likelihood_distance_to_nearest_object": distance_to_nearest_object_likelihood, + "likelihood_time_to_collision": time_to_collision_likelihood, + "likelihood_collision_indication": collision_likelihood, + "likelihood_distance_to_road_edge": distance_to_road_edge_likelihood, + "likelihood_offroad_indication": offroad_likelihood, } ) @@ -283,10 +459,19 @@ def compute_metrics( [ "ade", "min_ade", + "num_collisions_sim", + "num_collisions_ref", + "num_offroad_sim", + "num_offroad_ref", "likelihood_linear_speed", "likelihood_linear_acceleration", "likelihood_angular_speed", "likelihood_angular_acceleration", + "likelihood_distance_to_nearest_object", + "likelihood_time_to_collision", + "likelihood_collision_indication", + "likelihood_distance_to_road_edge", + "likelihood_offroad_indication", ] ].mean() diff --git a/pufferlib/ocean/benchmark/geometry_utils.py b/pufferlib/ocean/benchmark/geometry_utils.py new file mode 100644 index 000000000..14cf9f8ae --- /dev/null +++ b/pufferlib/ocean/benchmark/geometry_utils.py @@ -0,0 +1,225 @@ +"""Geometry utilities for distance computation between 2D boxes. + +Adapted from: +- https://github.com/waymo-research/waymo-open-dataset/blob/master/src/waymo_open_dataset/utils/box_utils.py +- https://github.com/waymo-research/waymo-open-dataset/blob/master/src/waymo_open_dataset/utils/geometry_utils.py +""" + +import torch +from typing import Tuple + +NUM_VERTICES_IN_BOX = 4 + + +def get_yaw_rotation_2d(heading: torch.Tensor) -> torch.Tensor: + """Gets 2D rotation matrices from heading angles. + + Args: + heading: Rotation angles in radians, any shape + + Returns: + Rotation matrices, shape [..., 2, 2] + """ + cos_heading = torch.cos(heading) + sin_heading = torch.sin(heading) + + return torch.stack( + [torch.stack([cos_heading, -sin_heading], dim=-1), torch.stack([sin_heading, cos_heading], dim=-1)], dim=-2 + ) + + +def cross_product_2d(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Computes signed magnitude of cross product of 2D vectors. + + Args: + a: Tensor with shape (..., 2) + b: Tensor with same shape as a + + Returns: + Cross product a[0]*b[1] - a[1]*b[0], shape (...) + """ + return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0] + + +def _get_downmost_edge_in_box(box: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Finds the downmost (lowest y-coordinate) edge in the box. + + Assumes box edges are given in counter-clockwise order. + + Args: + box: Tensor of shape (num_boxes, num_points_per_box, 2) with x-y coordinates + + Returns: + Tuple of: + - downmost_vertex_idx: Index of downmost vertex, shape (num_boxes, 1) + - downmost_edge_direction: Tangent unit vector of downmost edge, shape (num_boxes, 1, 2) + """ + downmost_vertex_idx = torch.argmin(box[..., 1], dim=-1).unsqueeze(-1) + + edge_start_vertex = torch.gather(box, 1, downmost_vertex_idx.unsqueeze(-1).expand(-1, -1, 2)) + edge_end_idx = torch.remainder(downmost_vertex_idx + 1, NUM_VERTICES_IN_BOX) + edge_end_vertex = torch.gather(box, 1, edge_end_idx.unsqueeze(-1).expand(-1, -1, 2)) + + downmost_edge = edge_end_vertex - edge_start_vertex + downmost_edge_length = torch.linalg.norm(downmost_edge, dim=-1) + downmost_edge_direction = downmost_edge / downmost_edge_length.unsqueeze(-1) + + return downmost_vertex_idx, downmost_edge_direction + + +def _get_edge_info(polygon_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Computes properties about the edges of a polygon. + + Args: + polygon_points: Vertices of each polygon, shape (num_polygons, num_points_per_polygon, 2) + + Returns: + Tuple of: + - tangent_unit_vectors: Shape (num_polygons, num_points_per_polygon, 2) + - normal_unit_vectors: Shape (num_polygons, num_points_per_polygon, 2) + - edge_lengths: Shape (num_polygons, num_points_per_polygon) + """ + first_point_in_polygon = polygon_points[:, 0:1, :] + shifted_polygon_points = torch.cat([polygon_points[:, 1:, :], first_point_in_polygon], dim=-2) + edge_vectors = shifted_polygon_points - polygon_points + + edge_lengths = torch.linalg.norm(edge_vectors, dim=-1) + tangent_unit_vectors = edge_vectors / edge_lengths.unsqueeze(-1) + normal_unit_vectors = torch.stack([-tangent_unit_vectors[..., 1], tangent_unit_vectors[..., 0]], dim=-1) + + return tangent_unit_vectors, normal_unit_vectors, edge_lengths + + +def get_2d_box_corners(boxes: torch.Tensor) -> torch.Tensor: + """Given a set of 2D boxes, return its 4 corners. + + Args: + boxes: Tensor of shape [..., 5] with [center_x, center_y, length, width, heading] + + Returns: + Corners tensor of shape [..., 4, 2] in counter-clockwise order + """ + center_x = boxes[..., 0] + center_y = boxes[..., 1] + length = boxes[..., 2] + width = boxes[..., 3] + heading = boxes[..., 4] + + rotation = get_yaw_rotation_2d(heading) + translation = torch.stack([center_x, center_y], dim=-1) + + l2 = length * 0.5 + w2 = width * 0.5 + + corners = torch.stack([l2, w2, -l2, w2, -l2, -w2, l2, -w2], dim=-1).reshape(boxes.shape[:-1] + (4, 2)) + + corners = torch.einsum("...ij,...kj->...ki", rotation, corners) + translation.unsqueeze(-2) + + return corners + + +def minkowski_sum_of_box_and_box_points(box1_points: torch.Tensor, box2_points: torch.Tensor) -> torch.Tensor: + """Batched Minkowski sum of two boxes (counter-clockwise corners in xy). + + Args: + box1_points: Vertices for box 1, shape (num_boxes, 4, 2) + box2_points: Vertices for box 2, shape (num_boxes, 4, 2) + + Returns: + Minkowski sum of the two boxes, shape (num_boxes, 8, 2), in counter-clockwise order + """ + device = box1_points.device + point_order_1 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64, device=device) + point_order_2 = torch.tensor([0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.int64, device=device) + + box1_start_idx, downmost_box1_edge_direction = _get_downmost_edge_in_box(box1_points) + box2_start_idx, downmost_box2_edge_direction = _get_downmost_edge_in_box(box2_points) + + condition = cross_product_2d(downmost_box1_edge_direction, downmost_box2_edge_direction) >= 0.0 + condition = condition.repeat(1, 8) + + box1_point_order = torch.where(condition, point_order_2, point_order_1) + box1_point_order = torch.remainder(box1_point_order + box1_start_idx, NUM_VERTICES_IN_BOX) + ordered_box1_points = torch.gather(box1_points, 1, box1_point_order.unsqueeze(-1).expand(-1, -1, 2)) + + box2_point_order = torch.where(condition, point_order_1, point_order_2) + box2_point_order = torch.remainder(box2_point_order + box2_start_idx, NUM_VERTICES_IN_BOX) + ordered_box2_points = torch.gather(box2_points, 1, box2_point_order.unsqueeze(-1).expand(-1, -1, 2)) + + minkowski_sum = ordered_box1_points + ordered_box2_points + + return minkowski_sum + + +def dot_product_2d(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Computes the dot product of 2D vectors. + + Args: + a: Tensor with shape (..., 2) + b: Tensor with same shape as a + + Returns: + Dot product a[0]*b[0] + a[1]*b[1], shape (...) + """ + return a[..., 0] * b[..., 0] + a[..., 1] * b[..., 1] + + +def rotate_2d_points(xys: torch.Tensor, rotation_yaws: torch.Tensor) -> torch.Tensor: + """Rotates xys counter-clockwise using rotation_yaws. + + Rotates about the origin counter-clockwise in the x-y plane. + + Args: + xys: Tensor with shape (..., 2) containing xy coordinates + rotation_yaws: Tensor with shape (...) containing angles in radians + + Returns: + Rotated xys, shape (..., 2) + """ + rel_cos_yaws = torch.cos(rotation_yaws) + rel_sin_yaws = torch.sin(rotation_yaws) + xs_out = rel_cos_yaws * xys[..., 0] - rel_sin_yaws * xys[..., 1] + ys_out = rel_sin_yaws * xys[..., 0] + rel_cos_yaws * xys[..., 1] + return torch.stack([xs_out, ys_out], dim=-1) + + +def signed_distance_from_point_to_convex_polygon( + query_points: torch.Tensor, polygon_points: torch.Tensor +) -> torch.Tensor: + """Finds signed distances from query points to convex polygons. + + Vertices must be ordered counter-clockwise. + + Args: + query_points: Shape (batch_size, 2) with x-y coordinates + polygon_points: Shape (batch_size, num_points_per_polygon, 2) with x-y coordinates + + Returns: + Signed distances, shape (batch_size,). Negative if point is inside polygon. + """ + tangent_unit_vectors, normal_unit_vectors, edge_lengths = _get_edge_info(polygon_points) + + query_points = query_points.unsqueeze(1) + vertices_to_query_vectors = query_points - polygon_points + vertices_distances = torch.linalg.norm(vertices_to_query_vectors, dim=-1) + + edge_signed_perp_distances = torch.sum(-normal_unit_vectors * vertices_to_query_vectors, dim=-1) + + is_inside = torch.all(edge_signed_perp_distances <= 0, dim=-1) + + projection_along_tangent = torch.sum(tangent_unit_vectors * vertices_to_query_vectors, dim=-1) + projection_along_tangent_proportion = projection_along_tangent / edge_lengths + + is_projection_on_edge = (projection_along_tangent_proportion >= 0.0) & (projection_along_tangent_proportion <= 1.0) + + edge_perp_distances = torch.abs(edge_signed_perp_distances) + edge_distances = torch.where( + is_projection_on_edge, edge_perp_distances, torch.tensor(float("inf"), device=query_points.device) + ) + + edge_and_vertex_distance = torch.cat([edge_distances, vertices_distances], dim=-1) + + min_distance = torch.min(edge_and_vertex_distance, dim=-1).values + signed_distances = torch.where(is_inside, -min_distance, min_distance) + + return signed_distances diff --git a/pufferlib/ocean/benchmark/interaction_features.py b/pufferlib/ocean/benchmark/interaction_features.py new file mode 100644 index 000000000..0a1111619 --- /dev/null +++ b/pufferlib/ocean/benchmark/interaction_features.py @@ -0,0 +1,326 @@ +"""Interaction features for the computation of the WOSAC score. +Adapted from: https://github.com/waymo-research/waymo-open-dataset/blob/master/src/waymo_open_dataset/wdl_limited/sim_agents_metrics/interaction_features.py +""" + +import math +import torch + +from pufferlib.ocean.benchmark import geometry_utils + +EXTREMELY_LARGE_DISTANCE = 1e10 +COLLISION_DISTANCE_THRESHOLD = 0.0 +CORNER_ROUNDING_FACTOR = 0.7 +MAX_HEADING_DIFF = math.radians(75.0) +MAX_HEADING_DIFF_FOR_SMALL_OVERLAP = math.radians(10.0) +SMALL_OVERLAP_THRESHOLD = 0.5 +MAXIMUM_TIME_TO_COLLISION = 5.0 + + +def compute_signed_distances( + center_x: torch.Tensor, + center_y: torch.Tensor, + length: torch.Tensor, + width: torch.Tensor, + heading: torch.Tensor, + valid: torch.Tensor, + evaluated_object_mask: torch.Tensor, + corner_rounding_factor: float = CORNER_ROUNDING_FACTOR, +) -> torch.Tensor: + """Computes pairwise signed distances between evaluated objects and all other objects. + + Objects are represented by 2D rectangles with rounded corners. + + Args: + center_x: Shape (num_agents, num_rollouts, num_steps) + center_y: Shape (num_agents, num_rollouts, num_steps) + length: Shape (num_agents, num_rollouts) - constant per timestep + width: Shape (num_agents, num_rollouts) - constant per timestep + heading: Shape (num_agents, num_rollouts, num_steps) + valid: Shape (num_agents, num_rollouts, num_steps) + corner_rounding_factor: Rounding factor for box corners, between 0 (sharp) and 1 (capsule) + + Returns: + signed_distances: shape (num_eval, num_agents, num_rollouts, num_steps) + """ + + num_agents = center_x.shape[0] + num_rollouts = center_x.shape[1] + num_steps = center_x.shape[2] + + eval_indices = torch.nonzero(evaluated_object_mask, as_tuple=False).squeeze(-1) + num_eval = eval_indices.numel() + + if length.dim() == 2: + length = length.unsqueeze(-1) + if width.dim() == 2: + width = width.unsqueeze(-1) + length = length.expand(num_agents, num_rollouts, num_steps) + width = width.expand(num_agents, num_rollouts, num_steps) + + boxes = torch.stack([center_x, center_y, length, width, heading], dim=-1) + + shrinking_distance = torch.minimum(boxes[..., 2], boxes[..., 3]) * corner_rounding_factor / 2.0 + + shrunk_len = boxes[..., 2:3] - 2.0 * shrinking_distance.unsqueeze(-1) + shrunk_wid = boxes[..., 3:4] - 2.0 * shrinking_distance.unsqueeze(-1) + + boxes = torch.cat( + [ + boxes[..., :2], + shrunk_len, + shrunk_wid, + boxes[..., 4:], + ], + dim=-1, + ) + + boxes_flat = boxes.reshape(num_agents * num_rollouts * num_steps, 5) + box_corners = geometry_utils.get_2d_box_corners(boxes_flat) + box_corners = box_corners.reshape(num_agents, num_rollouts, num_steps, 4, 2) + + eval_corners = box_corners[eval_indices] + + batch_size = num_eval * num_agents * num_rollouts * num_steps + + corners_flat_1 = ( + eval_corners.unsqueeze(1).expand(num_eval, num_agents, num_rollouts, num_steps, 4, 2).reshape(batch_size, 4, 2) + ) + + corners_flat_2 = ( + box_corners.unsqueeze(0).expand(num_eval, num_agents, num_rollouts, num_steps, 4, 2).reshape(batch_size, 4, 2) + ) + + corners_flat_2.neg_() + + minkowski_sum = geometry_utils.minkowski_sum_of_box_and_box_points(corners_flat_1, corners_flat_2) + + del corners_flat_1, corners_flat_2 + + query_points = torch.zeros((batch_size, 2), dtype=center_x.dtype, device=center_x.device) + + signed_distances_flat = geometry_utils.signed_distance_from_point_to_convex_polygon( + query_points=query_points, polygon_points=minkowski_sum + ) + + del minkowski_sum, query_points + + signed_distances = signed_distances_flat.reshape(num_eval, num_agents, num_rollouts, num_steps) + + eval_shrinking = shrinking_distance[eval_indices] + + signed_distances.sub_(eval_shrinking[:, None, :, :]) + signed_distances.sub_(shrinking_distance[None, :, :, :]) + + agent_indices = torch.arange(num_agents, device=center_x.device) + self_mask = eval_indices[:, None] == agent_indices[None, :] + + self_mask = self_mask.unsqueeze(-1).unsqueeze(-1) + + signed_distances.masked_fill_(self_mask, EXTREMELY_LARGE_DISTANCE) + + eval_valid = valid[eval_indices] + + valid_mask = torch.logical_and(eval_valid[:, None, :, :], valid[None, :, :, :]) + + signed_distances.masked_fill_(~valid_mask, EXTREMELY_LARGE_DISTANCE) + + return signed_distances + + +def compute_distance_to_nearest_object( + center_x: torch.Tensor, + center_y: torch.Tensor, + length: torch.Tensor, + width: torch.Tensor, + heading: torch.Tensor, + valid: torch.Tensor, + evaluated_object_mask: torch.Tensor, + corner_rounding_factor: float = CORNER_ROUNDING_FACTOR, +) -> torch.Tensor: + signed_distances = compute_signed_distances( + center_x=center_x, + center_y=center_y, + length=length, + width=width, + heading=heading, + valid=valid, + evaluated_object_mask=evaluated_object_mask, + corner_rounding_factor=corner_rounding_factor, + ) + + min_distances = torch.min(signed_distances, dim=1).values + return min_distances + + +def compute_time_to_collision( + center_x: torch.Tensor, + center_y: torch.Tensor, + length: torch.Tensor, + width: torch.Tensor, + heading: torch.Tensor, + valid: torch.Tensor, + evaluated_object_mask: torch.Tensor, + seconds_per_step: float, +) -> torch.Tensor: + """Computes time-to-collision of the evaluated objects. + + The time-to-collision measures, in seconds, the time until an object collides + with the object it is following, assuming constant speeds. + + Args: + center_x: Shape (num_agents, num_rollouts, num_steps) + center_y: Shape (num_agents, num_rollouts, num_steps) + length: Shape (num_agents, num_rollouts) - constant per timestep + width: Shape (num_agents, num_rollouts) - constant per timestep + heading: Shape (num_agents, num_rollouts, num_steps) + valid: Shape (num_agents, num_rollouts, num_steps) + evaluated_object_mask: Shape (num_agents,) - boolean mask for evaluated agents + seconds_per_step: Duration of one step in seconds + + Returns: + Time-to-collision, shape (num_eval_agents, num_rollouts, num_steps) + """ + from pufferlib.ocean.benchmark import metrics + + valid = valid.to(dtype=torch.bool, device=center_x.device) + evaluated_object_mask = evaluated_object_mask.to(dtype=torch.bool, device=center_x.device) + + num_agents = center_x.shape[0] + num_rollouts = center_x.shape[1] + num_steps = center_x.shape[2] + + eval_indices = torch.nonzero(evaluated_object_mask, as_tuple=False).squeeze(-1) + num_eval = eval_indices.numel() + + # TODO: Convert to torch + speed = metrics.compute_kinematic_features( + x=center_x.cpu().numpy(), + y=center_y.cpu().numpy(), + heading=heading.cpu().numpy(), + seconds_per_step=seconds_per_step, + )[0] + if not isinstance(speed, torch.Tensor): + speed = torch.as_tensor(speed, device=center_x.device, dtype=center_x.dtype) + + if length.dim() == 2: + length = length.unsqueeze(-1) + if width.dim() == 2: + width = width.unsqueeze(-1) + + length = length.expand(num_agents, num_rollouts, num_steps).permute(2, 0, 1) + width = width.expand(num_agents, num_rollouts, num_steps).permute(2, 0, 1) + + center_x = center_x.permute(2, 0, 1) + center_y = center_y.permute(2, 0, 1) + heading = heading.permute(2, 0, 1) + speed = speed.permute(2, 0, 1) + valid = valid.permute(2, 0, 1) + + ego_x = center_x[:, eval_indices] + ego_y = center_y[:, eval_indices] + ego_len = length[:, eval_indices] + ego_wid = width[:, eval_indices] + ego_heading = heading[:, eval_indices] + ego_speed = speed[:, eval_indices] + + yaw_diff = torch.abs(heading.unsqueeze(1) - ego_heading.unsqueeze(2)) + + yaw_diff_cos = torch.cos(yaw_diff) + yaw_diff_sin = torch.sin(yaw_diff) + + all_sizes_half = torch.stack([length, width], dim=-1).unsqueeze(1) / 2.0 + + other_long_offset = geometry_utils.dot_product_2d( + all_sizes_half, + torch.abs(torch.stack([yaw_diff_cos, yaw_diff_sin], dim=-1)), + ) + other_lat_offset = geometry_utils.dot_product_2d( + all_sizes_half, + torch.abs(torch.stack([yaw_diff_sin, yaw_diff_cos], dim=-1)), + ) + + del all_sizes_half + + relative_x = center_x.unsqueeze(1) - ego_x.unsqueeze(2) + relative_y = center_y.unsqueeze(1) - ego_y.unsqueeze(2) + relative_xy = torch.stack([relative_x, relative_y], dim=-1) + + del relative_x, relative_y + + rotation = -ego_heading.unsqueeze(2).expand(-1, -1, num_agents, -1) + + other_relative_xy = geometry_utils.rotate_2d_points(relative_xy, rotation) + + del relative_xy, rotation + + long_distance = other_relative_xy[..., 0] - ego_len.unsqueeze(2) / 2.0 - other_long_offset + lat_overlap = torch.abs(other_relative_xy[..., 1]) - ego_wid.unsqueeze(2) / 2.0 - other_lat_offset + + del other_relative_xy, other_long_offset, other_lat_offset + + following_mask = _get_object_following_mask( + long_distance.permute(1, 2, 0, 3), + lat_overlap.permute(1, 2, 0, 3), + yaw_diff.permute(1, 2, 0, 3), + ) + + del lat_overlap, yaw_diff + + valid_mask = torch.logical_and(valid.unsqueeze(1), following_mask.permute(2, 0, 1, 3)) + + del following_mask + + long_distance.masked_fill_(~valid_mask, EXTREMELY_LARGE_DISTANCE) + + box_ahead_index = torch.argmin(long_distance, dim=2, keepdim=True) + distance_to_box_ahead = torch.gather(long_distance, 2, box_ahead_index).squeeze(2) + + del long_distance + + speed_expanded = speed.unsqueeze(1).expand(-1, num_eval, -1, -1) + box_ahead_speed = torch.gather(speed_expanded, 2, box_ahead_index).squeeze(2) + + rel_speed = ego_speed - box_ahead_speed + + rel_speed_safe = torch.where(rel_speed > 0.0, rel_speed, torch.ones_like(rel_speed)) + + max_ttc = torch.full_like(rel_speed, MAXIMUM_TIME_TO_COLLISION) + + time_to_collision = torch.where( + rel_speed > 0.0, + torch.minimum(distance_to_box_ahead / rel_speed_safe, max_ttc), + max_ttc, + ) + + return time_to_collision.permute(1, 2, 0) + + +def _get_object_following_mask( + longitudinal_distance, + lateral_overlap, + yaw_diff, +): + """Checks whether objects satisfy criteria for following another object. + + Args: + longitudinal_distance: Shape (num_agents, num_agents, num_rollouts, num_steps) + Longitudinal distances from back side of each ego box to other boxes. + lateral_overlap: Shape (num_agents, num_agents, num_rollouts, num_steps) + Lateral overlaps of other boxes over trails of ego boxes. + yaw_diff: Shape (num_agents, num_agents, num_rollouts, num_steps) + Absolute yaw differences between egos and other boxes. + + Returns: + Boolean array indicating for each ego box if it is following the other boxes. + Shape (num_agents, num_agents, num_rollouts, num_steps) + """ + valid_mask = longitudinal_distance > 0.0 + valid_mask = torch.logical_and(valid_mask, yaw_diff <= MAX_HEADING_DIFF) + valid_mask = torch.logical_and(valid_mask, lateral_overlap < 0.0) + return torch.logical_and( + valid_mask, + torch.logical_or( + lateral_overlap < -SMALL_OVERLAP_THRESHOLD, + yaw_diff <= MAX_HEADING_DIFF_FOR_SMALL_OVERLAP, + ), + ) diff --git a/pufferlib/ocean/benchmark/kinematic_features.py b/pufferlib/ocean/benchmark/kinematic_features.py new file mode 100644 index 000000000..8ae4883e0 --- /dev/null +++ b/pufferlib/ocean/benchmark/kinematic_features.py @@ -0,0 +1,56 @@ +"""Kinematic feature computation utilities for WOSAC metrics.""" + +import numpy as np + + +def central_diff(t: np.ndarray, pad_value: float) -> np.ndarray: + """Computes the central difference along the last axis. + + This function is used to compute 1st order derivatives (speeds) when called + once. Calling this function twice is used to compute 2nd order derivatives + (accelerations) instead. + This function returns the central difference as + df(x)/dx = [f(x+h)-f(x-h)] / 2h. + + Args: + t: A float array of shape [..., steps]. + pad_value: To maintain the original tensor shape, this value is prepended + once and appended once to the difference. + + Returns: + An array of shape [..., steps] containing the central differences, + appropriately prepended and appended with `pad_value` to maintain the + original shape. + """ + pad_shape = (*t.shape[:-1], 1) + pad_array = np.full(pad_shape, pad_value) + diff_t = (t[..., 2:] - t[..., :-2]) / 2 + return np.concatenate([pad_array, diff_t, pad_array], axis=-1) + + +def central_logical_and(t: np.ndarray, pad_value: bool) -> np.ndarray: + """Computes the central `logical_and` along the last axis. + + This function is used to compute the validity tensor for 1st and 2nd order + derivatives using central difference, where element [i] is valid only if + both elements [i-1] and [i+1] are valid. + + Args: + t: A bool array of shape [..., steps]. + pad_value: To maintain the original tensor shape, this value is prepended + once and appended once to the difference. + + Returns: + An array of shape [..., steps] containing the central `logical_and`, + appropriately prepended and appended with `pad_value` to maintain the + original shape. + """ + pad_shape = (*t.shape[:-1], 1) + pad_array = np.full(pad_shape, pad_value) + diff_t = np.logical_and(t[..., 2:], t[..., :-2]) + return np.concatenate([pad_array, diff_t, pad_array], axis=-1) + + +def _wrap_angle(angle: np.ndarray) -> np.ndarray: + """Wraps angles in the range [-pi, pi].""" + return (angle + np.pi) % (2 * np.pi) - np.pi diff --git a/pufferlib/ocean/benchmark/map_metric_features.py b/pufferlib/ocean/benchmark/map_metric_features.py new file mode 100644 index 000000000..10e549490 --- /dev/null +++ b/pufferlib/ocean/benchmark/map_metric_features.py @@ -0,0 +1,273 @@ +"""Map-based metric features for WOSAC evaluation. + +Adapted from Waymo Open Dataset: +https://github.com/waymo-research/waymo-open-dataset/blob/master/src/waymo_open_dataset/wdl_limited/sim_agents_metrics/map_metric_features.py +""" + +import torch + +from pufferlib.ocean.benchmark.geometry_utils import ( + get_2d_box_corners, + cross_product_2d, + dot_product_2d, +) + +EXTREMELY_LARGE_DISTANCE = 1e10 +OFFROAD_DISTANCE_THRESHOLD = 0.0 + + +def compute_distance_to_road_edge( + center_x: torch.Tensor, + center_y: torch.Tensor, + length: torch.Tensor, + width: torch.Tensor, + heading: torch.Tensor, + valid: torch.Tensor, + polyline_x: torch.Tensor, + polyline_y: torch.Tensor, + polyline_lengths: torch.Tensor, +) -> torch.Tensor: + """Computes signed distance to road edge for each agent at each timestep. + + Args: + center_x: Shape (num_agents, num_steps) + center_y: Shape (num_agents, num_steps) + length: Shape (num_agents,) or (num_agents, num_steps) + width: Shape (num_agents,) or (num_agents, num_steps) + heading: Shape (num_agents, num_steps) + valid: Shape (num_agents, num_steps) boolean + polyline_x: Flattened x coordinates of all polyline points + polyline_y: Flattened y coordinates of all polyline points + polyline_lengths: Length of each polyline + + Returns: + Signed distances, shape (num_agents, num_steps). + Negative = on-road, positive = off-road. + """ + num_agents, num_steps = center_x.shape + + if length.ndim == 1: + length = length.unsqueeze(-1).expand(-1, num_steps) + if width.ndim == 1: + width = width.unsqueeze(-1).expand(-1, num_steps) + + boxes = torch.stack([center_x, center_y, length, width, heading], dim=-1) + boxes_flat = boxes.reshape(-1, 5) + + corners = get_2d_box_corners(boxes_flat) + corners = corners.reshape(num_agents, num_steps, 4, 2) + + flat_corners = corners.reshape(-1, 2) + + polylines_padded, polylines_valid = _pad_polylines(polyline_x, polyline_y, polyline_lengths) + + corner_distances = _compute_signed_distance_to_polylines(flat_corners, polylines_padded, polylines_valid) + + corner_distances = corner_distances.reshape(num_agents, num_steps, 4) + signed_distances = torch.max(corner_distances, dim=-1).values + + offroad_fill = signed_distances.new_full((), -EXTREMELY_LARGE_DISTANCE) + signed_distances = torch.where(valid, signed_distances, offroad_fill) + + return signed_distances + + +def _pad_polylines( + polyline_x: torch.Tensor, + polyline_y: torch.Tensor, + polyline_lengths: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Convert flattened polylines to padded tensor format. + + Returns: + polylines: Shape (num_polylines, max_length, 2) + valid: Shape (num_polylines, max_length) + """ + device = polyline_x.device + num_polylines = polyline_lengths.shape[0] + max_length = int(polyline_lengths.max().item()) + + polylines = torch.zeros((num_polylines, max_length, 2), dtype=torch.float32, device=device) + valid = torch.zeros((num_polylines, max_length), dtype=torch.bool, device=device) + + lengths_long = polyline_lengths.to(torch.long) + boundaries = torch.cumsum(torch.cat([lengths_long.new_zeros(1), lengths_long]), dim=0) + + for i in range(num_polylines): + start = int(boundaries[i].item()) + end = int(boundaries[i + 1].item()) + length_i = int(lengths_long[i].item()) + polylines[i, :length_i, 0] = polyline_x[start:end] + polylines[i, :length_i, 1] = polyline_y[start:end] + valid[i, :length_i] = True + + return polylines, valid + + +def _check_polyline_cycles( + polylines: torch.Tensor, + polylines_valid: torch.Tensor, + tolerance: float = 1e-3, +) -> torch.Tensor: + """Check if polylines are cyclic (first point == last point). + + Args: + polylines: Shape (num_polylines, max_length, 2) + polylines_valid: Shape (num_polylines, max_length) + tolerance: Distance threshold for considering points equal + + Returns: + Boolean array of shape (num_polylines,) + """ + device = polylines.device + max_length = polylines.shape[1] + valid_counts = polylines_valid.sum(dim=-1) + has_enough_points = valid_counts >= 2 + + indices = torch.arange(max_length, device=device) + last_idx = torch.argmax(polylines_valid.int() * indices, dim=-1) + + first_pts = polylines[:, 0] + gather_idx = last_idx.view(-1, 1, 1).expand(-1, 1, 2) + last_pts = torch.gather(polylines, 1, gather_idx).squeeze(1) + dist = torch.linalg.norm(first_pts - last_pts, dim=-1) + + return (dist < tolerance) & has_enough_points + + +def _compute_signed_distance_to_polylines( + xys: torch.Tensor, + polylines: torch.Tensor, + polylines_valid: torch.Tensor, +) -> torch.Tensor: + """Computes signed distance from points to polylines (2D). + + Args: + xys: Shape (num_points, 2) + polylines: Shape (num_polylines, max_length, 2) + polylines_valid: Shape (num_polylines, max_length) + + Returns: + Signed distances, shape (num_points,). + Negative = on-road (port side), positive = off-road (starboard). + """ + num_points = xys.shape[0] + num_polylines, max_length = polylines.shape[:2] + num_segments = max_length - 1 + + is_segment_valid = polylines_valid[:, :-1] & polylines_valid[:, 1:] + is_polyline_cyclic = _check_polyline_cycles(polylines, polylines_valid) + + xy_starts = polylines[:, :-1, :] + xy_ends = polylines[:, 1:, :] + start_to_end = xy_ends - xy_starts + + start_to_point = xys.unsqueeze(0).unsqueeze(0) - xy_starts[:, :, None, :] + + dot_se_se = dot_product_2d(start_to_end, start_to_end) + dot_sp_se = dot_product_2d(start_to_point, start_to_end[:, :, None, :]) + + denom = dot_se_se[:, :, None] + rel_t = torch.where( + denom != 0, + dot_sp_se / denom, + torch.zeros_like(dot_sp_se), + ) + + n = torch.sign(cross_product_2d(start_to_point, start_to_end[:, :, None, :])) + + segment_to_point = start_to_point - (start_to_end[:, :, None, :] * torch.clamp(rel_t, 0.0, 1.0)[:, :, :, None]) + distance_to_segment_2d = torch.linalg.norm(segment_to_point, dim=-1) + + start_to_end_padded = torch.cat( + [ + start_to_end[:, -1:, :], + start_to_end, + start_to_end[:, :1, :], + ], + dim=1, + ) + + is_locally_convex = ( + cross_product_2d(start_to_end_padded[:, :-1, None, :], start_to_end_padded[:, 1:, None, :]) > 0.0 + ) + + n_prior = torch.cat( + [ + torch.where( + is_polyline_cyclic[:, None, None], + n[:, -1:, :], + n[:, :1, :], + ), + n[:, :-1, :], + ], + dim=1, + ) + n_next = torch.cat( + [ + n[:, 1:, :], + torch.where( + is_polyline_cyclic[:, None, None], + n[:, :1, :], + n[:, -1:, :], + ), + ], + dim=1, + ) + + is_prior_valid = torch.cat( + [ + torch.where( + is_polyline_cyclic[:, None], + is_segment_valid[:, -1:], + is_segment_valid[:, :1], + ), + is_segment_valid[:, :-1], + ], + dim=1, + ) + is_next_valid = torch.cat( + [ + is_segment_valid[:, 1:], + torch.where( + is_polyline_cyclic[:, None], + is_segment_valid[:, :1], + is_segment_valid[:, -1:], + ), + ], + dim=1, + ) + + sign_if_before = torch.where( + is_locally_convex[:, :-1, :], + torch.maximum(n, n_prior), + torch.minimum(n, n_prior), + ) + sign_if_after = torch.where( + is_locally_convex[:, 1:, :], + torch.maximum(n, n_next), + torch.minimum(n, n_next), + ) + + sign_to_segment = torch.where( + (rel_t < 0.0) & is_prior_valid[:, :, None], + sign_if_before, + torch.where((rel_t > 1.0) & is_next_valid[:, :, None], sign_if_after, n), + ) + + distance_to_segment_2d = distance_to_segment_2d.reshape(num_polylines * num_segments, num_points).T + sign_to_segment = sign_to_segment.reshape(num_polylines * num_segments, num_points).T + + is_segment_valid_flat = is_segment_valid.reshape(num_polylines * num_segments) + valid_mask = is_segment_valid_flat.unsqueeze(0).expand(num_points, -1) + distance_to_segment_2d = distance_to_segment_2d.masked_fill( + ~valid_mask, + EXTREMELY_LARGE_DISTANCE, + ) + + closest_idx = torch.argmin(distance_to_segment_2d, dim=1) + point_indices = torch.arange(num_points, device=xys.device) + distance_2d = distance_to_segment_2d[point_indices, closest_idx] + distance_sign = sign_to_segment[point_indices, closest_idx] + + return distance_sign * distance_2d diff --git a/pufferlib/ocean/benchmark/metrics.py b/pufferlib/ocean/benchmark/metrics.py index bc1f049e0..cf1af358e 100644 --- a/pufferlib/ocean/benchmark/metrics.py +++ b/pufferlib/ocean/benchmark/metrics.py @@ -3,7 +3,23 @@ """ import numpy as np -from typing import Dict, Tuple +import torch +from typing import Tuple + +from pufferlib.ocean.benchmark import kinematic_features, interaction_features, map_metric_features + + +def _to_tensor(value, dtype, device=None): + """Utility to convert numpy inputs to tensors on the requested device.""" + if isinstance(value, torch.Tensor): + tensor = value + else: + tensor = torch.as_tensor(value, dtype=dtype) + if dtype is not None and tensor.dtype != dtype: + tensor = tensor.to(dtype) + if device is not None and tensor.device != device: + tensor = tensor.to(device) + return tensor def compute_displacement_error( @@ -47,56 +63,6 @@ def compute_displacement_error( return ade, min_ade -def central_diff(t: np.ndarray, pad_value: float) -> np.ndarray: - """Computes the central difference along the last axis. - - This function is used to compute 1st order derivatives (speeds) when called - once. Calling this function twice is used to compute 2nd order derivatives - (accelerations) instead. - This function returns the central difference as - df(x)/dx = [f(x+h)-f(x-h)] / 2h. - - Args: - t: A float array of shape [..., steps]. - pad_value: To maintain the original tensor shape, this value is prepended - once and appended once to the difference. - - Returns: - An array of shape [..., steps] containing the central differences, - appropriately prepended and appended with `pad_value` to maintain the - original shape. - """ - # Prepare the array containing the value(s) to pad the result with. - pad_shape = (*t.shape[:-1], 1) - pad_array = np.full(pad_shape, pad_value) - diff_t = (t[..., 2:] - t[..., :-2]) / 2 - return np.concatenate([pad_array, diff_t, pad_array], axis=-1) - - -def central_logical_and(t: np.ndarray, pad_value: bool) -> np.ndarray: - """Computes the central `logical_and` along the last axis. - - This function is used to compute the validity tensor for 1st and 2nd order - derivatives using central difference, where element [i] is valid only if - both elements [i-1] and [i+1] are valid. - - Args: - t: A bool array of shape [..., steps]. - pad_value: To maintain the original tensor shape, this value is prepended - once and appended once to the difference. - - Returns: - An array of shape [..., steps] containing the central `logical_and`, - appropriately prepended and appended with `pad_value` to maintain the - original shape. - """ - # Prepare the array containing the value(s) to pad the result with. - pad_shape = (*t.shape[:-1], 1) - pad_array = np.full(pad_shape, pad_value) - diff_t = np.logical_and(t[..., 2:], t[..., :-2]) - return np.concatenate([pad_array, diff_t, pad_array], axis=-1) - - def compute_displacement_error_3d( x: np.ndarray, y: np.ndarray, z: np.ndarray, ref_x: np.ndarray, ref_y: np.ndarray, ref_z: np.ndarray ) -> np.ndarray: @@ -120,6 +86,30 @@ def compute_displacement_error_3d( return np.linalg.norm(np.stack([x, y], axis=-1) - np.stack([ref_x, ref_y], axis=-1), ord=2, axis=-1) +def _reduce_average_with_validity(tensor: np.ndarray, validity: np.ndarray, axis: int = None) -> np.ndarray: + """Returns the tensor's average, only selecting valid items. + + Args: + tensor: A float array of any shape. + validity: A boolean array of the same shape as `tensor`. + axis: The axis or axes along which to average. If None, averages over all axes. + + Returns: + A float or array containing the average of the valid elements of `tensor`. + """ + if tensor.shape != validity.shape: + raise ValueError( + f"Shapes of `tensor` and `validity` must be the same. (Actual: {tensor.shape}, {validity.shape})." + ) + cond_sum = np.sum(np.where(validity, tensor, np.zeros_like(tensor)), axis=axis, keepdims=False) + valid_sum = np.sum(validity.astype(np.float32), axis=axis, keepdims=False) + + # Safe division: + safe_valid_sum = np.where(valid_sum == 0, 1, valid_sum) + + return np.where(valid_sum == 0, np.nan, cond_sum / safe_valid_sum) + + def compute_kinematic_features( x: np.ndarray, y: np.ndarray, @@ -148,21 +138,16 @@ def compute_kinematic_features( angular_acceleration: Angular acceleration (changes in `angular_speed`). Shape (..., num_steps). """ - # Linear speed and acceleration. - dpos = central_diff(np.stack([x, y], axis=0), pad_value=np.nan) + dpos = kinematic_features.central_diff(np.stack([x, y], axis=0), pad_value=np.nan) linear_speed = np.linalg.norm(dpos, ord=2, axis=0) / seconds_per_step - linear_accel = central_diff(linear_speed, pad_value=np.nan) / seconds_per_step - # Angular speed and acceleration. - dh_step = _wrap_angle(central_diff(heading, pad_value=np.nan) * 2) / 2 + linear_accel = kinematic_features.central_diff(linear_speed, pad_value=np.nan) / seconds_per_step + + dh_step = kinematic_features._wrap_angle(kinematic_features.central_diff(heading, pad_value=np.nan) * 2) / 2 dh = dh_step / seconds_per_step - d2h_step = _wrap_angle(central_diff(dh_step, pad_value=np.nan) * 2) / 2 + d2h_step = kinematic_features._wrap_angle(kinematic_features.central_diff(dh_step, pad_value=np.nan) * 2) / 2 d2h = d2h_step / (seconds_per_step**2) - return linear_speed, linear_accel, dh, d2h - -def _wrap_angle(angle: np.ndarray) -> np.ndarray: - """Wraps angles in the range [-pi, pi].""" - return (angle + np.pi) % (2 * np.pi) - np.pi + return linear_speed, linear_accel, dh, d2h def compute_kinematic_validity(valid: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: @@ -180,12 +165,10 @@ def compute_kinematic_validity(valid: np.ndarray) -> Tuple[np.ndarray, np.ndarra speed_validity: A validity array for speed fields (central_and applied once). acceleration_validity: A validity array for acceleration fields (central_and applied twice). """ - # First application for speeds pad_shape = (*valid.shape[:-1], 1) pad_tensor = np.full(pad_shape, False) speed_validity = np.concatenate([pad_tensor, np.logical_and(valid[..., 2:], valid[..., :-2]), pad_tensor], axis=-1) - # Second application for accelerations pad_tensor = np.full(pad_shape, False) acceleration_validity = np.concatenate( [pad_tensor, np.logical_and(speed_validity[..., 2:], speed_validity[..., :-2]), pad_tensor], axis=-1 @@ -194,21 +177,222 @@ def compute_kinematic_validity(valid: np.ndarray) -> Tuple[np.ndarray, np.ndarra return speed_validity, acceleration_validity -def _reduce_average_with_validity(tensor: np.ndarray, validity: np.ndarray, axis: int = None) -> np.ndarray: - """Returns the tensor's average, only selecting valid items. +def compute_interaction_features( + x: np.ndarray, + y: np.ndarray, + heading: np.ndarray, + scenario_ids: np.ndarray, + agent_length: np.ndarray, + agent_width: np.ndarray, + eval_mask: np.ndarray, + device: torch.device, + valid: np.ndarray | None = None, + corner_rounding_factor: float = 0.7, + seconds_per_step: float = 0.1, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Computes distance to nearest object for each agent, grouped by scenario. Args: - tensor: A float array of any shape. - validity: A boolean array of the same shape as `tensor`. - axis: The axis or axes along which to average. If None, averages over all axes. + x: Shape (num_agents, num_rollouts, num_steps) + y: Shape (num_agents, num_rollouts, num_steps) + heading: Shape (num_agents, num_rollouts, num_steps) + scenario_ids: Shape (num_agents, 1) + agent_length: Shape (num_agents,) + agent_width: Shape (num_agents,) + eval_mask: Shape (num_agents,) - boolean mask for evaluated agents + valid: Shape (num_agents, num_rollouts, num_steps), optional Returns: - A float or array containing the average of the valid elements of `tensor`. + Tuple of: + - Distance to nearest object, shape (num_eval_agents, num_rollouts, num_steps) + - Collision indicator per step, shape (num_eval_agents, num_rollouts, num_steps) + - Time to collision, shape (num_eval_agents, num_rollouts, num_steps) """ - if tensor.shape != validity.shape: - raise ValueError( - f"Shapes of `tensor` and `validity` must be the same. (Actual: {tensor.shape}, {validity.shape})." + x_t = _to_tensor(x, torch.float32, device=device) + y_t = _to_tensor(y, torch.float32, device=device) + heading_t = _to_tensor(heading, torch.float32, device=device) + agent_length_t = _to_tensor(agent_length, torch.float32, device=device) + agent_width_t = _to_tensor(agent_width, torch.float32, device=device) + + num_agents = x_t.shape[0] + num_eval_agents = int(np.sum(eval_mask)) + num_rollouts = x_t.shape[1] + num_steps = x_t.shape[2] + + if valid is None: + valid_t = torch.ones((num_agents, num_rollouts, num_steps), dtype=torch.bool, device=x_t.device) + else: + valid_t = _to_tensor(valid, torch.bool, device=x_t.device) + + length_broadcast = agent_length_t.unsqueeze(-1).expand(num_agents, num_rollouts) + width_broadcast = agent_width_t.unsqueeze(-1).expand(num_agents, num_rollouts) + + result_distances = np.full( + (num_eval_agents, num_rollouts, num_steps), interaction_features.EXTREMELY_LARGE_DISTANCE, dtype=np.float32 + ) + result_collisions = np.full((num_eval_agents, num_rollouts, num_steps), False, dtype=bool) + result_ttc = np.full( + (num_eval_agents, num_rollouts, num_steps), interaction_features.MAXIMUM_TIME_TO_COLLISION, dtype=np.float32 + ) + + unique_scenarios = np.unique(scenario_ids) + + eval_indices = np.where(eval_mask)[0] + eval_to_result = {idx: i for i, idx in enumerate(eval_indices)} + + for scenario_id in unique_scenarios: + scenario_mask_np = scenario_ids[:, 0] == scenario_id + agent_indices = np.where(scenario_mask_np)[0] + if agent_indices.size == 0: + continue + + scenario_mask = torch.as_tensor(scenario_mask_np, dtype=torch.bool, device=x_t.device) + scenario_x = x_t[scenario_mask] + scenario_y = y_t[scenario_mask] + scenario_length = length_broadcast[scenario_mask] + scenario_width = width_broadcast[scenario_mask] + scenario_heading = heading_t[scenario_mask] + scenario_valid = valid_t[scenario_mask] + + scenario_eval_mask_np = eval_mask[scenario_mask_np] + scenario_eval_mask = torch.as_tensor(scenario_eval_mask_np, dtype=torch.bool, device=x_t.device) + + distances_to_objects = interaction_features.compute_distance_to_nearest_object( + center_x=scenario_x, + center_y=scenario_y, + length=scenario_length, + width=scenario_width, + heading=scenario_heading, + valid=scenario_valid, + corner_rounding_factor=corner_rounding_factor, + evaluated_object_mask=scenario_eval_mask, ) - cond_sum = np.sum(np.where(validity, tensor, np.zeros_like(tensor)), axis=axis, keepdims=False) - valid_sum = np.sum(validity.astype(np.float32), axis=axis, keepdims=False) - return cond_sum / valid_sum + + is_colliding_per_step = distances_to_objects < interaction_features.COLLISION_DISTANCE_THRESHOLD + + times_to_collision = interaction_features.compute_time_to_collision( + center_x=scenario_x, + center_y=scenario_y, + length=scenario_length, + width=scenario_width, + heading=scenario_heading, + valid=scenario_valid, + seconds_per_step=seconds_per_step, + evaluated_object_mask=scenario_eval_mask, + ) + + eval_agents_in_scenario = agent_indices[scenario_eval_mask_np] + result_indices = [eval_to_result[idx] for idx in eval_agents_in_scenario] + + distances_np = distances_to_objects.cpu().numpy() + collisions_np = is_colliding_per_step.cpu().numpy() + ttc_np = times_to_collision.cpu().numpy() + + result_distances[result_indices] = distances_np + result_collisions[result_indices] = collisions_np + result_ttc[result_indices] = ttc_np + + return result_distances, result_collisions, result_ttc + + +def compute_map_features( + x: np.ndarray, + y: np.ndarray, + heading: np.ndarray, + scenario_ids: np.ndarray, + agent_length: np.ndarray, + agent_width: np.ndarray, + road_edge_polylines: dict, + device: torch.device, + valid: np.ndarray | None = None, +) -> Tuple[np.ndarray, np.ndarray]: + """Computes distance to road edge and offroad indication for each agent. + + Args: + x: Shape (num_agents, num_rollouts, num_steps) + y: Shape (num_agents, num_rollouts, num_steps) + heading: Shape (num_agents, num_rollouts, num_steps) + scenario_ids: Shape (num_agents, 1) + agent_length: Shape (num_agents,) + agent_width: Shape (num_agents,) + road_edge_polylines: Dictionary with polyline data + valid: Shape (num_agents, num_rollouts, num_steps), optional + + Returns: + Tuple of: + - Distance to road edge, shape (num_agents, num_rollouts, num_steps) + - Offroad indication per step, shape (num_agents, num_rollouts, num_steps) + """ + x_t = _to_tensor(x, torch.float32, device=device) + y_t = _to_tensor(y, torch.float32, device=device) + heading_t = _to_tensor(heading, torch.float32, device=device) + agent_length_t = _to_tensor(agent_length, torch.float32, device=device) + agent_width_t = _to_tensor(agent_width, torch.float32, device=device) + num_agents = x_t.shape[0] + num_rollouts = x_t.shape[1] + num_steps = x_t.shape[2] + + if valid is None: + valid_t = torch.ones((num_agents, num_rollouts, num_steps), dtype=torch.bool, device=device) + else: + valid_t = _to_tensor(valid, torch.bool, device=device) + + result_distances = np.zeros((num_agents, num_rollouts, num_steps), dtype=np.float32) + result_offroad = np.zeros((num_agents, num_rollouts, num_steps), dtype=bool) + + unique_scenarios = np.unique(scenario_ids) + + polyline_boundaries = np.cumsum(np.concatenate([[0], road_edge_polylines["lengths"]])) + + for scenario_id in unique_scenarios: + agent_mask_np = scenario_ids[:, 0] == scenario_id + agent_indices = np.where(agent_mask_np)[0] + + if len(agent_indices) == 0: + continue + + polyline_mask = road_edge_polylines["scenario_id"] == scenario_id + polyline_indices = np.where(polyline_mask)[0] + + scenario_lengths = road_edge_polylines["lengths"][polyline_mask] + + scenario_x_list = [] + scenario_y_list = [] + for idx in polyline_indices: + start = polyline_boundaries[idx] + end = polyline_boundaries[idx + 1] + scenario_x_list.append(road_edge_polylines["x"][start:end]) + scenario_y_list.append(road_edge_polylines["y"][start:end]) + + scenario_polyline_x = torch.as_tensor(np.concatenate(scenario_x_list), dtype=torch.float32, device=x_t.device) + scenario_polyline_y = torch.as_tensor(np.concatenate(scenario_y_list), dtype=torch.float32, device=x_t.device) + scenario_lengths_t = torch.as_tensor(scenario_lengths, dtype=torch.int64, device=x_t.device) + + agent_mask = torch.as_tensor(agent_mask_np, dtype=torch.bool, device=x_t.device) + scenario_x = x_t[agent_mask] + scenario_y = y_t[agent_mask] + scenario_heading = heading_t[agent_mask] + scenario_valid = valid_t[agent_mask] + scenario_length = agent_length_t[agent_mask] + scenario_width = agent_width_t[agent_mask] + + for rollout_idx in range(num_rollouts): + distances = map_metric_features.compute_distance_to_road_edge( + center_x=scenario_x[:, rollout_idx, :], + center_y=scenario_y[:, rollout_idx, :], + length=scenario_length, + width=scenario_width, + heading=scenario_heading[:, rollout_idx, :], + valid=scenario_valid[:, rollout_idx, :], + polyline_x=scenario_polyline_x, + polyline_y=scenario_polyline_y, + polyline_lengths=scenario_lengths_t, + ) + + distances_np = distances.cpu().numpy() + result_distances[agent_mask_np, rollout_idx, :] = distances_np + result_offroad[agent_mask_np, rollout_idx, :] = ( + distances_np > map_metric_features.OFFROAD_DISTANCE_THRESHOLD + ) + + return result_distances, result_offroad diff --git a/pufferlib/ocean/benchmark/metrics_sanity_check.py b/pufferlib/ocean/benchmark/metrics_sanity_check.py index bb0acfbeb..dcda26b69 100644 --- a/pufferlib/ocean/benchmark/metrics_sanity_check.py +++ b/pufferlib/ocean/benchmark/metrics_sanity_check.py @@ -31,11 +31,13 @@ def run_validation_experiment(config, vecenv, policy): gt_trajectories = evaluator.collect_ground_truth_trajectories(vecenv) simulated_trajectories = evaluator.collect_simulated_trajectories(config, vecenv, policy) + agent_state = vecenv.driver_env.get_global_agent_state() + road_edge_polylines = vecenv.driver_env.get_road_edge_polylines() results = {} for num_gt in [0, 1, 2, 8, 16, 32]: modified_sim = replace_rollouts_with_gt(simulated_trajectories, gt_trajectories, num_gt) - scene_results = evaluator.compute_metrics(gt_trajectories, modified_sim) + scene_results = evaluator.compute_metrics(gt_trajectories, modified_sim, agent_state, road_edge_polylines) results[num_gt] = { "ade": scene_results["ade"].mean(), @@ -44,7 +46,12 @@ def run_validation_experiment(config, vecenv, policy): "likelihood_linear_acceleration": scene_results["likelihood_linear_acceleration"].mean(), "likelihood_angular_speed": scene_results["likelihood_angular_speed"].mean(), "likelihood_angular_acceleration": scene_results["likelihood_angular_acceleration"].mean(), - "realism_metametric": scene_results["realism_metametric"].mean(), + "likelihood_distance_to_nearest_object": scene_results["likelihood_distance_to_nearest_object"].mean(), + "likelihood_time_to_collision": scene_results["likelihood_time_to_collision"].mean(), + "likelihood_collision_indication": scene_results["likelihood_collision_indication"].mean(), + "likelihood_distance_to_road_edge": scene_results["likelihood_distance_to_road_edge"].mean(), + "likelihood_offroad_indication": scene_results["likelihood_offroad_indication"].mean(), + "realism_meta_score": scene_results["realism_meta_score"].mean(), } return results @@ -53,8 +60,8 @@ def run_validation_experiment(config, vecenv, policy): def format_results_table(results): lines = [ "## WOSAC Log-Likelihood Validation Results\n", - "| GT Rollouts | ADE | minADE | Linear Speed | Linear Accel | Angular Speed | Angular Accel | Realism Metametric |", - "|-------------|--------|--------|--------------|--------------|---------------|---------------|--------------------|\n", + "| GT Rollouts | ADE | minADE | Linear Speed | Linear Accel | Angular Speed | Angular Accel | Dist Obj | TTC | Collision | Dist Road | Offroad | Metametric |", + "|-------------|--------|--------|--------------|--------------|---------------|---------------|----------|--------|-----------|-----------|---------|------------|\n", ] for num_gt in sorted(results.keys()): @@ -63,7 +70,9 @@ def format_results_table(results): lines.append( f"| {label:11s} | {r['ade']:6.4f} | {r['min_ade']:6.4f} | {r['likelihood_linear_speed']:12.4f} | " f"{r['likelihood_linear_acceleration']:12.4f} | {r['likelihood_angular_speed']:13.4f} | " - f"{r['likelihood_angular_acceleration']:13.4f} | {r['realism_metametric']:18.4f} |" + f"{r['likelihood_angular_acceleration']:13.4f} | {r['likelihood_distance_to_nearest_object']:8.4f} | " + f"{r['likelihood_time_to_collision']:6.4f} | {r['likelihood_collision_indication']:9.4f} | " + f"{r['likelihood_distance_to_road_edge']:9.4f} | {r['likelihood_offroad_indication']:7.4f} | {r['realism_meta_score']:10.4f} |" ) return "\n".join(lines) diff --git a/pufferlib/ocean/benchmark/test_boxes.png b/pufferlib/ocean/benchmark/test_boxes.png new file mode 100644 index 000000000..d59320420 Binary files /dev/null and b/pufferlib/ocean/benchmark/test_boxes.png differ diff --git a/pufferlib/ocean/benchmark/test_geometry.py b/pufferlib/ocean/benchmark/test_geometry.py new file mode 100644 index 000000000..d2b691815 --- /dev/null +++ b/pufferlib/ocean/benchmark/test_geometry.py @@ -0,0 +1,316 @@ +import numpy as np +import torch + +from pufferlib.ocean.benchmark import interaction_features + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def _tensor(value, *, dtype=torch.float32): + return torch.tensor(value, dtype=dtype, device=DEVICE) + + +def test_box_distance_calculations(): + """Test with manually designed box configurations. + + Box 0: x=2, y=1, heading=0, length=2, width=1 + Box 1: x=4.5, y=2.5, heading=0, l=2, w=1 + Box 2: x=1.5, y=2, heading=pi/2, l=2, w=1 + Box 3: x=3.5, y=0, heading=pi/4, l=sqrt(2), w=sqrt(2) + + Expected distances from Box 0 to others: + - Box 0 to Box 0: EXTREMELY_LARGE_DISTANCE (self-distance masked) + - Box 0 to Box 1: 1/sqrt(2) ≈ 0.707 + - Box 0 to Box 2: -0.5 (overlapping) + - Box 0 to Box 3: 0.0 (touching) + """ + + center_x_np = np.array([[[2.0]], [[4.5]], [[1.5]], [[3.5]]]) + center_y_np = np.array([[[1.0]], [[2.5]], [[2.0]], [[0.0]]]) + length_np = np.array([[2.0], [2.0], [2.0], [np.sqrt(2)]]) + width_np = np.array([[1.0], [1.0], [1.0], [np.sqrt(2)]]) + heading_np = np.array([[[0.0]], [[0.0]], [[np.pi / 2]], [[np.pi / 4]]]) + valid_np = np.ones((4, 1, 1), dtype=bool) + + print("Test: Handpicked box configurations") + print("=" * 60) + print(f"\nArray shapes:") + print(f" center_x: {center_x_np.shape} = (num_agents=4, num_rollouts=1, num_timesteps=1)") + print(f" length: {length_np.shape} = (num_agents=4, num_rollouts=1)") + print(f" This test uses: 4 agents, 1 rollout, 1 timestep") + + center_x = _tensor(center_x_np) + center_y = _tensor(center_y_np) + length = _tensor(length_np) + width = _tensor(width_np) + heading = _tensor(heading_np) + valid = _tensor(valid_np, dtype=torch.bool) + eval_mask = _tensor(np.ones(4, dtype=bool), dtype=torch.bool) + + signed_distances = interaction_features.compute_signed_distances( + center_x=center_x, + center_y=center_y, + length=length, + width=width, + heading=heading, + valid=valid, + evaluated_object_mask=eval_mask, + corner_rounding_factor=0.0, + ) + signed_distances = signed_distances.cpu().numpy() + + expected_distances = np.array([interaction_features.EXTREMELY_LARGE_DISTANCE, 1.0 / np.sqrt(2), -0.5, 0.0]) + + print("\nBox configurations:") + print(" Box 0: center=(2.0, 1.0), heading=0°, length=2.0, width=1.0") + print(" Box 1: center=(4.5, 2.5), heading=0°, length=2.0, width=1.0") + print(" Box 2: center=(1.5, 2.0), heading=90°, length=2.0, width=1.0") + print(f" Box 3: center=(3.5, 0.0), heading=45°, length={np.sqrt(2):.3f}, width={np.sqrt(2):.3f}") + + print("\nDistances from Box 0:") + print(f" To itself (Box 0): {signed_distances[0, 0, 0, 0]:.6e} (expected: {expected_distances[0]:.6e})") + print(f" To Box 1: {signed_distances[0, 1, 0, 0]:.6f} (expected: {expected_distances[1]:.6f})") + print(f" To Box 2: {signed_distances[0, 2, 0, 0]:.6f} (expected: {expected_distances[2]:.6f})") + print(f" To Box 3: {signed_distances[0, 3, 0, 0]:.6f} (expected: {expected_distances[3]:.6f})") + + atol = 0.01 + assert signed_distances[0, 0, 0, 0] >= interaction_features.EXTREMELY_LARGE_DISTANCE - 1, ( + f"Self-distance should be EXTREMELY_LARGE_DISTANCE, got {signed_distances[0, 0, 0, 0]}" + ) + + assert np.abs(signed_distances[0, 1, 0, 0] - expected_distances[1]) < atol, ( + f"Distance to Box 1 should be {expected_distances[1]:.6f}, got {signed_distances[0, 1, 0, 0]:.6f}" + ) + + assert np.abs(signed_distances[0, 2, 0, 0] - expected_distances[2]) < atol, ( + f"Distance to Box 2 should be {expected_distances[2]:.6f}, got {signed_distances[0, 2, 0, 0]:.6f}" + ) + + assert np.abs(signed_distances[0, 3, 0, 0] - expected_distances[3]) < atol, ( + f"Distance to Box 3 should be {expected_distances[3]:.6f}, got {signed_distances[0, 3, 0, 0]:.6f}" + ) + + print(" ✓ Test passed!") + + +def test_invalid_objects(): + """Test invalid objects using handpicked box 1, but marked invalid.""" + + center_x_np = np.array([[[2.0]], [[4.5]]]) + center_y_np = np.array([[[1.0]], [[2.5]]]) + length_np = np.array([[2.0], [2.0]]) + width_np = np.array([[1.0], [1.0]]) + heading_np = np.array([[[0.0]], [[0.0]]]) + valid_np = np.array([[True], [False]], dtype=bool)[:, :, np.newaxis] + + print("\nTest: Invalid objects") + print(f"Array shapes: {center_x_np.shape} = (num_agents=2, num_rollouts=1, num_timesteps=1)") + print("This test uses: 2 agents (Box 0 and Box 1), 1 rollout, 1 timestep") + print("Box 1 is marked as invalid") + + center_x = _tensor(center_x_np) + center_y = _tensor(center_y_np) + length = _tensor(length_np) + width = _tensor(width_np) + heading = _tensor(heading_np) + valid = _tensor(valid_np, dtype=torch.bool) + eval_mask = _tensor(np.ones(2, dtype=bool), dtype=torch.bool) + + distances = interaction_features.compute_distance_to_nearest_object( + center_x=center_x, + center_y=center_y, + length=length, + width=width, + heading=heading, + valid=valid, + evaluated_object_mask=eval_mask, + ) + distances = distances.cpu().numpy() + print(f" Agent 0 distance (box 1 invalid): {distances[0, 0, 0]}") + print(f" Expected: {interaction_features.EXTREMELY_LARGE_DISTANCE}") + + assert distances[0, 0, 0] >= interaction_features.EXTREMELY_LARGE_DISTANCE - 1, ( + f"Invalid object not handled correctly! Got {distances[0, 0, 0]}" + ) + + print(" ✓ Test passed!") + + +def test_multiple_rollouts(): + """Test with handpicked boxes across multiple rollouts.""" + + center_x_np = np.array( + [ + [[2.0], [2.0]], + [[4.5], [5.0]], + ] + ) + center_y_np = np.array( + [ + [[1.0], [1.0]], + [[2.5], [3.0]], + ] + ) + length_np = np.ones((2, 2)) * 2.0 + width_np = np.ones((2, 2)) * 1.0 + heading_np = np.zeros_like(center_x_np) + valid_np = np.ones((2, 2, 1), dtype=bool) + + print("\nTest: Multiple rollouts") + print(f"Array shapes: {center_x_np.shape} = (num_agents=2, num_rollouts=2, num_timesteps=1)") + print("This test uses: 2 agents (Box 0 and Box 1), 2 rollouts, 1 timestep") + print("Rollout 0: Box 1 at (4.5, 2.5) - closer to Box 0") + print("Rollout 1: Box 1 at (5.0, 3.0) - further from Box 0") + + center_x = _tensor(center_x_np) + center_y = _tensor(center_y_np) + length = _tensor(length_np) + width = _tensor(width_np) + heading = _tensor(heading_np) + valid = _tensor(valid_np, dtype=torch.bool) + eval_mask = _tensor(np.ones(2, dtype=bool), dtype=torch.bool) + + distances = interaction_features.compute_distance_to_nearest_object( + center_x=center_x, + center_y=center_y, + length=length, + width=width, + heading=heading, + valid=valid, + evaluated_object_mask=eval_mask, + corner_rounding_factor=0.0, + ) + distances = distances.cpu().numpy() + print(f" Agent 0, rollout 0: {distances[0, 0, 0]:.2f}m") + print(f" Agent 0, rollout 1: {distances[0, 1, 0]:.2f}m") + print(" Distances should vary across rollouts") + + assert distances[0, 0, 0] != distances[0, 1, 0], "Distances should differ across rollouts" + + print(" ✓ Test passed!") + + +def test_multiple_timesteps(): + """Test with handpicked boxes across multiple timesteps.""" + + center_x_np = np.array( + [ + [[2.0, 2.0, 2.0]], + [[4.5, 5.0, 6.0]], + ] + ) + center_y_np = np.array( + [ + [[1.0, 1.0, 1.0]], + [[2.5, 2.5, 2.5]], + ] + ) + length_np = np.ones((2, 1)) * 2.0 + width_np = np.ones((2, 1)) * 1.0 + heading_np = np.zeros_like(center_x_np) + valid_np = np.ones((2, 1, 3), dtype=bool) + + print("\nTest: Multiple timesteps") + print(f"Array shapes: {center_x_np.shape} = (num_agents=2, num_rollouts=1, num_timesteps=3)") + print("This test uses: 2 agents (Box 0 and Box 1), 1 rollout, 3 timesteps") + print("Box 0 stays at (2.0, 1.0) across all timesteps") + print("Box 1 moves away: t=0 at (4.5, 2.5), t=1 at (5.0, 2.5), t=2 at (6.0, 2.5)") + + center_x = _tensor(center_x_np) + center_y = _tensor(center_y_np) + length = _tensor(length_np) + width = _tensor(width_np) + heading = _tensor(heading_np) + valid = _tensor(valid_np, dtype=torch.bool) + eval_mask = _tensor(np.ones(2, dtype=bool), dtype=torch.bool) + + distances = interaction_features.compute_distance_to_nearest_object( + center_x=center_x, + center_y=center_y, + length=length, + width=width, + heading=heading, + valid=valid, + evaluated_object_mask=eval_mask, + corner_rounding_factor=0.0, + ) + distances = distances.cpu().numpy() + + print(f" Agent 0, timestep 0: {distances[0, 0, 0]:.2f}m") + print(f" Agent 0, timestep 1: {distances[0, 0, 1]:.2f}m") + print(f" Agent 0, timestep 2: {distances[0, 0, 2]:.2f}m") + print(" Distances should increase as Box 1 moves away") + + assert distances[0, 0, 0] < distances[0, 0, 1] < distances[0, 0, 2], ( + "Distances should increase over time as Box 1 moves away" + ) + + print(" ✓ Test passed!") + + +def test_rollout_isolation(): + """Test that agents from different rollouts never interact.""" + + center_x_np = np.array([[[2.0], [100.0]], [[4.5], [2.0]]]) + center_y_np = np.array([[[1.0], [1.0]], [[2.5], [100.0]]]) + length_np = np.ones((2, 2)) * 2.0 + width_np = np.ones((2, 2)) * 1.0 + heading_np = np.zeros_like(center_x_np) + valid_np = np.ones((2, 2, 1), dtype=bool) + + print("\nTest: Rollout isolation") + print(f"Array shapes: {center_x_np.shape} = (num_agents=2, num_rollouts=2, num_timesteps=1)") + print("This test uses: 2 agents (Box 0 and Box 1), 2 rollouts, 1 timestep") + print("Rollout 0: Box 0 at (2.0, 1.0), Box 1 at (4.5, 2.5) - close together") + print("Rollout 1: Box 0 at (100.0, 1.0), Box 1 at (2.0, 100.0) - far apart") + print("Each agent should only see the other agent in its own rollout") + + center_x = _tensor(center_x_np) + center_y = _tensor(center_y_np) + length = _tensor(length_np) + width = _tensor(width_np) + heading = _tensor(heading_np) + valid = _tensor(valid_np, dtype=torch.bool) + eval_mask = _tensor(np.ones(2, dtype=bool), dtype=torch.bool) + + distances = interaction_features.compute_distance_to_nearest_object( + center_x=center_x, + center_y=center_y, + length=length, + width=width, + heading=heading, + valid=valid, + evaluated_object_mask=eval_mask, + corner_rounding_factor=0.0, + ) + distances = distances.cpu().numpy() + print(f" Agent 0, rollout 0: {distances[0, 0, 0]:.2f}m (should see Agent 1 rollout 0)") + print(f" Agent 0, rollout 1: {distances[0, 1, 0]:.2f}m (should see Agent 1 rollout 1)") + + assert distances[0, 0, 0] < 10.0, "Agent 0 rollout 0 should see nearby Agent 1 rollout 0" + assert distances[0, 1, 0] > 90.0, "Agent 0 rollout 1 should see distant Agent 1 rollout 1" + + print(" ✓ Test passed! Rollouts are properly isolated.") + + +if __name__ == "__main__": + print("Running geometry and distance computation tests...\n") + print("=" * 60) + + try: + test_box_distance_calculations() + test_invalid_objects() + test_multiple_timesteps() + test_multiple_rollouts() + test_rollout_isolation() + + print("\n" + "=" * 60) + print("✓ All tests passed!") + + except Exception as e: + print("\n" + "=" * 60) + print(f"✗ Test failed with error:") + print(f" {type(e).__name__}: {e}") + import traceback + + traceback.print_exc() diff --git a/pufferlib/ocean/benchmark/test_map_metrics.py b/pufferlib/ocean/benchmark/test_map_metrics.py new file mode 100644 index 000000000..46e690e25 --- /dev/null +++ b/pufferlib/ocean/benchmark/test_map_metrics.py @@ -0,0 +1,548 @@ +"""Tests for map metric features (distance to road edge).""" + +import math +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle +from matplotlib.transforms import Affine2D +import torch + +from pufferlib.ocean.benchmark import map_metric_features + + +def _tensor(data, dtype=torch.float32): + """Convenience helper to create torch tensors for the map feature API.""" + return torch.as_tensor(data, dtype=dtype) + + +def plot_test_cases(): + """Visualize all test cases.""" + fig, axes = plt.subplots(3, 3, figsize=(15, 15)) + + # Test 1: Sign correctness + ax = axes[0, 0] + ax.plot([0, 0], [0, 2], "b-", linewidth=2, label="Road edge") + ax.arrow(0, 0.5, 0, 0.8, head_width=0.1, head_length=0.1, fc="b", ec="b") + ax.plot(-1, 1, "go", markersize=10, label="P (left, neg)") + ax.plot(2, 1, "ro", markersize=10, label="Q (right, pos)") + ax.set_xlim(-2, 3) + ax.set_ylim(-0.5, 2.5) + ax.set_aspect("equal") + ax.set_title("Test 1: Sign convention") + ax.legend(loc="upper right", fontsize=8) + ax.grid(True, alpha=0.3) + + # Test 2: Magnitude + ax = axes[0, 1] + ax.plot([0, 2], [0, 0], "b-", linewidth=2) + ax.arrow(0.5, 0, 0.8, 0, head_width=0.1, head_length=0.1, fc="b", ec="b") + ax.plot(0, 1, "go", markersize=10, label="P (d=1)") + ax.plot(3, -1, "ro", markersize=10, label=f"Q (d={math.sqrt(2):.2f})") + ax.plot([0, 0], [0, 1], "g--", alpha=0.5) + ax.plot([2, 3], [0, -1], "r--", alpha=0.5) + ax.set_xlim(-1, 4) + ax.set_ylim(-2, 2) + ax.set_aspect("equal") + ax.set_title("Test 2: Distance magnitude") + ax.legend(loc="upper right", fontsize=8) + ax.grid(True, alpha=0.3) + + # Test 3: Two parallel lines + ax = axes[0, 2] + ax.plot([0, 0], [3, -3], "b-", linewidth=2, label="Left edge") + ax.plot([2, 2], [-3, 3], "b-", linewidth=2, label="Right edge") + ax.arrow(0, 0, 0, -1.5, head_width=0.1, head_length=0.1, fc="b", ec="b") + ax.arrow(2, 0, 0, 1.5, head_width=0.1, head_length=0.1, fc="b", ec="b") + ax.axvspan(0, 2, alpha=0.2, color="green", label="On-road") + ax.set_xlim(-1.5, 4.5) + ax.set_ylim(-3.5, 3.5) + ax.set_aspect("equal") + ax.set_title("Test 3: Road corridor") + ax.legend(loc="upper right", fontsize=8) + ax.grid(True, alpha=0.3) + + # Test 4: Padded polylines + ax = axes[1, 0] + ax.plot([0, 0, 0, 0], [4, 1.5, -1.5, -4], "b-", linewidth=2, marker="o", markersize=4, label="4-pt line") + ax.plot([2, 2], [-4, 4], "r-", linewidth=2, marker="o", markersize=4, label="2-pt line (padded)") + ax.axvspan(0, 2, alpha=0.2, color="green") + ax.arrow(0, 0, 0, -1.5, head_width=0.1, head_length=0.1, fc="b", ec="b") + ax.arrow(2, 0, 0, 1.5, head_width=0.1, head_length=0.1, fc="r", ec="r") + ax.text(1, 0, "on-road", ha="center", va="center", fontsize=8) + ax.set_xlim(-1, 4) + ax.set_ylim(-5, 5) + ax.set_aspect("equal") + ax.set_title("Test 4: Polylines padded to same length") + ax.legend(loc="upper right", fontsize=8) + ax.grid(True, alpha=0.3) + + # Test 5: Agent boxes + ax = axes[1, 1] + ax.plot([0, 0], [5, -5], "b-", linewidth=2) + ax.plot([2, 2], [-5, 5], "b-", linewidth=2) + ax.axvspan(0, 2, alpha=0.2, color="green") + + # A0: Fully on-road - center (1, 0), 1m x 0.5m + rect0 = Rectangle((0.5, -0.25), 1, 0.5, fill=False, edgecolor="green", linewidth=2) + ax.add_patch(rect0) + ax.text(1, 0, "A0", ha="center", va="center", fontsize=8) + + # A1: At boundary - center (1, 0), 2m x 2m (offset in y for visibility) + rect1 = Rectangle((0, 1.5), 2, 2, fill=False, edgecolor="orange", linewidth=2) + ax.add_patch(rect1) + ax.text(1, 2.5, "A1", ha="center", va="center", fontsize=8) + + # A2: One side off - center (1.75, 0), 1m x 0.5m (offset in y) + rect2 = Rectangle((1.25, -1.5), 1, 0.5, fill=False, edgecolor="red", linewidth=2) + ax.add_patch(rect2) + ax.text(1.75, -1.25, "A2", ha="center", va="center", fontsize=8) + + # A3: Fully off-road - center (5, 0), 1m x 0.5m + rect3 = Rectangle((4.5, -0.25), 1, 0.5, fill=False, edgecolor="darkred", linewidth=2) + ax.add_patch(rect3) + ax.text(5, 0, "A3", ha="center", va="center", fontsize=8) + + # A4: Rotated - center (1.5, 0), sqrt(2) x sqrt(2), heading=pi/4 + # Corners at (2.5, 0), (1.5, 1), (0.5, 0), (1.5, -1) + diamond_x = [2.5, 1.5, 0.5, 1.5, 2.5] + diamond_y = [-3.5, -2.5, -3.5, -4.5, -3.5] + ax.plot(diamond_x, diamond_y, "purple", linewidth=2) + ax.plot(2.5, -3.5, "ro", markersize=6) # off-road corner + ax.text(1.5, -3.5, "A4", ha="center", va="center", fontsize=8) + + ax.set_xlim(-1, 6) + ax.set_ylim(-6, 4) + ax.set_aspect("equal") + ax.set_title("Test 5: Agent boxes") + ax.grid(True, alpha=0.3) + + # Hide unused subplot + axes[1, 2].axis("off") + + # Test 6: Donut road (outer CCW, inner CW) + ax = axes[2, 0] + # Outer square + outer_x = [0, 4, 4, 0, 0] + outer_y = [0, 0, 4, 4, 0] + ax.plot(outer_x, outer_y, "b-", linewidth=2) + # Inner square + inner_x = [1, 1, 3, 3, 1] + inner_y = [1, 3, 3, 1, 1] + ax.plot(inner_x, inner_y, "b-", linewidth=2) + # Fill road area + from matplotlib.patches import Polygon + + outer_poly = np.array([[0, 0], [4, 0], [4, 4], [0, 4]]) + inner_poly = np.array([[1, 1], [1, 3], [3, 3], [3, 1]]) + ax.fill(outer_x, outer_y, alpha=0.2, color="green") + ax.fill(inner_x, inner_y, alpha=1.0, color="white") + + # Direction arrows - outer CCW, inner CW + ax.arrow(1.5, 0, 1, 0, head_width=0.15, head_length=0.1, fc="b", ec="b") # outer bottom: right + ax.arrow(2.5, 1, -1, 0, head_width=0.15, head_length=0.1, fc="b", ec="b") # inner bottom: left (CW) + + # Test points + ax.plot(0.5, 2, "go", markersize=10, label="P1 (on road)") + ax.plot(2, 2, "ro", markersize=10, label="P2 (inside inner)") + ax.plot(5, 2, "ro", markersize=8) + ax.plot(2, 0.5, "go", markersize=8) + + ax.text(0.5, 2.3, "P1", ha="center", fontsize=8) + ax.text(2, 2.3, "P2", ha="center", fontsize=8) + ax.text(5.2, 2, "P3", ha="left", fontsize=8) + ax.text(2, 0.2, "P4", ha="center", fontsize=8) + + ax.set_xlim(-0.5, 5.5) + ax.set_ylim(-0.5, 4.5) + ax.set_aspect("equal") + ax.set_title("Test 6: Donut road") + ax.legend(loc="upper right", fontsize=7) + ax.grid(True, alpha=0.3) + + # Test 7: Triangle with acute corner (cyclic test) + ax = axes[2, 1] + # Triangle counterclockwise + shape_x = [0, 10, 10, 0] + shape_y = [0, 0, 10, 0] + ax.plot(shape_x, shape_y, "b-", linewidth=2) + ax.fill(shape_x, shape_y, alpha=0.2, color="green") + + # Mark the acute corner at (0, 0) + ax.plot(0, 0, "ko", markersize=10, markerfacecolor="yellow", label="Acute corner") + + # Test point P at (-2, 1) - outside but Seg 0 thinks inside + ax.plot(-2.0, 1.0, "ro", markersize=10, label="P (-2, 1)") + + # Direction arrows + ax.arrow(3, 0, 3, 0, head_width=0.4, head_length=0.3, fc="b", ec="b") # Seg 0: right + ax.arrow(6, 6, -1.5, -1.5, head_width=0.4, head_length=0.3, fc="b", ec="b") # Seg 2: down-left (on hypotenuse) + + # Labels + ax.text(5, -1, "Seg 0", ha="center", fontsize=7) + ax.text(1.5, 5, "Seg 2", ha="center", fontsize=7) + + ax.set_xlim(-4, 12) + ax.set_ylim(-2, 12) + ax.set_aspect("equal") + ax.set_title("Test 7: Triangle (cyclic corner)") + ax.legend(loc="upper right", fontsize=7) + ax.grid(True, alpha=0.3) + + # Hide unused subplot + axes[2, 2].axis("off") + + plt.tight_layout() + plt.savefig("test_map_metrics.png", dpi=150) + print(f"Plot saved to test_map_metrics.png") + plt.close() + + +def test_signed_distance_correct_sign(): + """Test sign convention: negative = left (port), positive = right (starboard). + + R2 + ^ + P | Q + R1 + + P at (-1, 1) should be negative (left of upward line) + Q at (2, 1) should be positive (right of upward line) + """ + query_points = _tensor([[-1.0, 1.0], [2.0, 1.0]]) + + polyline_x = _tensor([0.0, 0.0]) + polyline_y = _tensor([0.0, 2.0]) + polyline_lengths = _tensor([2], dtype=torch.int64) + + polylines, valid = map_metric_features._pad_polylines(polyline_x, polyline_y, polyline_lengths) + + distances = map_metric_features._compute_signed_distance_to_polylines(query_points, polylines, valid).cpu().numpy() + + expected = np.array([-1.0, 2.0]) + np.testing.assert_allclose(distances, expected, rtol=1e-5, atol=1e-5) + print("✓ test_signed_distance_correct_sign passed") + + +def test_signed_distance_correct_magnitude(): + """Test distance magnitude for points projecting onto and beyond segment. + + P + + R1----->R2 + + Q + + P at (0, 1) projects onto segment -> distance = 1.0 + Q at (3, -1) projects beyond R2 -> distance = sqrt(2) to corner + """ + query_points = _tensor([[0.0, 1.0], [3.0, -1.0]]) + + polyline_x = _tensor([0.0, 2.0]) + polyline_y = _tensor([0.0, 0.0]) + polyline_lengths = _tensor([2], dtype=torch.int64) + + polylines, valid = map_metric_features._pad_polylines(polyline_x, polyline_y, polyline_lengths) + + distances = map_metric_features._compute_signed_distance_to_polylines(query_points, polylines, valid).cpu().numpy() + + expected_abs = np.array([1.0, math.sqrt(2)]) + np.testing.assert_allclose(np.abs(distances), expected_abs, rtol=1e-5, atol=1e-5) + print("✓ test_signed_distance_correct_magnitude passed") + + +def test_signed_distance_two_parallel_lines(): + """Test with two parallel lines forming a road corridor. + + Query grid from -1 to 4, two lines at x=0 and x=2. + Points between lines should be negative (on-road). + Points outside should be positive (off-road). + Expected: |x - 1| - 1 (distance to center minus half-width) + """ + x = np.linspace(-1.0, 4.0, 10, dtype=np.float32) + mesh_xys_np = np.stack(np.meshgrid(x, x), axis=-1).reshape(-1, 2) + mesh_xys = _tensor(mesh_xys_np) + + # Line 1: x=0, pointing down (y: 10 to -10) + # Line 2: x=2, pointing up (y: -10 to 10) + polyline_x = _tensor([0.0, 0.0, 2.0, 2.0]) + polyline_y = _tensor([10.0, -10.0, -10.0, 10.0]) + polyline_lengths = _tensor([2, 2], dtype=torch.int64) + + polylines, valid = map_metric_features._pad_polylines(polyline_x, polyline_y, polyline_lengths) + + distances = map_metric_features._compute_signed_distance_to_polylines(mesh_xys, polylines, valid).cpu().numpy() + + expected = np.abs(mesh_xys_np[:, 0] - 1.0) - 1.0 + np.testing.assert_allclose(distances, expected, rtol=1e-5, atol=1e-5) + print("✓ test_signed_distance_two_parallel_lines passed") + + +def test_signed_distance_with_padding(): + """Test with polylines of different lengths (padded).""" + x = np.linspace(-1.0, 4.0, 10, dtype=np.float32) + mesh_xys_np = np.stack(np.meshgrid(x, x), axis=-1).reshape(-1, 2) + mesh_xys = _tensor(mesh_xys_np) + + # Line 1: 4 points, Line 2: 2 points (will be padded) + polyline_x = _tensor([0.0, 0.0, 0.0, 0.0, 2.0, 2.0]) + polyline_y = _tensor([10.0, 3.0, -3.0, -10.0, -10.0, 10.0]) + polyline_lengths = _tensor([4, 2], dtype=torch.int64) + + polylines, valid = map_metric_features._pad_polylines(polyline_x, polyline_y, polyline_lengths) + + distances = map_metric_features._compute_signed_distance_to_polylines(mesh_xys, polylines, valid).cpu().numpy() + + expected = np.abs(mesh_xys_np[:, 0] - 1.0) - 1.0 + np.testing.assert_allclose(distances, expected, rtol=1e-5, atol=1e-5) + print("✓ test_signed_distance_with_padding passed") + + +def test_cyclic_polyline(): + """Test with a cyclic polyline (closed square boundary). + + Square road boundary: corners at (0,0), (2,0), (2,2), (0,2), back to (0,0) + Winding order: counterclockwise (inside = on-road) + + Points: + - P1 at (1, 1): center of square, should be negative (on-road) + - P2 at (3, 1): outside right edge, should be positive (off-road) + - P3 at (1, 3): outside top edge, should be positive (off-road) + - P4 at (-1, 1): outside left edge, should be positive (off-road) + - P5 at (2, 2): exactly at corner, should be ~0 + """ + query_points = _tensor( + [ + [1.0, 1.0], # P1: center + [3.0, 1.0], # P2: outside right + [1.0, 3.0], # P3: outside top + [-1.0, 1.0], # P4: outside left + [2.0, 2.0], # P5: at corner + ] + ) + + # Counterclockwise square: (0,0) -> (2,0) -> (2,2) -> (0,2) -> (0,0) + polyline_x = _tensor([0.0, 2.0, 2.0, 0.0, 0.0]) + polyline_y = _tensor([0.0, 0.0, 2.0, 2.0, 0.0]) + polyline_lengths = _tensor([5], dtype=torch.int64) + + polylines, valid = map_metric_features._pad_polylines(polyline_x, polyline_y, polyline_lengths) + + distances = map_metric_features._compute_signed_distance_to_polylines(query_points, polylines, valid).cpu().numpy() + + print(f"P1 (center): {distances[0]:.3f} (expected ~ -1.0)") + print(f"P2 (outside right): {distances[1]:.3f} (expected ~ +1.0)") + print(f"P3 (outside top): {distances[2]:.3f} (expected ~ +1.0)") + print(f"P4 (outside left): {distances[3]:.3f} (expected ~ +1.0)") + print(f"P5 (at corner): {distances[4]:.3f} (expected ~ 0.0)") + + # P1: inside square, distance to nearest edge is 1.0 + assert distances[0] < 0, f"P1 should be on-road (negative), got {distances[0]}" + np.testing.assert_allclose(distances[0], -1.0, atol=0.1) + + # P2, P3, P4: outside square by 1.0 + assert distances[1] > 0, f"P2 should be off-road (positive), got {distances[1]}" + assert distances[2] > 0, f"P3 should be off-road (positive), got {distances[2]}" + assert distances[3] > 0, f"P4 should be off-road (positive), got {distances[3]}" + np.testing.assert_allclose(distances[1], 1.0, atol=0.1) + np.testing.assert_allclose(distances[2], 1.0, atol=0.1) + np.testing.assert_allclose(distances[3], 1.0, atol=0.1) + + # P5: at corner + np.testing.assert_allclose(distances[4], 0.0, atol=0.1) + + print("✓ test_cyclic_polyline passed") + + +# NOTE: I made this test to understand why it is needed to handle cyclic polylines specially. +def test_cyclic_seam(): + """Test acute corner tie-breaking using a clean Triangle. + + Shape: Triangle (0,0) -> (10,0) -> (10,10) -> (0,0) + Winding: Counter-Clockwise (Left is Inside). + Corner at (0,0) is Acute (45 degrees). + """ + # Query Point P (-2, 1) + # - Physically: To the left of the diagonal hypotenuse -> OUTSIDE. + # - To Segment 0 (Bottom): Above y=0 -> INSIDE (The Blind Spot). + query_points = _tensor([[-2.0, 1.0]]) + + # The Triangle + polyline_x = _tensor([0.0, 10.0, 10.0, 0.0]) + polyline_y = _tensor([0.0, 0.0, 10.0, 0.0]) + polyline_lengths = _tensor([4], dtype=torch.int64) + + polylines, valid = map_metric_features._pad_polylines(polyline_x, polyline_y, polyline_lengths) + + distances = map_metric_features._compute_signed_distance_to_polylines(query_points, polylines, valid).cpu().numpy() + + # Distance to vertex (0,0) is sqrt(2^2 + 1^2) = sqrt(5) + expected = np.sqrt(5) + print(f"P at (-2, 1): {distances[0]:.3f} (expected ~ +{expected:.3f})") + + assert distances[0] > 0, f"P should be OFF-ROAD (Positive), but got {distances[0]}" + np.testing.assert_allclose(distances[0], expected, atol=0.01) + + print("✓ test_cyclic_seam passed (Clean Triangle)") + + +def test_donut_road(): + """Test with a donut-shaped road (outer CCW, inner CW). + + Outer square: (0,0) -> (4,0) -> (4,4) -> (0,4) -> (0,0) [counterclockwise] + Inner square: (1,1) -> (1,3) -> (3,3) -> (3,1) -> (1,1) [clockwise] + + Road is the region between the two squares (width=1 on each side). + + Points: + - P1 at (0.5, 2): on road (between outer and inner), should be negative + - P2 at (2, 2): inside inner square (off-road), should be positive + - P3 at (5, 2): outside outer square (off-road), should be positive + - P4 at (2, 0.5): on road (between outer and inner), should be negative + """ + query_points = _tensor( + [ + [0.5, 2.0], # P1: on road (left side) + [2.0, 2.0], # P2: inside inner square + [5.0, 2.0], # P3: outside outer square + [2.0, 0.5], # P4: on road (bottom side) + ] + ) + + # Outer square: counterclockwise + outer_x = _tensor([0.0, 4.0, 4.0, 0.0, 0.0]) + outer_y = _tensor([0.0, 0.0, 4.0, 4.0, 0.0]) + + # Inner square: clockwise (opposite winding) + inner_x = _tensor([1.0, 1.0, 3.0, 3.0, 1.0]) + inner_y = _tensor([1.0, 3.0, 3.0, 1.0, 1.0]) + + polyline_x = torch.cat([outer_x, inner_x]) + polyline_y = torch.cat([outer_y, inner_y]) + polyline_lengths = _tensor([5, 5], dtype=torch.int64) + + polylines, valid = map_metric_features._pad_polylines(polyline_x, polyline_y, polyline_lengths) + + distances = map_metric_features._compute_signed_distance_to_polylines(query_points, polylines, valid).cpu().numpy() + + print(f"P1 (on road, left): {distances[0]:.3f} (expected ~ -0.5)") + print(f"P2 (inside inner): {distances[1]:.3f} (expected ~ +1.0)") + print(f"P3 (outside outer): {distances[2]:.3f} (expected ~ +1.0)") + print(f"P4 (on road, bottom): {distances[3]:.3f} (expected ~ -0.5)") + + # P1: on road, 0.5m from outer edge + assert distances[0] < 0, f"P1 should be on-road (negative), got {distances[0]}" + np.testing.assert_allclose(distances[0], -0.5, atol=0.1) + + # P2: inside inner square, 1m from inner edge + assert distances[1] > 0, f"P2 should be off-road (positive), got {distances[1]}" + np.testing.assert_allclose(distances[1], 1.0, atol=0.1) + + # P3: outside outer square, 1m from outer edge + assert distances[2] > 0, f"P3 should be off-road (positive), got {distances[2]}" + np.testing.assert_allclose(distances[2], 1.0, atol=0.1) + + # P4: on road, 0.5m from outer edge + assert distances[3] < 0, f"P4 should be on-road (negative), got {distances[3]}" + np.testing.assert_allclose(distances[3], -0.5, atol=0.1) + + print("✓ test_donut_road passed") + + +def test_compute_distance_to_road_edge(): + """Test full pipeline with agent boxes.""" + num_agents = 5 + num_steps = 1 + + # Road corridor from x=0 to x=2 + # A0: Fully on-road - center (1, 0), 1m x 0.5m, heading=0 + # Corners x ∈ [0.5, 1.5] → all inside, nearest edge 0.5m away, expected ~ -0.5 + # A1: At boundary - center (1, 0), 2m x 2m, heading=0 + # Corners x ∈ [0, 2] → exactly at edges, expected ~ 0 + # A2: One side off - center (1.75, 0), 1m x 0.5m, heading=0 + # Corners x ∈ [1.25, 2.25] → right off by 0.25m, expected ~ 0.25 + # A3: Fully off-road - center (5, 0), 1m x 0.5m, heading=0 + # Corners x ∈ [4.5, 5.5] → far off, expected ~ 3.5 + # A4: Rotated with one corner off - center (1.5, 0), sqrt(2) x sqrt(2), heading=pi/4 + # Corners at (2.5, 0), (1.5, 1), (0.5, 0), (1.5, -1) + # Corner at (2.5, 0) is 0.5m outside, expected ~ +0.5 + + center_x = _tensor([[1.0], [1.0], [1.75], [5.0], [1.5]]) + center_y = _tensor([[0.0], [0.0], [0.0], [0.0], [0.0]]) + length = _tensor([1.0, 2.0, 1.0, 1.0, np.sqrt(2)]) + width = _tensor([0.5, 2.0, 0.5, 0.5, np.sqrt(2)]) + heading = _tensor([[0.0], [0.0], [0.0], [0.0], [np.pi / 4]]) + valid = torch.ones((num_agents, num_steps), dtype=torch.bool) + + # Two parallel lines at x=0 and x=2 + polyline_x = _tensor([0.0, 0.0, 2.0, 2.0]) + polyline_y = _tensor([10.0, -10.0, -10.0, 10.0]) + polyline_lengths = _tensor([2, 2], dtype=torch.int64) + + distances = map_metric_features.compute_distance_to_road_edge( + center_x=center_x, + center_y=center_y, + length=length, + width=width, + heading=heading, + valid=valid, + polyline_x=polyline_x, + polyline_y=polyline_y, + polyline_lengths=polyline_lengths, + ) + + distances_np = distances.cpu().numpy() + + assert distances.shape == (num_agents, num_steps) + + print(f"A0 (fully on-road): {distances_np[0, 0]:.3f} (expected ~ -0.5)") + print(f"A1 (at boundary): {distances_np[1, 0]:.3f} (expected ~ 0)") + print(f"A2 (one side off): {distances_np[2, 0]:.3f} (expected ~ 0.25)") + print(f"A3 (fully off-road): {distances_np[3, 0]:.3f} (expected ~ 3.5)") + print(f"A4 (rotated, one corner off): {distances_np[4, 0]:.3f} (expected ~ 0.5)") + + # A0: fully on-road, corners at x=0.5 and x=1.5, both 0.5m inside road + assert distances_np[0, 0] < 0, f"A0 should be on-road (negative), got {distances_np[0, 0]}" + np.testing.assert_allclose(distances_np[0, 0], -0.5, atol=0.1) + + # A1: at boundary, distance ~ 0 + np.testing.assert_allclose(distances_np[1, 0], 0.0, atol=0.1) + + # A2: one side off by 0.25m + np.testing.assert_allclose(distances_np[2, 0], 0.25, atol=0.1) + + # A3: fully off-road + np.testing.assert_allclose(distances_np[3, 0], 3.5, atol=0.1) + + # A4: rotated, corner at (2.5, 0) is 0.5m outside + np.testing.assert_allclose(distances_np[4, 0], 0.5, atol=0.1) + + print("✓ test_compute_distance_to_road_edge passed") + + +if __name__ == "__main__": + print("Running map metric feature tests...\n") + print("=" * 60) + + try: + print("Generating visualization first...") + plot_test_cases() + print() + + test_signed_distance_correct_sign() + test_signed_distance_correct_magnitude() + test_signed_distance_two_parallel_lines() + test_signed_distance_with_padding() + test_cyclic_polyline() + test_cyclic_seam() + test_donut_road() + test_compute_distance_to_road_edge() + + print("\n" + "=" * 60) + print("✓ All tests passed!") + + except Exception as e: + print("\n" + "=" * 60) + print(f"✗ Test failed:") + import traceback + + traceback.print_exc() diff --git a/pufferlib/ocean/benchmark/test_road_edges.py b/pufferlib/ocean/benchmark/test_road_edges.py new file mode 100644 index 000000000..7fbe164c3 --- /dev/null +++ b/pufferlib/ocean/benchmark/test_road_edges.py @@ -0,0 +1,88 @@ +"""Test script for road edge extraction. + +Run with: python -m pufferlib.ocean.benchmark.test_road_edges +""" + +import numpy as np +import matplotlib.pyplot as plt +import pufferlib +import pufferlib.vector +from pufferlib.pufferl import load_config + + +def main(): + env_name = "puffer_drive" + args = load_config(env_name) + + args["vec"] = dict(backend="PufferEnv", num_envs=1) + args["env"]["num_agents"] = 32 + + from pufferlib.ocean import env_creator + + make_env = env_creator(env_name) + vecenv = pufferlib.vector.make(make_env, env_kwargs=args["env"], **args["vec"]) + vecenv.reset() + + polylines = vecenv.driver_env.get_road_edge_polylines() + + print("\n=== Road Edge Statistics ===") + print(f"num_polylines: {len(polylines['lengths'])}") + print(f"total_points: {len(polylines['x'])}") + print( + f"points per polyline: min={polylines['lengths'].min()}, max={polylines['lengths'].max()}, mean={polylines['lengths'].mean():.1f}" + ) + print(f"x range: [{polylines['x'].min():.1f}, {polylines['x'].max():.1f}]") + print(f"y range: [{polylines['y'].min():.1f}, {polylines['y'].max():.1f}]") + + unique_scenarios = np.unique(polylines["scenario_id"]) + print(f"unique scenarios: {len(unique_scenarios)} -> {unique_scenarios}") + + for sid in unique_scenarios[:3]: + mask = polylines["scenario_id"] == sid + n_polys = mask.sum() + pts = polylines["lengths"][mask].sum() + print(f" scenario {sid}: {n_polys} polylines, {pts} points") + + # Plot first scenario + fig, ax = plt.subplots(1, 1, figsize=(12, 12)) + + sid = unique_scenarios[0] + mask = polylines["scenario_id"] == sid + poly_indices = np.where(mask)[0] + + boundaries = np.cumsum(np.concatenate([[0], polylines["lengths"]])) + + for i, idx in enumerate(poly_indices): + start = boundaries[idx] + end = boundaries[idx + 1] + x = polylines["x"][start:end] + y = polylines["y"][start:end] + + ax.plot(x, y, "b-", linewidth=0.5, alpha=0.7) + + # Mark direction with arrow on first segment + if len(x) >= 2: + mid = len(x) // 2 + dx = x[mid] - x[mid - 1] + dy = y[mid] - y[mid - 1] + ax.annotate( + "", + xy=(x[mid], y[mid]), + xytext=(x[mid - 1], y[mid - 1]), + arrowprops=dict(arrowstyle="->", color="red", lw=0.5), + ) + + ax.set_aspect("equal") + ax.set_title(f"Road edges for scenario {sid} ({len(poly_indices)} polylines)") + ax.set_xlabel("x (m)") + ax.set_ylabel("y (m)") + + plt.tight_layout() + plt.savefig("road_edges_test.png", dpi=150) + print(f"\nPlot saved to road_edges_test.png") + + vecenv.close() + + +if __name__ == "__main__": + main() diff --git a/pufferlib/ocean/benchmark/test_ttc.py b/pufferlib/ocean/benchmark/test_ttc.py new file mode 100644 index 000000000..ce5f45bf9 --- /dev/null +++ b/pufferlib/ocean/benchmark/test_ttc.py @@ -0,0 +1,268 @@ +import math +import numpy as np +import pytest +import torch + +from pufferlib.ocean.benchmark import interaction_features + +MAX_HEADING_DIFF = interaction_features.MAX_HEADING_DIFF +SMALL_OVERLAP_THRESHOLD = interaction_features.SMALL_OVERLAP_THRESHOLD +MAX_HEADING_DIFF_FOR_SMALL_OVERLAP = interaction_features.MAX_HEADING_DIFF_FOR_SMALL_OVERLAP +MAX_TTC_SEC = interaction_features.MAXIMUM_TIME_TO_COLLISION + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def _tensor(value, *, dtype=torch.float32): + return torch.tensor(value, dtype=dtype, device=DEVICE) + + +def test_time_to_collision_output_shape(): + """Test that TTC returns correct shape.""" + num_agents = 4 + num_rollouts = 2 + num_steps = 10 + + center_x = _tensor(np.random.randn(num_agents, num_rollouts, num_steps).astype(np.float32)) + center_y = _tensor(np.random.randn(num_agents, num_rollouts, num_steps).astype(np.float32)) + length = _tensor(np.ones((num_agents, num_rollouts), dtype=np.float32) * 4.0) + width = _tensor(np.ones((num_agents, num_rollouts), dtype=np.float32) * 2.0) + heading = _tensor(np.zeros((num_agents, num_rollouts, num_steps), dtype=np.float32)) + valid = _tensor(np.ones((num_agents, num_rollouts, num_steps), dtype=bool), dtype=torch.bool) + + eval_mask = _tensor(np.ones(num_agents, dtype=bool), dtype=torch.bool) + ttc = interaction_features.compute_time_to_collision( + center_x=center_x, + center_y=center_y, + length=length, + width=width, + heading=heading, + valid=valid, + evaluated_object_mask=eval_mask, + seconds_per_step=0.1, + ) + ttc = ttc.cpu().numpy() + + assert ttc.shape == (num_agents, num_rollouts, num_steps) + + +@pytest.mark.parametrize( + "center_xys,headings,boxes_sizes,speeds,expected_ttc_sec", + [ + # 9 square boxes in a 3x3 grid + ( + [[-3, 3], [0, 3], [3, 3], [-3, 0], [0, 0], [3, 0], [-3, -3], [0, -3], [3, -3]], + [0] * 9, + [[1, 1]] * 9, + [10, 6, 1] * 3, + [2 / 4, 2 / 5, MAX_TTC_SEC] * 3, + ), + # Rectangles in a line + ( + [[0, 0], [5, 0], [10, 0], [15, 0]], + [0, 0, 0, 0], + [[4, 2]] * 4, + [6, 10, 3, 1], + [MAX_TTC_SEC, 1 / 7, 1 / 2, MAX_TTC_SEC], + ), + # Test ignore misaligned + ( + [[0, 0], [5, 0], [10, 0], [15, 0]], + [0, MAX_HEADING_DIFF + 0.01, 0, MAX_HEADING_DIFF - 0.01], + [[4, 2]] * 4, + [10, 6, 3, 1], + [ + 6 / 7, + MAX_TTC_SEC, + (3 - math.cos(MAX_HEADING_DIFF - 0.01) * 2 - math.sin(MAX_HEADING_DIFF - 0.01)) / 2, + MAX_TTC_SEC, + ], + ), + # Test ignore no overlap + ( + [[0, 0], [5, 2.1], [10, 1.1], [15, 0]], + [0, 0, 0, 0], + [[4, 2]] * 4, + [10, 6, 3, 1], + [6 / 7, 1 / 3, 1 / 2, MAX_TTC_SEC], + ), + # Test ignore small misalignment with low overlap + ( + [ + [0, 0], + [5, 2.5 - SMALL_OVERLAP_THRESHOLD], + [10, -2.5 + SMALL_OVERLAP_THRESHOLD], + ], + [0, MAX_HEADING_DIFF_FOR_SMALL_OVERLAP + 0.01, -MAX_HEADING_DIFF_FOR_SMALL_OVERLAP + 0.01], + [[4, 2]] * 3, + [6, 3, 1], + [ + ( + 8 + - math.cos(MAX_HEADING_DIFF_FOR_SMALL_OVERLAP - 0.01) * 2 + - math.sin(MAX_HEADING_DIFF_FOR_SMALL_OVERLAP - 0.01) + ) + / 5, + MAX_TTC_SEC, + MAX_TTC_SEC, + ], + ), + ], +) +def test_time_to_collision_values(center_xys, headings, boxes_sizes, speeds, expected_ttc_sec): + """Test TTC computation with various configurations.""" + center_xys = np.array(center_xys, dtype=np.float32) + headings = np.array(headings, dtype=np.float32) + boxes_sizes = np.array(boxes_sizes, dtype=np.float32) + speeds = np.array(speeds, dtype=np.float32) + expected_ttc_sec = np.array(expected_ttc_sec, dtype=np.float32) + + num_agents = len(center_xys) + num_rollouts = 1 + seconds_per_step = 0.1 + + # Simulate 3 timesteps (t-1, t, t+1) to get proper speeds with central difference + center_x_1 = center_xys[:, 0] + center_x_2 = center_x_1 + speeds * np.cos(headings) * seconds_per_step + center_x_0 = center_x_1 - speeds * np.cos(headings) * seconds_per_step + center_x = np.stack([center_x_0, center_x_1, center_x_2], axis=-1) + + center_y_1 = center_xys[:, 1] + center_y_2 = center_y_1 - speeds * np.sin(headings) * seconds_per_step + center_y_0 = center_y_1 + speeds * np.sin(headings) * seconds_per_step + center_y = np.stack([center_y_0, center_y_1, center_y_2], axis=-1) + + # Reshape to (num_agents, num_rollouts, num_steps) + center_x = _tensor(center_x[:, np.newaxis, :]) + center_y = _tensor(center_y[:, np.newaxis, :]) + + length = _tensor(np.broadcast_to(boxes_sizes[:, 0:1], (num_agents, num_rollouts))) + width = _tensor(np.broadcast_to(boxes_sizes[:, 1:2], (num_agents, num_rollouts))) + heading = _tensor(np.broadcast_to(headings[:, np.newaxis, np.newaxis], (num_agents, num_rollouts, 3))) + valid = _tensor(np.ones((num_agents, num_rollouts, 3), dtype=bool), dtype=torch.bool) + + eval_mask = _tensor(np.ones(num_agents, dtype=bool), dtype=torch.bool) + ttc = interaction_features.compute_time_to_collision( + center_x=center_x, + center_y=center_y, + length=length, + width=width, + heading=heading, + valid=valid, + evaluated_object_mask=eval_mask, + seconds_per_step=seconds_per_step, + ) + ttc = ttc.cpu().numpy() + + # Check TTC at timestep 1 (middle) where speeds are valid + np.testing.assert_allclose(ttc[:, 0, 1], expected_ttc_sec, rtol=1e-5, atol=1e-5) + + +def test_time_to_collision_invalid_objects(): + """Test that invalid objects are ignored in TTC computation.""" + num_agents = 3 + num_rollouts = 1 + num_steps = 3 + + # Create 3 agents in a line: agent 0 at x=0, agent 1 at x=5, agent 2 at x=10 + # All moving forward with speeds [6, 3, 1] + center_xys = np.array([[0, 0], [5, 0], [10, 0]], dtype=np.float32) + speeds = np.array([6, 3, 1], dtype=np.float32) + headings = np.zeros(num_agents, dtype=np.float32) + seconds_per_step = 0.1 + + # Create 3 timesteps + center_x_1 = center_xys[:, 0] + center_x_2 = center_x_1 + speeds * np.cos(headings) * seconds_per_step + center_x_0 = center_x_1 - speeds * np.cos(headings) * seconds_per_step + center_x = np.stack([center_x_0, center_x_1, center_x_2], axis=-1)[:, np.newaxis, :] + + center_y_1 = center_xys[:, 1] + center_y_2 = center_y_1 - speeds * np.sin(headings) * seconds_per_step + center_y_0 = center_y_1 + speeds * np.sin(headings) * seconds_per_step + center_y = np.stack([center_y_0, center_y_1, center_y_2], axis=-1)[:, np.newaxis, :] + + length = np.ones((num_agents, num_rollouts), dtype=np.float32) * 4.0 + width = np.ones((num_agents, num_rollouts), dtype=np.float32) * 2.0 + heading = np.zeros((num_agents, num_rollouts, num_steps), dtype=np.float32) + + # Test with agent 1 invalid - agent 0 should see agent 2 as nearest + center_x = _tensor(center_x) + center_y = _tensor(center_y) + length = _tensor(length) + width = _tensor(width) + heading = _tensor(heading) + + valid = np.ones((num_agents, num_rollouts, num_steps), dtype=bool) + valid[1, 0, 1] = False + valid = _tensor(valid, dtype=torch.bool) + + eval_mask = _tensor(np.ones(num_agents, dtype=bool), dtype=torch.bool) + ttc = interaction_features.compute_time_to_collision( + center_x=center_x, + center_y=center_y, + length=length, + width=width, + heading=heading, + valid=valid, + evaluated_object_mask=eval_mask, + seconds_per_step=seconds_per_step, + ) + ttc = ttc.cpu().numpy() + + # Agent 0 should now see agent 2 at distance 10 with relative speed 6-1=5 + # TTC = (10 - 2 - 2) / 5 = 6/5 = 1.2 + expected_ttc_agent0 = (10 - 2 - 2) / (6 - 1) + + np.testing.assert_allclose(ttc[0, 0, 1], expected_ttc_agent0, rtol=1e-5, atol=1e-5) + + +def test_time_to_collision_no_object_ahead(): + """Test TTC returns max value when no object is ahead.""" + num_agents = 2 + num_rollouts = 1 + num_steps = 3 + + # Create 2 agents moving away from each other + center_xys = np.array([[0, 0], [10, 0]], dtype=np.float32) + speeds = np.array([5, 5], dtype=np.float32) + headings = np.array([0, math.pi], dtype=np.float32) # Opposite directions + seconds_per_step = 0.1 + + center_x_1 = center_xys[:, 0] + center_x_2 = center_x_1 + speeds * np.cos(headings) * seconds_per_step + center_x_0 = center_x_1 - speeds * np.cos(headings) * seconds_per_step + center_x = np.stack([center_x_0, center_x_1, center_x_2], axis=-1)[:, np.newaxis, :] + + center_y_1 = center_xys[:, 1] + center_y_2 = center_y_1 - speeds * np.sin(headings) * seconds_per_step + center_y_0 = center_y_1 + speeds * np.sin(headings) * seconds_per_step + center_y = np.stack([center_y_0, center_y_1, center_y_2], axis=-1)[:, np.newaxis, :] + + length = _tensor(np.ones((num_agents, num_rollouts), dtype=np.float32) * 4.0) + width = _tensor(np.ones((num_agents, num_rollouts), dtype=np.float32) * 2.0) + heading = _tensor(np.broadcast_to(headings[:, np.newaxis, np.newaxis], (num_agents, num_rollouts, num_steps))) + valid = _tensor(np.ones((num_agents, num_rollouts, num_steps), dtype=bool), dtype=torch.bool) + + center_x = _tensor(center_x) + center_y = _tensor(center_y) + + eval_mask = _tensor(np.ones(num_agents, dtype=bool), dtype=torch.bool) + ttc = interaction_features.compute_time_to_collision( + center_x=center_x, + center_y=center_y, + length=length, + width=width, + heading=heading, + valid=valid, + evaluated_object_mask=eval_mask, + seconds_per_step=seconds_per_step, + ) + ttc = ttc.cpu().numpy() + + # Both agents should have max TTC since they're moving away + np.testing.assert_allclose(ttc[:, 0, 1], MAX_TTC_SEC, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/pufferlib/ocean/benchmark/visual_sanity_check.py b/pufferlib/ocean/benchmark/visual_sanity_check.py new file mode 100644 index 000000000..ca4657441 --- /dev/null +++ b/pufferlib/ocean/benchmark/visual_sanity_check.py @@ -0,0 +1,217 @@ +""" +Visual validation script for WOSAC collision and offroad detection. + +Plots road edges and agent trajectories for a single scenario, +marking collision and offroad events. +""" + +import argparse +import numpy as np +import matplotlib.pyplot as plt +import torch + +from matplotlib.patches import Polygon + +from pufferlib.pufferl import load_config, load_env, load_policy +from pufferlib.ocean.benchmark.evaluator import WOSACEvaluator +from pufferlib.ocean.benchmark.metrics import compute_interaction_features, compute_map_features +from pufferlib.ocean.benchmark.geometry_utils import get_2d_box_corners + + +def plot_road_edges(ax, road_edge_polylines, scenario_id): + """Plot road edge polylines for a specific scenario.""" + lengths = road_edge_polylines["lengths"] + scenario_ids = road_edge_polylines["scenario_id"] + x = road_edge_polylines["x"] + y = road_edge_polylines["y"] + + pt_idx = 0 + for i in range(len(lengths)): + length = lengths[i] + if scenario_ids[i] == scenario_id: + poly_x = x[pt_idx : pt_idx + length] + poly_y = y[pt_idx : pt_idx + length] + ax.plot(poly_x, poly_y, "k-", linewidth=1, alpha=0.7) + pt_idx += length + + +def plot_agent_trajectories(ax, traj, agent_mask, rollout_idx, collisions, offroad, agent_length, agent_width): + """Plot trajectories as bounding boxes with collision/offroad coloring.""" + x = traj["x"][agent_mask, rollout_idx, :] + y = traj["y"][agent_mask, rollout_idx, :] + heading = traj["heading"][agent_mask, rollout_idx, :] + length = agent_length[agent_mask] + width = agent_width[agent_mask] + coll = collisions[agent_mask, rollout_idx, :] + off = offroad[agent_mask, rollout_idx, :] + + num_agents = x.shape[0] + num_steps = x.shape[1] + + collision_agents = [] + offroad_agents = [] + + for i in range(num_agents): + has_collision = np.any(coll[i]) + has_offroad = np.any(off[i]) + if has_collision: + collision_agents.append(i) + if has_offroad: + offroad_agents.append(i) + + for t in range(num_steps): + box = torch.as_tensor( + [[x[i, t], y[i, t], length[i], width[i], heading[i, t]]], + dtype=torch.float32, + ) + corners = get_2d_box_corners(box)[0].cpu().numpy() + + if coll[i, t]: + facecolor = "red" + alpha = 0.6 + edgecolor = "red" + elif off[i, t]: + facecolor = "orange" + alpha = 0.4 + edgecolor = "orange" + else: + facecolor = plt.cm.viridis(t / num_steps) + alpha = 0.3 + edgecolor = facecolor + + polygon = Polygon(corners, facecolor=facecolor, edgecolor=edgecolor, linewidth=0.3, alpha=alpha) + ax.add_patch(polygon) + + return collision_agents, offroad_agents + + +def main(): + parser = argparse.ArgumentParser(description="Visual validation of collision/offroad detection") + parser.add_argument("--env", default="puffer_drive") + parser.add_argument("--output", default="visual_sanity_check.png") + args = parser.parse_args() + + config = load_config(args.env) + config["vec"]["backend"] = "PufferEnv" + config["vec"]["num_envs"] = 1 + config["eval"]["enabled"] = True + config["eval"]["wosac_num_rollouts"] = 1 + + config["env"]["num_agents"] = config["eval"]["wosac_num_agents"] + config["env"]["init_mode"] = config["eval"]["wosac_init_mode"] + config["env"]["control_mode"] = config["eval"]["wosac_control_mode"] + config["env"]["init_steps"] = config["eval"]["wosac_init_steps"] + config["env"]["goal_behavior"] = config["eval"]["wosac_goal_behavior"] + + vecenv = load_env(args.env, config) + policy = load_policy(config, vecenv, args.env) + + evaluator = WOSACEvaluator(config) + gt_traj = evaluator.collect_ground_truth_trajectories(vecenv) + sim_traj = evaluator.collect_simulated_trajectories(config, vecenv, policy) + agent_state = vecenv.driver_env.get_global_agent_state() + road_edge_polylines = vecenv.driver_env.get_road_edge_polylines() + + scenario_ids = gt_traj["scenario_id"] + agent_length = agent_state["length"] + agent_width = agent_state["width"] + + # Compute per-timestep indicators + num_agents = sim_traj["x"].shape[0] + eval_mask = np.ones(num_agents, dtype=bool) + + device = torch.device("cpu") + + _, collisions, _ = compute_interaction_features( + sim_traj["x"], + sim_traj["y"], + sim_traj["heading"], + scenario_ids, + agent_length, + agent_width, + eval_mask, + device=device, + ) + _, offroad = compute_map_features( + sim_traj["x"], + sim_traj["y"], + sim_traj["heading"], + scenario_ids, + agent_length, + agent_width, + road_edge_polylines, + device=device, + ) + + # Plot each scenario + unique_scenarios = np.unique(scenario_ids[:, 0]) + + for scenario_idx, target_scenario in enumerate(unique_scenarios): + agent_mask = scenario_ids[:, 0] == target_scenario + + fig, ax = plt.subplots(figsize=(12, 10)) + + plot_road_edges(ax, road_edge_polylines, target_scenario) + collision_agents, offroad_agents = plot_agent_trajectories( + ax, + sim_traj, + agent_mask, + rollout_idx=0, + collisions=collisions, + offroad=offroad, + agent_length=agent_length, + agent_width=agent_width, + ) + + ax.set_aspect("equal") + ax.set_xlabel("X (m)") + ax.set_ylabel("Y (m)") + ax.set_title(f"Scenario {target_scenario} - Collision/Offroad Detection") + + # Legend + from matplotlib.patches import Patch + + legend_elements = [ + Patch(facecolor="red", alpha=0.6, label="Collision"), + Patch(facecolor="orange", alpha=0.4, label="Offroad"), + ] + ax.legend(handles=legend_elements, loc="upper right") + + # Colorbar for time + sm = plt.cm.ScalarMappable(cmap="viridis", norm=plt.Normalize(0, 1)) + sm.set_array([]) + cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04) + cbar.set_label("Time (normalized)") + + # Summary text + num_agents_in_scenario = agent_mask.sum() + summary = f"Agents: {num_agents_in_scenario}\n" + summary += f"Collisions: {len(collision_agents)} agents ({collision_agents})\n" + summary += f"Offroad: {len(offroad_agents)} agents ({offroad_agents})" + ax.text( + 0.02, + 0.98, + summary, + transform=ax.transAxes, + verticalalignment="top", + fontsize=9, + bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), + ) + + plt.tight_layout() + + # Save with scenario index in filename + output_base = args.output.rsplit(".", 1) + if len(output_base) == 2: + output_path = f"{output_base[0]}_{scenario_idx}.{output_base[1]}" + else: + output_path = f"{args.output}_{scenario_idx}" + + plt.savefig(output_path, dpi=300) + plt.close(fig) + print(f"Scenario {target_scenario}: {output_path}") + print(f" {summary.replace(chr(10), ', ')}") + + +if __name__ == "__main__": + main() diff --git a/pufferlib/ocean/drive/README.md b/pufferlib/ocean/drive/README.md index a37907b20..571001b32 100644 --- a/pufferlib/ocean/drive/README.md +++ b/pufferlib/ocean/drive/README.md @@ -21,7 +21,7 @@ Determines which created agents are **controlled** by the policy. | ----------------------------------------- | ------------------------------------------------------------------------------------------------- | | `control_vehicles` (default) | Control only valid **vehicles** (not experts, beyond `MIN_DISTANCE_TO_GOAL`, under `MAX_AGENTS`). | | `control_agents` | Control all valid **agent types** (vehicles, cyclists, pedestrians). | -| `control_tracks_to_predict` *(WOMD only)* | Control agents listed in the `tracks_to_predict` metadata. | +| `control_wosac` *(WOMD only)* | Control all agents with their valid flag to `True` at the `init_step`. | ## Termination conditions (`done`) diff --git a/pufferlib/ocean/drive/binding.c b/pufferlib/ocean/drive/binding.c index 1061f33b7..d8ebbcdd5 100644 --- a/pufferlib/ocean/drive/binding.c +++ b/pufferlib/ocean/drive/binding.c @@ -70,10 +70,22 @@ static int my_put(Env* env, PyObject* args, PyObject* kwargs) { static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs) { int num_agents = unpack(kwargs, "num_agents"); int num_maps = unpack(kwargs, "num_maps"); - int init_mode = unpack(kwargs, "init_mode"); int control_mode = unpack(kwargs, "control_mode"); int init_steps = unpack(kwargs, "init_steps"); int goal_behavior = unpack(kwargs, "goal_behavior"); + + // Get configs + char* ini_file = unpack_str(kwargs, "ini_file"); + env_init_config conf = {0}; + if(ini_parse(ini_file, handler, &conf) < 0) { + raise_error_with_message(ERROR_UNKNOWN, "Error while loading %s", ini_file); + } + Init_Mode init_mode = conf.init_mode; + int num_agents_per_world = -1; + if (init_mode == DYNAMIC_AGENTS_PER_ENV) { + num_agents_per_world = conf.num_agents_per_world; + } + clock_gettime(CLOCK_REALTIME, &ts); srand(ts.tv_nsec); int total_agent_count = 0; @@ -86,21 +98,48 @@ static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs) { while(total_agent_count < num_agents && env_count < max_envs){ char map_file[100]; int map_id = rand() % num_maps; - Drive* env = calloc(1, sizeof(Drive)); - env->init_mode = init_mode; - env->control_mode = control_mode; - env->init_steps = init_steps; - env->goal_behavior = goal_behavior; - sprintf(map_file, "resources/drive/binaries/map_%03d.bin", map_id); - env->entities = load_map_binary(map_file, env); - set_active_agents(env); + if (init_mode == DYNAMIC_AGENTS_PER_ENV) { + // Store map_id + PyObject* map_id_obj = PyLong_FromLong(map_id); + PyList_SetItem(map_ids, env_count, map_id_obj); + // Store agent offset + PyObject* offset = PyLong_FromLong(total_agent_count); + PyList_SetItem(agent_offsets, env_count, offset); + total_agent_count += num_agents_per_world; + env_count++; + } + else { + Drive* env = calloc(1, sizeof(Drive)); + env->init_mode = init_mode; + env->control_mode = control_mode; + env->init_steps = init_steps; + env->goal_behavior = goal_behavior; + sprintf(map_file, "resources/drive/binaries/map_%03d.bin", map_id); + env->entities = load_map_binary(map_file, env); + set_active_agents(env); + + // Skip map if it doesn't contain any controllable agents + if(env->active_agent_count == 0) { + maps_checked++; - // Skip map if it doesn't contain any controllable agents - if(env->active_agent_count == 0) { - maps_checked++; + // Safeguard: if we've checked all available maps and found no active agents, raise an error + if(maps_checked >= num_maps) { + for(int j=0;jnum_entities;j++) { + free_entity(&env->entities[j]); + } + free(env->entities); + free(env->active_agent_indices); + free(env->static_agent_indices); + free(env->expert_static_agent_indices); + free(env); + Py_DECREF(agent_offsets); + Py_DECREF(map_ids); + char error_msg[256]; + sprintf(error_msg, "No controllable agents found in any of the %d available maps", num_maps); + PyErr_SetString(PyExc_ValueError, error_msg); + return NULL; + } - // Safeguard: if we've checked all available maps and found no active agents, raise an error - if(maps_checked >= num_maps) { for(int j=0;jnum_entities;j++) { free_entity(&env->entities[j]); } @@ -109,14 +148,17 @@ static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs) { free(env->static_agent_indices); free(env->expert_static_agent_indices); free(env); - Py_DECREF(agent_offsets); - Py_DECREF(map_ids); - char error_msg[256]; - sprintf(error_msg, "No controllable agents found in any of the %d available maps", num_maps); - PyErr_SetString(PyExc_ValueError, error_msg); - return NULL; + continue; } + // Store map_id + PyObject* map_id_obj = PyLong_FromLong(map_id); + PyList_SetItem(map_ids, env_count, map_id_obj); + // Store agent offset + PyObject* offset = PyLong_FromLong(total_agent_count); + PyList_SetItem(agent_offsets, env_count, offset); + total_agent_count += env->active_agent_count; + env_count++; for(int j=0;jnum_entities;j++) { free_entity(&env->entities[j]); } @@ -125,25 +167,7 @@ static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs) { free(env->static_agent_indices); free(env->expert_static_agent_indices); free(env); - continue; - } - - // Store map_id - PyObject* map_id_obj = PyLong_FromLong(map_id); - PyList_SetItem(map_ids, env_count, map_id_obj); - // Store agent offset - PyObject* offset = PyLong_FromLong(total_agent_count); - PyList_SetItem(agent_offsets, env_count, offset); - total_agent_count += env->active_agent_count; - env_count++; - for(int j=0;jnum_entities;j++) { - free_entity(&env->entities[j]); } - free(env->entities); - free(env->active_agent_indices); - free(env->static_agent_indices); - free(env->expert_static_agent_indices); - free(env); } //printf("Generated %d environments to cover %d agents (requested %d agents)\n", env_count, total_agent_count, num_agents); if(total_agent_count >= num_agents){ @@ -197,7 +221,7 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) { int init_steps = unpack(kwargs, "init_steps"); char map_file[100]; sprintf(map_file, "resources/drive/binaries/map_%03d.bin", map_id); - env->num_agents = max_agents; + env->max_active_agents = max_agents; env->map_name = strdup(map_file); env->init_steps = init_steps; env->timestep = init_steps; diff --git a/pufferlib/ocean/drive/drive.c b/pufferlib/ocean/drive/drive.c index 4b2645836..e4afe491d 100644 --- a/pufferlib/ocean/drive/drive.c +++ b/pufferlib/ocean/drive/drive.c @@ -19,7 +19,7 @@ void test_drivenet() { //Weights* weights = load_weights("resources/drive/puffer_drive_weights.bin"); Weights* weights = load_weights("puffer_drive_weights.bin"); - DriveNet* net = init_drivenet(weights, num_agents); + DriveNet* net = init_drivenet(weights, num_agents, CLASSIC); forward(net, observations, actions); for (int i = 0; i < num_agents*num_actions; i++) { @@ -43,17 +43,20 @@ void demo() { } Drive env = { - .human_agent_idx = 0, .dynamics_model = conf.dynamics_model, .reward_vehicle_collision = conf.reward_vehicle_collision, .reward_offroad_collision = conf.reward_offroad_collision, .reward_ade = conf.reward_ade, .goal_radius = conf.goal_radius, .dt = conf.dt, - .map_name = "resources/drive/binaries/map_000.bin", .init_steps = conf.init_steps, + .max_controlled_agents = -1, .collision_behavior = conf.collision_behavior, .offroad_behavior = conf.offroad_behavior, + .goal_behavior = conf.goal_behavior, + .init_mode = conf.init_mode, + .control_mode = conf.control_mode, + .ini_file = (char*)ini_file, }; allocate(&env); c_reset(&env); diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index 27084cdcf..30c0d658b 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -11,6 +11,7 @@ #include "rlgl.h" #include #include "error.h" +#include "../env_config.h" // Entity Types #define NONE 0 @@ -37,7 +38,7 @@ // Control modes #define CONTROL_VEHICLES 0 #define CONTROL_AGENTS 1 -#define CONTROL_TRACKS_TO_PREDICT 2 +#define CONTROL_WOSAC 2 #define CONTROL_SDC_ONLY 3 // Minimum distance to goal position @@ -62,9 +63,10 @@ #define LANE_ALIGNED_IDX 3 #define AVG_DISPLACEMENT_ERROR_IDX 4 -// Grid cell size +// Grid Map Related #define GRID_CELL_SIZE 5.0f -#define MAX_ENTITIES_PER_CELL 30 // Depends on resolution of data Formula: 3 * (2 + GRID_CELL_SIZE*sqrt(2)/resolution) => For each entity type in gridmap, diagonal poly-lines -> sqrt(2), include diagonal ends -> 2 +#define GRID_MAP_CACHE_VISION_RANGE 21 // Vision range to cache neighbor offsets +#define COLLISION_VISION_RANGE 5 // Vision range for collision checking // Max road segment observation entities #define MAX_ROAD_SEGMENT_OBSERVATIONS 200 @@ -177,6 +179,7 @@ struct Entity { float heading_y; int current_lane_idx; int valid; + int initialized; // Only for random init mode int respawn_timestep; int respawn_count; int collided_before_goal; @@ -275,6 +278,7 @@ struct GridMap { int cell_size_y; int* cell_entities_count; // number of entities in each cell of the GridMap GridMapEntity** cells; // list of gridEntities in each cell of the GridMap + int* cell_roadlanes_count; // number of road lanes in each cell // Extras/Optimizations int vision_range; @@ -290,7 +294,7 @@ struct Drive { unsigned char* terminals; Log log; Log* logs; - int num_agents; + int max_active_agents; int active_agent_count; int* active_agent_indices; int action_type; @@ -305,6 +309,13 @@ struct Drive { int* static_agent_indices; int expert_static_agent_count; int* expert_static_agent_indices; + int num_lanes; + int num_road_edges; + int num_road_lines; + int static_car_count; + int* static_car_indices; + int expert_static_car_count; + int* expert_static_car_indices; int timestep; int init_steps; int dynamics_model; @@ -320,7 +331,8 @@ struct Drive { float dt; float reward_goal; float reward_goal_post_respawn; - float goal_radius; + float goal_radius; // Distance threshold to consider goal reached + float goal_distance; // Distance from agent to goal int max_controlled_agents; int logs_capacity; int goal_behavior; @@ -333,6 +345,7 @@ struct Drive { int* tracks_to_predict_indices; int init_mode; int control_mode; + env_init_config conf; }; void add_log(Drive* env) { @@ -383,6 +396,11 @@ struct Graph { struct AdjListNode** array; }; +// Forward declarations +int* get_relative_neighbor_offsets(Drive* env, float neighbor_radius, int* offset_count); +GridMapEntity* checkNeighbors(Drive* env, float x, float y, const int (*local_offsets)[2], int* out_size, int offset_size); +bool check_line_intersection(float p1[2], float p2[2], float q1[2], float q2[2]); + // Function to create a new adjacency list node struct AdjListNode* newAdjListNode(int dest) { struct AdjListNode* newNode = malloc(sizeof(struct AdjListNode)); @@ -435,6 +453,77 @@ void freeTopologyGraph(struct Graph* graph) { free(graph); } +// Checks for collisions at a specified spawn point +bool is_valid_spawn_point(Drive* env, int agent_idx, float start_x, float start_y, float start_z, float heading, float vehicle_width, float vehicle_length, float vehicle_height) { + // Create Mock Entity for collision checking + Entity mock_agent = { + .x = start_x, + .y = start_y, + .z = start_z, + .width = vehicle_width, + .length = vehicle_length, + .height = vehicle_height, + .heading = heading, + .heading_x = cosf(heading), + .heading_y = sinf(heading) + }; + + float half_length = mock_agent.length/2.0f; + float half_width = mock_agent.width/2.0f; + float cos_heading = cosf(mock_agent.heading); + float sin_heading = sinf(mock_agent.heading); + float agent_corners[4][2]; + for (int i = 0; i < 4; i++) { + agent_corners[i][0] = mock_agent.x + (offsets[i][0]*half_length*cos_heading - offsets[i][1]*half_width*sin_heading); + agent_corners[i][1] = mock_agent.y + (offsets[i][0]*half_length*sin_heading + offsets[i][1]*half_width*cos_heading); + } + int collided = 0; + + // Vehicle collision check + for (int i = 0; i < env->num_entities; i++) { + Entity* entity = &env->entities[i]; + if (entity->type != VEHICLE && entity-> type != PEDESTRIAN && entity->type != CYCLIST) continue; // Only check dynamic entities + if (entity->id == agent_idx) continue; // Skip self + + if (entity->type == VEHICLE) { + if (entity->initialized == 0) continue; // Skip uninitialized entities + if(check_aabb_collision(&mock_agent, entity)) { + collided = VEHICLE_COLLISION; + return false; // Collision detected with existing vehicle at spawn point + } + } + } + + // Offroad check + // int list_size = 0; + // GridMapEntity* entity_list = checkNeighbors(env, mock_agent.x, mock_agent.y, collision_offsets, &list_size, COLLISION_VISION_RANGE*COLLISION_VISION_RANGE); + // for (int i = 0; i < list_size ; i++) { + // if(entity_list[i].entity_idx == -1) continue; + // Entity* entity; + // entity = &env->entities[entity_list[i].entity_idx]; + + // // Check for offroad collision with road edges + // if(entity->type == ROAD_EDGE) { + // int geometry_idx = entity_list[i].geometry_idx; + // float start[2] = {entity->traj_x[geometry_idx], entity->traj_y[geometry_idx]}; + // float end[2] = {entity->traj_x[geometry_idx + 1], entity->traj_y[geometry_idx + 1]}; + // for (int k = 0; k < 4; k++) { // Check each edge of the bounding box + // int next = (k + 1) % 4; + // if (check_line_intersection(agent_corners[k], agent_corners[next], start, end)) { + // collided = OFFROAD; + // printf("Invalid Spawn: collision with road edge (Entity ID: %d)\n", entity->id); + // break; + // } + // } + // } + + // if (collided == OFFROAD) break; + // } + + // Cleanup + // free(entity_list); + return true; // No collisions detected +} Entity* load_map_binary(const char* filename, Drive* env) { FILE* file = fopen(filename, "rb"); @@ -456,64 +545,368 @@ Entity* load_map_binary(const char* filename, Drive* env) { env->tracks_to_predict_indices = NULL; } - fread(&env->num_objects, sizeof(int), 1, file); + fread(&env->num_objects, sizeof(int), 1, file); // Can be 0 for dynamic_no_agents init_mode fread(&env->num_roads, sizeof(int), 1, file); - env->num_entities = env->num_objects + env->num_roads; - Entity* entities = (Entity*)malloc(env->num_entities * sizeof(Entity)); - for (int i = 0; i < env->num_entities; i++) { - // Read base entity data - fread(&entities[i].scenario_id, sizeof(int), 1, file); - fread(&entities[i].type, sizeof(int), 1, file); - fread(&entities[i].id, sizeof(int), 1, file); - fread(&entities[i].array_size, sizeof(int), 1, file); - // Allocate arrays based on type - int size = entities[i].array_size; - entities[i].traj_x = (float*)malloc(size * sizeof(float)); - entities[i].traj_y = (float*)malloc(size * sizeof(float)); - entities[i].traj_z = (float*)malloc(size * sizeof(float)); - if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || entities[i].type == CYCLIST) { // Object type - // Allocate arrays for object-specific data - entities[i].traj_vx = (float*)malloc(size * sizeof(float)); - entities[i].traj_vy = (float*)malloc(size * sizeof(float)); - entities[i].traj_vz = (float*)malloc(size * sizeof(float)); - entities[i].traj_heading = (float*)malloc(size * sizeof(float)); - entities[i].traj_valid = (int*)malloc(size * sizeof(int)); + int num_bin_entities = env->num_objects + env->num_roads; + if (env->conf.init_mode == DYNAMIC_AGENTS_PER_ENV) { + Entity read_bin_entities[num_bin_entities]; + // Read all binaries into temporary array + for (int i = 0; i < num_bin_entities; i++) { + // Read base entity data + fread(&read_bin_entities[i].scenario_id, sizeof(int), 1, file); + fread(&read_bin_entities[i].type, sizeof(int), 1, file); + fread(&read_bin_entities[i].id, sizeof(int), 1, file); + fread(&read_bin_entities[i].array_size, sizeof(int), 1, file); + // Allocate arrays based on type + int size = read_bin_entities[i].array_size; + read_bin_entities[i].traj_x = (float*)malloc(size * sizeof(float)); + read_bin_entities[i].traj_y = (float*)malloc(size * sizeof(float)); + read_bin_entities[i].traj_z = (float*)malloc(size * sizeof(float)); + if (read_bin_entities[i].type == VEHICLE || read_bin_entities[i].type == PEDESTRIAN || read_bin_entities[i].type == CYCLIST) { // Object type + // Allocate arrays for object-specific data + read_bin_entities[i].traj_vx = (float*)malloc(size * sizeof(float)); + read_bin_entities[i].traj_vy = (float*)malloc(size * sizeof(float)); + read_bin_entities[i].traj_vz = (float*)malloc(size * sizeof(float)); + read_bin_entities[i].traj_heading = (float*)malloc(size * sizeof(float)); + read_bin_entities[i].traj_valid = (int*)malloc(size * sizeof(int)); + } else { + // Roads don't use these arrays + read_bin_entities[i].traj_vx = NULL; + read_bin_entities[i].traj_vy = NULL; + read_bin_entities[i].traj_vz = NULL; + read_bin_entities[i].traj_heading = NULL; + read_bin_entities[i].traj_valid = NULL; + } + // Read array data + fread(read_bin_entities[i].traj_x, sizeof(float), size, file); + fread(read_bin_entities[i].traj_y, sizeof(float), size, file); + fread(read_bin_entities[i].traj_z, sizeof(float), size, file); + if (read_bin_entities[i].type == VEHICLE || read_bin_entities[i].type == PEDESTRIAN || read_bin_entities[i].type == CYCLIST) { // Object type + fread(read_bin_entities[i].traj_vx, sizeof(float), size, file); + fread(read_bin_entities[i].traj_vy, sizeof(float), size, file); + fread(read_bin_entities[i].traj_vz, sizeof(float), size, file); + fread(read_bin_entities[i].traj_heading, sizeof(float), size, file); + fread(read_bin_entities[i].traj_valid, sizeof(int), size, file); + } + // Read remaining scalar fields + fread(&read_bin_entities[i].width, sizeof(float), 1, file); + fread(&read_bin_entities[i].length, sizeof(float), 1, file); + fread(&read_bin_entities[i].height, sizeof(float), 1, file); + fread(&read_bin_entities[i].goal_position_x, sizeof(float), 1, file); + fread(&read_bin_entities[i].goal_position_y, sizeof(float), 1, file); + fread(&read_bin_entities[i].goal_position_z, sizeof(float), 1, file); + fread(&read_bin_entities[i].mark_as_expert, sizeof(int), 1, file); + + read_bin_entities[i].initialized = 1; + } + + env->num_objects = env->conf.num_agents_per_world; + env->num_entities = env->num_objects + env->num_roads; + Entity* entities = (Entity*)malloc(env->num_entities * sizeof(Entity)); + + // Only initialize road entities + int env_entity_idx = 0; + for (int i = 0; i < num_bin_entities; i++) { + if (read_bin_entities[i].type == VEHICLE || read_bin_entities[i].type == PEDESTRIAN || read_bin_entities[i].type == CYCLIST) { + continue; + } else { + entities[env_entity_idx] = read_bin_entities[i]; + env_entity_idx++; + } + } + + while (env_entity_idx < env->num_entities) { + entities[env_entity_idx].type = VEHICLE; // Placeholder type for non-road entities + entities[env_entity_idx].id = env_entity_idx; + entities[env_entity_idx].initialized = 0; + entities[env_entity_idx].array_size = 0; + entities[env_entity_idx].traj_x = NULL; + entities[env_entity_idx].traj_y = NULL; + entities[env_entity_idx].traj_z = NULL; + entities[env_entity_idx].traj_vx = NULL; + entities[env_entity_idx].traj_vy = NULL; + entities[env_entity_idx].traj_vz = NULL; + entities[env_entity_idx].traj_heading = NULL; + entities[env_entity_idx].traj_valid = NULL; + env_entity_idx++; + } + + fclose(file); + return entities; + } else { + env->num_entities = num_bin_entities; + Entity* entities = (Entity*)malloc(env->num_entities * sizeof(Entity)); + for (int i = 0; i < env->num_entities; i++) { + // Read base entity data + fread(&entities[i].scenario_id, sizeof(int), 1, file); + fread(&entities[i].type, sizeof(int), 1, file); + fread(&entities[i].id, sizeof(int), 1, file); + fread(&entities[i].array_size, sizeof(int), 1, file); + // Allocate arrays based on type + int size = entities[i].array_size; + entities[i].traj_x = (float*)malloc(size * sizeof(float)); + entities[i].traj_y = (float*)malloc(size * sizeof(float)); + entities[i].traj_z = (float*)malloc(size * sizeof(float)); + if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || entities[i].type == CYCLIST) { // Object type + // Allocate arrays for object-specific data + entities[i].traj_vx = (float*)malloc(size * sizeof(float)); + entities[i].traj_vy = (float*)malloc(size * sizeof(float)); + entities[i].traj_vz = (float*)malloc(size * sizeof(float)); + entities[i].traj_heading = (float*)malloc(size * sizeof(float)); + entities[i].traj_valid = (int*)malloc(size * sizeof(int)); + } else { + // Roads don't use these arrays + entities[i].traj_vx = NULL; + entities[i].traj_vy = NULL; + entities[i].traj_vz = NULL; + entities[i].traj_heading = NULL; + entities[i].traj_valid = NULL; + } + // Read array data + fread(entities[i].traj_x, sizeof(float), size, file); + fread(entities[i].traj_y, sizeof(float), size, file); + fread(entities[i].traj_z, sizeof(float), size, file); + if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || entities[i].type == CYCLIST) { // Object type + fread(entities[i].traj_vx, sizeof(float), size, file); + fread(entities[i].traj_vy, sizeof(float), size, file); + fread(entities[i].traj_vz, sizeof(float), size, file); + fread(entities[i].traj_heading, sizeof(float), size, file); + fread(entities[i].traj_valid, sizeof(int), size, file); + } + // Read remaining scalar fields + fread(&entities[i].width, sizeof(float), 1, file); + fread(&entities[i].length, sizeof(float), 1, file); + fread(&entities[i].height, sizeof(float), 1, file); + fread(&entities[i].goal_position_x, sizeof(float), 1, file); + fread(&entities[i].goal_position_y, sizeof(float), 1, file); + fread(&entities[i].goal_position_z, sizeof(float), 1, file); + fread(&entities[i].mark_as_expert, sizeof(int), 1, file); + + entities[i].initialized = 1; + } + fclose(file); + return entities; + } +} + +bool get_valid_goal(Drive* env, float start_point_x, float start_point_y, float start_point_z, + int start_grid_row, int start_grid_col, + int* rel_neighbor_offsets, int rel_neighbor_offsets_cnt, + float* goal_x, float* goal_y, float* goal_z) { + float goal_radius = env->goal_radius; + + int grid_cell_offset_x = (int)(goal_radius / env->grid_map->cell_size_x) + 1; + int grid_cell_offset_y = (int)(goal_radius / env->grid_map->cell_size_y) + 1; + + int valid_goals_count = 0; + int i = 0; + while (i < rel_neighbor_offsets_cnt) { + int offset_row = rel_neighbor_offsets[i++]; + int offset_col = rel_neighbor_offsets[i++]; + int neighbor_grid_row = start_grid_row + offset_row; + int neighbor_grid_col = start_grid_col + offset_col; + // Check grid bounds + if (neighbor_grid_row < 0 || neighbor_grid_row >= env->grid_map->grid_rows || + neighbor_grid_col < 0 || neighbor_grid_col >= env->grid_map->grid_cols) { + continue; + } + + valid_goals_count += env->grid_map->cell_roadlanes_count[neighbor_grid_row * env->grid_map->grid_cols + neighbor_grid_col]; + } + if (valid_goals_count == 0) return false; + + int rand_goal_idx = rand() % valid_goals_count; + i = 0; + while (1) { + int offset_row = rel_neighbor_offsets[i++]; + int offset_col = rel_neighbor_offsets[i++]; + int neighbor_grid_row = start_grid_row + offset_row; + int neighbor_grid_col = start_grid_col + offset_col; + // Check grid bounds + if (neighbor_grid_row < 0 || neighbor_grid_row >= env->grid_map->grid_rows || + neighbor_grid_col < 0 || neighbor_grid_col >= env->grid_map->grid_cols) { + continue; + } + + int current_grid_cell_idx = neighbor_grid_row * env->grid_map->grid_cols + neighbor_grid_col; + int current_grid_cell_roadlane_count = env->grid_map->cell_roadlanes_count[current_grid_cell_idx]; + if (rand_goal_idx - current_grid_cell_roadlane_count < 0) { + // This is the grid cell containing the goal + // Find the specific road lane within this cell + int lane_idx = rand_goal_idx; + int cnt = 0; + for (int j = 0; j < env->grid_map->cell_entities_count[current_grid_cell_idx]; j++) { + int entity_idx = env->grid_map->cells[current_grid_cell_idx][j].entity_idx; + Entity entity = env->entities[entity_idx]; + if (entity.type != ROAD_LANE) { + continue; + } + if (cnt == lane_idx) { + int goal_point_geometry_idx = env->grid_map->cells[current_grid_cell_idx][j].geometry_idx; + *goal_x = entity.traj_x[goal_point_geometry_idx]; + *goal_y = entity.traj_y[goal_point_geometry_idx]; + *goal_z = entity.traj_z[goal_point_geometry_idx]; + return true; + } + cnt++; + } } else { - // Roads don't use these arrays - entities[i].traj_vx = NULL; - entities[i].traj_vy = NULL; - entities[i].traj_vz = NULL; - entities[i].traj_heading = NULL; - entities[i].traj_valid = NULL; - } - // Read array data - fread(entities[i].traj_x, sizeof(float), size, file); - fread(entities[i].traj_y, sizeof(float), size, file); - fread(entities[i].traj_z, sizeof(float), size, file); - if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || entities[i].type == CYCLIST) { // Object type - fread(entities[i].traj_vx, sizeof(float), size, file); - fread(entities[i].traj_vy, sizeof(float), size, file); - fread(entities[i].traj_vz, sizeof(float), size, file); - fread(entities[i].traj_heading, sizeof(float), size, file); - fread(entities[i].traj_valid, sizeof(int), size, file); - } - // Read remaining scalar fields - fread(&entities[i].width, sizeof(float), 1, file); - fread(&entities[i].length, sizeof(float), 1, file); - fread(&entities[i].height, sizeof(float), 1, file); - fread(&entities[i].goal_position_x, sizeof(float), 1, file); - fread(&entities[i].goal_position_y, sizeof(float), 1, file); - fread(&entities[i].goal_position_z, sizeof(float), 1, file); - fread(&entities[i].mark_as_expert, sizeof(int), 1, file); - } - - fclose(file); - return entities; + rand_goal_idx -= current_grid_cell_roadlane_count; + } + } + return false; +} + +void init_agents_random_start(Drive* env) { + Entity* entities = env->entities; + int rel_neighbor_offsets_cnt = 0; + int* rel_neighbor_offsets = get_relative_neighbor_offsets(env, env->conf.goal_distance, &rel_neighbor_offsets_cnt); + + for (int agent_idx=0; agent_idxnum_entities; agent_idx++){ + if (entities[agent_idx].type != VEHICLE) { + continue; + } + + Entity agent = entities[agent_idx]; + while(1){ + // Random Road lane of a random grid cell + int rand_grid_col = rand() % env->grid_map->grid_cols; + int rand_grid_row = rand() % env->grid_map->grid_rows; + int grid_map_idx = rand_grid_row * env->grid_map->grid_cols + rand_grid_col; + // If cell has no lanes, repeat + if (env->grid_map->cell_roadlanes_count[grid_map_idx] == 0){ + continue; + } + int cell_lane_count = env->grid_map->cell_roadlanes_count[grid_map_idx]; + int lane_idx = rand() % cell_lane_count; + // float heading = ((float)rand() / RAND_MAX) * 2.0f * PI; // Random heading + float heading = 0.0f; + + // Get start lane point info + float start_x = INVALID_POSITION, start_y = INVALID_POSITION, start_z = INVALID_POSITION, start_heading = INVALID_POSITION; + int cnt = 0; + for (int i = 0; i < env->grid_map->cell_entities_count[grid_map_idx]; i++) { + int entity_idx = env->grid_map->cells[grid_map_idx][i].entity_idx; + Entity entity = env->entities[entity_idx]; + if (entity.type != ROAD_LANE) { + continue; + } + // Current spawn point + if (cnt == lane_idx) { + int start_point_geometry_idx = env->grid_map->cells[grid_map_idx][i].geometry_idx; + if (start_point_geometry_idx + 1 < entity.array_size) { + float dx = entity.traj_x[start_point_geometry_idx + 1] - entity.traj_x[start_point_geometry_idx]; + float dy = entity.traj_y[start_point_geometry_idx + 1] - entity.traj_y[start_point_geometry_idx]; + heading = atan2f(dy, dx); + } + else if (start_point_geometry_idx - 1 >= 0) { + float dx = entity.traj_x[start_point_geometry_idx] - entity.traj_x[start_point_geometry_idx - 1]; + float dy = entity.traj_y[start_point_geometry_idx] - entity.traj_y[start_point_geometry_idx - 1]; + heading = atan2f(dy, dx); + } + else { + heading = 0.0f; // Design choice: default to 0 if no previous lane + } + // printf("Selected Point (%f, %f) at Grid Cell (%d, %d) with heading %f\n", entity.traj_x[start_point_geometry_idx], entity.traj_y[start_point_geometry_idx], rand_grid_row, rand_grid_col, heading); + bool valid_spawn = is_valid_spawn_point( + env, + agent.id, + entity.traj_x[start_point_geometry_idx], + entity.traj_y[start_point_geometry_idx], + entity.traj_z[start_point_geometry_idx], + heading, + env->conf.vehicle_width, env->conf.vehicle_length, env->conf.vehicle_height + ); + if (valid_spawn) { + start_x = entity.traj_x[start_point_geometry_idx]; + start_y = entity.traj_y[start_point_geometry_idx]; + start_z = entity.traj_z[start_point_geometry_idx]; + start_heading = heading; + break; + } + + break; // Break anyways as current spawn pt is invalid + } + cnt++; + } + + // Invalid start point, repeat + if (start_x == INVALID_POSITION || start_y == INVALID_POSITION || start_z == INVALID_POSITION) { + // printf("Invalid start point(%.2f, %.2f, %.2f) at Grid Cell (%d, %d), retrying...\n", start_x, start_y, start_z, rand_grid_row, rand_grid_col); + continue; + } + + // Get goal point info + float goal_x = INVALID_POSITION, goal_y = INVALID_POSITION, goal_z = INVALID_POSITION; + bool is_valid_goal = get_valid_goal( + env, + start_x, start_y, start_z, + rand_grid_row, rand_grid_col, + rel_neighbor_offsets, rel_neighbor_offsets_cnt, + &goal_x, &goal_y, &goal_z + ); + // If no surrounding cell with goal curriculum, repeat + if (!is_valid_goal) { + // printf("No valid goals found around start point (%.2f, %.2f, %.2f)\n", start_x, start_y, start_z); + continue; + } + // Else break with given start and goal + agent.initialized = 1; + agent.type = VEHICLE; + agent.array_size = 1; + agent.traj_x = (float*)calloc(agent.array_size, sizeof(float)); + agent.traj_y = (float*)calloc(agent.array_size, sizeof(float)); + agent.traj_z = (float*)calloc(agent.array_size, sizeof(float)); + agent.traj_vx = (float*)calloc(agent.array_size, sizeof(float)); + agent.traj_vy = (float*)calloc(agent.array_size, sizeof(float)); + agent.traj_vz = (float*)calloc(agent.array_size, sizeof(float)); + agent.traj_heading = (float*)calloc(agent.array_size, sizeof(float)); + agent.traj_valid = (int*)calloc(agent.array_size, sizeof(int)); + agent.traj_x[0] = start_x; + agent.traj_y[0] = start_y; + agent.traj_z[0] = start_z; + agent.x = start_x; + agent.y = start_y; + agent.z = start_z; + agent.traj_heading[0] = start_heading; + agent.heading = start_heading; + agent.heading_x = cosf(agent.heading); + agent.heading_y = sinf(agent.heading); + agent.traj_valid[0] = 1; + agent.height = env->conf.vehicle_height; + agent.width = env->conf.vehicle_width; + agent.length = env->conf.vehicle_length; + agent.goal_position_x = goal_x; + agent.goal_position_y = goal_y; + agent.goal_position_z = goal_z; + + agent.mark_as_expert = 0; + agent.goal_radius = env->goal_radius; + break; + } + entities[agent_idx] = agent; + // printf("Initialized Agent %d at random start position (%.2f, %.2f, %.2f) with goal (%.2f, %.2f, %.2f)\n", + // agent.id, + // agent.traj_x[0], agent.traj_y[0], agent.traj_z[0], + // agent.goal_position_x, agent.goal_position_y, agent.goal_position_z + // ); + // printf("Agent dimensions (W x L x H): %.2f x %.2f x %.2f\n", + // agent.width, agent.length, agent.height + // ); + } + + free(rel_neighbor_offsets); } void set_start_position(Drive* env){ //InitWindow(800, 600, "GPU Drive"); //BeginDrawing(); + + // Must start at step 0 if using DYNAMIC_AGENTS_PER_ENV init_mode + if (env->conf.init_mode == DYNAMIC_AGENTS_PER_ENV && env->init_steps != 0) { + raise_error_with_message(ERROR_INVALID_CONFIG, "init_steps must be 0 when using DYNAMIC_AGENTS_PER_ENV init_mode, %d given\n", env->init_steps); + } + for(int i = 0; i < env->num_entities; i++){ int is_active = 0; for(int j = 0; j < env->active_agent_count; j++){ @@ -581,8 +974,8 @@ int getGridIndex(Drive* env, float x1, float y1) { float relativeX = x1 - env->grid_map->top_left_x; // Distance from left float relativeY = y1 - env->grid_map->bottom_right_y; // Distance from bottom - int gridX = (int)(relativeX / GRID_CELL_SIZE); // Column index - int gridY = (int)(relativeY / GRID_CELL_SIZE); // Row index + int gridX = (int)(relativeX / env->grid_map->cell_size_x); // Column index + int gridY = (int)(relativeY / env->grid_map->cell_size_y); // Row index if (gridX < 0 || gridX >= env->grid_map->grid_cols || gridY < 0 || gridY >= env->grid_map->grid_rows) { return -1; // Return -1 for out of bounds } @@ -597,8 +990,12 @@ void add_entity_to_grid(Drive* env, int grid_index, int entity_idx, int geometry int count = cell_entities_insert_index[grid_index]; if(count >= env->grid_map->cell_entities_count[grid_index]) { - printf("Error: Exceeded precomputed entity count for grid cell %d. Current count: %d, Max count(Precomputed): %d\n", grid_index, count, env->grid_map->cell_entities_count[grid_index]); - return; + raise_error_with_message(ERROR_UNKNOWN, + "Exceeded precomputed entity count for grid cell %d. Current count: %d, Max count(Precomputed): %d\n", + grid_index, + count, + env->grid_map->cell_entities_count[grid_index] + ); } env->grid_map->cells[grid_index][count].entity_idx = entity_idx; @@ -670,7 +1067,6 @@ void init_topology_graph(Drive* env){ } void init_grid_map(Drive* env){ - // Allocate memory for the grid map structure env->grid_map = (GridMap*)malloc(sizeof(GridMap)); // Find top left and bottom right points of the map @@ -715,6 +1111,7 @@ void init_grid_map(Drive* env){ int grid_cell_count = env->grid_map->grid_cols*env->grid_map->grid_rows; env->grid_map->cells = (GridMapEntity**)calloc(grid_cell_count, sizeof(GridMapEntity*)); env->grid_map->cell_entities_count = (int*)calloc(grid_cell_count, sizeof(int)); + env->grid_map->cell_roadlanes_count = (int*)calloc(grid_cell_count, sizeof(int)); // Calculate number of entities in each grid cell for(int i = 0; i < env->num_entities; i++){ @@ -723,6 +1120,10 @@ void init_grid_map(Drive* env){ float x_center = (env->entities[i].traj_x[j] + env->entities[i].traj_x[j+1]) / 2; float y_center = (env->entities[i].traj_y[j] + env->entities[i].traj_y[j+1]) / 2; int grid_index = getGridIndex(env, x_center, y_center); + if (grid_index == -1) { + printf("Warning: Not Supposed to happen, Point (%f, %f) out of grid bounds\n", x_center, y_center); + continue; // Out of bounds + } env->grid_map->cell_entities_count[grid_index]++; } } @@ -732,6 +1133,8 @@ void init_grid_map(Drive* env){ // Initialize grid cells for(int grid_index = 0; grid_index < grid_cell_count; grid_index++){ + if (env->grid_map->cell_entities_count[grid_index] > 100) + printf("Warning: High entity count in Grid Cell %d = %d\n", grid_index, env->grid_map->cell_entities_count[grid_index]); env->grid_map->cells[grid_index] = (GridMapEntity*)calloc(env->grid_map->cell_entities_count[grid_index], sizeof(GridMapEntity)); } for(int i = 0;ientities[i].traj_y[j] + env->entities[i].traj_y[j+1]) / 2; int grid_index = getGridIndex(env, x_center, y_center); add_entity_to_grid(env, grid_index, i, j, cell_entities_insert_index); + if (env->entities[i].type == ROAD_LANE) { + // Also add to road lane count + env->grid_map->cell_roadlanes_count[grid_index]++; + } } } } @@ -797,6 +1204,93 @@ void init_neighbor_offsets(Drive* env) { } } +// Function to get neighbor grid offsets within a given radius +int* get_relative_neighbor_offsets(Drive* env, float neighbor_radius, int* offset_count) { + + int rel_col_offset = (int)(neighbor_radius / env->grid_map->cell_size_x) + 1; + int rel_row_offset = (int)(neighbor_radius / env->grid_map->cell_size_y) + 1; + + int number_cols = 2 * rel_col_offset + 1; + int number_rows = 2 * rel_row_offset + 1; + int rel_offsets[number_rows][number_cols]; + memset(rel_offsets, 0, number_rows * number_cols * sizeof(int)); + + // Origin is center of center cell + float center_cell_top_left[2] = {-env->grid_map->cell_size_x/2.0, env->grid_map->cell_size_y/2.0}; // Top-Left + + // (0, 0) is the top-left offset + int center_cell_idx_row = rel_row_offset; + int center_cell_idx_col = rel_col_offset; + + int offset_cnt = 0; + + // Only calculate for first quadrant and mirror + // Calculating for Top-Left quadrant + for (int i = 0; i < center_cell_idx_row; i++) { + for (int j = 0; j < center_cell_idx_col; j++) { + float cell_center_x = -(center_cell_idx_col - j) * env->grid_map->cell_size_x; + float cell_center_y = (center_cell_idx_row - i) * env->grid_map->cell_size_y; + + float bottom_right_cell_corner[2] = { + cell_center_x + env->grid_map->cell_size_x/2.0, + cell_center_y - env->grid_map->cell_size_y/2.0 + }; + + int within_radius = 0; + // Sufficient to check with Bottom-Right(if min < radius there is a chance of one point being in radius) + float min_dist = sqrtf(powf(bottom_right_cell_corner[0] - center_cell_top_left[0], 2) + + powf(bottom_right_cell_corner[1] - center_cell_top_left[1], 2)); + if (neighbor_radius > min_dist) { + within_radius = 1; + break; + } + + // Mirror offsets to other quadrants + if (within_radius) { + offset_cnt+=4; + rel_offsets[i][j] = 1; // Top-Left + // Set offsets for remaining quadrants + rel_offsets[i][j + 2*(center_cell_idx_col - j)] = 1; // Top-Right + rel_offsets[i + 2*(center_cell_idx_row - i)][j + 2*(center_cell_idx_col - j)] = 1; // Bottom-Right + rel_offsets[i + 2*(center_cell_idx_row - i)][j] = 1; // Bottom-Left + } + } + } + + // Check Vertical axis + for (int i = 0; i < center_cell_idx_row; i++) { + // Within radius + if (env->grid_map->cell_size_y * (center_cell_idx_row - i - 1) < neighbor_radius) { + offset_cnt += 2; + rel_offsets[i][center_cell_idx_col] = 1; // Top + rel_offsets[i + 2*(center_cell_idx_row - i)][center_cell_idx_col] = 1; // Bottom + } + } + + // Check Horizontal axis + for (int j = 0; j < center_cell_idx_col; j++) { + // Within radius + if (env->grid_map->cell_size_x * (center_cell_idx_col - j - 1) < neighbor_radius) { + offset_cnt += 2; + rel_offsets[center_cell_idx_row][j] = 1; // Left + rel_offsets[center_cell_idx_row][j + 2*(center_cell_idx_col - j)] = 1; // Right + } + } + + int* final_offsets = (int*)calloc(offset_cnt * 2, sizeof(int)); + int offset_idx = 0; + for (int i = 0; i < number_rows; i++) { + for (int j = 0; j < number_cols; j++) { + if (rel_offsets[i][j] == 1) { + final_offsets[offset_idx++] = i - center_cell_idx_row; // row offset + final_offsets[offset_idx++] = j - center_cell_idx_col; // col offset + } + } + } + *(offset_count) = offset_cnt * 2; + return final_offsets; +} + void cache_neighbor_offsets(Drive* env){ int count = 0; int cell_count = env->grid_map->grid_cols*env->grid_map->grid_rows; @@ -854,19 +1348,22 @@ void cache_neighbor_offsets(Drive* env){ } } -int get_neighbor_cache_entities(Drive* env, int cell_idx, GridMapEntity* entities, int max_entities) { +GridMapEntity* get_neighbor_cache_entities(Drive* env, int cell_idx, int* neighbor_entities_cnt, int max_entities) { GridMap* grid_map = env->grid_map; if (cell_idx < 0 || cell_idx >= (grid_map->grid_cols * grid_map->grid_rows)) { return 0; // Invalid cell index } int count = grid_map->neighbor_cache_count[cell_idx]; - // Limit to available space - if (count > max_entities) { + // Limit to available size + if (count > max_entities){ count = max_entities; } + + GridMapEntity* entities = (GridMapEntity*)calloc(count, sizeof(GridMapEntity)); memcpy(entities, grid_map->neighbor_cache_entities[cell_idx], count * sizeof(GridMapEntity)); - return count; + *neighbor_entities_cnt = count; + return entities; } void set_means(Drive* env) { @@ -897,6 +1394,7 @@ void set_means(Drive* env) { env->world_mean_y = mean_y; for (int i = 0; i < env->num_entities; i++) { if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || env->entities[i].type == CYCLIST || env->entities[i].type >= 4) { + if (env->entities[i].initialized == 0) continue; for (int j = 0; j < env->entities[i].array_size; j++) { if(env->entities[i].traj_x[j] == INVALID_POSITION) continue; env->entities[i].traj_x[j] -= mean_x; @@ -967,14 +1465,32 @@ bool check_line_intersection(float p1[2], float p2[2], float q1[2], float q2[2]) return (s >= 0 && s <= 1 && t >= 0 && t <= 1); } -int checkNeighbors(Drive* env, float x, float y, GridMapEntity* entity_list, int max_size, const int (*local_offsets)[2], int offset_size) { +GridMapEntity* checkNeighbors(Drive* env, float x, float y, const int (*local_offsets)[2], int* out_size, int offset_size) { // Get the grid index for the given position (x, y) int index = getGridIndex(env, x, y); - if (index == -1) return 0; // Return 0 size if position invalid + if (index == -1) { + *out_size = 0; + return NULL; // Return NULL if position invalid + } + // Calculate 2D grid coordinates int cellsX = env->grid_map->grid_cols; int gridX = index % cellsX; int gridY = index / cellsX; + // Calculate size of neighbor entities list + for (int i = 0; i < offset_size; i++) { + int nx = gridX + local_offsets[i][0]; + int ny = gridY + local_offsets[i][1]; + // Ensure the neighbor is within grid bounds + if(nx < 0 || nx >= env->grid_map->grid_cols || ny < 0 || ny >= env->grid_map->grid_rows) continue; + int neighborIndex = ny * env->grid_map->grid_cols + nx; + *out_size = (*out_size) + env->grid_map->cell_entities_count[neighborIndex]; + } + if ((*out_size) == 0) { + return NULL; // No neighboring entities found + } + + GridMapEntity* entity_list = (GridMapEntity*)calloc((*out_size), sizeof(GridMapEntity)); int entity_list_count = 0; // Fill the provided array for (int i = 0; i < offset_size; i++) { @@ -985,7 +1501,7 @@ int checkNeighbors(Drive* env, float x, float y, GridMapEntity* entity_list, int int neighborIndex = ny * env->grid_map->grid_cols + nx; int count = env->grid_map->cell_entities_count[neighborIndex]; // Add entities from this cell to the list - for (int j = 0; j < count && entity_list_count < max_size; j++) { + for (int j = 0; j < count; j++) { int entityId = env->grid_map->cells[neighborIndex][j].entity_idx; int geometry_idx = env->grid_map->cells[neighborIndex][j].geometry_idx; entity_list[entity_list_count].entity_idx = entityId; @@ -993,7 +1509,7 @@ int checkNeighbors(Drive* env, float x, float y, GridMapEntity* entity_list, int entity_list_count += 1; } } - return entity_list_count; + return entity_list; } int check_aabb_collision(Entity* car1, Entity* car2) { @@ -1025,6 +1541,7 @@ int check_aabb_collision(Entity* car1, Entity* car2) { {car2->x + (-half_len2 * cos2 + half_width2 * sin2), car2->y + (-half_len2 * sin2 - half_width2 * cos2)} }; + // Get the axes to check (normalized vectors perpendicular to each edge) float axes[4][2] = { {cos1, sin1}, // Car1's length axis @@ -1203,8 +1720,8 @@ void compute_agent_metrics(Drive* env, int agent_idx) { corners[i][1] = agent->y + (offsets[i][0]*half_length*sin_heading + offsets[i][1]*half_width*cos_heading); } - GridMapEntity entity_list[MAX_ENTITIES_PER_CELL*25]; // Array big enough for all neighboring cells - int list_size = checkNeighbors(env, agent->x, agent->y, entity_list, MAX_ENTITIES_PER_CELL*25, collision_offsets, 25); + int list_size = 0; + GridMapEntity* entity_list = checkNeighbors(env, agent->x, agent->y, collision_offsets, &list_size, COLLISION_VISION_RANGE*COLLISION_VISION_RANGE); for (int i = 0; i < list_size ; i++) { if(entity_list[i].entity_idx == -1) continue; if(entity_list[i].entity_idx == agent_idx) continue; @@ -1270,45 +1787,48 @@ void compute_agent_metrics(Drive* env, int agent_idx) { agent->collision_state = collided; + // Cleanup + free(entity_list); return; } bool should_control_agent(Drive* env, int agent_idx){ // Check if we have room for more agents or are already at capacity - if (env->active_agent_count >= env->num_agents) { + if (env->active_agent_count >= env->max_active_agents) { return false; } Entity* entity = &env->entities[agent_idx]; - // Shrink agent size for collision checking - entity->width *= 0.7f; // TODO: Move this somewhere else + // TODO: Move this elsewhere or remove + entity->width *= 0.7f; entity->length *= 0.7f; if (env->control_mode == CONTROL_SDC_ONLY) { - return (agent_idx == env->sdc_track_index); + return agent_idx == env->sdc_track_index; } - // Special mode: control only agents in prediction track list - if (env->control_mode == CONTROL_TRACKS_TO_PREDICT) { - for (int j = 0; j < env->num_tracks_to_predict; j++) { - if (env->tracks_to_predict_indices[j] == agent_idx) { - return true; - } - } - return false; - } + bool is_vehicle = (entity->type == VEHICLE); + bool is_ped_or_bike = (entity->type == PEDESTRIAN || entity->type == CYCLIST); + bool type_is_valid = false; + + switch (env->control_mode) { + case CONTROL_WOSAC: + // Valid types only, ignore expert flag and goal distance + return (is_vehicle || is_ped_or_bike); + + case CONTROL_VEHICLES: + type_is_valid = is_vehicle; + break; - // Standard mode: check type, distance to goal, and expert status - bool type_is_controllable = false; - if (env->control_mode == CONTROL_VEHICLES) { - type_is_controllable = (entity->type == VEHICLE); - } else { // CONTROL_AGENTS mode - type_is_controllable = (entity->type == VEHICLE || entity->type == PEDESTRIAN || entity->type == CYCLIST); + default: + type_is_valid = (is_vehicle || is_ped_or_bike); + break; } - if (!type_is_controllable || entity->mark_as_expert) { + // Filter invalid types or experts + if (!type_is_valid || entity->mark_as_expert) { return false; } @@ -1338,13 +1858,16 @@ void set_active_agents(Drive* env){ int static_agent_indices[MAX_AGENTS]; int expert_static_agent_indices[MAX_AGENTS]; - if(env->num_agents == 0){ - env->num_agents = MAX_AGENTS; + if(env->max_active_agents == 0){ + env->max_active_agents = MAX_AGENTS; } // Iterate through entities to find agents to create and/or control - for(int i = 0; i < env->num_objects && env->num_actors < MAX_AGENTS; i++){ - + for(int i = 0; i < env->num_entities && env->num_actors < env->max_active_agents; i++){ + // Skip non-agent entities + if (env->entities[i].type != VEHICLE && env->entities[i].type != PEDESTRIAN && env->entities[i].type != CYCLIST) { + continue; + } Entity* entity = &env->entities[i]; // Skip if not valid at initialization @@ -1379,7 +1902,7 @@ void set_active_agents(Drive* env){ static_agent_indices[env->static_agent_count] = i; env->static_agent_count++; env->entities[i].active_agent = 0; - if(env->entities[i].mark_as_expert == 1 || env->active_agent_count == env->num_agents) { + if(env->entities[i].mark_as_expert == 1 || env->active_agent_count == env->max_active_agents) { expert_static_agent_indices[env->expert_static_agent_count] = i; env->expert_static_agent_count++; env->entities[i].mark_as_expert = 1; @@ -1407,7 +1930,7 @@ void set_active_agents(Drive* env){ void remove_bad_trajectories(Drive* env){ - if (env->control_mode != CONTROL_TRACKS_TO_PREDICT) { + if (env->control_mode != CONTROL_WOSAC) { return; // Leave all trajectories in WOSAC control mode } @@ -1462,20 +1985,51 @@ void init_goal_positions(Drive* env){ } } +void seed_prng_once(void) { + static int seeded = 0; + if (seeded) return; + + const char* env_seed = getenv("PUFFER_SEED"); + if (env_seed && *env_seed) { + unsigned s = (unsigned)strtoul(env_seed, NULL, 10); + srand(s); + seeded = 1; + return; + } + + struct timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); + unsigned seed = (unsigned)(ts.tv_nsec ^ (ts.tv_sec * 1000003u) ^ getpid() ^ (uintptr_t)&seeded); + srand(seed); + seeded = 1; +} + void init(Drive* env){ + seed_prng_once(); // unique seed for each run(for reproducibility comment this line) env->human_agent_idx = 0; env->timestep = 0; + env_init_config conf = {0}; + if(ini_parse(env->ini_file, handler, &conf) < 0) { + printf("Error while loading config file %s", env->ini_file); + } + env->conf = conf; env->entities = load_map_binary(env->map_name, env); set_means(env); init_grid_map(env); + if (env->conf.init_mode == DYNAMIC_AGENTS_PER_ENV) { + env->goal_distance = env->conf.goal_distance; + init_agents_random_start(env); + } if (env->goal_behavior==GOAL_GENERATE_NEW) init_topology_graph(env); - env->grid_map->vision_range = 21; + env->grid_map->vision_range = GRID_MAP_CACHE_VISION_RANGE; init_neighbor_offsets(env); cache_neighbor_offsets(env); env->logs_capacity = 0; set_active_agents(env); env->logs_capacity = env->active_agent_count; - remove_bad_trajectories(env); + if (conf.init_mode != DYNAMIC_AGENTS_PER_ENV) { + remove_bad_trajectories(env); + } set_start_position(env); init_goal_positions(env); env->logs = (Log*)calloc(env->active_agent_count, sizeof(Log)); @@ -1495,6 +2049,7 @@ void c_close(Drive* env){ } free(env->grid_map->cells); free(env->grid_map->cell_entities_count); + free(env->grid_map->cell_roadlanes_count); free(env->neighbor_offsets); for(int i = 0; i < grid_cell_count; i++){ @@ -1716,7 +2271,19 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ return; } -void c_get_global_agent_state(Drive* env, float* x_out, float* y_out, float* z_out, float* heading_out, int* id_out) { +static inline int get_track_id_or_placeholder(Drive* env, int agent_idx) { + if (env->tracks_to_predict_indices == NULL || env->num_tracks_to_predict == 0) { + return -1; + } + for (int k = 0; k < env->num_tracks_to_predict; k++) { + if (env->tracks_to_predict_indices[k] == agent_idx) { + return env->tracks_to_predict_indices[k]; + } + } + return -1; +} + +void c_get_global_agent_state(Drive* env, float* x_out, float* y_out, float* z_out, float* heading_out, int* id_out, float* length_out, float* width_out) { for(int i = 0; i < env->active_agent_count; i++){ int agent_idx = env->active_agent_indices[i]; Entity* agent = &env->entities[agent_idx]; @@ -1726,7 +2293,9 @@ void c_get_global_agent_state(Drive* env, float* x_out, float* y_out, float* z_o y_out[i] = agent->y + env->world_mean_y; z_out[i] = agent->z; heading_out[i] = agent->heading; - id_out[i] = env->tracks_to_predict_indices[i]; + id_out[i] = get_track_id_or_placeholder(env, agent_idx); + length_out[i] = agent->length; + width_out[i] = agent->width; } } @@ -1734,7 +2303,7 @@ void c_get_global_ground_truth_trajectories(Drive* env, float* x_out, float* y_o for(int i = 0; i < env->active_agent_count; i++){ int agent_idx = env->active_agent_indices[i]; Entity* agent = &env->entities[agent_idx]; - id_out[i] = env->tracks_to_predict_indices[i]; + id_out[i] = get_track_id_or_placeholder(env, agent_idx); scenario_id_out[i] = agent->scenario_id; for(int t = env->init_steps; t < agent->array_size; t++){ @@ -1749,6 +2318,35 @@ void c_get_global_ground_truth_trajectories(Drive* env, float* x_out, float* y_o } } +void c_get_road_edge_counts(Drive* env, int* num_polylines_out, int* total_points_out) { + int count = 0, points = 0; + for(int i = env->num_objects; i < env->num_entities; i++) { + if(env->entities[i].type == ROAD_EDGE) { + count++; + points += env->entities[i].array_size; + } + } + *num_polylines_out = count; + *total_points_out = points; +} + +void c_get_road_edge_polylines(Drive* env, float* x_out, float* y_out, int* lengths_out, int* scenario_ids_out) { + int poly_idx = 0, pt_idx = 0; + for(int i = env->num_objects; i < env->num_entities; i++) { + Entity* e = &env->entities[i]; + if(e->type == ROAD_EDGE) { + lengths_out[poly_idx] = e->array_size; + scenario_ids_out[poly_idx] = e->scenario_id; + for(int j = 0; j < e->array_size; j++) { + x_out[pt_idx] = e->traj_x[j] + env->world_mean_x; + y_out[pt_idx] = e->traj_y[j] + env->world_mean_y; + pt_idx++; + } + poly_idx++; + } + } +} + void compute_observations(Drive* env) { int ego_dim = (env->dynamics_model == JERK) ? 10 : 7; int max_obs = ego_dim + 7*(MAX_AGENTS - 1) + 7*MAX_ROAD_SEGMENT_OBSERVATIONS; @@ -1837,10 +2435,10 @@ void compute_observations(Drive* env) { memset(&obs[obs_idx], 0, remaining_partner_obs * sizeof(float)); obs_idx += remaining_partner_obs; // map observations - GridMapEntity entity_list[MAX_ENTITIES_PER_CELL*25]; int grid_idx = getGridIndex(env, ego_entity->x, ego_entity->y); - int list_size = get_neighbor_cache_entities(env, grid_idx, entity_list, MAX_ROAD_SEGMENT_OBSERVATIONS); + int list_size = 0; + GridMapEntity* entity_list = get_neighbor_cache_entities(env, grid_idx, &list_size, MAX_ROAD_SEGMENT_OBSERVATIONS); for(int k = 0; k < list_size; k++) { int entity_idx = entity_list[k].entity_idx; @@ -1897,6 +2495,9 @@ void compute_observations(Drive* env) { int remaining_obs = (MAX_ROAD_SEGMENT_OBSERVATIONS - list_size) * 7; // Set the entire block to 0 at once memset(&obs[obs_idx], 0, remaining_obs * sizeof(float)); + + // Free heap memory + free(entity_list); } } @@ -2928,7 +3529,7 @@ void draw_scene(Drive* env, Client* client, int mode, int obs_only, int lasers, EndMode3D(); // Draw track indices for the tracks to predict - if (mode == 1 && env->control_mode == CONTROL_TRACKS_TO_PREDICT) { + if (mode == 1 && env->control_mode == CONTROL_WOSAC) { float map_width = env->grid_map->bottom_right_x - env->grid_map->top_left_x; float map_height = env->grid_map->top_left_y - env->grid_map->bottom_right_y; float pixels_per_world_unit = client->height / map_height; diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index ad6dbd62e..3a811b89f 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -25,8 +25,8 @@ def __init__( collision_behavior=0, offroad_behavior=0, dt=0.1, - scenario_length=None, - resample_frequency=91, + scenario_length=91, + resample_frequency=910, num_maps=100, num_agents=512, action_type="discrete", @@ -37,6 +37,12 @@ def __init__( init_steps=0, init_mode="create_all_valid", control_mode="control_vehicles", + control_non_vehicles=False, + num_agents_per_world=32, + vehicle_width=2.0, + vehicle_length=4.5, + vehicle_height=1.8, + goal_curriculum=30.0, ): # env self.dt = dt @@ -80,18 +86,20 @@ def __init__( self.control_mode = 0 elif self.control_mode_str == "control_agents": self.control_mode = 1 - elif self.control_mode_str == "control_tracks_to_predict": + elif self.control_mode_str == "control_wosac": self.control_mode = 2 elif self.control_mode_str == "control_sdc_only": self.control_mode = 3 else: raise ValueError( - f"control_mode must be one of 'control_vehicles', 'control_tracks_to_predict', or 'control_agents'. Got: {self.control_mode_str}" + f"control_mode must be one of 'control_vehicles', 'control_wosac', or 'control_agents'. Got: {self.control_mode_str}" ) if self.init_mode_str == "create_all_valid": self.init_mode = 0 elif self.init_mode_str == "create_only_controlled": self.init_mode = 1 + elif self.init_mode_str == "dynamic_no_agents": + self.init_mode = 0 # All created agents are valid else: raise ValueError( f"init_mode must be one of 'create_all_valid' or 'create_only_controlled'. Got: {self.init_mode_str}" @@ -138,6 +146,7 @@ def __init__( init_steps=self.init_steps, max_controlled_agents=self.max_controlled_agents, goal_behavior=self.goal_behavior, + ini_file="pufferlib/config/ocean/drive.ini", ) self.num_agents = num_agents @@ -210,6 +219,7 @@ def step(self, actions): init_steps=self.init_steps, max_controlled_agents=self.max_controlled_agents, goal_behavior=self.goal_behavior, + ini_file="pufferlib/config/ocean/drive.ini", ) env_ids = [] seed = np.random.randint(0, 2**32 - 1) @@ -255,7 +265,7 @@ def get_global_agent_state(self): """Get current global state of all active agents. Returns: - dict with keys 'x', 'y', 'z', 'heading', 'id' containing numpy arrays + dict with keys 'x', 'y', 'z', 'heading', 'id', 'length', 'width' containing numpy arrays of shape (num_active_agents,) """ num_agents = self.num_agents @@ -266,10 +276,19 @@ def get_global_agent_state(self): "z": np.zeros(num_agents, dtype=np.float32), "heading": np.zeros(num_agents, dtype=np.float32), "id": np.zeros(num_agents, dtype=np.int32), + "length": np.zeros(num_agents, dtype=np.float32), + "width": np.zeros(num_agents, dtype=np.float32), } binding.vec_get_global_agent_state( - self.c_envs, states["x"], states["y"], states["z"], states["heading"], states["id"] + self.c_envs, + states["x"], + states["y"], + states["z"], + states["heading"], + states["id"], + states["length"], + states["width"], ) return states @@ -308,6 +327,32 @@ def get_ground_truth_trajectories(self): return trajectories + def get_road_edge_polylines(self): + """Get road edge polylines for all scenarios. + + Returns: + dict with keys 'x', 'y', 'lengths', 'scenario_id' containing numpy arrays. + x, y are flattened point coordinates; lengths indicates points per polyline. + """ + num_polylines, total_points = binding.vec_get_road_edge_counts(self.c_envs) + + polylines = { + "x": np.zeros(total_points, dtype=np.float32), + "y": np.zeros(total_points, dtype=np.float32), + "lengths": np.zeros(num_polylines, dtype=np.int32), + "scenario_id": np.zeros(num_polylines, dtype=np.int32), + } + + binding.vec_get_road_edge_polylines( + self.c_envs, + polylines["x"], + polylines["y"], + polylines["lengths"], + polylines["scenario_id"], + ) + + return polylines + def render(self): binding.vec_render(self.c_envs, 0) diff --git a/pufferlib/ocean/drive/error.h b/pufferlib/ocean/drive/error.h index b1eb78e7e..88bf1c477 100644 --- a/pufferlib/ocean/drive/error.h +++ b/pufferlib/ocean/drive/error.h @@ -15,6 +15,7 @@ typedef enum { ERROR_MEMORY_ALLOCATION, ERROR_FILE_NOT_FOUND, ERROR_INITIALIZATION_FAILED, + ERROR_INVALID_CONFIG, ERROR_UNKNOWN } ErrorType; @@ -27,6 +28,7 @@ const char* error_type_to_string(ErrorType type) { case ERROR_MEMORY_ALLOCATION: return "Memory Allocation Failed"; case ERROR_FILE_NOT_FOUND: return "File Not Found"; case ERROR_INITIALIZATION_FAILED: return "Initialization Failed"; + case ERROR_INVALID_CONFIG: return "Invalid Configuration"; default: return "Unknown Error"; } } diff --git a/pufferlib/ocean/drive/visualize.c b/pufferlib/ocean/drive/visualize.c index c7eb29beb..8bf38893d 100644 --- a/pufferlib/ocean/drive/visualize.c +++ b/pufferlib/ocean/drive/visualize.c @@ -206,7 +206,7 @@ static int make_gif_from_frames(const char *pattern, int fps, int eval_gif(const char* map_name, const char* policy_name, int show_grid, int obs_only, int lasers, int log_trajectories, int frame_skip, float goal_radius, int init_steps, int max_controlled_agents, const char* view_mode, const char* output_topdown, const char* output_agent, int num_maps, int scenario_length_override, int init_mode, int control_mode, int goal_behavior) { // Parse configuration from INI file - env_init_config conf = {0}; // Initialize to zero + env_init_config conf = {0}; const char* ini_file = "pufferlib/config/ocean/drive.ini"; if(ini_parse(ini_file, handler, &conf) < 0) { fprintf(stderr, "Error: Could not load %s. Cannot determine environment configuration.\n", ini_file); @@ -253,6 +253,7 @@ int eval_gif(const char* map_name, const char* policy_name, int show_grid, int o .goal_behavior = goal_behavior, .init_mode = init_mode, .control_mode = control_mode, + .ini_file = (char*)ini_file, }; env.scenario_length = (scenario_length_override > 0) ? scenario_length_override : diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h index f11ea0f39..1b305bc4a 100644 --- a/pufferlib/ocean/env_binding.h +++ b/pufferlib/ocean/env_binding.h @@ -614,8 +614,8 @@ static PyObject* vec_close(PyObject* self, PyObject* args) { } static PyObject* get_global_agent_state(PyObject* self, PyObject* args) { - if (PyTuple_Size(args) != 5) { - PyErr_SetString(PyExc_TypeError, "get_global_agent_state requires 5 arguments"); + if (PyTuple_Size(args) != 7) { + PyErr_SetString(PyExc_TypeError, "get_global_agent_state requires 7 arguments"); return NULL; } @@ -632,10 +632,13 @@ static PyObject* get_global_agent_state(PyObject* self, PyObject* args) { PyObject* z_arr = PyTuple_GetItem(args, 3); PyObject* heading_arr = PyTuple_GetItem(args, 4); PyObject* id_arr = PyTuple_GetItem(args, 5); + PyObject* length_arr = PyTuple_GetItem(args, 6); + PyObject* width_arr = PyTuple_GetItem(args, 7); if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || - !PyArray_Check(id_arr)) { + !PyArray_Check(id_arr) || !PyArray_Check(length_arr) || + !PyArray_Check(width_arr)) { PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays"); return NULL; } @@ -645,14 +648,16 @@ static PyObject* get_global_agent_state(PyObject* self, PyObject* args) { float* z_data = (float*)PyArray_DATA((PyArrayObject*)z_arr); float* heading_data = (float*)PyArray_DATA((PyArrayObject*)heading_arr); int* id_data = (int*)PyArray_DATA((PyArrayObject*)id_arr); + float* length_data = (float*)PyArray_DATA((PyArrayObject*)length_arr); + float* width_data = (float*)PyArray_DATA((PyArrayObject*)width_arr); - c_get_global_agent_state(drive, x_data, y_data, z_data, heading_data, id_data); + c_get_global_agent_state(drive, x_data, y_data, z_data, heading_data, id_data, length_data, width_data); Py_RETURN_NONE; } static PyObject* vec_get_global_agent_state(PyObject* self, PyObject* args) { - if (PyTuple_Size(args) != 6) { - PyErr_SetString(PyExc_TypeError, "vec_get_global_agent_state requires 6 arguments"); + if (PyTuple_Size(args) != 8) { + PyErr_SetString(PyExc_TypeError, "vec_get_global_agent_state requires 8 arguments"); return NULL; } @@ -667,10 +672,13 @@ static PyObject* vec_get_global_agent_state(PyObject* self, PyObject* args) { PyObject* z_arr = PyTuple_GetItem(args, 3); PyObject* heading_arr = PyTuple_GetItem(args, 4); PyObject* id_arr = PyTuple_GetItem(args, 5); + PyObject* length_arr = PyTuple_GetItem(args, 6); + PyObject* width_arr = PyTuple_GetItem(args, 7); if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || - !PyArray_Check(id_arr)) { + !PyArray_Check(id_arr) || !PyArray_Check(length_arr) || + !PyArray_Check(width_arr)) { PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays"); return NULL; } @@ -680,6 +688,8 @@ static PyObject* vec_get_global_agent_state(PyObject* self, PyObject* args) { PyArrayObject* z_array = (PyArrayObject*)z_arr; PyArrayObject* heading_array = (PyArrayObject*)heading_arr; PyArrayObject* id_array = (PyArrayObject*)id_arr; + PyArrayObject* length_array = (PyArrayObject*)length_arr; + PyArrayObject* width_array = (PyArrayObject*)width_arr; // Get base pointers to the arrays float* x_base = (float*)PyArray_DATA(x_array); @@ -687,6 +697,8 @@ static PyObject* vec_get_global_agent_state(PyObject* self, PyObject* args) { float* z_base = (float*)PyArray_DATA(z_array); float* heading_base = (float*)PyArray_DATA(heading_array); int* id_base = (int*)PyArray_DATA(id_array); + float* length_base = (float*)PyArray_DATA(length_array); + float* width_base = (float*)PyArray_DATA(width_array); // Iterate through environments and write to correct offsets int offset = 0; @@ -699,7 +711,9 @@ static PyObject* vec_get_global_agent_state(PyObject* self, PyObject* args) { &y_base[offset], &z_base[offset], &heading_base[offset], - &id_base[offset]); + &id_base[offset], + &length_base[offset], + &width_base[offset]); // Move offset forward by the number of agents in this environment offset += drive->active_agent_count; @@ -821,6 +835,60 @@ static PyObject* vec_get_global_ground_truth_trajectories(PyObject* self, PyObje Py_RETURN_NONE; } + +static PyObject* vec_get_road_edge_counts(PyObject* self, PyObject* args) { + VecEnv* vec = unpack_vecenv(args); + if (!vec) return NULL; + + int total_polylines = 0, total_points = 0; + for (int i = 0; i < vec->num_envs; i++) { + Drive* drive = (Drive*)vec->envs[i]; + int np, tp; + c_get_road_edge_counts(drive, &np, &tp); + total_polylines += np; + total_points += tp; + } + return Py_BuildValue("(ii)", total_polylines, total_points); +} + +static PyObject* vec_get_road_edge_polylines(PyObject* self, PyObject* args) { + if (PyTuple_Size(args) != 5) { + PyErr_SetString(PyExc_TypeError, "vec_get_road_edge_polylines requires 5 arguments"); + return NULL; + } + + VecEnv* vec = unpack_vecenv(args); + if (!vec) return NULL; + + PyObject* x_arr = PyTuple_GetItem(args, 1); + PyObject* y_arr = PyTuple_GetItem(args, 2); + PyObject* lengths_arr = PyTuple_GetItem(args, 3); + PyObject* scenario_ids_arr = PyTuple_GetItem(args, 4); + + if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || + !PyArray_Check(lengths_arr) || !PyArray_Check(scenario_ids_arr)) { + PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays"); + return NULL; + } + + float* x_base = (float*)PyArray_DATA((PyArrayObject*)x_arr); + float* y_base = (float*)PyArray_DATA((PyArrayObject*)y_arr); + int* lengths_base = (int*)PyArray_DATA((PyArrayObject*)lengths_arr); + int* scenario_ids_base = (int*)PyArray_DATA((PyArrayObject*)scenario_ids_arr); + + int poly_offset = 0, pt_offset = 0; + for (int i = 0; i < vec->num_envs; i++) { + Drive* drive = (Drive*)vec->envs[i]; + int np, tp; + c_get_road_edge_counts(drive, &np, &tp); + c_get_road_edge_polylines(drive, &x_base[pt_offset], &y_base[pt_offset], + &lengths_base[poly_offset], &scenario_ids_base[poly_offset]); + poly_offset += np; + pt_offset += tp; + } + Py_RETURN_NONE; +} + static double unpack(PyObject* kwargs, char* key) { PyObject* val = PyDict_GetItemString(kwargs, key); if (val == NULL) { @@ -896,6 +964,8 @@ static PyMethodDef methods[] = { {"vec_get_global_agent_state", vec_get_global_agent_state, METH_VARARGS, "Get agent state from vectorized env"}, {"get_ground_truth_trajectories", get_ground_truth_trajectories, METH_VARARGS, "Get ground truth trajectories"}, {"vec_get_global_ground_truth_trajectories", vec_get_global_ground_truth_trajectories, METH_VARARGS, "Get ground truth trajectories from vectorized env"}, + {"vec_get_road_edge_counts", vec_get_road_edge_counts, METH_VARARGS, "Get road edge polyline counts from vectorized env"}, + {"vec_get_road_edge_polylines", vec_get_road_edge_polylines, METH_VARARGS, "Get road edge polylines from vectorized env"}, MY_METHODS, {NULL, NULL, 0, NULL} }; diff --git a/pufferlib/ocean/env_config.h b/pufferlib/ocean/env_config.h index e8400c51c..3ff1d1168 100644 --- a/pufferlib/ocean/env_config.h +++ b/pufferlib/ocean/env_config.h @@ -6,9 +6,16 @@ #include #include +typedef enum{ + UNKNOWN_INIT_MODE = -1, + DEFAULT_INIT_MODE = 0, + DYNAMIC_AGENTS_PER_ENV = 1, + INIT_ALL_VALID = 2, + INIT_ONLY_CONTROLLABLE_AGENTS = 3, +} Init_Mode; + // Config struct for parsing INI files - contains all environment configuration -typedef struct -{ +typedef struct { int action_type; int dynamics_model; float reward_vehicle_collision; @@ -25,7 +32,12 @@ typedef struct int goal_behavior; int scenario_length; int init_steps; - int init_mode; + Init_Mode init_mode; + int num_agents_per_world; + float goal_distance; + float vehicle_width; + float vehicle_length; + float vehicle_height; int control_mode; } env_init_config; @@ -85,10 +97,34 @@ static int handler( env_config->scenario_length = atoi(value); } else if (MATCH("env", "init_steps")) { env_config->init_steps = atoi(value); - } else if (MATCH("env", "init_mode")) { - env_config->init_mode = atoi(value); } else if (MATCH("env", "control_mode")) { env_config->control_mode = atoi(value); + } else if (MATCH("env", "init_mode")) { + if (strcmp(value, "\"default\"") == 0 || strcmp(value, "default") == 0) { + env_config->init_mode = DEFAULT_INIT_MODE; // DEFAULT + } else if (strcmp(value, "\"dynamic_no_agents\"") == 0 || strcmp(value, "dynamic_no_agents") == 0) { + env_config->init_mode = DYNAMIC_AGENTS_PER_ENV; // DYNAMIC_NO_AGENTS + } else if (strcmp(value, "\"create_all_valid\"") == 0) { + env_config->init_mode = INIT_ALL_VALID; // CREATE_ALL_VALID + } else if (strcmp(value, "\"create_only_controlled\"") == 0) { + env_config->init_mode = INIT_ONLY_CONTROLLABLE_AGENTS; // INIT_ONLY_CONTROLLABLE_AGENTS + } else { + raise_error_with_message(ERROR_INVALID_CONFIG, "Unknown init_mode value: %s", value); + env_config->init_mode = UNKNOWN_INIT_MODE; // Default to UNKNOWN + } + } else if (MATCH("env", "num_agents_per_world")) { + env_config->num_agents_per_world = atoi(value); + if(env_config->num_agents_per_world <= 0 && env_config->init_mode == DYNAMIC_AGENTS_PER_ENV) { + raise_error_with_message(ERROR_INVALID_CONFIG, "num_agents_per_world must be positive for dynamic_agents_per_env init_mode"); + } + } else if(MATCH("env", "vehicle_width")) { + env_config->vehicle_width = atof(value); + } else if(MATCH("env", "vehicle_length")) { + env_config->vehicle_length = atof(value); + } else if(MATCH("env", "vehicle_height")) { + env_config->vehicle_height = atof(value); + } else if (MATCH("env", "goal_curriculum")) { + env_config->goal_distance = atof(value); } #undef MATCH diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index ef9efddf4..7743d26c3 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -1063,12 +1063,20 @@ def eval(env_name, args=None, vecenv=None, policy=None): # Roll out trained policy in the simulator simulated_trajectories = evaluator.collect_simulated_trajectories(args, vecenv, policy) + print(f"\nCollected trajectories on {len(np.unique(gt_trajectories['scenario_id']))} scenarios.") + if args["eval"]["wosac_sanity_check"]: evaluator._quick_sanity_check(gt_trajectories, simulated_trajectories) # Analyze and compute metrics + agent_state = vecenv.driver_env.get_global_agent_state() + road_edge_polylines = vecenv.driver_env.get_road_edge_polylines() results = evaluator.compute_metrics( - gt_trajectories, simulated_trajectories, args["eval"]["wosac_aggregate_results"] + gt_trajectories, + simulated_trajectories, + agent_state, + road_edge_polylines, + args["eval"]["wosac_aggregate_results"], ) if args["eval"]["wosac_aggregate_results"]: @@ -1200,6 +1208,53 @@ def sweep(args=None, env_name=None): args["train"]["total_timesteps"] = total_timesteps +def controlled_exp(env_name, args=None): + """Run experiments with all combinations of specified parameter values.""" + import itertools + from copy import deepcopy + + args = args or load_config(env_name) + if not args["wandb"] and not args["neptune"]: + raise pufferlib.APIUsageError("Targeted experiments require either wandb or neptune") + + # Check if controlled_exp config exists + if "controlled_exp" not in args: + raise pufferlib.APIUsageError("No [controlled_exp.*] sections found in config") + + # Extract parameters from controlled_exp namespace + params = {} + for section, section_config in args["controlled_exp"].items(): + if isinstance(section_config, dict): + for param, param_config in section_config.items(): + if isinstance(param_config, dict) and "values" in param_config: + params[f"{section}.{param}"] = param_config["values"] + + if not params: + raise pufferlib.APIUsageError("No parameters with 'values' lists found in [controlled_exp.*] sections") + + # Generate all combinations + keys = list(params.keys()) + combinations = list(itertools.product(*[params[k] for k in keys])) + + print(f"Running a total of {len(combinations)} experiments with parameters: {keys}") + + # Run each combination + for i, combo in enumerate(combinations, 1): + exp_args = deepcopy(args) + + # Set parameters + for key, value in zip(keys, combo): + section, param = key.split(".") + exp_args[section][param] = value + + print(f"\nExperiment {i}/{len(combinations)}: {dict(zip(keys, combo))}") + + # Train + train(env_name, args=exp_args) + + print(f"\n✓ Completed all {len(combinations)} experiments") + + def profile(args=None, env_name=None, vecenv=None, policy=None): args = load_config() vecenv = vecenv or load_env(env_name, args) @@ -1408,9 +1463,7 @@ def puffer_type(value): def main(): - err = ( - "Usage: puffer [train, eval, sweep, autotune, profile, export] [env_name] [optional args]. --help for more info" - ) + err = "Usage: puffer [train, eval, sweep, controlled_exp, autotune, profile, export] [env_name] [optional args]. --help for more info" if len(sys.argv) < 3: raise pufferlib.APIUsageError(err) @@ -1422,6 +1475,8 @@ def main(): eval(env_name=env_name) elif mode == "sweep": sweep(env_name=env_name) + elif mode == "controlled_exp": + controlled_exp(env_name=env_name) elif mode == "autotune": autotune(env_name=env_name) elif mode == "profile": diff --git a/pufferlib/utils.py b/pufferlib/utils.py index 85cb64732..1e7afbb11 100644 --- a/pufferlib/utils.py +++ b/pufferlib/utils.py @@ -114,7 +114,7 @@ def run_wosac_eval_in_subprocess(config, logger, global_step): "--eval.wosac-init-mode", str(eval_config.get("wosac_init_mode", "create_all_valid")), "--eval.wosac-control-mode", - str(eval_config.get("wosac_control_mode", "control_tracks_to_predict")), + str(eval_config.get("wosac_control_mode", "control_wosac")), "--eval.wosac-init-steps", str(eval_config.get("wosac_init_steps", 10)), "--eval.wosac-goal-behavior", @@ -151,12 +151,20 @@ def run_wosac_eval_in_subprocess(config, logger, global_step): step=global_step, ) else: - print(f"WOSAC evaluation failed with exit code {result.returncode}: {result.stderr}") + print(f"WOSAC evaluation failed with exit code {result.returncode}") + print(f"Error: {result.stderr}") + + # Check for memory issues + stderr_lower = result.stderr.lower() + if "out of memory" in stderr_lower or "cuda out of memory" in stderr_lower: + print("GPU out of memory. Skipping this WOSAC evaluation.") except subprocess.TimeoutExpired: - print("WOSAC evaluation timed out") + print("WOSAC evaluation timed out after 600 seconds") + except MemoryError as e: + print(f"WOSAC evaluation ran out of memory. Skipping this evaluation: {e}") except Exception as e: - print(f"Failed to run WOSAC evaluation: {e}") + print(f"Failed to run WOSAC evaluation: {type(e).__name__}: {e}") def render_videos(config, vecenv, logger, epoch, global_step, bin_path): @@ -194,84 +202,119 @@ def render_videos(config, vecenv, logger, epoch, global_step, bin_path): # TODO: Fix memory leaks so that this is not needed # Suppress AddressSanitizer exit code (temp) - env = os.environ.copy() - env["ASAN_OPTIONS"] = "exitcode=0" - - cmd = ["xvfb-run", "-a", "-s", "-screen 0 1280x720x24", "./visualize"] - - # Add render configurations - if config["show_grid"]: - cmd.append("--show-grid") - if config["obs_only"]: - cmd.append("--obs-only") - if config["show_lasers"]: - cmd.append("--lasers") - if config["show_human_logs"]: - cmd.append("--log-trajectories") - if vecenv.driver_env.goal_radius is not None: - cmd.extend(["--goal-radius", str(vecenv.driver_env.goal_radius)]) - if vecenv.driver_env.init_steps > 0: - cmd.extend(["--init-steps", str(vecenv.driver_env.init_steps)]) - if config["render_map"] is not None: - map_path = config["render_map"] - if os.path.exists(map_path): - cmd.extend(["--map-name", map_path]) - if vecenv.driver_env.init_mode is not None: - cmd.extend(["--init-mode", str(vecenv.driver_env.init_mode)]) - if vecenv.driver_env.control_mode is not None: - cmd.extend(["--control-mode", str(vecenv.driver_env.control_mode)]) - - # Specify output paths for videos - cmd.extend(["--output-topdown", "resources/drive/output_topdown.mp4"]) - cmd.extend(["--output-agent", "resources/drive/output_agent.mp4"]) - - # Add environment configuration + env_vars = os.environ.copy() + env_vars["ASAN_OPTIONS"] = "exitcode=0" + + # Base command (without map/output paths) + base_cmd = ["xvfb-run", "-a", "-s", "-screen 0 1280x720x24", "./visualize"] + + # Render config flags + if config.get("show_grid", False): + base_cmd.append("--show-grid") + if config.get("obs_only", False): + base_cmd.append("--obs-only") + if config.get("show_lasers", False): + base_cmd.append("--lasers") + if config.get("show_human_logs", False): + base_cmd.append("--log-trajectories") + env_cfg = getattr(vecenv, "driver_env", None) if env_cfg is not None: - n_policy = getattr(env_cfg, "max_controlled_agents", -1) + if getattr(env_cfg, "control_non_vehicles", False): + base_cmd.append("--control-non-vehicles") + if getattr(env_cfg, "goal_radius", None) is not None: + base_cmd.extend(["--goal-radius", str(env_cfg.goal_radius)]) + if getattr(env_cfg, "init_steps", 0) > 0: + base_cmd.extend(["--init-steps", str(env_cfg.init_steps)]) + if getattr(env_cfg, "init_mode", None) is not None: + base_cmd.extend(["--init-mode", str(env_cfg.init_mode)]) + if getattr(env_cfg, "control_mode", None) is not None: + base_cmd.extend(["--control-mode", str(env_cfg.control_mode)]) + if getattr(env_cfg, "control_all_agents", False): + base_cmd.append("--pure-self-play") + if getattr(env_cfg, "deterministic_agent_selection", False): + base_cmd.append("--deterministic-selection") + + # Policy-controlled agents (prefer num_policy_controlled_agents, fallback to max_controlled_agents) + n_policy = getattr(env_cfg, "num_policy_controlled_agents", getattr(env_cfg, "max_controlled_agents", -1)) try: n_policy = int(n_policy) except (TypeError, ValueError): n_policy = -1 if n_policy > 0: - cmd += ["--num-policy-controlled-agents", str(n_policy)] + base_cmd += ["--num-policy-controlled-agents", str(n_policy)] + if getattr(env_cfg, "num_maps", False): - cmd.extend(["--num-maps", str(env_cfg.num_maps)]) + base_cmd.extend(["--num-maps", str(env_cfg.num_maps)]) if getattr(env_cfg, "scenario_length", None): - cmd.extend(["--scenario-length", str(env_cfg.scenario_length)]) - - # Call C code that runs eval_gif() in subprocess - result = subprocess.run(cmd, cwd=os.getcwd(), capture_output=True, text=True, timeout=120, env=env) - - vids_exist = os.path.exists("resources/drive/output_topdown.mp4") and os.path.exists( - "resources/drive/output_agent.mp4" - ) - - if result.returncode == 0 or (result.returncode == 1 and vids_exist): - # Move both generated videos to the model directory - videos = [ - ("resources/drive/output_topdown.mp4", f"epoch_{epoch:06d}_topdown.mp4"), - ("resources/drive/output_agent.mp4", f"epoch_{epoch:06d}_agent.mp4"), - ] - - for source_vid, target_filename in videos: - if os.path.exists(source_vid): - target_gif = os.path.join(video_output_dir, target_filename) - shutil.move(source_vid, target_gif) - - # Log to wandb if available - if hasattr(logger, "wandb") and logger.wandb: - import wandb - - view_type = "world_state" if "topdown" in target_filename else "agent_view" - logger.wandb.log( - {f"render/{view_type}": wandb.Video(target_gif, format="mp4")}, - step=global_step, - ) - else: - print(f"Video generation completed but {source_vid} not found") + base_cmd.extend(["--scenario-length", str(env_cfg.scenario_length)]) + + # Handle single or multiple map rendering + render_maps = config.get("render_map", None) + if render_maps is None: + render_maps = [None] + elif isinstance(render_maps, (str, os.PathLike)): + render_maps = [render_maps] else: - print(f"C rendering failed with exit code {result.returncode}: {result.stdout}") + # Ensure list-like + render_maps = list(render_maps) + + # Collect videos to log as lists so W&B shows all in the same step + videos_to_log_world = [] + videos_to_log_agent = [] + + for i, map_path in enumerate(render_maps): + cmd = list(base_cmd) # copy + if map_path is not None and os.path.exists(map_path): + cmd.extend(["--map-name", str(map_path)]) + + # Output paths (overwrite each iteration; then moved/renamed) + cmd.extend(["--output-topdown", "resources/drive/output_topdown.mp4"]) + cmd.extend(["--output-agent", "resources/drive/output_agent.mp4"]) + + result = subprocess.run(cmd, cwd=os.getcwd(), capture_output=True, text=True, timeout=120, env=env_vars) + + vids_exist = os.path.exists("resources/drive/output_topdown.mp4") and os.path.exists( + "resources/drive/output_agent.mp4" + ) + + if result.returncode == 0 or (result.returncode == 1 and vids_exist): + videos = [ + ( + "resources/drive/output_topdown.mp4", + f"epoch_{epoch:06d}_map{i:02d}_topdown.mp4" if map_path else f"epoch_{epoch:06d}_topdown.mp4", + ), + ( + "resources/drive/output_agent.mp4", + f"epoch_{epoch:06d}_map{i:02d}_agent.mp4" if map_path else f"epoch_{epoch:06d}_agent.mp4", + ), + ] + + for source_vid, target_filename in videos: + if os.path.exists(source_vid): + target_path = os.path.join(video_output_dir, target_filename) + shutil.move(source_vid, target_path) + # Accumulate for a single wandb.log call + if hasattr(logger, "wandb") and logger.wandb: + import wandb + + if "topdown" in target_filename: + videos_to_log_world.append(wandb.Video(target_path, format="mp4")) + else: + videos_to_log_agent.append(wandb.Video(target_path, format="mp4")) + else: + print(f"Video generation completed but {source_vid} not found") + else: + print(f"C rendering failed (map index {i}) with exit code {result.returncode}: {result.stdout}") + + # Log all videos at once so W&B keeps all of them under the same step + if hasattr(logger, "wandb") and logger.wandb and (videos_to_log_world or videos_to_log_agent): + payload = {} + if videos_to_log_world: + payload["render/world_state"] = videos_to_log_world + if videos_to_log_agent: + payload["render/agent_view"] = videos_to_log_agent + logger.wandb.log(payload, step=global_step) except subprocess.TimeoutExpired: print("C rendering timed out") diff --git a/setup.py b/setup.py index e3d86c0e6..b834f94cf 100644 --- a/setup.py +++ b/setup.py @@ -326,6 +326,7 @@ def run(self): "heavyball<2.0.0", "neptune", "wandb", + "matplotlib", ] setup(