Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion src/lossfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# Description: Define the loss function for training.
"""
import torch
import numpy as np
from assets.cuda.chamfer3D import nnChamferDis
MyCUDAChamferDis = nnChamferDis()
from src.utils.av2_eval import CATEGORY_TO_INDEX, BUCKETED_METACATAGORIES
Expand All @@ -18,7 +19,7 @@
# If your scenario is different, may need adjust this TRUNCATED to 80-120km/h vel.
TRUNCATED_DIST = 4


# ---------------------- Self-Supervised Flow Loss without GT ----------------------
def seflowLoss(res_dict, timer=None):
pc0_label = res_dict['pc0_labels']
pc1_label = res_dict['pc1_labels']
Expand Down Expand Up @@ -99,6 +100,52 @@ def seflowLoss(res_dict, timer=None):
}
return res_loss

# ---------------------- Supervised Flow Loss with GT ----------------------
# designed from MambaFlow: https://github.com/SCNU-RISLAB/MambaFlow
def mambaflowLoss(res_dict):
pred = res_dict['est_flow']
gt = res_dict['gt_flow']
mask_no_nan = (~gt.isnan() & ~pred.isnan() & ~gt.isinf() & ~pred.isinf())

pred = pred[mask_no_nan].reshape(-1, 3)
gt = gt[mask_no_nan].reshape(-1, 3)

speed = gt.norm(dim=1, p=2) / 0.1
# pts_loss = torch.norm(pred - gt, dim=1, p=2)
pts_loss = torch.linalg.vector_norm(pred - gt, dim=-1)

velocities = speed.cpu().numpy()

# 计算直方图,返回每个区间的计数和区间边界
counts, bin_edges = np.histogram(velocities, bins=100, density=False)

# 计算每个区间的点数占总点数的比例
total_points = len(velocities)
proportions = counts / total_points

# 计算每个区间的中心位置,用于绘图
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

# 设置占比阈值
proportion_threshold = 0.01 # 可以根据需要调整这个值

# 找出第一个占比小于阈值的柱子
first_below_threshold = next((i for i, prop in enumerate(proportions) if prop < proportion_threshold), None)
turning_speed = bin_centers[first_below_threshold]

weight_loss = 0.0
speed_mid = 2
speed_0 = pts_loss[speed < turning_speed].mean()
speed_1 = pts_loss[(speed >= turning_speed) & (speed <= speed_mid)].mean()
speed_2 = pts_loss[speed > speed_mid].mean()
if ~speed_1.isnan():
weight_loss += speed_1
if ~speed_0.isnan():
weight_loss += speed_0
if ~speed_2.isnan():
weight_loss += speed_2
return {'loss': weight_loss}

def deflowLoss(res_dict):
pred = res_dict['est_flow']
gt = res_dict['gt_flow']
Expand Down