# coding=UTF-8
"""Multiprocessing implementation of raster_calculator."""
import collections
import errno
import multiprocessing
import os
import pprint
import signal
import sys
import time
import numpy
from osgeo import gdal
from ..geoprocessing import _is_raster_path_band_formatted
from ..geoprocessing import _LARGEST_ITERBLOCK
from ..geoprocessing import _LOGGING_PERIOD
from ..geoprocessing import _MAX_TIMEOUT
from ..geoprocessing import _VALID_GDAL_TYPES
from ..geoprocessing import geoprocessing_core
from ..geoprocessing import get_raster_info
from ..geoprocessing import iterblocks
from ..geoprocessing import LOGGER
from ..geoprocessing_core import DEFAULT_GTIFF_CREATION_TUPLE_OPTIONS
if sys.version_info >= (3, 8):
import multiprocessing.shared_memory
def _block_success_handler(callback_state):
"""Used to update callback state after a successful block is complete.
Updates the blocks complete and last time, if last_time has been >
_LOGGING_PERIOD then dump a log.
Args:
callback_state (dict): contains the following keys
'blocks_complete' -- number of raster calculator blocks processed
'total_blocks' -- total number to process
'last_time' -- last time.time() when a log was printed
Returns:
None
"""
callback_state['blocks_complete'] += 1
if time.time() - callback_state['last_time'] > _LOGGING_PERIOD:
LOGGER.info(
f"""raster_calculator {
callback_state['blocks_complete']/
callback_state['total_blocks']*100:.2f}% complete""")
callback_state['last_time'] = time.time()
RasterPathBand = collections.namedtuple(
'RasterPathBand', ['path', 'band_id'])
def _build_raster_calc_error_handler(pool):
def _raster_calc_error_handler(exception):
"""Error handler for raster_calculator."""
pool.terminate()
raise Exception(
f"error on raster calculator worker '{exception}'").with_traceback(
exception.__traceback__)
return _raster_calc_error_handler
def _raster_calculator_worker(
block_offset_queue, base_canonical_arg_list, local_op,
stats_worker_queue, nodata_target, target_raster_path, write_lock,
processing_state, result_array_shared_memory):
"""Process a single block of an array for raster_calculation.
Args:
block_offset (dict): contains 'xoff', 'yoff', 'xwin_size', 'ywin_size'
and can be used to pass directly to Band.ReadAsArray.
base_canonical_arg_list (list): list of RasterPathBand, numpy arrays,
or 'raw' objects to pass to the ``local_op``.
local_op (function): callable that a function that must take in as
many parameters as there are elements in
``base_canonical_arg_list``. Full description can be found in the
public facing ``raster_calculator`` operation.
stats_worker_queue (queue): pass a shared memory object ``local_op``
result to queue if stats are being calculated. None otherwise.
nodata_target (numeric or None): desired target raster nodata
target_raster_path (str): path to target raster.
write_lock (multiprocessing.Lock): Lock object used to coordinate
writes to raster_path.
processing_state (multiprocessing.Manager.dict): a global object to
pass to ``__block_success_handler`` for this execution context.
result_array_shared_memory (multiprocessing.shared_memory): If
Python version >= 3.8, this is a shared
memory object used to pass data to the stats worker process if
required. Should be pre-allocated with enough data to hold the
largest result from ``local_op`` given any ``block_offset`` from
``block_offset_queue``. None otherwise.
Returns:
None.
"""
# read input blocks
while True:
block_offset = block_offset_queue.get()
if block_offset is None:
# indicates this worker should terminate
return
offset_list = (block_offset['yoff'], block_offset['xoff'])
blocksize = (block_offset['win_ysize'], block_offset['win_xsize'])
data_blocks = []
for value in base_canonical_arg_list:
if isinstance(value, RasterPathBand):
raster = gdal.OpenEx(value.path, gdal.OF_RASTER)
band = raster.GetRasterBand(value.band_id)
data_blocks.append(band.ReadAsArray(**block_offset))
# I've encountered the following error when a gdal raster
# is corrupt, often from multiple threads writing to the
# same file. This helps to catch the error early rather
# than lead to confusing values of ``data_blocks`` later.
if not isinstance(data_blocks[-1], numpy.ndarray):
raise ValueError(
f"got a {data_blocks[-1]} when trying to read "
f"{band.GetDataset().GetFileList()} at "
f"{block_offset}, expected numpy.ndarray.")
raster = None
band = None
elif isinstance(value, numpy.ndarray):
# must be numpy array and all have been conditioned to be
# 2d, so start with 0:1 slices and expand if possible
slice_list = [slice(0, 1)] * 2
tile_dims = list(blocksize)
for dim_index in [0, 1]:
if value.shape[dim_index] > 1:
slice_list[dim_index] = slice(
offset_list[dim_index],
offset_list[dim_index] +
blocksize[dim_index],)
tile_dims[dim_index] = 1
data_blocks.append(
numpy.tile(value[tuple(slice_list)], tile_dims))
else:
# must be a raw tuple
data_blocks.append(value[0])
target_block = local_op(*data_blocks)
if (not isinstance(target_block, numpy.ndarray) or
target_block.shape != blocksize):
raise ValueError(
"Expected `local_op` to return a numpy.ndarray of "
"shape %s but got this instead: %s" % (
blocksize, target_block))
with write_lock:
target_raster = gdal.OpenEx(
target_raster_path, gdal.OF_RASTER | gdal.GA_Update)
target_band = target_raster.GetRasterBand(1)
target_band.WriteArray(
target_block, yoff=block_offset['yoff'],
xoff=block_offset['xoff'])
_block_success_handler(processing_state)
target_band = None
target_raster = None
# send result to stats calculator
if not stats_worker_queue:
continue
# Construct shared memory object to pass to stats worker
if nodata_target is not None:
target_block = target_block[target_block != nodata_target]
target_block = target_block.astype(numpy.float64).flatten()
if result_array_shared_memory:
shared_memory_array = numpy.ndarray(
target_block.shape, dtype=target_block.dtype,
buffer=result_array_shared_memory.buf)
shared_memory_array[:] = target_block[:]
stats_worker_queue.put((
shared_memory_array.shape, shared_memory_array.dtype,
result_array_shared_memory))
else:
stats_worker_queue.put(target_block)
def _calculate_target_raster_size(
raster_info_list, base_raster_path_band_const_list):
"""Determine the target raster size.
Args:
raster_info_list (list): list of raster info from
``base_raster_path_band_const_list``.
base_raster_path_band_const_list (list/tuple): argument from
raster_calculator.
Returns:
count of number of valid numpy array elements in
``base_raster_path_band_const_list``.
Raises:
``ValueError`` if numpy array types in
``base_raster_path_band_const_list`` do not have sizes which can
be broadcast against each other.
``ValueError`` if calculated broadcast size is incompatable with the
raster sizes in ``base_raster_path_band_const_list``.
``ValueError`` if only ``'raw'`` objects have been passed as arguments.
"""
numpy_broadcast_list = [
x for x in base_raster_path_band_const_list
if isinstance(x, numpy.ndarray)]
numpy_broadcast_size = None
try:
# numpy.broadcast can only take up to 32 arguments, this loop works
# around that restriction:
while len(numpy_broadcast_list) > 1:
numpy_broadcast_list = (
[numpy.broadcast(*numpy_broadcast_list[:32])] +
numpy_broadcast_list[32:])
if numpy_broadcast_list:
numpy_broadcast_size = numpy_broadcast_list[0].shape
except ValueError:
# this gets raised if numpy.broadcast fails
raise ValueError(
"Numpy array inputs cannot be broadcast into a single shape %s" %
numpy_broadcast_list)
if numpy_broadcast_list and len(numpy_broadcast_list[0].shape) > 2:
raise ValueError(
"Numpy array inputs must be 2 dimensions or less %s" %
numpy_broadcast_list)
# if there are both rasters and arrays, check the numpy shape will
# be broadcastable with raster shape
if raster_info_list and numpy_broadcast_size:
# geospatial lists x/y order and numpy does y/x so reverse size list
raster_shape = tuple(reversed(raster_info_list[0]['raster_size']))
invalid_broadcast_size = False
if len(numpy_broadcast_size) == 1:
# if there's only one dimension it should match the last
# dimension first, in the raster case this is the columns
# because of the row/column order of numpy. No problem if
# that value is ``1`` because it will be broadcast, otherwise
# it should be the same as the raster.
if (numpy_broadcast_size[0] != raster_shape[1] and
numpy_broadcast_size[0] != 1):
invalid_broadcast_size = True
else:
for dim_index in range(2):
# no problem if 1 because it'll broadcast, otherwise must
# be the same value
if (numpy_broadcast_size[dim_index] !=
raster_shape[dim_index] and
numpy_broadcast_size[dim_index] != 1):
invalid_broadcast_size = True
if invalid_broadcast_size:
raise ValueError(
"Raster size %s cannot be broadcast to numpy shape %s" % (
raster_shape, numpy_broadcast_size))
# create target raster
if raster_info_list:
# if rasters are passed, the target is the same size as the raster
n_cols, n_rows = raster_info_list[0]['raster_size']
elif numpy_broadcast_size:
# numpy arrays in args and no raster result is broadcast shape
# expanded to two dimensions if necessary
if len(numpy_broadcast_size) == 1:
n_rows, n_cols = 1, numpy_broadcast_size[0]
else:
n_rows, n_cols = numpy_broadcast_size
else:
raise ValueError(
"Only (object, 'raw') values have been passed. Raster "
"calculator requires at least a raster or numpy array as a "
"parameter. This is the input list: %s" % pprint.pformat(
base_raster_path_band_const_list))
return n_cols, n_rows
def _validate_raster_input(
base_raster_path_band_const_list, raster_info_list,
target_raster_path):
"""Check for valid raster/arg inputs and output.
Args:
base_raster_path_band_const_list (list/tuple): the same object passed
to .raster_calculator indicating the datastack to process.
target_raster_path (str): desired target raster path from
raster_calculator, used to ensure it is not also an input parameter
Returns:
None
Raises:
ValueError if any input parameter would cause an error when passing to
.raster_calculator
"""
if not base_raster_path_band_const_list:
raise ValueError(
"`base_raster_path_band_const_list` is empty and "
"should have at least one value.")
# It's a common error to not pass in path/band tuples, so check for that
# and report error if so
bad_raster_path_list = False
if not isinstance(base_raster_path_band_const_list, (list, tuple)):
bad_raster_path_list = True
else:
for value in base_raster_path_band_const_list:
if (not _is_raster_path_band_formatted(value) and
not isinstance(value, numpy.ndarray) and
not (isinstance(value, tuple) and len(value) == 2 and
value[1] == 'raw')):
bad_raster_path_list = True
break
if bad_raster_path_list:
raise ValueError(
"Expected a sequence of path / integer band tuples, "
"ndarrays, or (value, 'raw') pairs for "
"`base_raster_path_band_const_list`, instead got: "
"%s" % pprint.pformat(base_raster_path_band_const_list))
# check that any rasters exist on disk and have enough bands
not_found_paths = []
gdal.PushErrorHandler('CPLQuietErrorHandler')
base_raster_path_band_list = [
path_band for path_band in base_raster_path_band_const_list
if _is_raster_path_band_formatted(path_band)]
for value in base_raster_path_band_list:
if gdal.OpenEx(value[0], gdal.OF_RASTER) is None:
not_found_paths.append(value[0])
gdal.PopErrorHandler()
if not_found_paths:
raise ValueError(
"The following files were expected but do not exist on the "
"filesystem: " + str(not_found_paths))
# check that band index exists in raster
invalid_band_index_list = []
for value in base_raster_path_band_list:
raster = gdal.OpenEx(value[0], gdal.OF_RASTER)
if not (1 <= value[1] <= raster.RasterCount):
invalid_band_index_list.append(value)
raster = None
if invalid_band_index_list:
raise ValueError(
"The following rasters do not contain requested band "
"indexes: %s" % invalid_band_index_list)
# check that the target raster is not also an input raster
if target_raster_path in [x[0] for x in base_raster_path_band_list]:
raise ValueError(
"%s is used as a target path, but it is also in the base input "
"path list %s" % (
target_raster_path, str(base_raster_path_band_const_list)))
# check that raster inputs are all the same dimensions
geospatial_info_set = set()
for raster_info in raster_info_list:
geospatial_info_set.add(raster_info['raster_size'])
if len(geospatial_info_set) > 1:
raise ValueError(
"Input Rasters are not the same dimensions. The "
"following raster are not identical %s" % str(
geospatial_info_set))
[docs]def raster_calculator(
base_raster_path_band_const_list, local_op, target_raster_path,
datatype_target, nodata_target,
n_workers=max(1, multiprocessing.cpu_count()),
calc_raster_stats=True, use_shared_memory=False,
largest_block=_LARGEST_ITERBLOCK,
raster_driver_creation_tuple=DEFAULT_GTIFF_CREATION_TUPLE_OPTIONS):
"""Apply local a raster operation on a stack of rasters.
This function applies a user defined function across a stack of
rasters' pixel stack. The rasters in ``base_raster_path_band_list`` must
be spatially aligned and have the same cell sizes.
Args:
base_raster_path_band_const_list (sequence): a sequence containing
either (str, int) tuples, ``numpy.ndarray`` s of up to two
dimensions, or an (object, 'raw') tuple. A ``(str, int)``
tuple refers to a raster path band index pair to use as an input.
The ``numpy.ndarray`` s must be broadcastable to each other AND the
size of the raster inputs. Values passed by ``(object, 'raw')``
tuples pass ``object`` directly into the ``local_op``. All rasters
must have the same raster size. If only arrays are input, numpy
arrays must be broadcastable to each other and the final raster
size will be the final broadcast array shape. A value error is
raised if only "raw" inputs are passed.
local_op (function) a function that must take in as many parameters as
there are elements in ``base_raster_path_band_const_list``. The
parameters in ``local_op`` will map 1-to-1 in order with the values
in ``base_raster_path_band_const_list``. ``raster_calculator`` will
call ``local_op`` to generate the pixel values in ``target_raster``
along memory block aligned processing windows. Note any
particular call to ``local_op`` will have the arguments from
``raster_path_band_const_list`` sliced to overlap that window.
If an argument from ``raster_path_band_const_list`` is a
raster/path band tuple, it will be passed to ``local_op`` as a 2D
numpy array of pixel values that align with the processing window
that ``local_op`` is targeting. A 2D or 1D array will be sliced to
match the processing window and in the case of a 1D array tiled in
whatever dimension is flat. If an argument is a scalar it is
passed as as scalar.
The return value must be a 2D array of the same size as any of the
input parameter 2D arrays and contain the desired pixel values
for the target raster.
target_raster_path (string): the path of the output raster. The
projection, size, and cell size will be the same as the rasters
in ``base_raster_path_const_band_list`` or the final broadcast
size of the constant/ndarray values in the list.
datatype_target (gdal datatype; int): the desired GDAL output type of
the target raster.
nodata_target (numerical value): the desired nodata value of the
target raster.
n_workers (int): number of Processes to launch for parallel processing,
defaults to ``multiprocessing.cpu_count()``.
calc_raster_stats (boolean): If True, calculates and sets raster
statistics (min, max, mean, and stdev) for target raster.
use_shared_memory (boolean): If True, uses Python Multiprocessing
shared memory to calculate raster stats for faster performance.
This feature is available for Python >= 3.8 and will otherwise
be ignored for earlier versions of Python.
largest_block (int): Attempts to internally iterate over raster blocks
with this many elements. Useful in cases where the blocksize is
relatively small, memory is available, and the function call
overhead dominates the iteration. Defaults to 2**20. A value of
anything less than the original blocksize of the raster will
result in blocksizes equal to the original size.
raster_driver_creation_tuple (tuple): a tuple containing a GDAL driver
name string as the first element and a GDAL creation options
tuple/list as the second. Defaults to
geoprocessing.DEFAULT_GTIFF_CREATION_TUPLE_OPTIONS.
Returns:
None
Raises:
ValueError: invalid input provided
"""
raster_info_list = [
get_raster_info(path_band[0])
for path_band in base_raster_path_band_const_list
if _is_raster_path_band_formatted(path_band)]
target_raster_path = os.path.abspath(target_raster_path)
_validate_raster_input(
base_raster_path_band_const_list, raster_info_list, target_raster_path)
n_cols, n_rows = _calculate_target_raster_size(
raster_info_list, base_raster_path_band_const_list)
# create a "canonical" argument list that contains only
# (file paths, band id) tuples, 2d numpy arrays, or raw values
base_canonical_arg_list = []
for value in base_raster_path_band_const_list:
# the input has been tested and value is either a raster/path band
# tuple, 1d ndarray, 2d ndarray, or (value, 'raw') tuple.
if _is_raster_path_band_formatted(value):
# it's a raster/path band, keep track of open raster and band
# for later so we can `None` them.
base_canonical_arg_list.append(
RasterPathBand(value[0], value[1]))
elif isinstance(value, numpy.ndarray):
if value.ndim == 1:
# easier to process as a 2d array for writing to band
base_canonical_arg_list.append(
value.reshape((1, value.shape[0])))
else: # dimensions are two because we checked earlier.
base_canonical_arg_list.append(value)
else:
# it's a regular tuple
base_canonical_arg_list.append(value)
if datatype_target not in _VALID_GDAL_TYPES:
raise ValueError(
'Invalid target type, should be a gdal.GDT_* type, received '
'"%s"' % datatype_target)
# create target raster
raster_driver = gdal.GetDriverByName(raster_driver_creation_tuple[0])
try:
os.makedirs(os.path.dirname(target_raster_path))
except OSError as exception:
# it's fine if the directory already exists, otherwise there's a big
# error!
if exception.errno != errno.EEXIST:
raise
target_raster = raster_driver.Create(
target_raster_path, n_cols, n_rows, 1, datatype_target,
options=raster_driver_creation_tuple[1])
target_band = target_raster.GetRasterBand(1)
if nodata_target is not None:
target_band.SetNoDataValue(nodata_target)
if raster_info_list:
# use the first raster in the list for the projection and geotransform
target_raster.SetProjection(raster_info_list[0]['projection_wkt'])
target_raster.SetGeoTransform(raster_info_list[0]['geotransform'])
target_band = None
target_raster = None
manager = multiprocessing.Manager()
stats_worker_queue = None
if calc_raster_stats:
# if this queue is used to send computed valid blocks of
# the raster to an incremental statistics calculator worker
stats_worker_queue = manager.Queue()
# iterate over each block and calculate local_op
block_offset_list = list(iterblocks(
(target_raster_path, 1), offset_only=True,
largest_block=largest_block))
if calc_raster_stats:
LOGGER.debug('start stats worker')
stats_worker = multiprocessing.Process(
target=geoprocessing_core.stats_worker,
args=(stats_worker_queue, len(block_offset_list)))
stats_worker.start()
LOGGER.debug('start workers')
processing_state = manager.dict()
processing_state['blocks_complete'] = 0
processing_state['total_blocks'] = len(block_offset_list)
processing_state['last_time'] = time.time()
block_size_bytes = (
numpy.dtype(numpy.float64).itemsize *
block_offset_list[0]['win_xsize'] * block_offset_list[0]['win_ysize'])
target_write_lock = manager.Lock()
block_offset_queue = multiprocessing.Queue(n_workers)
process_list = []
for _ in range(n_workers):
shared_memory = None
if calc_raster_stats:
if sys.version_info >= (3, 8) and use_shared_memory:
shared_memory = multiprocessing.shared_memory.SharedMemory(
create=True, size=block_size_bytes)
worker = multiprocessing.Process(
target=_raster_calculator_worker,
args=(
block_offset_queue, base_canonical_arg_list, local_op,
stats_worker_queue, nodata_target, target_raster_path,
target_write_lock, processing_state, shared_memory))
worker.start()
process_list.append((worker, shared_memory))
# Fill the work queue
for block_offset in block_offset_list:
block_offset_queue.put(block_offset)
for _ in range(n_workers):
block_offset_queue.put(None)
LOGGER.info('wait for stats worker to complete')
stats_worker.join(_MAX_TIMEOUT)
if stats_worker.is_alive():
LOGGER.error(
f'stats worker {stats_worker.pid} '
'didn\'t terminate, sending kill signal.')
try:
os.kill(stats_worker.pid, signal.SIGTERM)
except Exception:
LOGGER.exception(f'unable to kill {stats_worker.pid}')
# wait for the workers to join
LOGGER.info('all work sent, waiting for workers to finish')
for worker, shared_memory in process_list:
worker.join(_MAX_TIMEOUT)
if worker.is_alive():
LOGGER.error(
f'worker {worker.pid} didn\'t terminate, sending kill signal.')
try:
os.kill(stats_worker.pid, signal.SIGTERM)
except Exception:
LOGGER.exception(f'unable to kill {worker.pid}')
if shared_memory is not None:
LOGGER.debug(f'unlink {shared_memory.name}')
shared_memory.unlink()
if calc_raster_stats:
payload = stats_worker_queue.get(True, _MAX_TIMEOUT)
if payload is not None:
target_min, target_max, target_mean, target_stddev = payload
target_raster = gdal.OpenEx(
target_raster_path, gdal.OF_RASTER | gdal.GA_Update)
target_band = target_raster.GetRasterBand(1)
target_band.SetStatistics(
float(target_min), float(target_max), float(target_mean),
float(target_stddev))
target_band = None
target_raster = None
LOGGER.info('raster_calculator 100.0%% complete')