From 599c84d07c32ec69f00ab9d5dcf3bc9972153331 Mon Sep 17 00:00:00 2001 From: Kin Date: Thu, 13 Mar 2025 08:48:09 +0100 Subject: [PATCH] feat(loss): add mambaflow loss. --- src/lossfuncs.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/src/lossfuncs.py b/src/lossfuncs.py index cd161d9..c4feae9 100644 --- a/src/lossfuncs.py +++ b/src/lossfuncs.py @@ -10,6 +10,7 @@ # Description: Define the loss function for training. """ import torch +import numpy as np from assets.cuda.chamfer3D import nnChamferDis MyCUDAChamferDis = nnChamferDis() from src.utils.av2_eval import CATEGORY_TO_INDEX, BUCKETED_METACATAGORIES @@ -18,7 +19,7 @@ # If your scenario is different, may need adjust this TRUNCATED to 80-120km/h vel. TRUNCATED_DIST = 4 - +# ---------------------- Self-Supervised Flow Loss without GT ---------------------- def seflowLoss(res_dict, timer=None): pc0_label = res_dict['pc0_labels'] pc1_label = res_dict['pc1_labels'] @@ -99,6 +100,52 @@ def seflowLoss(res_dict, timer=None): } return res_loss +# ---------------------- Supervised Flow Loss with GT ---------------------- +# designed from MambaFlow: https://github.com/SCNU-RISLAB/MambaFlow +def mambaflowLoss(res_dict): + pred = res_dict['est_flow'] + gt = res_dict['gt_flow'] + mask_no_nan = (~gt.isnan() & ~pred.isnan() & ~gt.isinf() & ~pred.isinf()) + + pred = pred[mask_no_nan].reshape(-1, 3) + gt = gt[mask_no_nan].reshape(-1, 3) + + speed = gt.norm(dim=1, p=2) / 0.1 + # pts_loss = torch.norm(pred - gt, dim=1, p=2) + pts_loss = torch.linalg.vector_norm(pred - gt, dim=-1) + + velocities = speed.cpu().numpy() + + # 计算直方图,返回每个区间的计数和区间边界 + counts, bin_edges = np.histogram(velocities, bins=100, density=False) + + # 计算每个区间的点数占总点数的比例 + total_points = len(velocities) + proportions = counts / total_points + + # 计算每个区间的中心位置,用于绘图 + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + + # 设置占比阈值 + proportion_threshold = 0.01 # 可以根据需要调整这个值 + + # 找出第一个占比小于阈值的柱子 + first_below_threshold = next((i for i, prop in enumerate(proportions) if prop < proportion_threshold), None) + turning_speed = bin_centers[first_below_threshold] + + weight_loss = 0.0 + speed_mid = 2 + speed_0 = pts_loss[speed < turning_speed].mean() + speed_1 = pts_loss[(speed >= turning_speed) & (speed <= speed_mid)].mean() + speed_2 = pts_loss[speed > speed_mid].mean() + if ~speed_1.isnan(): + weight_loss += speed_1 + if ~speed_0.isnan(): + weight_loss += speed_0 + if ~speed_2.isnan(): + weight_loss += speed_2 + return {'loss': weight_loss} + def deflowLoss(res_dict): pred = res_dict['est_flow'] gt = res_dict['gt_flow']