import copy
from typing import Dict, List, Union
import numpy as np
import torch
from milliontrees.common.utils import get_counts
from milliontrees.datasets.milliontrees_dataset import MillionTreesDataset, MillionTreesSubset
import warnings
[docs]
class Grouper:
"""Groupers group data points together based on their metadata.
They are used for training and evaluation, e.g., to measure the accuracies of different groups
of data.
"""
def __init__(self):
raise NotImplementedError
@property
def n_groups(self):
"""The number of groups defined by this Grouper."""
return self._n_groups
[docs]
def group_str(self, group):
"""
Args:
- group (int): A single integer representing a group.
Output:
- group_str (str): A string containing the pretty name of that group.
"""
raise NotImplementedError
[docs]
def group_field_str(self, group):
"""
Args:
- group (int): A single integer representing a group.
Output:
- group_str (str): A string containing the name of that group.
"""
raise NotImplementedError
[docs]
class CombinatorialGrouper(Grouper):
def __init__(self, dataset, groupby_fields):
"""
CombinatorialGroupers form groups by taking all possible combinations of the metadata
fields specified in groupby_fields, in lexicographical order.
For example, if:
dataset.metadata_fields = ['country', 'time', 'y']
groupby_fields = ['country', 'time']
and if in dataset.metadata, country is in {0, 1} and time is in {0, 1, 2},
then the grouper will assign groups in the following way:
country = 0, time = 0 -> group 0
country = 1, time = 0 -> group 1
country = 0, time = 1 -> group 2
country = 1, time = 1 -> group 3
country = 0, time = 2 -> group 4
country = 1, time = 2 -> group 5
If groupby_fields is None, then all data points are assigned to group 0.
Args:
- dataset (MillionTreesDataset or list of MillionTreesDataset)
- groupby_fields (list of str)
"""
if isinstance(dataset, list):
if len(dataset) == 0:
raise ValueError(
"At least one dataset must be defined for Grouper.")
datasets: List[MillionTreesDataset] = dataset
else:
datasets: List[MillionTreesDataset] = [dataset]
metadata_fields: List[str] = datasets[0].metadata_fields
# Build the largest metadata_map to see to check if all the metadata_maps are subsets of each other
largest_metadata_map: Dict[str, Union[List,
np.ndarray]] = copy.deepcopy(
datasets[0].metadata_map)
for i, dataset in enumerate(datasets):
if isinstance(dataset, MillionTreesSubset):
raise ValueError(
"Grouper should be defined with full dataset(s) and not subset(s)."
)
# The first dataset was used to get the metadata_fields and initial metadata_map
if i == 0:
continue
if dataset.metadata_fields != metadata_fields:
raise ValueError(
f"The datasets passed in have different metadata_fields: {dataset.metadata_fields}. "
f"Expected: {metadata_fields}")
if dataset.metadata_map is None:
continue
for field, values in dataset.metadata_map.items():
n_overlap = min(len(values), len(largest_metadata_map[field]))
if not (np.asarray(values[:n_overlap]) == np.asarray(
largest_metadata_map[field][:n_overlap])).all():
raise ValueError(
"The metadata_maps of the datasets need to be ordered subsets of each other."
)
if len(values) > len(largest_metadata_map[field]):
largest_metadata_map[field] = values
self.groupby_fields = groupby_fields
if groupby_fields is None:
self._n_groups = 1
else:
self.groupby_field_indices = [
i for (i, field) in enumerate(metadata_fields)
if field in groupby_fields
]
if len(self.groupby_field_indices) != len(self.groupby_fields):
raise ValueError(
'At least one group field not found in dataset.metadata_fields'
)
metadata_array = torch.cat(
[dataset.metadata_array for dataset in datasets])
grouped_metadata = metadata_array[:, self.groupby_field_indices]
if not isinstance(grouped_metadata, torch.LongTensor):
grouped_metadata_long = grouped_metadata.long()
if not torch.all(grouped_metadata == grouped_metadata_long):
warnings.warn(
f'CombinatorialGrouper: converting metadata with fields [{", ".join(groupby_fields)}] into long'
)
grouped_metadata = grouped_metadata_long
for idx, field in enumerate(self.groupby_fields):
min_value = grouped_metadata[:, idx].min()
if min_value < 0:
raise ValueError(
f"Metadata for CombinatorialGrouper cannot have values less than 0: {field}, {min_value}"
)
if min_value > 0:
warnings.warn(
f"Minimum metadata value for CombinatorialGrouper is not 0 ({field}, {min_value}). This will result in empty groups"
)
# We assume that the metadata fields are integers,
# so we can measure the cardinality of each field by taking its max + 1.
# Note that this might result in some empty groups.
assert grouped_metadata.min(
) >= 0, "Group numbers cannot be negative."
self.cardinality = 1 + torch.max(grouped_metadata, dim=0)[0]
cumprod = torch.cumprod(self.cardinality, dim=0)
self._n_groups = cumprod[-1].item()
self.factors_np = np.concatenate(([1], cumprod[:-1]))
self.factors = torch.from_numpy(self.factors_np)
self.metadata_map = largest_metadata_map
def _group_to_values(self, group):
"""Decode an integer group index into its per-field integer values."""
# group is just an integer, not a Tensor
n = len(self.factors_np)
metadata = np.zeros(n)
for i in range(n - 1):
metadata[i] = (group % self.factors_np[i + 1]) // self.factors_np[i]
metadata[n - 1] = group // self.factors_np[n - 1]
return [int(metadata[i]) for i in range(n)]
[docs]
def group_str(self, group):
if self.groupby_fields is None:
return 'all'
values = self._group_to_values(group)
group_name = ''
for i in reversed(range(len(values))):
meta_val = values[i]
if self.metadata_map is not None:
if self.groupby_fields[i] in self.metadata_map:
meta_val = self.metadata_map[
self.groupby_fields[i]][meta_val]
group_name += f'{self.groupby_fields[i]} = {meta_val}, '
group_name = group_name[:-2]
return group_name
# a_n = S / x_n
# a_{n-1} = (S % x_n) / x_{n-1}
# a_{n-2} = (S % x_{n-1}) / x_{n-2}
# ...
#
# g =
# a_1 * x_1 +
# a_2 * x_2 + ...
# a_n * x_n
[docs]
def group_field_str(self, group):
# Result-dict keys must stay numeric (e.g. ``source_id:3``) even when
# ``group_str`` renders human-readable names via ``metadata_map``.
if self.groupby_fields is None:
return 'all'
values = self._group_to_values(group)
parts = [
f'{self.groupby_fields[i]}:{values[i]}'
for i in reversed(range(len(values)))
]
return '_'.join(parts)