import numpy as np
from milliontrees.common.utils import avg_over_groups, get_counts, numel
import torch
[docs]
class Metric:
"""Parent class for metrics."""
def __init__(self, name):
self._name = name
def _compute(self, y_pred, y_true):
"""Helper function for computing the metric.
Subclasses should implement this.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
Output:
- metric (0-dim tensor): metric
"""
return NotImplementedError
[docs]
def worst(self, metrics):
"""Given a list/numpy array/Tensor of metrics, computes the worst-case metric.
Args:
- metrics (Tensor, numpy array, or list): Metrics
Output:
- worst_metric (0-dim tensor): Worst-case metric
"""
raise NotImplementedError
@property
def name(self):
"""Metric name.
Used to name the key in the results dictionaries returned by the metric.
"""
return self._name
@property
def agg_metric_field(self):
"""The name of the key in the results dictionary returned by Metric.compute().
This should correspond to the aggregate metric computed on all of y_pred and y_true, in
contrast to a group-wise evaluation.
"""
return f'{self.name}_all'
[docs]
def group_metric_field(self, group_idx):
"""The name of the keys corresponding to individual group evaluations in the results
dictionary returned by Metric.compute_group_wise()."""
return f'{self.name}_group:{group_idx}'
@property
def worst_group_metric_field(self):
"""The name of the keys corresponding to the worst-group metric in the results dictionary
returned by Metric.compute_group_wise()."""
return f'{self.name}_wg'
[docs]
def group_count_field(self, group_idx):
"""The name of the keys corresponding to each group's count in the results dictionary
returned by Metric.compute_group_wise()."""
return f'count_group:{group_idx}'
[docs]
def compute(self, y_pred, y_true, return_dict=True):
"""Computes metric.
This is a wrapper around _compute.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
- return_dict (bool): Whether to return the output as a dictionary or a tensor
Output (return_dict=False):
- metric (0-dim tensor): metric. If the inputs are empty, returns tensor(0.)
Output (return_dict=True):
- results (dict): Dictionary of results, mapping metric.agg_metric_field to avg_metric
"""
if numel(y_true) == 0:
if hasattr(y_true, 'device'):
agg_metric = torch.tensor(0., device=y_true.device)
else:
agg_metric = torch.tensor(0.)
else:
agg_metric = self._compute(y_pred, y_true)
if return_dict:
results = {self.agg_metric_field: agg_metric.item()}
return results
else:
return agg_metric
[docs]
def compute_group_wise(self, y_pred, y_true, g, n_groups, return_dict=True):
"""Computes metrics for each group.
This is a wrapper around _compute.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
- g (Tensor): groups
- n_groups (int): number of groups
- return_dict (bool): Whether to return the output as a dictionary or a tensor
Output (return_dict=False):
- group_metrics (Tensor): tensor of size (n_groups, ) including the average metric for each group
- group_counts (Tensor): tensor of size (n_groups, ) including the group count
- worst_group_metric (0-dim tensor): worst-group metric
- For empty inputs/groups, corresponding metrics are tensor(0.)
Output (return_dict=True):
- results (dict): Dictionary of results
"""
group_metrics, group_counts, worst_group_metric = self._compute_group_wise(
y_pred, y_true, g, n_groups)
if return_dict:
results = {}
for group_idx in range(n_groups):
results[self.group_metric_field(
group_idx)] = group_metrics[group_idx].item()
results[self.group_count_field(
group_idx)] = group_counts[group_idx].item()
results[self.worst_group_metric_field] = worst_group_metric.item()
return results
else:
return group_metrics, group_counts, worst_group_metric
def _compute_group_wise(self, y_pred, y_true, g, n_groups):
group_metrics = []
group_counts = get_counts(g, n_groups)
for group_idx in range(n_groups):
if group_counts[group_idx] == 0:
group_metrics.append(torch.tensor(0., device=g.device))
else:
group_metrics.append(
self._compute(y_pred[g == group_idx],
y_true[g == group_idx]))
group_metrics = torch.stack(group_metrics)
worst_group_metric = self.worst(group_metrics[group_counts > 0])
return group_metrics, group_counts, worst_group_metric
[docs]
class ElementwiseMetric(Metric):
"""Averages."""
def _compute_element_wise(self, y_pred, y_true):
"""Helper for computing element-wise metric, implemented for each metric.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
Output:
- element_wise_metrics (Tensor): tensor of size (batch_size, )
"""
raise NotImplementedError
[docs]
def worst(self, metrics):
"""Given a list/numpy array/Tensor of metrics, computes the worst-case metric.
Args:
- metrics (Tensor, numpy array, or list): Metrics
Output:
- worst_metric (0-dim tensor): Worst-case metric
"""
raise NotImplementedError
def _compute(self, y_pred, y_true):
"""Helper function for computing the metric.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
Output:
- avg_metric (0-dim tensor): average of element-wise metrics
"""
element_wise_metrics = self._compute_element_wise(y_pred, y_true)
avg_metric = element_wise_metrics.mean()
return avg_metric
def _compute_group_wise(self, y_pred, y_true, g, n_groups):
if len(y_pred) == 0:
group_metrics = torch.zeros(n_groups)
group_counts = torch.zeros(n_groups, dtype=torch.long)
return group_metrics, group_counts, torch.tensor(0.)
element_wise_metrics = self._compute_element_wise(y_pred, y_true)
group_metrics, group_counts = avg_over_groups(element_wise_metrics, g,
n_groups)
worst_group_metric = self.worst(group_metrics[group_counts > 0])
return group_metrics, group_counts, worst_group_metric
@property
def agg_metric_field(self):
"""The name of the key in the results dictionary returned by Metric.compute()."""
return f'{self.name}_avg'
[docs]
def compute_element_wise(self, y_pred, y_true, return_dict=True):
"""Computes element-wise metric.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
- return_dict (bool): Whether to return the output as a dictionary or a tensor
Output (return_dict=False):
- element_wise_metrics (Tensor): tensor of size (batch_size, )
Output (return_dict=True):
- results (dict): Dictionary of results, mapping metric.name to element_wise_metrics
"""
element_wise_metrics = self._compute_element_wise(y_pred, y_true)
batch_size = y_pred.size()[0]
assert element_wise_metrics.dim() == 1 and element_wise_metrics.numel(
) == batch_size
if return_dict:
return {self.name: element_wise_metrics}
else:
return element_wise_metrics
[docs]
def compute_flattened(self, y_pred, y_true, return_dict=True):
flattened_metrics = self.compute_element_wise(y_pred,
y_true,
return_dict=False)
index = torch.arange(y_true.numel())
if return_dict:
return {self.name: flattened_metrics, 'index': index}
else:
return flattened_metrics, index
[docs]
class MultiTaskMetric(Metric):
def _compute_flattened(self, flattened_y_pred, flattened_y_true):
raise NotImplementedError
def _compute(self, y_pred, y_true):
flattened_metrics, _ = self.compute_flattened(y_pred,
y_true,
return_dict=False)
if flattened_metrics.numel() == 0:
return torch.tensor(0., device=y_true.device)
else:
return flattened_metrics.mean()
def _compute_group_wise(self, y_pred, y_true, g, n_groups):
flattened_metrics, indices = self.compute_flattened(y_pred,
y_true,
return_dict=False)
flattened_g = g[indices]
group_metrics, group_counts = avg_over_groups(flattened_metrics,
flattened_g, n_groups)
worst_group_metric = self.worst(group_metrics[group_counts > 0])
return group_metrics, group_counts, worst_group_metric
[docs]
def compute_flattened(self, y_pred, y_true, return_dict=True):
is_labeled = ~torch.isnan(y_true)
batch_idx = torch.where(is_labeled)[0]
flattened_y_pred = y_pred[is_labeled]
flattened_y_true = y_true[is_labeled]
flattened_metrics = self._compute_flattened(flattened_y_pred,
flattened_y_true)
if return_dict:
return {self.name: flattened_metrics, 'index': batch_idx}
else:
return flattened_metrics, batch_idx