from pathlib import Path
import os
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image
import torch
import albumentations as A
import torchvision.transforms as T
import fnmatch
from milliontrees.datasets.milliontrees_dataset import MillionTreesDataset
from milliontrees.common.eval_visualization import save_eval_visualizations
from milliontrees.common.grouper import CombinatorialGrouper
from milliontrees.common.metrics.all_metrics import (
CountingError,
DetectionAccuracy,
DetectionMAP,
MaskAwareDetectionPrecision,
MergeCommissionMetric,
)
from milliontrees.common.onboarding import print_dataset_summary
from PIL import Image
from albumentations.pytorch import ToTensorV2
[docs]
class TreeBoxesDataset(MillionTreesDataset):
"""A dataset of tree annotations with bounding box coordinates from multiple global sources.
The dataset contains aerial imagery of trees with their corresponding bounding box annotations.
Each tree is annotated with a 4-point bounding box (x_min, y_min, x_max, y_max).
Dataset Splits:
- random: For each source, a portion of images is in train and a portion in test.
- crossgeometry: Boxes and Points are used to predict polygons.
- zeroshot: Selected sources are entirely held out for testing.
Data Format:
Input (x): RGB aerial imagery
Labels (y): Nx4 array of bounding box coordinates
Metadata: Location identifiers for each image
Args:
version (str): The version of the dataset to load.
root_dir (str): The root directory to store the dataset.
download (bool): Whether to download the dataset if it is not already present.
split_scheme (str): The split scheme to use.
geometry_name (str): The name of the geometry to use.
eval_score_threshold (float): The threshold for the evaluation score.
remove_incomplete (bool): Whether to remove incomplete data.
image_size (int): The size of the image to use.
include_sources (list): The sources to include.
exclude_sources (list): The sources to exclude.
unsupervised (bool): If True, include unsupervised data in addition to
any other selected sources (unless explicitly excluded).
mini (bool): If True, download mini versions of datasets for development.
Mini datasets are smaller subsets that maintain the same structure.
small (bool): If True, download small releases (up to 50 images per source).
unsupervised_args (dict): The arguments to pass to the unsupervised download pipeline.
References:
Website: https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1009180
Citation:
@article{Weinstein2020,
title={A benchmark dataset for canopy crown detection and delineation in co-registered airborne RGB, LiDAR and hyperspectral imagery from the National Ecological Observation Network.},
author={Weinstein BG, Graves SJ, Marconi S, Singh A, Zare A, Stewart D, et al.},
journal={PLoS Comput Biol},
year={2021},
doi={10.1371/journal.pcbi.1009180}
}
License: Creative Commons Attribution License
"""
_dataset_name = 'TreeBoxes'
_versions_dict = {
# 0.0 is a placeholder for the testing dataset
'0.0': {
'download_url': '',
'supervised_download_url': '',
'compressed_size': 105525592
},
"0.17": {
'download_url':
"https://data.rc.ufl.edu/pub/ewhite/MillionTrees/TreeBoxes_v0.17.zip",
'supervised_download_url':
"https://data.rc.ufl.edu/pub/ewhite/MillionTrees/TreeBoxes_supervised_v0.17.zip",
'compressed_size':
50996758836
}
}
def __init__(self,
version=None,
root_dir='data',
download=False,
split_scheme='random',
geometry_name='y',
eval_score_threshold=0.1,
remove_incomplete=False,
image_size=448,
include_sources=None,
exclude_sources=None,
mini=False,
small=False,
verbose=True,
include_unsupervised=False):
if mini and small:
raise ValueError(
'At most one of mini=True and small=True may be set.')
self._version = version
self._split_scheme = split_scheme
self.geometry_name = geometry_name
self.eval_score_threshold = eval_score_threshold
self.image_size = image_size
self.mini = mini
self.small = small
self.verbose = verbose
self.include_unsupervised = include_unsupervised
if self._split_scheme not in ['random', 'crossgeometry', 'zeroshot']:
raise ValueError(
f'Split scheme {self._split_scheme} not recognized')
if mini:
self._versions_dict = self._get_mini_versions_dict()
elif small:
self._versions_dict = self._get_small_versions_dict()
# Select supervised-only dataset by default (smaller download).
# Users must opt in with include_unsupervised=True to get the full dataset.
if not include_unsupervised:
modified_versions = {}
for v, info in self._versions_dict.items():
modified_info = dict(info)
if info.get('supervised_download_url') is not None:
modified_info['download_url'] = info[
'supervised_download_url']
modified_versions[v] = modified_info
self._versions_dict = modified_versions
if small:
self._dataset_name = 'SmallTreeBoxes'
else:
self._dataset_name = 'TreeBoxes_supervised'
# path
self._data_dir = Path(self.initialize_data_dir(root_dir, download))
# Restore dataset name for proper operation after directory setup
self._dataset_name = 'TreeBoxes'
# Load splits (low_memory=False avoids mixed-type DtypeWarning on large CSVs)
df = pd.read_csv(self._data_dir / f"{self._split_scheme}.csv",
low_memory=False)
for _c in ("xmin", "ymin", "xmax", "ymax"):
df[_c] = pd.to_numeric(df[_c], errors="coerce")
df = df.dropna(subset=["xmin", "ymin", "xmax", "ymax"])
df = df[(df["xmax"] > df["xmin"]) &
(df["ymax"] > df["ymin"])].reset_index(drop=True)
# Cache available sources for convenience
self.sources = df['source'].unique()
available_source_count = len(self.sources)
# Remove incomplete data based on flag
if remove_incomplete:
df = df[df['complete'] == True]
# Filter by include/exclude source names with wildcard support
# Default: exclude sources containing 'unsupervised' unless include_unsupervised=True
include_patterns = None
if include_sources is not None and include_sources != []:
include_patterns = include_sources if isinstance(
include_sources, (list, tuple)) else [include_sources]
exclude_patterns = exclude_sources
if exclude_patterns is None:
exclude_patterns = [] if include_unsupervised else [
'*unsupervised*'
]
elif not isinstance(exclude_patterns, (list, tuple)):
exclude_patterns = [exclude_patterns]
source_str = df['source'].astype(str).str.lower()
if include_patterns is not None:
patterns_lower = [p.lower() for p in include_patterns]
mask_include = source_str.apply(
lambda s: any(fnmatch.fnmatch(s, p) for p in patterns_lower))
df = df[mask_include]
patterns_exclude_lower = [p.lower() for p in exclude_patterns]
if len(patterns_exclude_lower) > 0:
mask_exclude = source_str.apply(lambda s: any(
fnmatch.fnmatch(s, p) for p in patterns_exclude_lower))
df = df[~mask_exclude]
selected_source_count = df['source'].nunique()
df = df.reset_index(drop=True)
# Splits
self._split_dict = {
'train': 0,
'validation': 1,
'test': 2,
}
self._split_names = {
'train': 'Train',
'validation': 'Validation',
'test': 'Test (OOD/Trans)',
}
unique_files = df.drop_duplicates(subset=['filename'],
inplace=False).reset_index(drop=True)
unique_files['split_id'] = unique_files['split'].apply(
lambda x: self._split_dict[x])
self._split_array = unique_files['split_id'].values
# Filenames
self._input_array = unique_files.filename
# Create lookup table for which index to select for each filename
self._input_lookup = df.groupby('filename').apply(
lambda x: x.index.values, include_groups=False).to_dict()
self._y_array = df[["xmin", "ymin", "xmax",
"ymax"]].values.astype("float32")
# Labels -> just 'Tree'
self._n_classes = 1
# Length of targets
self._y_size = 4
# Class labels
self.labels = torch.zeros(df.shape[0])
# Create source locations with a numeric ID
df["source_id"] = df.source.astype('category').cat.codes
# Create filename numeric ID
df["filename_id"] = df.filename.astype('category').cat.codes
# Create dictionary for codes to names
self._source_id_to_code = df.set_index('source_id')['source'].to_dict()
self._filename_id_to_code = df.set_index(
'filename_id')['filename'].to_dict()
# Location/group info
n_groups = max(df['source_id']) + 1
self._n_groups = n_groups
assert len(np.unique(df['source_id'])) == self._n_groups
# Metadata is at the image level
unique_sources = df[['filename_id', 'source_id']].drop_duplicates(
subset="filename_id", inplace=False).reset_index(drop=True)
self._metadata_array = torch.tensor(unique_sources.values.astype('int'))
self._metadata_fields = ['filename_id', 'source_id']
# Map source_id -> complete (used by CountingError to gate which images
# contribute to MAE). Sources flagged complete=True in
# source_completeness.csv are exhaustively annotated; others get NaN.
if 'complete' in df.columns:
source_complete = df.groupby('source_id')['complete'].first()
self._source_id_complete = {
int(k): bool(v) for k, v in source_complete.items()
}
else:
self._source_id_complete = {}
self._collate = TreeBoxesDataset._collate_fn
self.metrics = {
"accuracy":
DetectionAccuracy(geometry_name=self.geometry_name,
score_threshold=self.eval_score_threshold,
metric="accuracy"),
"recall":
DetectionAccuracy(geometry_name=self.geometry_name,
score_threshold=self.eval_score_threshold,
metric="recall"),
"maskaware_precision":
MaskAwareDetectionPrecision(
geometry_name=self.geometry_name,
score_threshold=self.eval_score_threshold),
"AP50":
DetectionMAP(geometry_name=self.geometry_name,
score_threshold=self.eval_score_threshold,
iou_type="bbox",
iou_thresholds=[0.5]),
"merge_commission":
MergeCommissionMetric(
geometry_name=self.geometry_name,
score_threshold=self.eval_score_threshold,
modality="bbox",
),
"counting_mae":
CountingError(
score_threshold=self.eval_score_threshold,
geometry_name=self.geometry_name,
),
}
# eval grouper
self._eval_grouper = CombinatorialGrouper(dataset=self,
groupby_fields=(['source_id'
]))
if self.verbose:
n_train_images = int(
(self._split_array == self._split_dict['train']).sum())
n_test_images = int(
(self._split_array == self._split_dict['test']).sum())
print_dataset_summary(
dataset_name=self._dataset_name,
version=self.version,
data_dir=self._data_dir,
split_scheme=self._split_scheme,
n_annotations=len(df),
n_total_images=len(unique_files),
n_train_images=n_train_images,
n_test_images=n_test_images,
n_available_sources=available_source_count,
n_selected_sources=selected_source_count,
mini=self.mini,
small=self.small,
include_patterns=include_patterns,
exclude_patterns=exclude_patterns,
)
super().__init__(root_dir, download, self._split_scheme)
[docs]
def eval(self,
y_pred,
y_true,
metadata,
*,
viz_dir=None,
viz_n_per_source=10):
"""Performs evaluation on the given predictions.
The main evaluation metric, detection_acc_avg_dom, measures the simple average of the
detection accuracies of each domain.
If ``viz_dir`` is set, writes overlay PNGs (purple = ground truth, orange = predictions above
the eval score threshold), up to ``viz_n_per_source`` images per source, in subfolders
named by source.
"""
results = {}
results_str = ''
for metric in self.metrics:
result, result_str = self.standard_group_eval(
self.metrics[metric], self._eval_grouper, y_pred, y_true,
metadata)
results[metric] = result
results_str += result_str
detection_accs = []
for k, v in results["accuracy"].items():
if k.startswith('detection_acc_source:'):
d = k.split(':')[1]
count = results["accuracy"][f'source:{d}']
if count > 0:
detection_accs.append(v)
detection_acc_avg_dom = np.array(detection_accs).mean()
results['detection_acc_avg_dom'] = detection_acc_avg_dom
results_str = f'Average detection_acc across source: {detection_acc_avg_dom:.3f}\n' + results_str
# Format results with tables
from milliontrees.common.utils import format_eval_results
formatted_results = format_eval_results(results, self)
results_str = formatted_results + '\n' + results_str
if viz_dir is not None:
paths = save_eval_visualizations(
self,
y_pred,
y_true,
metadata,
viz_dir,
n_per_source=viz_n_per_source,
score_threshold=self.eval_score_threshold,
)
results["eval_visualization_paths"] = [str(p) for p in paths]
return results, results_str
def _get_mini_versions_dict(self):
from milliontrees.common.release_sizes import subset_versions_dict
return subset_versions_dict(self._versions_dict, "TreeBoxes", "Mini")
def _get_small_versions_dict(self):
from milliontrees.common.release_sizes import subset_versions_dict
return subset_versions_dict(self._versions_dict, "TreeBoxes", "Small")
@staticmethod
def _collate_fn(batch):
"""Collates a batch by stacking `x` (features) and `metadata`, but not `y` (targets).
The batch is initially a tuple of individual data points: (item1, item2, item3, ...).
After zipping, it transforms into a list of tuples:
[(item1[0], item2[0], ...), (item1[1], item2[1], ...), ...].
Args:
batch (list): A batch of data points, where each data point is a tuple (metadata, x, y).
Returns:
tuple: A tuple containing:
- Stacked `x` (features).
- Stacked `metadata`.
"""
batch = list(zip(*batch))
batch[1] = torch.stack(batch[1])
batch[0] = torch.stack(batch[0])
batch[2] = list(batch[2])
return tuple(batch)
def _transform_(self):
transform = A.Compose([
A.Resize(height=self.image_size, width=self.image_size, p=1.0),
ToTensorV2()
],
bbox_params=A.BboxParams(format='pascal_voc',
label_fields=['labels'],
clip=True))
return transform