"""
Raster Layers
-------------
"""
from __future__ import annotations
import copy
import inspect
import itertools
import math
import warnings
from collections.abc import Callable, Iterable, Iterator, Sequence
from typing import Any, cast, overload
import numpy as np
import rasterio as rio
from affine import Affine
from mesa import Model
from mesa.agent import Agent
from mesa.space import Coordinate, FloatCoordinate, accept_tuple_argument
from rasterio.warp import (
Resampling,
calculate_default_transform,
reproject,
transform_bounds,
)
from mesa_geo.geo_base import GeoBase
[docs]
class RasterBase(GeoBase):
"""
Base class for raster layers.
"""
_width: int
_height: int
_transform: Affine
_total_bounds: np.ndarray # [min_x, min_y, max_x, max_y]
def __init__(self, width, height, crs, total_bounds):
"""
Initialize a raster base layer.
:param width: Width of the raster base layer.
:param height: Height of the raster base layer.
:param crs: Coordinate reference system of the raster base layer.
:param total_bounds: Bounds of the raster base layer in [min_x, min_y, max_x, max_y] format.
"""
super().__init__(crs)
self._width = width
self._height = height
self._total_bounds = total_bounds
self._update_transform()
@property
def width(self) -> int:
"""
Return the width of the raster base layer.
:return: Width of the raster base layer.
:rtype: int
"""
return self._width
@width.setter
def width(self, width: int) -> None:
"""
Set the width of the raster base layer.
:param int width: Width of the raster base layer.
"""
self._width = width
self._update_transform()
@property
def height(self) -> int:
"""
Return the height of the raster base layer.
:return: Height of the raster base layer.
:rtype: int
"""
return self._height
@height.setter
def height(self, height: int) -> None:
"""
Set the height of the raster base layer.
:param int height: Height of the raster base layer.
"""
self._height = height
self._update_transform()
@property
def total_bounds(self) -> np.ndarray | None:
"""
Return the bounds of the raster layer in [min_x, min_y, max_x, max_y] format.
:return: Bounds of the raster layer in [min_x, min_y, max_x, max_y] format.
:rtype: np.ndarray | None
"""
return self._total_bounds
@total_bounds.setter
def total_bounds(self, total_bounds: np.ndarray) -> None:
"""
Set the bounds of the raster base layer in [min_x, min_y, max_x, max_y] format.
:param np.ndarray total_bounds: Bounds of the raster base layer in [min_x, min_y, max_x, max_y] format.
"""
self._total_bounds = total_bounds
self._update_transform()
@property
def transform(self) -> Affine:
"""
Return the affine transformation of the raster base layer.
:return: Affine transformation of the raster base layer.
:rtype: Affine
"""
return self._transform
@property
def resolution(self) -> tuple[float, float]:
"""
Returns the (width, height) of a cell in the units of CRS.
:return: Width and height of a cell in the units of CRS.
:rtype: Tuple[float, float]
"""
a, b, _, d, e, _, _, _, _ = self.transform
return math.sqrt(a**2 + d**2), math.sqrt(b**2 + e**2)
def _update_transform(self) -> None:
self._transform = rio.transform.from_bounds(
*self.total_bounds, width=self.width, height=self.height
)
[docs]
def to_crs(self, crs, inplace=False) -> RasterBase | None:
raise NotImplementedError
[docs]
def out_of_bounds(
self,
pos: Coordinate | None = None,
*,
rowcol: Coordinate | None = None,
xy: FloatCoordinate | None = None,
) -> bool:
"""
Determine whether a coordinate is outside the raster extent.
Exactly one selector must be provided.
:param Coordinate | None pos: Grid position in ``(grid_x, grid_y)`` format
with origin at lower left.
:param Coordinate | None rowcol: Raster indices in ``(row, col)`` format
with origin at upper left.
:param FloatCoordinate | None xy: Continuous ``(x, y)`` coordinate in CRS units.
:return: True if the selected coordinate is off the raster, False otherwise.
:rtype: bool
:raises ValueError: If selector arguments are invalid.
"""
provided = [
name
for name, arg in (("pos", pos), ("rowcol", rowcol), ("xy", xy))
if arg is not None
]
if len(provided) != 1:
selected = ", ".join(provided) if provided else "none"
raise ValueError(
"Exactly one of ``pos``, ``rowcol``, or ``xy`` must be provided. "
f"Received: {selected}."
)
if pos is not None:
grid_x, grid_y = pos
return (
grid_x < 0
or grid_x >= self.width
or grid_y < 0
or grid_y >= self.height
)
if rowcol is not None:
row, col = rowcol
return row < 0 or row >= self.height or col < 0 or col >= self.width
assert xy is not None
x_coord, y_coord = xy
if not (np.isfinite(x_coord) and np.isfinite(y_coord)):
return True
# Use inverse affine mapping so rotated/sheared rasters are handled
# correctly (total_bounds alone can include points outside coverage).
col, row = (~self.transform) * (x_coord, y_coord)
if not (np.isfinite(col) and np.isfinite(row)):
return True
# Inverse-transform outputs floats; boundary points can land slightly
# outside [0, width]/[0, height] due to floating-point roundoff.
tol = np.finfo(float).eps * max(
1.0,
abs(col),
abs(row),
float(self.width),
float(self.height),
)
return (
col < -tol
or col > self.width + tol
or row < -tol
or row > self.height + tol
)
[docs]
class Cell(Agent):
"""
Cells are containers of raster attributes, and are building blocks of `RasterLayer`.
Deprecated:
`Cell.indices` is deprecated. Use `Cell.rowcol` instead.
"""
_pos: Coordinate | None
_rowcol: Coordinate | None
_xy: FloatCoordinate | None
def __init__(
self,
model,
pos=None,
indices=None,
*,
rowcol=None,
xy=None,
):
"""
Initialize a cell.
:param pos: Grid position of the cell in (grid_x, grid_y) format.
Origin is at lower left corner of the grid
:param indices: (Deprecated) Indices of the cell in (row, col) format.
Origin is at upper left corner of the grid. Use rowcol instead.
:param rowcol: Indices of the cell in (row, col) format.
Origin is at upper left corner of the grid
:param xy: Geographic/projected (x, y) coordinates of the cell center in the CRS.
"""
super().__init__(model)
self._pos = pos
self._rowcol = indices if rowcol is None else rowcol
self._xy = xy
@property
def pos(self) -> Coordinate | None:
"""
Grid position in (grid_x, grid_y) format with origin at lower left of the grid.
"""
return self._pos
@pos.setter
def pos(self, pos: Coordinate | None) -> None:
"""
Deprecated setter for `pos`.
"""
# mesa Agent set pos to None by default
# avoid raising a warning when pos is set to None by the Agent constructor
if pos is not None:
warnings.warn(
"Cell.pos setter is deprecated and will be read-only in a future release.",
DeprecationWarning,
stacklevel=2,
)
# set the pos for backward compatibility
# in the future, this will be removed because pos is read-only
self._pos = pos
@property
def indices(self) -> Coordinate | None:
"""
Deprecated alias of `rowcol`.
"""
warnings.warn(
"Cell.indices is deprecated and will be removed in a future release. "
"Use Cell.rowcol instead.",
DeprecationWarning,
stacklevel=2,
)
return self._rowcol
@indices.setter
def indices(self, indices: Coordinate | None) -> None:
"""
Deprecated setter for `rowcol`.
"""
warnings.warn(
"Cell.indices is deprecated and will be removed in a future release. "
"Use Cell.rowcol instead.",
DeprecationWarning,
stacklevel=2,
)
# for backward compatibility, set the rowcol to the indices
# in the future, this will be removed
# and raise an AttributeError, because indices is read-only
self._rowcol = indices
@property
def rowcol(self) -> Coordinate | None:
"""
Raster indices in (row, col) format with origin at upper left of the grid.
"""
return self._rowcol
@property
def xy(self) -> FloatCoordinate | None:
"""
Geographic/projected (x, y) coordinates of the cell center in the CRS.
"""
return self._xy
[docs]
class RasterLayer(RasterBase):
"""
Some methods in `RasterLayer` are copied from `mesa.space.Grid`, including:
__getitem__
__iter__
coord_iter
iter_neighborhood
get_neighborhood
iter_neighbors
get_neighbors # copied and renamed to `get_neighboring_cells`
out_of_bounds # copied into `RasterBase`
iter_cell_list_contents
get_cell_list_contents
Methods from `mesa.space.Grid` that are not copied over:
torus_adj
neighbor_iter
move_agent
place_agent
_place_agent
remove_agent
is_cell_empty
move_to_empty
find_empty
exists_empty_cells
Another difference is that `mesa.space.Grid` has `self.grid: List[List[Agent | None]]`,
whereas it is `self.cells: List[List[Cell]]` here in `RasterLayer`.
"""
cells: list[list[Cell]]
_neighborhood_cache: dict[Any, list[Coordinate]]
_attributes: set[str]
def __init__(
self, width, height, crs, total_bounds, model, cell_cls: type[Cell] = Cell
):
super().__init__(width, height, crs, total_bounds)
self.model = model
self.cell_cls = cell_cls
self._initialize_cells()
self._attributes = set()
self._neighborhood_cache = {}
def _update_transform(self) -> None:
super()._update_transform()
if getattr(self, "cells", None):
self._sync_cell_xy()
def _sync_cell_xy(self) -> None:
for column in self.cells:
for cell in column:
row, col = cell.rowcol
cell._xy = rio.transform.xy(self.transform, row, col, offset="center")
def _initialize_cells(self) -> None:
try:
init_params = inspect.signature(self.cell_cls.__init__).parameters
except (TypeError, ValueError):
supports_legacy_pos_indices = False
else:
supports_legacy_pos_indices = (
"pos" in init_params and "indices" in init_params
)
if supports_legacy_pos_indices:
def make_cell(grid_x: int, grid_y: int, row_idx: int, col_idx: int, xy):
# Backward-compatible path for legacy signature:
# __init__(self, model, pos=None, indices=None, ...)
cell = self.cell_cls(
self.model,
pos=(grid_x, grid_y),
indices=(row_idx, col_idx),
)
# Legacy constructor path does not accept xy; set it manually.
cell._xy = xy
return cell
else:
# New constructor path: __init__(self, model, pos=None, rowcol=None, xy=None, ...)
# or: __init__(self, model, **kwargs)
def make_cell(grid_x: int, grid_y: int, row_idx: int, col_idx: int, xy):
return self.cell_cls(
self.model,
pos=(grid_x, grid_y),
rowcol=(row_idx, col_idx),
xy=xy,
)
self.cells = []
for grid_x in range(self.width):
col: list[Cell] = []
for grid_y in range(self.height):
row_idx, col_idx = self.height - grid_y - 1, grid_x
xy = rio.transform.xy(self.transform, row_idx, col_idx, offset="center")
cell = make_cell(grid_x, grid_y, row_idx, col_idx, xy)
col.append(cell)
self.cells.append(col)
@property
def attributes(self) -> set[str]:
"""
Return the attributes of the cells in the raster layer.
:return: Attributes of the cells in the raster layer.
:rtype: Set[str]
"""
return self._attributes
@overload
def __getitem__(self, index: int) -> list[Cell]: ...
@overload
def __getitem__(
self, index: tuple[int | slice, int | slice]
) -> Cell | list[Cell]: ...
@overload
def __getitem__(self, index: Sequence[Coordinate]) -> list[Cell]: ...
def __getitem__(
self, index: int | Sequence[Coordinate] | tuple[int | slice, int | slice]
) -> Cell | list[Cell]:
"""
Access contents from the grid.
"""
if isinstance(index, int):
# cells[x]
return self.cells[index]
if isinstance(index[0], tuple):
# cells[(x1, y1), (x2, y2)]
index = cast(Sequence[Coordinate], index)
cells = []
for pos in index:
x1, y1 = pos
cells.append(self.cells[x1][y1])
return cells
x, y = index
if isinstance(x, int) and isinstance(y, int):
# cells[x, y]
x, y = cast(Coordinate, index)
return self.cells[x][y]
if isinstance(x, int):
# cells[x, :]
x = slice(x, x + 1)
if isinstance(y, int):
# grid[:, y]
y = slice(y, y + 1)
# cells[:, :]
x, y = (cast(slice, x), cast(slice, y))
cells = []
for rows in self.cells[x]:
for cell in rows[y]:
cells.append(cell)
return cells
def __iter__(self) -> Iterator[Cell]:
"""
Create an iterator that chains the rows of the cells together
as if it is one list
"""
return itertools.chain(*self.cells)
[docs]
def coord_iter(self) -> Iterator[tuple[Cell, int, int]]:
"""
An iterator that returns coordinates as well as cell contents.
"""
for row in range(self.width):
for col in range(self.height):
yield self.cells[row][col], row, col # cell, x, y
[docs]
def apply_raster(
self, data: np.ndarray, attr_name: str | Sequence[str] | None = None
) -> None:
"""
Apply raster data to the cells.
:param np.ndarray data: 3D numpy array with shape (bands, height, width).
:param str | Sequence[str] | None attr_name: Attribute name(s) to be added to the
cells. For multi-band rasters, pass a list of names with length equal to
the number of bands, or a single base name to be suffixed per band. If None,
names are generated. Default is None.
:raises ValueError: If the shape of the data does not match the raster.
"""
if data.ndim != 3 or data.shape[1:] != (self.height, self.width):
raise ValueError(
f"Data shape does not match raster shape. "
f"Expected (*, {self.height}, {self.width}), received {data.shape}."
)
num_bands = data.shape[0]
if num_bands == 1:
if isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
if len(attr_name) != 1:
raise ValueError(
"attr_name sequence length must match the number of raster bands; "
f"expected {num_bands} band names, got {len(attr_name)}."
)
names = [attr_name[0]]
else:
names = [cast(str | None, attr_name)]
else:
if isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
if len(attr_name) != num_bands:
raise ValueError(
"attr_name sequence length must match the number of raster bands; "
f"expected {num_bands} band names, got {len(attr_name)}."
)
names = list(attr_name)
elif isinstance(attr_name, str):
names = [f"{attr_name}_{band_idx + 1}" for band_idx in range(num_bands)]
else:
names = [None] * num_bands
def _default_attr_name() -> str:
base = f"attribute_{len(self.cell_cls.__dict__)}"
if base not in self._attributes:
return base
suffix = 1
candidate = f"{base}_{suffix}"
while candidate in self._attributes:
suffix += 1
candidate = f"{base}_{suffix}"
return candidate
for band_idx, name in enumerate(names):
attr = _default_attr_name() if name is None else name
self._attributes.add(attr)
for grid_x in range(self.width):
for grid_y in range(self.height):
setattr(
self.cells[grid_x][grid_y],
attr,
data[band_idx, self.height - grid_y - 1, grid_x],
)
[docs]
def get_raster(self, attr_name: str | Sequence[str] | None = None) -> np.ndarray:
"""
Return the values of given attribute.
:param str | Sequence[str] | None attr_name: Name(s) of attributes to be returned.
If None, returns all attributes. Default is None.
:return: The values of given attribute(s) as a numpy array with shape
(bands, height, width).
:rtype: np.ndarray
"""
if isinstance(attr_name, str) and attr_name not in self.attributes:
raise ValueError(
f"Attribute {attr_name} does not exist. "
f"Choose from {self.attributes}, or set `attr_name` to `None` to retrieve all."
)
if isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
missing = [name for name in attr_name if name not in self.attributes]
if missing:
raise ValueError(
f"Attribute {missing[0]} does not exist. "
f"Choose from {self.attributes}, or set `attr_name` to `None` to retrieve all."
)
if attr_name is None:
num_bands = len(self.attributes)
attr_names = self.attributes
elif isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
num_bands = len(attr_name)
attr_names = list(attr_name)
else:
num_bands = 1
attr_names = [attr_name]
data = np.empty((num_bands, self.height, self.width))
for ind, name in enumerate(attr_names):
for grid_x in range(self.width):
for grid_y in range(self.height):
data[ind, self.height - grid_y - 1, grid_x] = getattr(
self.cells[grid_x][grid_y], name
)
return data
[docs]
def get_random_xy(
self,
cell: Cell | None = None,
*,
pos: Coordinate | None = None,
rowcol: Coordinate | None = None,
) -> FloatCoordinate:
"""
Generate random continuous (x, y) coordinates within a specific raster cell.
Exactly one of ``cell``, ``pos``, or ``rowcol`` must be provided.
:param Cell | None cell: Cell to sample from.
:param Coordinate | None pos: Grid coordinate in ``(grid_x, grid_y)``
format with origin at lower left.
:param Coordinate | None rowcol: Raster index in ``(row, col)`` format
with origin at upper left.
:return: Random continuous ``(x, y)`` coordinate within the selected
cell in CRS units.
:rtype: FloatCoordinate
:raises ValueError: If selector arguments are invalid or out of bounds.
"""
provided = [
name
for name, arg in (("cell", cell), ("pos", pos), ("rowcol", rowcol))
if arg is not None
]
if len(provided) != 1:
selected = ", ".join(provided) if provided else "none"
raise ValueError(
"Exactly one of ``cell``, ``pos``, or ``rowcol`` must be provided. "
f"Received: {selected}."
)
# Resolve to pixel coordinates (row, col)
if cell is not None:
if cell.rowcol is None:
raise ValueError("`cell.rowcol` is None; cannot derive raster indices.")
row, col = cell.rowcol
if self.out_of_bounds(rowcol=(row, col)):
raise ValueError(
f"`cell.rowcol` {(row, col)} is out of bounds for raster with "
f"height={self.height} and width={self.width}."
)
elif pos is not None:
if self.out_of_bounds(pos):
raise ValueError(
f"`pos` {pos} is out of bounds for raster with width={self.width} and "
f"height={self.height}."
)
grid_x, grid_y = pos
row, col = self.height - grid_y - 1, grid_x
else:
assert rowcol is not None
row, col = rowcol
if self.out_of_bounds(rowcol=(row, col)):
raise ValueError(
f"`rowcol` {(row, col)} is out of bounds for raster with "
f"height={self.height} and width={self.width}."
)
# Generate random fractional offsets [0.0, 1.0)
u = self.model.random.random()
v = self.model.random.random()
# Map pixel space to continuous CRS space using Affine matrix
x, y = self.transform * (col + u, row + v)
return x, y
[docs]
def iter_neighborhood(
self,
pos: Coordinate,
moore: bool,
include_center: bool = False,
radius: int = 1,
) -> Iterator[Coordinate]:
"""
Return an iterator over cell coordinates that are in the
neighborhood of a certain point.
:param Coordinate pos: Grid coordinate tuple (grid_x, grid_y) for the
neighborhood to get. Origin is at lower left corner of the grid.
:param bool moore: Whether to use Moore neighborhood or not. If True,
return Moore neighborhood (including diagonals). If False, return
Von Neumann neighborhood (exclude diagonals).
:param bool include_center: If True, return the (grid_x, grid_y) cell as
well. Otherwise, return surrounding cells only. Default is False.
:param int radius: Radius, in cells, of the neighborhood. Default is 1.
:return: An iterator over cell coordinates that are in the neighborhood.
For example with radius 1, it will return list with number of elements
equals at most 9 (8) if Moore, 5 (4) if Von Neumann (if not including
the center).
:rtype: Iterator[Coordinate]
"""
yield from self.get_neighborhood(pos, moore, include_center, radius)
[docs]
def iter_neighbors(
self,
pos: Coordinate,
moore: bool,
include_center: bool = False,
radius: int = 1,
) -> Iterator[Cell]:
"""
Return an iterator over neighbors to a certain point.
:param Coordinate pos: Grid coordinate tuple (grid_x, grid_y) for the
neighborhood to get. Origin is at lower left corner of the grid.
:param bool moore: Whether to use Moore neighborhood or not. If True,
return Moore neighborhood (including diagonals). If False, return
Von Neumann neighborhood (exclude diagonals).
:param bool include_center: If True, return the (grid_x, grid_y) cell
as well. Otherwise, return surrounding cells only. Default is False.
:param int radius: Radius, in cells, of the neighborhood. Default is 1.
:return: An iterator of cells that are in the neighborhood; at most 9 (8)
if Moore, 5 (4) if Von Neumann (if not including the center).
:rtype: Iterator[Cell]
"""
neighborhood = self.get_neighborhood(pos, moore, include_center, radius)
return self.iter_cell_list_contents(neighborhood)
@accept_tuple_argument
def iter_cell_list_contents(
self, cell_list: Iterable[Coordinate]
) -> Iterator[Cell]:
"""
Returns an iterator of the contents of the cells
identified in cell_list.
:param Iterable[Coordinate] cell_list: Array-like of grid (grid_x, grid_y) tuples,
or single tuple (grid_x, grid_y). Origin is at lower left corner of the grid.
:return: An iterator of the contents of the cells identified in cell_list.
:rtype: Iterator[Cell]
"""
# Note: filter(None, iterator) filters away an element of iterator that
# is falsy. Hence, iter_cell_list_contents returns only non-empty
# contents.
return filter(None, (self.cells[x][y] for x, y in cell_list))
@accept_tuple_argument
def get_cell_list_contents(self, cell_list: Iterable[Coordinate]) -> list[Cell]:
"""
Returns a list of the contents of the cells
identified in cell_list.
Note: this method returns a list of cells.
:param Iterable[Coordinate] cell_list: Array-like of grid (grid_x, grid_y) tuples,
or single tuple (grid_x, grid_y). Origin is at lower left corner of the grid.
:return: A list of the contents of the cells identified in cell_list.
:rtype: List[Cell]
"""
return list(self.iter_cell_list_contents(cell_list))
[docs]
def get_neighborhood(
self,
pos: Coordinate,
moore: bool,
include_center: bool = False,
radius: int = 1,
) -> list[Coordinate]:
"""
Return a list of cell coordinates that are in the
neighborhood of a certain point.
:param Coordinate pos: Grid coordinate tuple (grid_x, grid_y) for the
neighborhood to get. Origin is at lower left corner of the grid.
:param bool moore: Whether to use Moore neighborhood or not. If True,
return Moore neighborhood (including diagonals). If False, return
Von Neumann neighborhood (exclude diagonals).
:param bool include_center: If True, return the (grid_x, grid_y) cell as
well. Otherwise, return surrounding cells only. Default is False.
:param int radius: Radius, in cells, of the neighborhood. Default is 1.
:return: A list of cell coordinates that are in the neighborhood.
For example with radius 1, it will return list with number of elements
equals at most 9 (8) if Moore, 5 (4) if Von Neumann (if not including
the center).
:rtype: List[Coordinate]
"""
cache_key = (pos, moore, include_center, radius)
neighborhood = self._neighborhood_cache.get(cache_key, None)
if neighborhood is None:
coordinates: set[Coordinate] = set()
x, y = pos
for dy in range(-radius, radius + 1):
for dx in range(-radius, radius + 1):
if dx == 0 and dy == 0 and not include_center:
continue
# Skip coordinates that are outside manhattan distance
if not moore and abs(dx) + abs(dy) > radius:
continue
coord = (x + dx, y + dy)
if self.out_of_bounds(coord):
continue
coordinates.add(coord)
neighborhood = sorted(coordinates)
self._neighborhood_cache[cache_key] = neighborhood
return neighborhood
[docs]
def get_neighboring_cells(
self,
pos: Coordinate,
moore: bool,
include_center: bool = False,
radius: int = 1,
) -> list[Cell]:
neighboring_cell_idx = self.get_neighborhood(pos, moore, include_center, radius)
return [self.cells[idx[0]][idx[1]] for idx in neighboring_cell_idx]
[docs]
def to_crs(self, crs, inplace=False) -> RasterLayer | None:
"""
Transform the raster layer to a new coordinate reference system.
:param crs: The coordinate reference system to transform to.
:param inplace: Whether to transform the raster layer in place or
return a new raster layer. Defaults to False.
:return: The transformed raster layer if not inplace.
:rtype: RasterLayer | None
"""
super()._to_crs_check(crs)
layer = self if inplace else copy.deepcopy(self)
src_crs = rio.crs.CRS.from_user_input(layer.crs)
dst_crs = rio.crs.CRS.from_user_input(crs)
if not layer.crs.is_exact_same(crs):
transform, _, _ = calculate_default_transform(
src_crs,
dst_crs,
self.width,
self.height,
*layer.total_bounds,
)
layer._total_bounds = [
*transform_bounds(src_crs, dst_crs, *layer.total_bounds)
]
layer.crs = crs
layer._transform = transform
if getattr(layer, "cells", None):
layer._sync_cell_xy()
if not inplace:
return layer
[docs]
def to_image(self, colormap) -> ImageLayer:
"""
Returns an ImageLayer colored by the provided colormap.
"""
values = np.empty(shape=(4, self.height, self.width))
for cell in self:
row, col = cell.rowcol
values[:, row, col] = colormap(cell)
return ImageLayer(values=values, crs=self.crs, total_bounds=self.total_bounds)
[docs]
@classmethod
def from_file(
cls,
raster_file: str,
model: Model,
cell_cls: type[Cell] = Cell,
attr_name: str | Sequence[str] | None = None,
rio_opener: Callable | None = None,
) -> RasterLayer:
"""
Creates a RasterLayer from a raster file.
:param str raster_file: Path to the raster file.
:param Type[Cell] cell_cls: The class of the cells in the layer.
:param str | Sequence[str] | None attr_name: Attribute name(s) to use for the cell
values. For multi-band rasters, pass a list of names with length equal to
the number of bands, or a single base name to be suffixed per band. If None,
names are generated. Default is None.
:param Callable | None rio_opener: A callable passed to Rasterio open() function.
"""
with rio.open(raster_file, "r", opener=rio_opener) as dataset:
values = dataset.read()
_, height, width = values.shape
total_bounds = [
dataset.bounds.left,
dataset.bounds.bottom,
dataset.bounds.right,
dataset.bounds.top,
]
obj = cls(width, height, dataset.crs, total_bounds, model, cell_cls)
obj._transform = dataset.transform
obj._sync_cell_xy()
obj.apply_raster(values, attr_name=attr_name)
return obj
[docs]
def to_file(
self,
raster_file: str,
attr_name: str | Sequence[str] | None = None,
driver: str = "GTiff",
) -> None:
"""
Writes a raster layer to a file.
:param str raster_file: The path to the raster file to write to.
:param str | Sequence[str] | None attr_name: The name(s) of attributes to write
to the raster. If None, all attributes are written. Default is None.
:param str driver: The GDAL driver to use for writing the raster file.
Default is 'GTiff'. See GDAL docs at https://gdal.org/drivers/raster/index.html.
"""
data = self.get_raster(attr_name)
with rio.open(
raster_file,
"w",
driver=driver,
width=self.width,
height=self.height,
count=data.shape[0],
dtype=data.dtype,
crs=self.crs,
transform=self.transform,
) as dataset:
dataset.write(data)
[docs]
class ImageLayer(RasterBase):
_values: np.ndarray
def __init__(self, values, crs, total_bounds):
"""
Initializes an ImageLayer.
:param values: The values of the image layer.
:param crs: The coordinate reference system of the image layer.
:param total_bounds: The bounds of the image layer in [min_x, min_y, max_x, max_y] format.
"""
super().__init__(
width=values.shape[2],
height=values.shape[1],
crs=crs,
total_bounds=total_bounds,
)
self._values = values.copy()
@property
def values(self) -> np.ndarray:
"""
Returns the values of the image layer.
:return: The values of the image layer.
:rtype: np.ndarray
"""
return self._values
@values.setter
def values(self, values: np.ndarray) -> None:
"""
Sets the values of the image layer.
:param np.ndarray values: The values of the image layer.
"""
self._values = values
self._width = values.shape[2]
self._height = values.shape[1]
self._update_transform()
[docs]
def to_crs(self, crs, inplace=False) -> ImageLayer | None:
"""
Transform the image layer to a new coordinate reference system.
:param crs: The coordinate reference system to transform to.
:param inplace: Whether to transform the image layer in place or
return a new image layer. Defaults to False.
:return: The transformed image layer if not inplace.
:rtype: ImageLayer | None
"""
super()._to_crs_check(crs)
layer = self if inplace else copy.copy(self)
src_crs = rio.crs.CRS.from_user_input(layer.crs)
dst_crs = rio.crs.CRS.from_user_input(crs)
if not layer.crs.is_exact_same(crs):
num_bands, src_height, src_width = self.values.shape
transform, dst_width, dst_height = calculate_default_transform(
src_crs,
dst_crs,
src_width,
src_height,
*layer.total_bounds,
)
dst = np.empty(shape=(num_bands, dst_height, dst_width))
for i, band in enumerate(layer.values):
reproject(
source=band,
destination=dst[i],
src_transform=layer.transform,
src_crs=src_crs,
dst_transform=transform,
dst_crs=dst_crs,
resampling=Resampling.nearest,
)
layer._total_bounds = [
*transform_bounds(src_crs, dst_crs, *layer.total_bounds)
]
layer._values = dst
layer._height = layer._values.shape[1]
layer._width = layer._values.shape[2]
layer.crs = crs
layer._transform = transform
if not inplace:
return layer
[docs]
@classmethod
def from_file(cls, image_file) -> ImageLayer:
"""
Creates an ImageLayer from an image file.
:param image_file: The path to the image file.
:return: The ImageLayer.
:rtype: ImageLayer
"""
with rio.open(image_file, "r") as dataset:
values = dataset.read()
total_bounds = [
dataset.bounds.left,
dataset.bounds.bottom,
dataset.bounds.right,
dataset.bounds.top,
]
obj = cls(values=values, crs=dataset.crs, total_bounds=total_bounds)
obj._transform = dataset.transform
return obj
def __repr__(self) -> str:
return f"{self.__class__.__name__}(crs={self.crs}, total_bounds={self.total_bounds}, values={self.values!r})"