-
paul.tresson_ird.fr authoredpaul.tresson_ird.fr authored
transforms.py 5.46 KiB
# modified from torchgeo
"""TorchGeo transforms."""
from typing import Any, Optional, Union
import kornia.augmentation as K
import torch
from einops import rearrange
from kornia.geometry import crop_by_indices
from torch import Tensor
from torch.nn.modules import Module
# TODO: contribute these to Kornia and delete this file
class AugmentationSequential(Module):
"""Wrapper around kornia AugmentationSequential to handle input dicts.
.. deprecated:: 0.4
Use :class:`kornia.augmentation.container.AugmentationSequential` instead.
"""
def __init__(
self,
*args: Union[K.base._AugmentationBase, K.ImageSequential],
data_keys: list[str],
**kwargs: Any,
) -> None:
"""Initialize a new augmentation sequential instance.
Args:
*args: Sequence of kornia augmentations
data_keys: List of inputs to augment (e.g., ["image", "mask", "boxes"])
**kwargs: Keyword arguments passed to ``K.AugmentationSequential``
.. versionadded:: 0.5
The ``**kwargs`` parameter.
"""
super().__init__()
self.data_keys = data_keys
keys: list[str] = []
for key in data_keys:
if key == "image":
keys.append("input")
elif key == "boxes":
keys.append("bbox")
else:
keys.append(key)
self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Perform augmentations and update data dict.
Args:
batch: the input
Returns:
the augmented input
"""
# Kornia augmentations require all inputs to be float
dtype = {}
for key in self.data_keys:
dtype[key] = batch[key].dtype
batch[key] = batch[key].float()
# Kornia requires masks to have a channel dimension
if "mask" in batch and len(batch["mask"].shape) == 3:
batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w")
inputs = [batch[k] for k in self.data_keys]
outputs_list: Union[Tensor, list[Tensor]] = self.augs(*inputs)
outputs_list = (
outputs_list if isinstance(outputs_list, list) else [outputs_list]
)
outputs: dict[str, Tensor] = {
k: v for k, v in zip(self.data_keys, outputs_list)
}
batch.update(outputs)
# Convert all inputs back to their previous dtype
for key in self.data_keys:
batch[key] = batch[key].to(dtype[key])
# Torchmetrics does not support masks with a channel dimension
if "mask" in batch and batch["mask"].shape[1] == 1:
batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w")
return batch
class _RandomNCrop(K.GeometricAugmentationBase2D):
"""Take N random crops of a tensor."""
def __init__(self, size: tuple[int, int], num: int) -> None:
"""Initialize a new _RandomNCrop instance.
Args:
size: desired output size (out_h, out_w) of the crop
num: number of crops to take
"""
super().__init__(p=1)
self._param_generator: _NCropGenerator = _NCropGenerator(size, num)
self.flags = {"size": size, "num": num}
def compute_transformation(
self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any]
) -> Tensor:
"""Compute the transformation.
Args:
input: the input tensor
params: generated parameters
flags: static parameters
Returns:
the transformation
"""
out: Tensor = self.identity_matrix(input)
return out
def apply_transform(
self,
input: Tensor,
params: dict[str, Tensor],
flags: dict[str, Any],
transform: Optional[Tensor] = None,
) -> Tensor:
"""Apply the transform.
Args:
input: the input tensor
params: generated parameters
flags: static parameters
transform: the geometric transformation tensor
Returns:
the augmented input
"""
out = []
for i in range(flags["num"]):
out.append(crop_by_indices(input, params["src"][i], flags["size"]))
return torch.cat(out)
class _NCropGenerator(K.random_generator.CropGenerator):
"""Generate N random crops."""
def __init__(self, size: Union[tuple[int, int], Tensor], num: int) -> None:
"""Initialize a new _NCropGenerator instance.
Args:
size: desired output size (out_h, out_w) of the crop
num: number of crops to generate
"""
super().__init__(size)
self.num = num
def forward(
self, batch_shape: tuple[int, ...], same_on_batch: bool = False
) -> dict[str, Tensor]:
"""Generate the crops.
Args:
batch_shape: input size (b, c?, in_h, in_w)
same_on_batch: apply the same transformation across the batch
Returns:
the randomly generated parameters
"""
out = []
for _ in range(self.num):
out.append(super().forward(batch_shape, same_on_batch))
return {
"src": torch.stack([x["src"] for x in out]),
"dst": torch.stack([x["dst"] for x in out]),
"input_size": out[0]["input_size"],
"output_size": out[0]["output_size"],
}