import os
import time
from pathlib import Path
import torch
import numpy as np
from PIL import Image
[docs]
class MillionTreesDataset:
"""Shared dataset class for all MillionTrees datasets.
Each data point in the dataset is a tuple (x, y, metadata), where:
- x: The input features
- y: The target
- metadata: A vector of relevant information (e.g., domain).
For convenience, metadata also contains y.
"""
DEFAULT_SPLITS = {'train': 0, 'val': 1}
DEFAULT_SPLIT_NAMES = {
'train': 'Train',
'val': 'Validation',
}
DEFAULT_SOURCE_DOMAIN_SPLITS = [0]
def __init__(
self,
root_dir,
download,
split_scheme,
):
if len(self._metadata_array.shape) == 1:
self._metadata_array = self._metadata_array.unsqueeze(1)
self.check_init()
def __len__(self):
return len(self._input_array)
def __getitem__(self, idx):
# Any transformations are handled by the MillionTreesSubset
# since different subsets (e.g., train vs test) might have different transforms
x = self.get_input(idx)
y_indices = self._input_lookup[self._input_array[idx]]
y = torch.tensor(self.y_array[y_indices])
metadata = self.metadata_array[idx].clone()
targets = {self.geometry_name: y, "labels": np.zeros(len(y), dtype=int)}
tree_coverage_mask = self.get_tree_coverage_mask(idx, x.shape[:2])
if tree_coverage_mask is not None:
targets["tree_coverage_mask"] = tree_coverage_mask
targets["complete"] = self._sample_is_complete(idx)
return metadata, x, targets
def _sample_is_complete(self, idx):
"""Whether the source for this sample is fully annotated.
Drives the counting MAE metric: ``CountingError`` only contributes a
per-image MAE when ``targets["complete"]`` is ``True``; otherwise it
emits ``NaN`` and the value is dropped from aggregation.
"""
source_id_complete = getattr(self, "_source_id_complete", None)
if not source_id_complete:
return False
try:
source_field_idx = self._metadata_fields.index("source_id")
except ValueError:
return False
source_id = int(self._metadata_array[idx, source_field_idx].item())
return bool(source_id_complete.get(source_id, False))
[docs]
def get_tree_coverage_mask(self, idx, image_shape):
"""Load a precomputed tree/no-tree mask for an image if available."""
masks_dir = Path(self._data_dir) / "masks"
if not masks_dir.exists():
return None
image_name = self._input_array[idx]
mask_path = masks_dir / f"{Path(image_name).stem}.png"
if not mask_path.exists():
raise FileNotFoundError(
f"Missing tree coverage mask for {image_name}: expected {mask_path}"
)
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8)
mask = (mask > 0).astype(np.uint8)
if tuple(mask.shape[:2]) != tuple(image_shape):
raise ValueError(
f"Mask shape {mask.shape[:2]} does not match image shape {image_shape} for {image_name}"
)
return mask
[docs]
def eval(self,
y_pred,
y_true,
metadata,
*,
viz_dir=None,
viz_n_per_source=10):
"""
Args:
- y_pred (list[dict]): Predicted targets per image
- y_true (list[dict]): True targets per image
- metadata (Tensor): Metadata rows aligned with predictions
- viz_dir (str | Path | None): If set, write up to ``viz_n_per_source`` overlay
PNGs per ``source_id`` under this directory (see ``eval_visualization``).
- viz_n_per_source (int | None): Max images to save per source when ``viz_dir`` is set.
Pass None to write all images.
Output:
- results (dict): Dictionary of results (may include ``eval_visualization_paths``)
- results_str (str): Pretty print version of the results
"""
raise NotImplementedError
[docs]
def get_subset(self, split, frac=1.0, transform=None):
"""
Args:
- split (str): Split identifier, e.g., 'train', 'val', 'test'.
Must be in self.split_dict.
- frac (float): What fraction of the split to randomly sample.
Used for fast development on a small dataset.
- transform (function): Any data transformations to be applied to the input x.
Output:
- subset (MillionTreesSubset): A (potentially subsampled) subset of the MillionTreesDataset.
"""
if split not in self.split_dict:
raise ValueError(
f"Split {split} not found in dataset's split_dict.")
split_mask = self.split_array == self.split_dict[split]
split_idx = np.where(split_mask)[0]
if frac < 1.0:
# Randomly sample a fraction of the split
num_to_retain = int(np.round(float(len(split_idx)) * frac))
split_idx = np.sort(
np.random.permutation(split_idx)[:num_to_retain])
return MillionTreesSubset(self, split_idx, transform,
self.geometry_name)
[docs]
def check_init(self):
"""Convenience function to check that the MillionTreesDataset is properly configured."""
required_attrs = [
'_dataset_name', '_data_dir', '_split_scheme', '_split_array',
'_y_array', '_y_size', '_metadata_fields', '_metadata_array'
]
for attr_name in required_attrs:
assert hasattr(
self, attr_name), f'MillionTreesDataset is missing {attr_name}.'
# Check that data directory exists
if not os.path.exists(self.data_dir):
raise ValueError(
f'{self.data_dir} does not exist yet. Please generate the dataset first.'
)
# Check splits
assert self.split_dict.keys() == self.split_names.keys()
assert 'train' in self.split_dict
assert 'test' in self.split_dict
# Check the form of the required arrays
assert isinstance(self.y_array, np.ndarray) or isinstance(
self.y_array, list), 'y_array must be a numpy array or list'
assert isinstance(self.metadata_array,
torch.Tensor), 'metadata_array must be a torch tensor'
# Check that dimensions match
assert len(self._input_array) == len(self.metadata_array)
# Check metadata
assert len(self.metadata_array.shape) == 2
assert len(self.metadata_fields) == self.metadata_array.shape[1]
# For convenience, include y in metadata_fields if y_size == 1
if self.y_size == 1:
assert 'y' in self.metadata_fields
@property
def latest_version(cls):
def is_later(u, v):
"""Returns true if u is a later version than v."""
u_major, u_minor = tuple(map(int, u.split('.')))
v_major, v_minor = tuple(map(int, v.split('.')))
if (u_major > v_major) or ((u_major == v_major) and
(u_minor > v_minor)):
return True
else:
return False
latest_version = '0.0'
for key in cls.versions_dict.keys():
if is_later(key, latest_version):
latest_version = key
return latest_version
@property
def dataset_name(self):
"""A string that identifies the dataset, e.g., 'amazon', 'camelyon17'."""
return self._dataset_name
@property
def version(self):
"""A string that identifies the dataset version, e.g., '1.0'."""
if self._version is None:
return self.latest_version
else:
return self._version
@property
def versions_dict(self):
"""A dictionary where each key is a version string (e.g., '1.0') and each value is a
dictionary containing the 'download_url' and 'compressed_size' keys.
'download_url' is the URL for downloading the dataset archive. If None, the dataset cannot
be downloaded automatically (e.g., because it first requires accepting a usage agreement).
'compressed_size' is the approximate size of the compressed dataset in bytes.
"""
return self._versions_dict
@property
def data_dir(self):
"""The full path to the folder in which the dataset is stored."""
return self._data_dir
@property
def collate(self):
"""Torch function to collate items in a batch.
By default returns None -> uses default torch collate.
"""
return getattr(self, '_collate', None)
@property
def split_scheme(self):
"""A string identifier of how the split is constructed, e.g., 'standard', 'mixed-to-test',
'user', etc."""
return self._split_scheme
@property
def split_dict(self):
"""A dictionary mapping splits to integer identifiers (used in split_array), e.g., {'train':
0, 'val': 1, 'test': 2}.
Keys should match up with split_names.
"""
return getattr(self, '_split_dict', MillionTreesDataset.DEFAULT_SPLITS)
@property
def split_names(self):
"""A dictionary mapping splits to their pretty names, e.g., {'train': 'Train', 'val':
'Validation', 'test': 'Test'}.
Keys should match up with split_dict.
"""
return getattr(self, '_split_names',
MillionTreesDataset.DEFAULT_SPLIT_NAMES)
@property
def source_domain_splits(self):
"""List of split IDs that are from the source domain."""
return getattr(self, '_source_domain_splits',
MillionTreesDataset.DEFAULT_SOURCE_DOMAIN_SPLITS)
@property
def split_array(self):
"""An array of integers, with split_array[i] representing what split the i-th data point
belongs to."""
return self._split_array
@property
def y_array(self):
"""A Tensor of targets (e.g., labels for classification tasks), with y_array[i] representing
the target of the i-th data point.
y_array[i] can contain multiple elements.
"""
return self._y_array
@property
def y_size(self):
"""The number of dimensions/elements in the target, i.e., len(y_array[i]).
For standard classification/regression tasks, y_size = 1. For multi-task or structured
prediction settings, y_size > 1. Used for logging and to configure models to produce
appropriately- sized output.
"""
return self._y_size
@property
def n_classes(self):
"""Number of classes for single-task classification datasets.
Used for logging and to configure models to produce appropriately-sized output. None by
default. Leave as None if not applicable (e.g., regression or multi-task classification).
"""
return getattr(self, '_n_classes', None)
@property
def is_detection(self):
"""Boolean.
True if the task is detection, and false otherwise.
"""
return getattr(self, '_is_detection', False)
@property
def metadata_fields(self):
"""A list of strings naming each column of the metadata table, e.g., ['hospital', 'y'].
Must include 'y'.
"""
return self._metadata_fields
@property
def metadata_array(self):
"""A Tensor of metadata, with the i-th row representing the metadata associated with the
i-th data point.
The columns correspond to the metadata_fields defined above.
"""
return self._metadata_array
@property
def metadata_map(self):
"""An optional dictionary that, for each metadata field, contains a list that maps from
integers (in metadata_array) to a string representing what that integer means.
This is only used for logging, so that we print out more intelligible metadata values. Each
key must be in metadata_fields. For example, if we have metadata_fields = ['hospital',
'y'] metadata_map = {'hospital': ['East', 'West']} then if metadata_array[i, 0] == 0,
the i-th data point belongs to the 'East' hospital while if metadata_array[i, 0] == 1, it
belongs to the 'West' hospital.
"""
return getattr(self, '_metadata_map', None)
@property
def original_resolution(self):
"""Original image resolution for image datasets."""
return getattr(self, '_original_resolution', None)
[docs]
def initialize_data_dir(self, root_dir, download):
"""Helper function for downloading/updating the dataset if required.
Note that we only do a version check for datasets where the download_url is set. Currently,
this includes all datasets except Yelp. Datasets for which we don't control the download,
like Yelp, might not handle versions similarly.
"""
self.check_version()
os.makedirs(root_dir, exist_ok=True)
data_dir = os.path.join(root_dir,
f'{self.dataset_name}_v{self.version}')
version_file = os.path.join(data_dir, f'RELEASE_v{self.version}.txt')
# If the dataset exists at root_dir, then don't download.
if not self.dataset_exists_locally(data_dir, version_file):
self.download_dataset(data_dir, download)
return data_dir
[docs]
def dataset_exists_locally(self, data_dir, version_file):
download_url = self.versions_dict[self.version]['download_url']
# There are two ways to download a dataset:
# 1. Automatically through the MillionTrees package
# Datasets downloaded from a third party need not have a download_url and RELEASE text file.
return (os.path.exists(data_dir) and
(os.path.exists(version_file) or
(len(os.listdir(data_dir)) > 0 and download_url is None)))
[docs]
def download_dataset(self, data_dir, download_flag):
version_dict = self.versions_dict[self.version]
download_url = version_dict['download_url']
compressed_size = version_dict['compressed_size']
# Check that download_url exists.
if download_url is None:
raise ValueError(
f'{self.dataset_name} cannot be automatically downloaded. Please download it manually.'
)
# Check that the download_flag is set to true.
if not download_flag:
raise FileNotFoundError(
f'The {self.dataset_name} dataset could not be found in {data_dir}. Initialize the dataset with '
f'download=True to download the dataset. If you are using the example script, run with --download. '
f'This might take some time for large datasets.')
from milliontrees.datasets.download_utils import download_and_extract_archive
print(f'Downloading dataset to {data_dir}...')
try:
start_time = time.time()
download_and_extract_archive(url=download_url,
download_root=data_dir,
filename='archive.zip',
remove_finished=True,
size=compressed_size)
download_time_in_minutes = (time.time() - start_time) / 60
print(
f"\nIt took {round(download_time_in_minutes, 2)} minutes to download and uncompress the dataset.\n"
)
version_file = os.path.join(data_dir,
f'RELEASE_v{self.version}.txt')
with open(version_file, 'w') as f:
f.write(f'v{self.version}\n')
except Exception as e:
print(
f"\n{os.path.join(data_dir, 'archive.zip')} may be corrupted. Please try deleting it and rerunning this command.\n"
)
print(f"Exception: ", e)
[docs]
def check_version(self):
# Check that the version is valid.
if self.version not in self.versions_dict:
raise ValueError(
f'Version {self.version} not supported. Must be in {self.versions_dict.keys()}.'
)
# Temporarily disabled version warnings since newer versions aren't publicly available yet
# TODO: Re-enable when version 0.5 datasets are publicly released
# Check that the specified version is the latest version. Otherwise, warn.
# current_major_version, current_minor_version = tuple(
# map(int, self.version.split('.')))
# latest_major_version, latest_minor_version = tuple(
# map(int, self.latest_version.split('.')))
# if latest_major_version > current_major_version:
# print(
# f'*****************************\n'
# f'{self.dataset_name} has been updated to version {self.latest_version}.\n'
# f'You are currently using version {self.version}.\n'
# f'We highly recommend updating the dataset by not specifying the older version in the '
# f'command-line argument or dataset constructor.\n'
# f'*****************************\n')
# elif latest_minor_version > current_minor_version:
# print(
# f'*****************************\n'
# f'{self.dataset_name} has been updated to version {self.latest_version}.\n'
# f'You are currently using version {self.version}.\n'
# f'Please consider updating the dataset.\n'
# f'*****************************\n')
[docs]
@staticmethod
def standard_eval(metric, y_pred, y_true):
"""
Args:
- metric (Metric): Metric to use for eval
- y_pred (Tensor): Predicted targets
- y_true (Tensor): True targets
Output:
- results (dict): Dictionary of results
- results_str (str): Pretty print version of the results
"""
results = {
**metric.compute(y_pred, y_true),
}
results_str = (
f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n")
return results, results_str
[docs]
@staticmethod
def standard_group_eval(metric,
grouper,
y_pred,
y_true,
metadata,
aggregate=True):
"""
Args:
- metric (Metric): Metric to use for eval
- grouper (CombinatorialGrouper): Grouper object that converts metadata into groups
- y_pred (Tensor): Predicted targets
- y_true (Tensor): True targets
- metadata (Tensor): Metadata
Output:
- results (dict): Dictionary of results
- results_str (str): Pretty print version of the results
"""
results, results_str = {}, ''
if aggregate:
results.update(metric.compute(y_pred, y_true))
results_str += f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n"
g = grouper.metadata_to_group(metadata)
group_results = metric.compute_group_wise(y_pred, y_true, g,
grouper.n_groups)
for group_idx in range(grouper.n_groups):
group_str = grouper.group_field_str(group_idx)
group_metric = group_results[metric.group_metric_field(group_idx)]
group_counts = group_results[metric.group_count_field(group_idx)]
results[f'{metric.name}_{group_str}'] = group_metric
results[f'count_{group_str}'] = group_counts
if group_results[metric.group_count_field(group_idx)] == 0:
continue
results_str += (
f' {grouper.group_str(group_idx)} '
f"[n = {group_results[metric.group_count_field(group_idx)]:6.0f}]:\t"
f"{metric.name} = {group_results[metric.group_metric_field(group_idx)]:5.3f}\n"
)
results[f'{metric.worst_group_metric_field}'] = group_results[
f'{metric.worst_group_metric_field}']
results_str += f"Worst-group {metric.name}: {group_results[metric.worst_group_metric_field]:.3f}\n"
return results, results_str
[docs]
class MillionTreesSubset(MillionTreesDataset):
def __init__(self, dataset, indices, transform=None, geometry_name="y"):
"""This acts like `torch.utils.data.Subset`, but on `milliontreesDatasets`.
We pass in `transform` (which is used for data augmentation) explicitly because it can
potentially vary on the training vs. test subsets.
"""
self.dataset = dataset
self.indices = indices
self.geometry_name = geometry_name
inherited_attrs = [
'_dataset_name', '_data_dir', '_collate', '_split_scheme',
'_split_dict', '_split_names', '_y_size', '_n_classes',
'_metadata_fields', '_metadata_map'
]
for attr_name in inherited_attrs:
if hasattr(dataset, attr_name):
setattr(self, attr_name, getattr(dataset, attr_name))
if transform is None:
self.transform = dataset._transform_()
else:
self.transform = transform
def __getitem__(self, idx):
metadata, x, targets = self.dataset[self.indices[idx]]
tree_coverage_mask = targets.get("tree_coverage_mask")
complete = targets.get("complete", False)
if self._dataset_name == 'TreeBoxes':
# Extra safety: drop degenerate / out-of-bounds boxes before Albumentations
# (check_bboxes rejects xmax <= xmin; clip+float noise can create edge cases).
bboxes = np.asarray(targets[self.geometry_name], dtype=np.float64)
if bboxes.ndim == 1:
bboxes = bboxes.reshape(1, -1)
labels_arr = np.asarray(targets["labels"])
if labels_arr.ndim == 0:
labels_arr = labels_arr.reshape(1)
h, w = float(x.shape[0]), float(x.shape[1])
if len(bboxes) > 0:
bboxes[:, [0, 2]] = np.clip(bboxes[:, [0, 2]], 0.0, w)
bboxes[:, [1, 3]] = np.clip(bboxes[:, [1, 3]], 0.0, h)
eps = 1e-3
valid = ((bboxes[:, 2] - bboxes[:, 0]) > eps) & (
(bboxes[:, 3] - bboxes[:, 1]) > eps)
bboxes = bboxes[valid]
labels_arr = labels_arr[valid]
if tree_coverage_mask is None:
augmented = self.transform(
image=x,
bboxes=bboxes.tolist(),
labels=labels_arr.tolist(),
)
else:
augmented = self.transform(
image=x,
bboxes=bboxes.tolist(),
labels=labels_arr.tolist(),
mask=tree_coverage_mask,
)
tree_coverage_mask = augmented["mask"]
y = torch.tensor(np.array(augmented["bboxes"]), dtype=torch.float32)
elif self._dataset_name == 'TreePoints':
if tree_coverage_mask is None:
augmented = self.transform(
image=x,
keypoints=targets[self.geometry_name],
labels=targets["labels"],
)
else:
augmented = self.transform(
image=x,
keypoints=targets[self.geometry_name],
labels=targets["labels"],
mask=tree_coverage_mask,
)
tree_coverage_mask = augmented["mask"]
y = torch.tensor(np.array(augmented["keypoints"]),
dtype=torch.float32)
else:
masks = [mask for mask in targets[self.geometry_name]]
coverage_with_masks = tree_coverage_mask is not None
transformed_masks = masks
if coverage_with_masks:
transformed_masks = masks + [tree_coverage_mask]
# Albumentations rejects empty mask lists; use a dummy mask then discard
if len(masks) == 0:
h, w = x.shape[0], x.shape[1]
dummy_mask = [np.zeros((h, w), dtype=np.uint8)]
dummy_bboxes = [[0, 0, 1, 1]]
dummy_labels = [0]
masks_input = dummy_mask + ([tree_coverage_mask]
if coverage_with_masks else [])
augmented = self.transform(
image=x,
masks=masks_input,
bboxes=dummy_bboxes,
labels=dummy_labels,
)
img_h, img_w = augmented["image"].shape[1], augmented[
"image"].shape[2]
y = torch.zeros((0, img_h, img_w), dtype=torch.uint8)
bboxes = torch.zeros(0, 4)
labels = torch.zeros(0, dtype=torch.long)
if coverage_with_masks:
tree_coverage_mask = augmented["masks"][-1]
else:
augmented = self.transform(
image=x,
masks=transformed_masks,
bboxes=targets["bboxes"],
labels=targets["labels"],
)
if coverage_with_masks:
tree_coverage_mask = augmented["masks"][-1]
y = augmented["masks"][:-1]
else:
y = augmented["masks"]
if len(y) == 0:
img_h, img_w = augmented["image"].shape[1], augmented[
"image"].shape[2]
y = torch.zeros((0, img_h, img_w), dtype=torch.uint8)
else:
y = torch.stack(y, dim=0)
bboxes = augmented["bboxes"]
labels = torch.from_numpy(np.array(augmented["labels"]))
x = augmented["image"]
if self._dataset_name != "TreePolygons":
labels = torch.from_numpy(np.array(augmented["labels"]))
if tree_coverage_mask is not None:
if isinstance(tree_coverage_mask, np.ndarray):
tree_coverage_mask = torch.from_numpy(tree_coverage_mask)
tree_coverage_mask = (tree_coverage_mask > 0).to(torch.uint8)
# If image has no annotations, set zeros
if len(y) == 0:
if self._dataset_name == 'TreeBoxes':
y = torch.zeros(0, 4)
elif self._dataset_name == 'TreePoints':
y = torch.zeros(0, 2)
else:
bboxes = torch.zeros(0, 4)
if self._dataset_name == 'TreePolygons':
targets = {
self.geometry_name: y,
"labels": labels,
"bboxes": bboxes
}
else:
targets = {self.geometry_name: y, "labels": labels}
if tree_coverage_mask is not None:
targets["tree_coverage_mask"] = tree_coverage_mask
targets["complete"] = bool(complete)
return metadata, x, targets
def __len__(self):
return len(self.indices)
@property
def split_array(self):
return self.dataset._split_array[self.indices]
@property
def y_array(self):
return torch.tensor(self.dataset._y_array[self.indices])
@property
def metadata_array(self):
return torch.tensor(self.dataset.metadata_array[self.indices])
[docs]
def eval(self,
y_pred,
y_true,
metadata,
*,
viz_dir=None,
viz_n_per_source=10):
return self.dataset.eval(
y_pred,
y_true,
metadata,
viz_dir=viz_dir,
viz_n_per_source=viz_n_per_source,
)