import copy
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.ops.boxes import box_iou
from torchvision.ops import masks_to_boxes
from milliontrees.common.metrics.metric import Metric, ElementwiseMetric, MultiTaskMetric
from milliontrees.common.metrics.matching import (
greedy_distance_match,
greedy_iou_match,
merge_commission_rate_distance,
merge_commission_rate_iou,
n_matched_gt,
)
from milliontrees.common.metrics.loss import ElementwiseLoss
from milliontrees.common.utils import minimum, maximum, get_counts
import sklearn.metrics
from scipy.stats import pearsonr
[docs]
def detection_map_backend() -> str:
"""Prefer faster_coco_eval when installed; fall back to pycocotools."""
from torchmetrics.utilities.imports import _FASTER_COCO_EVAL_AVAILABLE
if _FASTER_COCO_EVAL_AVAILABLE:
return "faster_coco_eval"
return "pycocotools"
[docs]
def make_mean_average_precision(**kwargs):
"""Construct torchmetrics MeanAveragePrecision with the fastest available backend."""
from torchmetrics.detection import MeanAveragePrecision
return MeanAveragePrecision(backend=detection_map_backend(), **kwargs)
[docs]
def compute_polygon_mask_elementwise_batch(
y_pred: list,
y_true: list,
*,
accuracy_metric: "MaskAccuracy",
recall_metric: "MaskAccuracy",
maskaware_metric: "MaskAwareMaskPrecision",
merge_metric: "MergeCommissionMetric",
) -> dict[str, torch.Tensor]:
"""Per-image elementwise polygon metrics with a single ``_mask_iou`` per image."""
acc_out: list[torch.Tensor] = []
rec_out: list[torch.Tensor] = []
map_out: list[torch.Tensor] = []
merge_out: list[torch.Tensor] = []
for gt, target in zip(y_true, y_pred):
scores = target["scores"]
if not isinstance(scores, torch.Tensor):
scores = torch.as_tensor(scores, dtype=torch.float32)
target_masks = target[accuracy_metric.geometry_name]
gt_masks = gt[accuracy_metric.geometry_name]
pred_masks = target_masks[scores > accuracy_metric.score_threshold]
total_gt = len(gt_masks)
total_pred = len(pred_masks)
if total_gt > 0 and total_pred > 0:
iou = accuracy_metric._mask_iou(gt_masks, pred_masks)
else:
iou = None
acc_out.append(
accuracy_metric._accuracy(gt_masks,
pred_masks,
accuracy_metric.iou_threshold,
iou=iou))
rec_out.append(
recall_metric._recall(gt_masks,
pred_masks,
recall_metric.iou_threshold,
iou=iou))
if total_gt == 0 or total_pred == 0:
merge_out.append(torch.tensor(0.0))
else:
merge_out.append(
merge_commission_rate_iou(
iou, merge_metric.iou_threshold).to(dtype=torch.float32))
tree_mask = gt.get(maskaware_metric.tree_coverage_key)
map_out.append(
maskaware_metric._precision(gt_masks,
pred_masks,
tree_mask,
iou=iou))
return {
"accuracy": torch.stack(acc_out),
"recall": torch.stack(rec_out),
"maskaware_precision": torch.stack(map_out),
"merge_commission": torch.stack(merge_out),
}
[docs]
def binary_logits_to_score(logits):
assert logits.dim() in (1, 2)
if logits.dim() == 2: #multi-class logits
assert logits.size(1) == 2, "Only binary classification"
score = F.softmax(logits, dim=1)[:, 1]
else:
score = logits
return score
[docs]
def multiclass_logits_to_pred(logits):
"""Converts multi-class logits into predictions.
This function takes a tensor of logits with shape (batch_size, ..., n_classes)
and computes predictions by applying `argmax` along the last dimension.
Args:
logits (Tensor): A tensor of shape (batch_size, ..., n_classes) representing
multi-class logits.
Returns:
Tensor: A tensor containing predicted class indices.
"""
assert logits.dim() > 1
return logits.argmax(-1)
[docs]
def binary_logits_to_pred(logits):
return (logits > 0).long()
[docs]
def pseudolabel_binary_logits(logits, confidence_threshold):
"""Applies a confidence threshold to binary logits and generates pseudo- labels.
Args:
logits (Tensor): A tensor of shape (batch_size, n_tasks) representing binary logits.
A positive value (>0) indicates a positive prediction for the corresponding (example, task).
confidence_threshold (float): A threshold in the range [0,1] used to filter predictions.
Returns:
tuple:
- unlabeled_y_pred (Tensor): A filtered version of `logits`, discarding rows (examples)
where no predictions exceed the confidence threshold.
- unlabeled_y_pseudo (Tensor): A hard pseudo-labeled version of `logits`, where entries
below the confidence threshold are set to NaN. Rows with no
confident predictions are discarded.
- pseudolabels_kept_frac (float): The fraction of (example, task) pairs that are not set
to NaN or discarded.
- mask (Tensor): A mask indicating which predictions meet the confidence threshold.
"""
if len(logits.shape) != 2:
raise ValueError('Logits must be 2-dimensional.')
probs = 1 / (1 + torch.exp(-logits))
mask = (torch.max(probs, 1 - probs) >= confidence_threshold)
unlabeled_y_pseudo = (logits > 0).float()
unlabeled_y_pseudo[~mask] = float('nan')
# mask is bool, so no .mean()
pseudolabels_kept_frac = mask.sum() / mask.numel()
example_mask = torch.any(~torch.isnan(unlabeled_y_pseudo), dim=1)
unlabeled_y_pseudo = unlabeled_y_pseudo[example_mask]
unlabeled_y_pred = logits[example_mask]
return (unlabeled_y_pred, unlabeled_y_pseudo, pseudolabels_kept_frac,
example_mask)
[docs]
def pseudolabel_multiclass_logits(logits, confidence_threshold):
"""Applies a confidence threshold to multi-class logits and generates pseudo-labels.
Args:
logits (Tensor): A tensor of shape (batch_size, ..., n_classes) representing multi-class logits.
confidence_threshold (float): A threshold in the range [0,1] used to filter predictions.
Returns:
tuple:
- unlabeled_y_pred (Tensor): A filtered version of `logits`, discarding rows (examples)
where no predictions exceed the confidence threshold.
- unlabeled_y_pseudo (Tensor): A hard pseudo-labeled version of `logits`, where examples
with confidence below the threshold are discarded.
- pseudolabels_kept_frac (float): The fraction of examples retained after filtering.
- mask (Tensor): A mask indicating which predictions meet the confidence threshold.
"""
mask = torch.max(F.softmax(logits, -1), -1)[0] >= confidence_threshold
unlabeled_y_pseudo = multiclass_logits_to_pred(logits)
unlabeled_y_pseudo = unlabeled_y_pseudo[mask]
unlabeled_y_pred = logits[mask]
# mask is bool, so no .mean()
pseudolabels_kept_frac = mask.sum() / mask.numel()
return unlabeled_y_pred, unlabeled_y_pseudo, pseudolabels_kept_frac, mask
[docs]
def pseudolabel_identity(logits, confidence_threshold):
return logits, logits, 1, None
[docs]
def pseudolabel_detection(preds, confidence_threshold):
"""Filters detection predictions based on a confidence threshold.
Args:
preds (List[dict]): A list of length `batch_size`, where each entry is a dictionary
containing the following keys:
- 'boxes' (Tensor): Bounding box coordinates.
- 'labels' (Tensor): Class labels for detected objects.
- 'scores' (Tensor): Confidence scores for each detection.
- 'losses' (dict): An empty dictionary (not used).
confidence_threshold (float): A threshold in the range [0,1] used to filter predictions.
Returns:
List[dict]: A filtered version of `preds`, where detections with confidence scores
below `confidence_threshold` are removed.
"""
preds, pseudolabels_kept_frac = _mask_pseudolabels_detection(
preds, confidence_threshold)
unlabeled_y_pred = [{
'boxes': pred['boxes'],
'labels': pred['labels'],
'scores': pred['scores'],
'losses': pred['losses'],
} for pred in preds]
unlabeled_y_pseudo = [{
'boxes': pred['boxes'],
'labels': pred['labels'],
} for pred in preds]
# Keep all examples even if they don't have any confident-enough predictions
# They will be treated as empty images
example_mask = torch.ones(len(preds), dtype=torch.bool)
return (unlabeled_y_pred, unlabeled_y_pseudo, pseudolabels_kept_frac,
example_mask)
[docs]
def pseudolabel_detection_discard_empty(preds, confidence_threshold):
"""Filters detection predictions based on a confidence threshold and discards empty entries.
Args:
preds (List[dict]): A list of length `batch_size`, where each entry is a dictionary
containing the following keys:
- 'boxes' (Tensor): Bounding box coordinates.
- 'labels' (Tensor): Class labels for detected objects.
- 'scores' (Tensor): Confidence scores for each detection.
- 'losses' (dict): An empty dictionary (not used).
confidence_threshold (float): A threshold in the range [0,1] used to filter predictions.
Returns:
List[dict]: A filtered version of `preds`, where detections with confidence scores
below `confidence_threshold` are removed. Entries with no remaining detections
are discarded from the list.
"""
preds, pseudolabels_kept_frac = _mask_pseudolabels_detection(
preds, confidence_threshold)
unlabeled_y_pred = [{
'boxes': pred['boxes'],
'labels': pred['labels'],
'scores': pred['scores'],
'losses': pred['losses'],
} for pred in preds if len(pred['labels']) > 0]
unlabeled_y_pseudo = [{
'boxes': pred['boxes'],
'labels': pred['labels'],
} for pred in preds if len(pred['labels']) > 0]
example_mask = torch.tensor([len(pred['labels']) > 0 for pred in preds])
return unlabeled_y_pred, unlabeled_y_pseudo, pseudolabels_kept_frac, example_mask
def _mask_pseudolabels_detection(preds, confidence_threshold):
total_boxes = 0.0
kept_boxes = 0.0
preds = copy.deepcopy(preds)
for pred in preds:
mask = (pred['scores'] >= confidence_threshold)
pred['boxes'] = pred['boxes'][mask]
pred['labels'] = pred['labels'][mask]
pred['scores'] = pred['scores'][mask]
total_boxes += len(mask)
kept_boxes += mask.sum()
pseudolabels_kept_frac = kept_boxes / total_boxes
return preds, pseudolabels_kept_frac
[docs]
class Accuracy(ElementwiseMetric):
def __init__(self, prediction_fn=None, name=None):
self.prediction_fn = prediction_fn
if name is None:
name = 'acc'
super().__init__(name=name)
def _compute_element_wise(self, y_pred, y_true):
if self.prediction_fn is not None:
y_pred = self.prediction_fn(y_pred)
return torch.round((y_pred == y_true).float(), decimals=3)
[docs]
def worst(self, metrics):
return torch.round(minimum(metrics), decimals=3)
[docs]
class MultiTaskAccuracy(MultiTaskMetric):
def __init__(self, prediction_fn=None, name=None):
self.prediction_fn = prediction_fn # should work on flattened inputs
if name is None:
name = 'acc'
super().__init__(name=name)
def _compute_flattened(self, flattened_y_pred, flattened_y_true):
if self.prediction_fn is not None:
flattened_y_pred = self.prediction_fn(flattened_y_pred)
return (flattened_y_pred == flattened_y_true).float()
[docs]
def worst(self, metrics):
return minimum(metrics)
[docs]
class MultiTaskAveragePrecision(MultiTaskMetric):
def __init__(self, prediction_fn=None, name=None, average='macro'):
self.prediction_fn = prediction_fn
if name is None:
name = f'avgprec'
if average is not None:
name += f'-{average}'
self.average = average
super().__init__(name=name)
def _compute_flattened(self, flattened_y_pred, flattened_y_true):
if self.prediction_fn is not None:
flattened_y_pred = self.prediction_fn(flattened_y_pred)
ytr = np.array(flattened_y_true.squeeze().detach().cpu().numpy() > 0)
ypr = flattened_y_pred.squeeze().detach().cpu().numpy()
score = sklearn.metrics.average_precision_score(ytr,
ypr,
average=self.average)
to_ret = torch.tensor(score).to(flattened_y_pred.device)
return to_ret
def _compute_group_wise(self, y_pred, y_true, g, n_groups):
group_metrics = []
group_counts = get_counts(g, n_groups)
for group_idx in range(n_groups):
if group_counts[group_idx] == 0:
group_metrics.append(torch.tensor(0., device=g.device))
else:
flattened_metrics, _ = self.compute_flattened(
y_pred[g == group_idx],
y_true[g == group_idx],
return_dict=False)
group_metrics.append(flattened_metrics)
group_metrics = torch.stack(group_metrics)
worst_group_metric = self.worst(group_metrics[group_counts > 0])
return group_metrics, group_counts, worst_group_metric
[docs]
def worst(self, metrics):
return minimum(metrics)
[docs]
class Recall(Metric):
def __init__(self, prediction_fn=None, name=None, average='binary'):
self.prediction_fn = prediction_fn
if name is None:
name = f'recall'
if average is not None:
name += f'-{average}'
self.average = average
super().__init__(name=name)
def _compute(self, y_pred, y_true):
if self.prediction_fn is not None:
y_pred = self.prediction_fn(y_pred)
recall = sklearn.metrics.recall_score(y_true,
y_pred,
average=self.average,
labels=torch.unique(y_true))
return torch.tensor(recall)
[docs]
def worst(self, metrics):
return minimum(metrics)
[docs]
class F1(Metric):
def __init__(self, prediction_fn=None, name=None, average='binary'):
self.prediction_fn = prediction_fn
if name is None:
name = f'F1'
if average is not None:
name += f'-{average}'
self.average = average
super().__init__(name=name)
def _compute(self, y_pred, y_true):
if self.prediction_fn is not None:
y_pred = self.prediction_fn(y_pred)
score = sklearn.metrics.f1_score(y_true,
y_pred,
average=self.average,
labels=torch.unique(y_true))
return torch.tensor(score)
[docs]
def worst(self, metrics):
return minimum(metrics)
[docs]
class PearsonCorrelation(Metric):
def __init__(self, name=None):
if name is None:
name = 'r'
super().__init__(name=name)
def _compute(self, y_pred, y_true):
r = pearsonr(y_pred.squeeze().detach().cpu().numpy(),
y_true.squeeze().detach().cpu().numpy())[0]
return torch.tensor(r)
[docs]
def worst(self, metrics):
return minimum(metrics)
[docs]
def mse_loss(out, targets):
assert out.size() == targets.size()
if out.numel() == 0:
return torch.Tensor()
else:
assert out.dim(
) > 1, 'MSE loss currently supports Tensors of dimensions > 1'
losses = (out - targets)**2
reduce_dims = tuple(list(range(1, len(targets.shape))))
losses = torch.mean(losses, dim=reduce_dims)
return losses
[docs]
class MSE(ElementwiseLoss):
def __init__(self, name=None):
if name is None:
name = 'mse'
super().__init__(name=name, loss_fn=mse_loss)
[docs]
class PrecisionAtRecall(Metric):
"""Given a specific model threshold, determine the precision score achieved."""
def __init__(self, threshold, score_fn=None, name=None):
self.score_fn = score_fn
self.threshold = threshold
if name is None:
name = "precision_at_global_recall"
super().__init__(name=name)
def _compute(self, y_pred, y_true):
score = self.score_fn(y_pred)
predictions = (score > self.threshold)
return torch.tensor(sklearn.metrics.precision_score(
y_true, predictions))
[docs]
def worst(self, metrics):
return minimum(metrics)
[docs]
class DummyMetric(Metric):
"""For testing purposes.
This Metric always returns -1.
"""
def __init__(self, prediction_fn=None, name=None):
self.prediction_fn = prediction_fn
if name is None:
name = 'dummy'
super().__init__(name=name)
def _compute(self, y_pred, y_true):
return torch.tensor(-1)
def _compute_group_wise(self, y_pred, y_true, g, n_groups):
group_metrics = torch.ones(n_groups, device=g.device) * -1
group_counts = get_counts(g, n_groups)
worst_group_metric = self.worst(group_metrics)
return group_metrics, group_counts, worst_group_metric
[docs]
def worst(self, metrics):
return minimum(metrics)
[docs]
class DetectionAccuracy(ElementwiseMetric):
"""Per-image detection recall or accuracy with greedy 1:1 IoU matching."""
def __init__(self,
iou_threshold=0.4,
score_threshold=0.1,
name=None,
geometry_name="boxes",
metric="accuracy"):
self.iou_threshold = iou_threshold
self.score_threshold = score_threshold
self.geometry_name = geometry_name
self.metric = metric
if name is None:
name = "detection_{}".format(metric)
super().__init__(name=name)
def _compute_element_wise(self, y_pred, y_true):
batch_results = []
for gt, target in zip(y_true, y_pred):
target_boxes = target[self.geometry_name]
target_scores = target["scores"]
gt_boxes = gt[self.geometry_name]
if target_boxes.dim() == 1:
target_boxes = target_boxes.view(-1, 4)
pred_boxes = target_boxes[target_scores > self.score_threshold]
if self.metric == "accuracy":
det_accuracy = self._accuracy(gt_boxes, pred_boxes,
self.iou_threshold)
elif self.metric == "recall":
det_accuracy = self._recall(gt_boxes, pred_boxes,
self.iou_threshold)
batch_results.append(det_accuracy)
return torch.tensor(batch_results)
def _recall(self, src_boxes, pred_boxes, iou_threshold):
total_gt = len(src_boxes)
total_pred = len(pred_boxes)
if total_gt > 0 and total_pred > 0:
iou = box_iou(src_boxes, pred_boxes)
gt_to_pred = greedy_iou_match(iou, iou_threshold)
tp = n_matched_gt(gt_to_pred)
return tp / float(total_gt)
if total_gt == 0:
return torch.tensor(0.) if total_pred > 0 else torch.tensor(1.)
return torch.tensor(0.)
def _accuracy(self, src_boxes, pred_boxes, iou_threshold):
total_gt = len(src_boxes)
total_pred = len(pred_boxes)
if total_gt > 0 and total_pred > 0:
iou = box_iou(src_boxes, pred_boxes)
gt_to_pred = greedy_iou_match(iou, iou_threshold)
tp = n_matched_gt(gt_to_pred)
fp = float(total_pred) - float(tp)
fn = float(total_gt) - float(tp)
return torch.tensor(float(tp / (tp + fp + fn)))
if total_gt == 0:
return torch.tensor(0.) if total_pred > 0 else torch.tensor(1.)
return torch.tensor(0.)
[docs]
def worst(self, metrics):
return minimum(metrics)
[docs]
class MaskAwareDetectionPrecision(ElementwiseMetric):
"""Precision metric that avoids penalizing predictions on unannotated tree regions.
Unmatched predictions are excluded from false positives when enough of their box area overlaps
tree pixels from ``tree_coverage_mask``.
"""
def __init__(self,
iou_threshold=0.4,
score_threshold=0.1,
tree_fraction_threshold=0.5,
require_tree_coverage_mask=False,
name=None,
geometry_name="boxes",
tree_coverage_key="tree_coverage_mask"):
self.iou_threshold = iou_threshold
self.score_threshold = score_threshold
self.tree_fraction_threshold = tree_fraction_threshold
self.require_tree_coverage_mask = require_tree_coverage_mask
self.geometry_name = geometry_name
self.tree_coverage_key = tree_coverage_key
if name is None:
name = "maskaware_detection_precision"
super().__init__(name=name)
def _compute_element_wise(self, y_pred, y_true):
batch_results = []
for gt, target in zip(y_true, y_pred):
target_boxes = target[self.geometry_name]
target_scores = target["scores"]
gt_boxes = gt[self.geometry_name]
if target_boxes.dim() == 1:
target_boxes = target_boxes.view(-1, 4)
pred_boxes = target_boxes[target_scores > self.score_threshold]
tree_mask = gt.get(self.tree_coverage_key)
det_precision = self._precision(gt_boxes, pred_boxes, tree_mask,
self.iou_threshold)
batch_results.append(det_precision)
return torch.tensor(batch_results)
def _prepare_tree_mask(self, tree_mask):
if tree_mask is None:
return None
if not isinstance(tree_mask, torch.Tensor):
tree_mask = torch.as_tensor(tree_mask)
if tree_mask.dim() == 3:
tree_mask = tree_mask.squeeze(0)
if tree_mask.dim() != 2:
raise ValueError(
f"Expected tree coverage mask to have shape [H, W], got {tuple(tree_mask.shape)}"
)
return tree_mask.bool()
def _tree_pixel_fraction(self, box, tree_mask):
height, width = tree_mask.shape
x1 = int(torch.floor(box[0]).item())
y1 = int(torch.floor(box[1]).item())
x2 = int(torch.ceil(box[2]).item())
y2 = int(torch.ceil(box[3]).item())
x1 = min(max(x1, 0), width)
x2 = min(max(x2, 0), width)
y1 = min(max(y1, 0), height)
y2 = min(max(y2, 0), height)
if x2 <= x1 or y2 <= y1:
return 0.0
crop = tree_mask[y1:y2, x1:x2]
return float(crop.float().mean().item())
def _count_ignored_unmatched_predictions(self,
unmatched_boxes,
tree_mask,
gt_boxes=None,
iou_threshold=None):
if tree_mask is None or len(unmatched_boxes) == 0:
return 0
ignored_count = 0
for box in unmatched_boxes:
if (gt_boxes is not None and len(gt_boxes) > 0 and
iou_threshold is not None):
max_iou = box_iou(box.unsqueeze(0), gt_boxes).max().item()
if max_iou > iou_threshold:
continue
tree_fraction = self._tree_pixel_fraction(box, tree_mask)
if tree_fraction >= self.tree_fraction_threshold:
ignored_count += 1
return ignored_count
def _precision(self, src_boxes, pred_boxes, tree_mask, iou_threshold):
total_gt = len(src_boxes)
total_pred = len(pred_boxes)
tree_mask = self._prepare_tree_mask(tree_mask)
if self.require_tree_coverage_mask and total_pred > 0 and tree_mask is None:
raise ValueError(
"tree_coverage_mask is required but missing for this example")
if total_pred == 0:
return torch.tensor(0.) if total_gt > 0 else torch.tensor(1.)
if total_gt == 0:
unmatched_false_positive = total_pred
ignored_unmatched = self._count_ignored_unmatched_predictions(
pred_boxes, tree_mask)
adjusted_false_positive = unmatched_false_positive - ignored_unmatched
if adjusted_false_positive == 0:
return torch.tensor(1.)
return torch.tensor(0.)
iou = box_iou(src_boxes, pred_boxes)
gt_to_pred = greedy_iou_match(iou, iou_threshold)
true_positive = n_matched_gt(gt_to_pred)
matched = gt_to_pred[gt_to_pred >= 0]
matched_pred_idx = matched.unique()
unmatched_mask = torch.ones(total_pred,
dtype=torch.bool,
device=pred_boxes.device)
if matched_pred_idx.numel() > 0:
unmatched_mask[matched_pred_idx.long()] = False
unmatched_indices = torch.nonzero(unmatched_mask,
as_tuple=False).squeeze(1)
unmatched_boxes = pred_boxes[unmatched_indices]
ignored_unmatched = self._count_ignored_unmatched_predictions(
unmatched_boxes, tree_mask, src_boxes, iou_threshold)
adjusted_false_positive = float(
unmatched_indices.numel()) - float(ignored_unmatched)
denominator = true_positive + adjusted_false_positive
if float(denominator) == 0:
return torch.tensor(1.)
return (true_positive / denominator).clone()
[docs]
def worst(self, metrics):
return minimum(metrics)
[docs]
class KeypointAccuracy(ElementwiseMetric):
"""Keypoint recall for a one-class detector: tp / (tp + fn).
The ``distance_threshold`` is interpreted as a **normalized** distance with
respect to the image size rather than raw pixels. For a square image of
side length ``image_size``, the effective pixel threshold is
pixel_threshold = distance_threshold * image_size
This makes the metric less sensitive to the absolute crop size while still
behaving like a fixed-radius matching rule in pixel space.
"""
def __init__(
self,
distance_threshold: float = 0.02,
score_threshold: float = 0.1,
name: str | None = None,
geometry_name: str = "y",
image_size: int = 448,
):
self.distance_threshold = distance_threshold
self.score_threshold = score_threshold
self.geometry_name = geometry_name
self.image_size = image_size
if name is None:
name = "keypoint_acc"
super().__init__(name=name)
def _compute_element_wise(self, y_pred, y_true):
batch_results = []
for gt, target in zip(y_true, y_pred):
target_boxes = target[self.geometry_name]
target_scores = target["scores"]
gt_boxes = gt[self.geometry_name]
pred_boxes = target_boxes[target_scores > self.score_threshold]
det_accuracy = self._accuracy(gt_boxes, pred_boxes)
batch_results.append(det_accuracy)
return torch.tensor(batch_results)
def _point_nearness(self, src_keypoints, pred_keypoints):
distance = torch.cdist(src_keypoints.float(),
pred_keypoints.float(),
p=2)
return distance
def _accuracy(self, src_keypoints, pred_keypoints):
total_gt = len(src_keypoints)
total_pred = len(pred_keypoints)
if total_gt > 0 and total_pred > 0:
distance_matrix = self._point_nearness(src_keypoints,
pred_keypoints)
pixel_threshold = self.distance_threshold * float(self.image_size)
gt_to_pred = greedy_distance_match(distance_matrix, pixel_threshold)
tp = n_matched_gt(gt_to_pred)
fn = float(total_gt) - float(tp)
return torch.tensor(float(tp / (tp + fn)))
elif total_gt == 0:
if total_pred > 0:
return torch.round(torch.tensor(0.), decimals=3)
else:
return torch.round(torch.tensor(1.), decimals=3)
elif total_gt > 0 and total_pred == 0:
return torch.round(torch.tensor(0.), decimals=3)
[docs]
def worst(self, metrics):
return torch.round(minimum(metrics), decimals=3)
[docs]
class MaskAwareKeypointPrecision(ElementwiseMetric):
"""Precision for point detection that ignores unmatched points on tree-covered pixels."""
def __init__(self,
distance_threshold: float = 0.02,
score_threshold: float = 0.1,
tree_fraction_threshold: float = 0.5,
require_tree_coverage_mask: bool = False,
name: str | None = None,
geometry_name: str = "y",
tree_coverage_key: str = "tree_coverage_mask",
image_size: int = 448):
self.distance_threshold = distance_threshold
self.score_threshold = score_threshold
self.tree_fraction_threshold = tree_fraction_threshold
self.require_tree_coverage_mask = require_tree_coverage_mask
self.geometry_name = geometry_name
self.tree_coverage_key = tree_coverage_key
self.image_size = image_size
if name is None:
name = "maskaware_keypoint_precision"
super().__init__(name=name)
def _compute_element_wise(self, y_pred, y_true):
batch_results = []
for gt, target in zip(y_true, y_pred):
target_points = target[self.geometry_name]
target_scores = target["scores"]
gt_points = gt[self.geometry_name]
pred_points = target_points[target_scores > self.score_threshold]
tree_mask = gt.get(self.tree_coverage_key)
precision = self._precision(gt_points, pred_points, tree_mask)
batch_results.append(precision)
return torch.tensor(batch_results)
def _prepare_tree_mask(self, tree_mask):
if tree_mask is None:
return None
if not isinstance(tree_mask, torch.Tensor):
tree_mask = torch.as_tensor(tree_mask)
if tree_mask.dim() == 3:
tree_mask = tree_mask.squeeze(0)
if tree_mask.dim() != 2:
raise ValueError(
f"Expected tree coverage mask to have shape [H, W], got {tuple(tree_mask.shape)}"
)
return tree_mask.bool()
def _point_tree_fraction(self, point, tree_mask):
height, width = tree_mask.shape
x = int(round(float(point[0])))
y = int(round(float(point[1])))
x = min(max(x, 0), width - 1)
y = min(max(y, 0), height - 1)
return float(tree_mask[y, x].float().item())
def _count_ignored_unmatched_predictions(self,
unmatched_points,
tree_mask,
gt_points=None,
max_distance=None):
if tree_mask is None or len(unmatched_points) == 0:
return 0
ignored_count = 0
for point in unmatched_points:
if (gt_points is not None and len(gt_points) > 0 and
max_distance is not None):
d = torch.norm(gt_points.float() - point.float(),
dim=1).min().item()
if d <= max_distance:
continue
tree_fraction = self._point_tree_fraction(point, tree_mask)
if tree_fraction >= self.tree_fraction_threshold:
ignored_count += 1
return ignored_count
def _precision(self, src_points, pred_points, tree_mask):
total_gt = len(src_points)
total_pred = len(pred_points)
tree_mask = self._prepare_tree_mask(tree_mask)
if self.require_tree_coverage_mask and total_pred > 0 and tree_mask is None:
raise ValueError(
"tree_coverage_mask is required but missing for this example")
if total_pred == 0:
return torch.tensor(0.) if total_gt > 0 else torch.tensor(1.)
if total_gt == 0:
ignored_unmatched = self._count_ignored_unmatched_predictions(
pred_points, tree_mask)
adjusted_false_positive = total_pred - ignored_unmatched
if adjusted_false_positive == 0:
return torch.tensor(1.)
return torch.tensor(0.)
distance_matrix = torch.cdist(src_points.float(),
pred_points.float(),
p=2)
pixel_threshold = self.distance_threshold * float(self.image_size)
gt_to_pred = greedy_distance_match(distance_matrix, pixel_threshold)
true_positive = n_matched_gt(gt_to_pred)
matched = gt_to_pred[gt_to_pred >= 0]
matched_pred_idx = matched.unique()
unmatched_mask = torch.ones(total_pred,
dtype=torch.bool,
device=pred_points.device)
if matched_pred_idx.numel() > 0:
unmatched_mask[matched_pred_idx.long()] = False
unmatched_indices = torch.nonzero(unmatched_mask,
as_tuple=False).squeeze(1)
unmatched_points = pred_points[unmatched_indices]
ignored_unmatched = self._count_ignored_unmatched_predictions(
unmatched_points, tree_mask, src_points, pixel_threshold)
adjusted_false_positive = float(
unmatched_indices.numel()) - float(ignored_unmatched)
denominator = true_positive + adjusted_false_positive
if float(denominator) == 0:
return torch.tensor(1.)
return (true_positive / denominator).clone()
[docs]
def worst(self, metrics):
return minimum(metrics)
[docs]
class MaskAccuracy(ElementwiseMetric):
"""Per-image mask recall or accuracy with greedy 1:1 mask IoU matching."""
def __init__(self,
iou_threshold=0.4,
score_threshold=0.1,
name=None,
geometry_name="masks",
metric="accuracy"):
self.iou_threshold = iou_threshold
self.score_threshold = score_threshold
self.geometry_name = geometry_name
self.metric = metric
if name is None:
name = "mask_acc"
super().__init__(name=name)
def _compute_element_wise(self, y_pred, y_true):
batch_results = []
for gt, target in zip(y_true, y_pred):
target_masks = target[self.geometry_name]
target_scores = target["scores"]
gt_masks = gt[self.geometry_name]
# Convert to tensors if needed
if not isinstance(target_scores, torch.Tensor):
target_scores = torch.as_tensor(target_scores,
dtype=torch.float32)
pred_masks = target_masks[target_scores > self.score_threshold]
if self.metric == "recall":
det_accuracy = self._recall(gt_masks, pred_masks,
self.iou_threshold)
else:
det_accuracy = self._accuracy(gt_masks, pred_masks,
self.iou_threshold)
batch_results.append(det_accuracy)
return torch.tensor(batch_results)
def _boxes_to_masks(self, boxes, height, width):
"""Convert bounding boxes [N, 4] (xyxy format) to masks [N, H, W]."""
if len(boxes) == 0:
device = boxes.device if isinstance(boxes, torch.Tensor) else 'cpu'
return torch.zeros((0, height, width),
dtype=torch.bool,
device=device)
# Convert to tensor if needed
if not isinstance(boxes, torch.Tensor):
boxes = torch.as_tensor(boxes, dtype=torch.float32)
boxes = boxes.clone()
# Clamp boxes to image bounds
boxes[:, 0] = torch.clamp(boxes[:, 0], 0, width)
boxes[:, 1] = torch.clamp(boxes[:, 1], 0, height)
boxes[:, 2] = torch.clamp(boxes[:, 2], 0, width)
boxes[:, 3] = torch.clamp(boxes[:, 3], 0, height)
device = boxes.device
masks = torch.zeros((len(boxes), height, width),
dtype=torch.bool,
device=device)
for i, box in enumerate(boxes):
x1, y1, x2, y2 = box.int()
# Ensure valid box
if x2 > x1 and y2 > y1:
masks[i, y1:y2, x1:x2] = True
return masks
def _mask_iou(self, src_masks, pred_masks):
# Convert to tensors if needed (preserve original dtype for shape detection)
src_is_tensor = isinstance(src_masks, torch.Tensor)
pred_is_tensor = isinstance(pred_masks, torch.Tensor)
if not src_is_tensor:
src_masks = torch.as_tensor(src_masks)
if not pred_is_tensor:
pred_masks = torch.as_tensor(pred_masks)
# Handle case where pred_masks are actually bounding boxes [M, 4]
# Check if pred_masks are boxes (shape [M, 4]) instead of masks [M, H, W]
# For empty tensors, check the shape tuple
is_pred_boxes = (pred_masks.dim() == 2 and
(len(pred_masks) == 0 or
(pred_masks.shape[1] == 4 and pred_masks.dim() == 2)))
if is_pred_boxes:
# Get image dimensions from src_masks
if len(src_masks) > 0 and src_masks.dim() == 3:
height, width = src_masks.shape[1], src_masks.shape[2]
# Convert boxes to masks
pred_masks = self._boxes_to_masks(pred_masks, height, width)
else:
# If no ground truth masks, return zero IoU
device = pred_masks.device if isinstance(
pred_masks, torch.Tensor) else 'cpu'
return torch.zeros((0, len(pred_masks)),
dtype=torch.float32,
device=device)
# Handle case where src_masks are boxes (shouldn't happen, but handle gracefully)
is_src_boxes = (src_masks.dim() == 2 and
(len(src_masks) == 0 or
(src_masks.shape[1] == 4 and src_masks.dim() == 2)))
if is_src_boxes:
if len(pred_masks) > 0 and pred_masks.dim() == 3:
height, width = pred_masks.shape[1], pred_masks.shape[2]
src_masks = self._boxes_to_masks(src_masks, height, width)
else:
device = src_masks.device if isinstance(src_masks,
torch.Tensor) else 'cpu'
return torch.zeros((len(src_masks), 0),
dtype=torch.float32,
device=device)
# Ensure masks are bool type for bitwise operations
if src_masks.dtype != torch.bool:
src_masks = src_masks.bool()
if pred_masks.dtype != torch.bool:
pred_masks = pred_masks.bool()
# Memory optimization: Use bbox IoU to pre-filter before computing expensive mask IoU
# This reduces memory usage from O(N*M*H*W) to O(N*M) for filtering, then only
# compute mask IoU for candidate pairs
device = src_masks.device
N, M = len(src_masks), len(pred_masks)
if N == 0 or M == 0:
return torch.zeros((N, M), dtype=torch.float32, device=device)
# Compute bboxes from masks for pre-filtering.
# Torchvision's masks_to_boxes errors if any individual mask is empty (all zeros),
# so compute boxes only for non-empty masks and zero-fill the rest.
src_nonempty = src_masks.flatten(1).any(dim=1)
pred_nonempty = pred_masks.flatten(1).any(dim=1)
src_boxes = torch.zeros((N, 4), dtype=torch.float32, device=device)
pred_boxes = torch.zeros((M, 4), dtype=torch.float32, device=device)
if src_nonempty.any():
src_boxes[src_nonempty] = masks_to_boxes(src_masks[src_nonempty])
if pred_nonempty.any():
pred_boxes[pred_nonempty] = masks_to_boxes(
pred_masks[pred_nonempty])
# Compute bbox IoU for all pairs (cheap: O(N*M))
bbox_iou = box_iou(src_boxes, pred_boxes) # [N, M]
bbox_iou[~src_nonempty, :] = 0.0
bbox_iou[:, ~pred_nonempty] = 0.0
# Initialize IoU matrix with bbox IoU values (will be refined for ambiguous cases)
iou = bbox_iou.clone()
# Option 3: Hybrid bbox/mask IoU - use bbox IoU as approximation for obvious cases
# For very low bbox IoU (< 0.1), no mask overlap is possible
iou[bbox_iou < 0.1] = 0.0
# For very high bbox IoU (> 0.9), bbox IoU is a good approximation of mask IoU
# Only compute expensive mask IoU for ambiguous cases (0.1 <= bbox_iou <= 0.9)
ambiguous_mask = (bbox_iou >= 0.1) & (bbox_iou <= 0.9)
if ambiguous_mask.any():
# Get indices of ambiguous pairs that need mask IoU computation
ambiguous_gt_indices, ambiguous_pred_indices = torch.where(
ambiguous_mask)
num_ambiguous = len(ambiguous_gt_indices)
# Process ambiguous pairs in chunks to avoid creating huge tensors
# Even with downsampling, U_gt * U_pred * H * W can be massive
chunk_size = 500 # Process 500 ambiguous pairs at a time
target_size = 224 # Downsample to 224x224 for memory efficiency
for chunk_start in range(0, num_ambiguous, chunk_size):
chunk_end = min(chunk_start + chunk_size, num_ambiguous)
chunk_gt_idx = ambiguous_gt_indices[chunk_start:chunk_end]
chunk_pred_idx = ambiguous_pred_indices[chunk_start:chunk_end]
# Get unique indices for this chunk to avoid redundant mask loading
unique_gt_idx = torch.unique(chunk_gt_idx)
unique_pred_idx = torch.unique(chunk_pred_idx)
# Load masks for this chunk
chunk_src_masks = src_masks[unique_gt_idx] # [U_gt, H, W]
chunk_pred_masks = pred_masks[unique_pred_idx] # [U_pred, H, W]
# Option 1: Downsample masks for IoU computation to reduce memory
if chunk_src_masks.shape[1] > target_size:
# Downsample using nearest neighbor to preserve binary nature
chunk_src_masks = F.interpolate(
chunk_src_masks.unsqueeze(1).float(),
size=(target_size, target_size),
mode='nearest').squeeze(1).bool()
chunk_pred_masks = F.interpolate(
chunk_pred_masks.unsqueeze(1).float(),
size=(target_size, target_size),
mode='nearest').squeeze(1).bool()
# Create mapping from original indices to chunk indices
gt_idx_map = {
int(idx): i for i, idx in enumerate(unique_gt_idx)
}
pred_idx_map = {
int(idx): i for i, idx in enumerate(unique_pred_idx)
}
# Compute mask IoU for chunk pairs using vectorized operations
src_expanded = chunk_src_masks.unsqueeze(1) # [U_gt, 1, H, W]
pred_expanded = chunk_pred_masks.unsqueeze(
0) # [1, U_pred, H, W]
intersection = (src_expanded & pred_expanded).float().sum(
(2, 3)) # [U_gt, U_pred]
union = (src_expanded | pred_expanded).float().sum(
(2, 3)) # [U_gt, U_pred]
chunk_mask_iou = intersection / union.clamp(min=1e-6)
chunk_mask_iou[union == 0] = 0.0
# Map chunk results back to original indices
for i, j in zip(chunk_gt_idx, chunk_pred_idx):
orig_gt_idx = int(i)
orig_pred_idx = int(j)
chunk_gt_pos = gt_idx_map[orig_gt_idx]
chunk_pred_pos = pred_idx_map[orig_pred_idx]
iou[orig_gt_idx,
orig_pred_idx] = chunk_mask_iou[chunk_gt_pos,
chunk_pred_pos]
return iou # Returns [N, M] matrix
def _recall(self, src_masks, pred_masks, iou_threshold, *, iou=None):
total_gt = len(src_masks)
total_pred = len(pred_masks)
if total_gt > 0 and total_pred > 0:
if iou is None:
iou = self._mask_iou(src_masks, pred_masks)
gt_to_pred = greedy_iou_match(iou, iou_threshold)
tp = n_matched_gt(gt_to_pred)
return tp / float(total_gt)
if total_gt == 0:
return torch.tensor(0.) if total_pred > 0 else torch.tensor(1.)
return torch.tensor(0.)
def _accuracy(self, src_masks, pred_masks, iou_threshold, *, iou=None):
total_gt = len(src_masks)
total_pred = len(pred_masks)
if total_gt > 0 and total_pred > 0:
if iou is None:
iou = self._mask_iou(src_masks, pred_masks)
gt_to_pred = greedy_iou_match(iou, iou_threshold)
tp = n_matched_gt(gt_to_pred)
fp = float(total_pred) - float(tp)
fn = float(total_gt) - float(tp)
return torch.tensor(float(tp / (tp + fp + fn)))
if total_gt == 0:
return torch.tensor(0.) if total_pred > 0 else torch.tensor(1.)
return torch.tensor(0.)
[docs]
def worst(self, metrics):
return minimum(metrics)
[docs]
class MaskAwareMaskPrecision(ElementwiseMetric):
"""Precision for mask detection that ignores unmatched masks on tree-covered pixels."""
def __init__(self,
iou_threshold=0.4,
score_threshold=0.1,
tree_fraction_threshold=0.5,
require_tree_coverage_mask=False,
name=None,
geometry_name="masks",
tree_coverage_key="tree_coverage_mask"):
self.iou_threshold = iou_threshold
self.score_threshold = score_threshold
self.tree_fraction_threshold = tree_fraction_threshold
self.require_tree_coverage_mask = require_tree_coverage_mask
self.geometry_name = geometry_name
self.tree_coverage_key = tree_coverage_key
self._mask_accuracy = MaskAccuracy(iou_threshold=iou_threshold,
score_threshold=score_threshold,
geometry_name=geometry_name)
if name is None:
name = "maskaware_mask_precision"
super().__init__(name=name)
def _compute_element_wise(self, y_pred, y_true):
batch_results = []
for gt, target in zip(y_true, y_pred):
target_masks = target[self.geometry_name]
target_scores = target["scores"]
gt_masks = gt[self.geometry_name]
if not isinstance(target_scores, torch.Tensor):
target_scores = torch.as_tensor(target_scores,
dtype=torch.float32)
pred_masks = target_masks[target_scores > self.score_threshold]
tree_mask = gt.get(self.tree_coverage_key)
precision = self._precision(gt_masks, pred_masks, tree_mask)
batch_results.append(precision)
return torch.tensor(batch_results)
def _prepare_tree_mask(self, tree_mask):
if tree_mask is None:
return None
if not isinstance(tree_mask, torch.Tensor):
tree_mask = torch.as_tensor(tree_mask)
if tree_mask.dim() == 3:
tree_mask = tree_mask.squeeze(0)
if tree_mask.dim() != 2:
raise ValueError(
f"Expected tree coverage mask to have shape [H, W], got {tuple(tree_mask.shape)}"
)
return tree_mask.bool()
def _prepare_masks(self, masks):
if not isinstance(masks, torch.Tensor):
masks = torch.as_tensor(masks)
if masks.dtype != torch.bool:
masks = masks.bool()
return masks
def _mask_tree_fraction(self, pred_mask, tree_mask):
if pred_mask.dim() != 2:
# Not a spatial mask (e.g. box coords [4]) — can't determine tree coverage
return 0.0
pred_area = pred_mask.float().sum()
if float(pred_area) == 0:
return 0.0
overlap = (pred_mask & tree_mask).float().sum()
return float((overlap / pred_area).item())
def _count_ignored_unmatched_predictions(self,
unmatched_masks,
tree_mask,
*,
iou=None,
unmatched_indices=None,
iou_threshold=None):
if tree_mask is None or len(unmatched_masks) == 0:
return 0
ignored_count = 0
for local_idx, pred_mask in enumerate(unmatched_masks):
if (iou is not None and unmatched_indices is not None and
iou_threshold is not None and iou.numel() > 0):
pred_col = int(unmatched_indices[local_idx])
if iou[:, pred_col].max().item() > iou_threshold:
continue
tree_fraction = self._mask_tree_fraction(pred_mask, tree_mask)
if tree_fraction >= self.tree_fraction_threshold:
ignored_count += 1
return ignored_count
def _precision(self, src_masks, pred_masks, tree_mask, *, iou=None):
src_masks = self._prepare_masks(src_masks)
pred_masks = self._prepare_masks(pred_masks)
tree_mask = self._prepare_tree_mask(tree_mask)
total_gt = len(src_masks)
total_pred = len(pred_masks)
if self.require_tree_coverage_mask and total_pred > 0 and tree_mask is None:
raise ValueError(
"tree_coverage_mask is required but missing for this example")
if total_pred == 0:
return torch.tensor(0.) if total_gt > 0 else torch.tensor(1.)
if total_gt == 0:
ignored_unmatched = self._count_ignored_unmatched_predictions(
pred_masks, tree_mask)
adjusted_false_positive = total_pred - ignored_unmatched
if adjusted_false_positive == 0:
return torch.tensor(1.)
return torch.tensor(0.)
if iou is None:
iou = self._mask_accuracy._mask_iou(src_masks, pred_masks)
gt_to_pred = greedy_iou_match(iou, self.iou_threshold)
true_positive = n_matched_gt(gt_to_pred)
matched = gt_to_pred[gt_to_pred >= 0]
matched_pred_idx = matched.unique()
unmatched_mask = torch.ones(total_pred,
dtype=torch.bool,
device=pred_masks.device)
if matched_pred_idx.numel() > 0:
unmatched_mask[matched_pred_idx.long()] = False
unmatched_indices = torch.nonzero(unmatched_mask,
as_tuple=False).squeeze(1)
unmatched_masks = pred_masks[unmatched_indices]
ignored_unmatched = self._count_ignored_unmatched_predictions(
unmatched_masks,
tree_mask,
iou=iou,
unmatched_indices=unmatched_indices,
iou_threshold=self.iou_threshold,
)
adjusted_false_positive = float(
unmatched_indices.numel()) - float(ignored_unmatched)
denominator = true_positive + adjusted_false_positive
if float(denominator) == 0:
return torch.tensor(1.)
return (true_positive / denominator).clone()
[docs]
def worst(self, metrics):
return minimum(metrics)
[docs]
class MergeCommissionMetric(ElementwiseMetric):
"""Fraction of predictions with IoU > ``iou_threshold`` against two or more GT objects."""
def __init__(self,
iou_threshold: float = 0.4,
score_threshold: float = 0.1,
geometry_name: str = "y",
modality: str = "bbox",
name: str | None = None):
self.iou_threshold = iou_threshold
self.score_threshold = score_threshold
self.geometry_name = geometry_name
self.modality = modality
self._mask_iou_backend: MaskAccuracy | None = None
if modality == "mask":
self._mask_iou_backend = MaskAccuracy(
iou_threshold=iou_threshold,
score_threshold=score_threshold,
geometry_name=geometry_name,
metric="accuracy",
)
if name is None:
name = "merge_commission"
super().__init__(name=name)
def _compute_element_wise(self, y_pred, y_true):
out = []
for gt, target in zip(y_true, y_pred):
scores = target["scores"]
if not isinstance(scores, torch.Tensor):
scores = torch.as_tensor(scores, dtype=torch.float32)
geo = target[self.geometry_name]
if not isinstance(geo, torch.Tensor):
geo = torch.as_tensor(geo)
if self.modality == "bbox" and geo.dim() == 1:
geo = geo.view(-1, 4)
pred = geo[scores > self.score_threshold]
gt_geo = gt[self.geometry_name]
if not isinstance(gt_geo, torch.Tensor):
gt_geo = torch.as_tensor(gt_geo)
if self.modality == "bbox":
if len(gt_geo) == 0 or len(pred) == 0:
out.append(torch.tensor(0.0))
continue
iou = box_iou(gt_geo, pred)
else:
if len(gt_geo) == 0 or len(pred) == 0:
out.append(torch.tensor(0.0))
continue
assert self._mask_iou_backend is not None
iou = self._mask_iou_backend._mask_iou(gt_geo, pred)
out.append(
merge_commission_rate_iou(
iou, self.iou_threshold).to(dtype=torch.float32))
return torch.stack(out)
[docs]
def worst(self, metrics):
return maximum(metrics)
[docs]
class KeypointMergeCommissionMetric(ElementwiseMetric):
"""Fraction of predictions within ``max_distance`` of two or more GT points."""
def __init__(self,
distance_threshold: float = 0.02,
score_threshold: float = 0.1,
geometry_name: str = "y",
image_size: int = 448,
name: str | None = None):
self.distance_threshold = distance_threshold
self.score_threshold = score_threshold
self.geometry_name = geometry_name
self.image_size = image_size
if name is None:
name = "keypoint_merge_commission"
super().__init__(name=name)
def _compute_element_wise(self, y_pred, y_true):
pixel_threshold = self.distance_threshold * float(self.image_size)
out = []
for gt, target in zip(y_true, y_pred):
pts = target[self.geometry_name]
scores = target["scores"]
pred = pts[scores > self.score_threshold]
gt_pts = gt[self.geometry_name]
if len(gt_pts) == 0 or len(pred) == 0:
out.append(torch.tensor(0.0))
continue
dist = torch.cdist(gt_pts.float(), pred.float(), p=2)
out.append(
merge_commission_rate_distance(
dist, pixel_threshold).to(dtype=torch.float32))
return torch.stack(out)
[docs]
def worst(self, metrics):
return maximum(metrics)
[docs]
class DetectionMAP(Metric):
"""Average Precision for object detection using torchmetrics.
Supports bounding boxes (iou_type="bbox") and instance segmentation masks (iou_type="segm").
Single-class: all labels are normalised to 0 so that predictions always match the ground-truth
class regardless of the model's raw label output.
``iou_thresholds`` controls the IoU threshold(s) AP is computed at. Pass ``[0.5]`` (the default
for TreeBoxes/TreePolygons) for PASCAL-style AP@0.5; pass ``None`` for COCO-style mAP averaged
over IoU 0.50:0.05:0.95.
"""
def __init__(self,
geometry_name="y",
score_threshold=0.1,
iou_type="bbox",
iou_thresholds=None,
max_detection_thresholds=None,
name=None):
self.geometry_name = geometry_name
self.score_threshold = score_threshold
self.iou_type = iou_type
self.iou_thresholds = iou_thresholds
self.max_detection_thresholds = max_detection_thresholds
if name is None:
name = "AP50" if iou_thresholds == [0.5] else "mAP"
super().__init__(name=name)
@property
def agg_metric_field(self):
return f'{self.name}_avg'
def _format(self, y_pred, y_true):
"""Convert MillionTrees dicts to the list-of-dicts format expected by torchmetrics
MeanAveragePrecision."""
key = "masks" if self.iou_type == "segm" else "boxes"
preds, targets = [], []
for pred, gt in zip(y_pred, y_true):
scores = pred["scores"]
if not isinstance(scores, torch.Tensor):
scores = torch.as_tensor(scores, dtype=torch.float32)
keep = scores > self.score_threshold
geo = pred[self.geometry_name]
if not isinstance(geo, torch.Tensor):
geo = torch.as_tensor(geo)
if self.iou_type == "bbox" and geo.dim() == 1:
geo = geo.view(-1, 4)
gt_geo = gt[self.geometry_name]
if not isinstance(gt_geo, torch.Tensor):
gt_geo = torch.as_tensor(gt_geo)
n_pred = int(keep.sum())
n_gt = len(gt_geo)
if self.iou_type == "segm":
# geo must be [N, H, W] masks; if it's box/point coords (wrong ndim)
# treat as no predictions — segm mAP is 0 for non-mask models
if geo.dim() != 3:
n_pred = 0
pred_masks = torch.zeros((0, 1, 1), dtype=torch.bool)
pred_scores = torch.zeros(0, dtype=torch.float32)
else:
pred_masks = geo[keep].bool()
pred_scores = scores[keep].float()
preds.append({
"masks": pred_masks,
"scores": pred_scores,
"labels": torch.zeros(n_pred, dtype=torch.long),
})
targets.append({
"masks": gt_geo.bool(),
"labels": torch.zeros(n_gt, dtype=torch.long),
})
else:
preds.append({
"boxes": geo[keep].float(),
"scores": scores[keep].float(),
"labels": torch.zeros(n_pred, dtype=torch.long),
})
targets.append({
"boxes": gt_geo.float(),
"labels": torch.zeros(n_gt, dtype=torch.long),
})
return preds, targets
def _compute(self, y_pred, y_true):
import torch
metric = make_mean_average_precision(
iou_type=self.iou_type,
iou_thresholds=self.iou_thresholds,
max_detection_thresholds=self.max_detection_thresholds,
class_metrics=False,
)
preds, targets = self._format(y_pred, y_true)
metric.update(preds, targets)
# NCCL (GPU distributed backend) cannot all_gather CPU tensors. When
# evaluate() is called after Lightning DDP training the process group is
# still live, so torchmetrics tries to sync and crashes. Disable sync so
# the metric is computed locally on each rank instead.
if torch.distributed.is_available(
) and torch.distributed.is_initialized():
metric._to_sync = False
result = metric.compute()
# torchmetrics >=1.x puts AP@0.5 in "map_50"; the primary "map" key is
# -1 when a custom max_detection_thresholds is combined with a single
# IoU threshold, so read map_50 directly for the AP50 configuration.
if self.iou_thresholds == [0.5]:
return result["map_50"]
return result["map"]
def _compute_group_wise(self, y_pred, y_true, g, n_groups):
group_counts = get_counts(g, n_groups)
group_metrics = []
for group_idx in range(n_groups):
if group_counts[group_idx] == 0:
group_metrics.append(torch.tensor(0., device=g.device))
else:
idx = (g == group_idx).nonzero(as_tuple=True)[0].tolist()
gp = [y_pred[i] for i in idx]
gt = [y_true[i] for i in idx]
group_metrics.append(self._compute(gp, gt))
group_metrics = torch.stack(group_metrics)
worst_group_metric = self.worst(group_metrics[group_counts > 0])
return group_metrics, group_counts, worst_group_metric
[docs]
def worst(self, metrics):
return minimum(metrics)
[docs]
class CountingError(ElementwiseMetric):
"""Mean Absolute Error between ground truth and predicted detection counts.
Counting MAE is only meaningful when annotations are exhaustive — partially-annotated images
artificially inflate the per-image error. The metric therefore only contributes a value for
images whose target dict carries ``complete=True`` (set per-source from
``data_prep/source_completeness.csv``). Other images yield NaN and are dropped from aggregation,
so ``counting_mae`` is computed only over fully-annotated sources.
"""
def __init__(self,
score_threshold=0.1,
name=None,
geometry_name="y",
complete_key="complete"):
self.score_threshold = score_threshold
self.geometry_name = geometry_name
self.complete_key = complete_key
if name is None:
name = "counting_mae"
super().__init__(name=name)
def _compute_element_wise(self, y_pred, y_true):
batch_results = []
for gt, target in zip(y_true, y_pred):
if not bool(gt.get(self.complete_key, False)):
batch_results.append(float("nan"))
continue
target_scores = target["scores"]
score_mask = target_scores > self.score_threshold
pred_count = int(score_mask.sum().item()) if isinstance(
score_mask, torch.Tensor) else int(sum(score_mask))
gt_count = len(gt[self.geometry_name])
batch_results.append(float(abs(gt_count - pred_count)))
return torch.tensor(batch_results, dtype=torch.float)
def _compute(self, y_pred, y_true):
"""Aggregate that ignores images without ``complete=True`` (NaN)."""
elementwise = self._compute_element_wise(y_pred, y_true)
finite = elementwise[~torch.isnan(elementwise)]
if finite.numel() == 0:
return torch.tensor(float("nan"))
return finite.mean()
def _compute_group_wise(self, y_pred, y_true, g, n_groups):
if len(y_pred) == 0:
group_metrics = torch.full((n_groups,), float("nan"))
group_counts = torch.zeros(n_groups, dtype=torch.long)
return group_metrics, group_counts, torch.tensor(float("nan"))
elementwise = self._compute_element_wise(y_pred, y_true)
finite_mask = ~torch.isnan(elementwise)
group_metrics = torch.full((n_groups,), float("nan"))
group_counts = torch.zeros(n_groups, dtype=torch.long)
for group_idx in range(n_groups):
mask = (g == group_idx) & finite_mask
count = int(mask.sum().item())
group_counts[group_idx] = count
if count > 0:
group_metrics[group_idx] = elementwise[mask].mean()
evaluated = group_metrics[group_counts > 0]
if evaluated.numel() == 0:
worst_group_metric = torch.tensor(float("nan"))
else:
worst_group_metric = self.worst(evaluated)
return group_metrics, group_counts, worst_group_metric
[docs]
def worst(self, metrics):
if isinstance(metrics, torch.Tensor) and metrics.numel() == 0:
return torch.tensor(float("nan"))
return maximum(metrics)