Source code for iris.fileformats.netcdf._thread_safe_nc

# Copyright Iris contributors
#
# This file is part of Iris and is released under the BSD license.
# See LICENSE in the root of the repository for full licensing details.
"""Module to ensure all calls to the netCDF4 library are thread-safe.

Intention is that no other Iris module should import the netCDF4 module.

"""

from abc import ABC
from threading import Lock
from time import sleep
import typing

import netCDF4
import numpy as np

_GLOBAL_NETCDF4_LOCK = Lock()

# Doesn't need thread protection, but this allows all netCDF4 refs to be
#  replaced with thread_safe refs.
default_fillvals = netCDF4.default_fillvals
VLType = netCDF4.VLType


class _ThreadSafeWrapper(ABC):
    """Contains a netCDF4 class instance, ensuring wrapping all API calls.

    Contains a netCDF4 class instance, ensuring wrapping all API calls within
    _GLOBAL_NETCDF4_LOCK.

    Designed to 'gate keep' all the instance's API calls, but allowing the
    same API as if working directly with the instance itself.

    Using a contained object instead of inheritance, as we cannot successfully
    subclass or monkeypatch netCDF4 classes, because they are only wrappers for
    the C-layer.
    """

    # Note: this is only used to create a "contained" from passed args.
    CONTAINED_CLASS = NotImplemented
    # Note: this defines how we identify/check that a contained is of the expected type
    # (in a duck-type way).
    _DUCKTYPE_CHECK_PROPERTIES: typing.List[str] = [NotImplemented]

    # Allows easy type checking, avoiding difficulties with isinstance and mocking.
    THREAD_SAFE_FLAG = True

    @classmethod
    def is_contained_type(cls, instance):
        return all(hasattr(instance, attr) for attr in cls._DUCKTYPE_CHECK_PROPERTIES)

    @classmethod
    def from_existing(cls, instance):
        """Routine to pass an existing instance to __init__, where it is contained."""
        assert cls.is_contained_type(instance)
        return cls(instance)

    def __init__(self, *args, **kwargs):
        """Contain an existing instance, or generate a new one from arguments."""
        if len(args) == 1 and self.is_contained_type(args[0]):
            # Passed a contained-type object : Wrap ourself around that.
            instance = args[0]
            # We should never find ourselves "wrapping a wrapper".
            assert not hasattr(instance, "THREAD_SAFE_FLAG")
        else:
            # Create a contained object of the intended type from passed args.
            with _GLOBAL_NETCDF4_LOCK:
                instance = self.CONTAINED_CLASS(*args, **kwargs)

        self._contained_instance = instance

    def __getattr__(self, item):
        if item == "_contained_instance":
            # Special behaviour when accessing the _contained_instance itself.
            return object.__getattribute__(self, item)
        else:
            with _GLOBAL_NETCDF4_LOCK:
                return getattr(self._contained_instance, item)

    def __setattr__(self, key, value):
        if key == "_contained_instance":
            # Special behaviour when accessing the _contained_instance itself.
            object.__setattr__(self, key, value)
        else:
            with _GLOBAL_NETCDF4_LOCK:
                return setattr(self._contained_instance, key, value)

    def __getitem__(self, item):
        with _GLOBAL_NETCDF4_LOCK:
            return self._contained_instance.__getitem__(item)

    def __setitem__(self, key, value):
        with _GLOBAL_NETCDF4_LOCK:
            return self._contained_instance.__setitem__(key, value)


class DimensionWrapper(_ThreadSafeWrapper):
    """Accessor for a netCDF4.Dimension, always acquiring _GLOBAL_NETCDF4_LOCK.

    All API calls should be identical to those for netCDF4.Dimension.
    """

    CONTAINED_CLASS = netCDF4.Dimension
    _DUCKTYPE_CHECK_PROPERTIES = ["isunlimited"]


class VariableWrapper(_ThreadSafeWrapper):
    """Accessor for a netCDF4.Variable, always acquiring _GLOBAL_NETCDF4_LOCK.

    All API calls should be identical to those for netCDF4.Variable.
    """

    CONTAINED_CLASS = netCDF4.Variable
    _DUCKTYPE_CHECK_PROPERTIES = ["dimensions", "dtype"]

    def setncattr(self, *args, **kwargs) -> None:
        """Call netCDF4.Variable.setncattr within _GLOBAL_NETCDF4_LOCK.

        Only defined explicitly in order to get some mocks to work.
        """
        with _GLOBAL_NETCDF4_LOCK:
            return self._contained_instance.setncattr(*args, **kwargs)

    @property
    def dimensions(self) -> typing.List[str]:
        """Call netCDF4.Variable.dimensions within _GLOBAL_NETCDF4_LOCK.

        Only defined explicitly in order to get some mocks to work.
        """
        with _GLOBAL_NETCDF4_LOCK:
            # Return value is a list of strings so no need for
            #  DimensionWrapper, unlike self.get_dims().
            return self._contained_instance.dimensions

    # All Variable API that returns Dimension(s) is wrapped to instead return
    #  DimensionWrapper(s).

    def get_dims(self, *args, **kwargs) -> typing.Tuple[DimensionWrapper]:
        """Call netCDF4.Variable.get_dims() within _GLOBAL_NETCDF4_LOCK.

        Call netCDF4.Variable.get_dims() within _GLOBAL_NETCDF4_LOCK,
        returning DimensionWrappers.  The original returned netCDF4.Dimensions
        are simply replaced with their respective DimensionWrappers, ensuring
        that downstream calls are also performed within _GLOBAL_NETCDF4_LOCK.
        """
        with _GLOBAL_NETCDF4_LOCK:
            dimensions_ = list(self._contained_instance.get_dims(*args, **kwargs))
        return tuple([DimensionWrapper.from_existing(d) for d in dimensions_])


class GroupWrapper(_ThreadSafeWrapper):
    """Accessor for a netCDF4.Group, always acquiring _GLOBAL_NETCDF4_LOCK.

    All API calls should be identical to those for netCDF4.Group.
    """

    CONTAINED_CLASS = netCDF4.Group
    # Note: will also accept a whole Dataset object, but that is OK.
    _DUCKTYPE_CHECK_PROPERTIES = ["createVariable"]

    # All Group API that returns Dimension(s) is wrapped to instead return
    #  DimensionWrapper(s).

    @property
    def dimensions(self) -> typing.Dict[str, DimensionWrapper]:
        """Call dimensions of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK.

        Calls dimensions of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK,
        returning DimensionWrappers.  The original returned netCDF4.Dimensions
        are simply replaced with their respective DimensionWrappers, ensuring
        that downstream calls are also performed within _GLOBAL_NETCDF4_LOCK.
        """
        with _GLOBAL_NETCDF4_LOCK:
            dimensions_ = self._contained_instance.dimensions
        return {k: DimensionWrapper.from_existing(v) for k, v in dimensions_.items()}

    def createDimension(self, *args, **kwargs) -> DimensionWrapper:
        """Call createDimension() from netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK.

        Call createDimension() from netCDF4.Group/Dataset within
        _GLOBAL_NETCDF4_LOCK, returning DimensionWrapper. The original returned
        netCDF4.Dimension is simply replaced with its respective
        DimensionWrapper, ensuring that downstream calls are also performed
        within _GLOBAL_NETCDF4_LOCK.
        """
        with _GLOBAL_NETCDF4_LOCK:
            new_dimension = self._contained_instance.createDimension(*args, **kwargs)
        return DimensionWrapper.from_existing(new_dimension)

    # All Group API that returns Variable(s) is wrapped to instead return
    #  VariableWrapper(s).

    @property
    def variables(self) -> typing.Dict[str, VariableWrapper]:
        """Call variables of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK.

        Calls variables of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK,
        returning VariableWrappers.  The original returned netCDF4.Variables
        are simply replaced with their respective VariableWrappers, ensuring
        that downstream calls are also performed within _GLOBAL_NETCDF4_LOCK.
        """
        with _GLOBAL_NETCDF4_LOCK:
            variables_ = self._contained_instance.variables
        return {k: VariableWrapper.from_existing(v) for k, v in variables_.items()}

    def createVariable(self, *args, **kwargs) -> VariableWrapper:
        """Call createVariable() from netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK.

        Call createVariable() from netCDF4.Group/Dataset within
        _GLOBAL_NETCDF4_LOCK, returning VariableWrapper.  The original
        returned netCDF4.Variable is simply replaced with its respective
        VariableWrapper, ensuring that downstream calls are also performed
        within _GLOBAL_NETCDF4_LOCK.
        """
        with _GLOBAL_NETCDF4_LOCK:
            new_variable = self._contained_instance.createVariable(*args, **kwargs)
        return VariableWrapper.from_existing(new_variable)

    def get_variables_by_attributes(
        self, *args, **kwargs
    ) -> typing.List[VariableWrapper]:
        """Call get_variables_by_attributes() from netCDF4.Group/Dataset.

        Call get_variables_by_attributes() from netCDF4.Group/Dataset
        within_GLOBAL_NETCDF4_LOCK, returning VariableWrappers.

        The original returned netCDF4.Variables are simply replaced with their
        respective VariableWrappers, ensuring that downstream calls are
        also performed within _GLOBAL_NETCDF4_LOCK.
        """
        with _GLOBAL_NETCDF4_LOCK:
            variables_ = list(
                self._contained_instance.get_variables_by_attributes(*args, **kwargs)
            )
        return [VariableWrapper.from_existing(v) for v in variables_]

    # All Group API that returns Group(s) is wrapped to instead return
    #  GroupWrapper(s).

    @property
    def groups(self):
        """Call groups of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK.

        Calls groups of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK,
        returning GroupWrappers.

        The original returned netCDF4.Groups are simply replaced with their
        respective GroupWrappers, ensuring that downstream calls are
        also performed within _GLOBAL_NETCDF4_LOCK.
        """
        with _GLOBAL_NETCDF4_LOCK:
            groups_ = self._contained_instance.groups
        return {k: GroupWrapper.from_existing(v) for k, v in groups_.items()}

    @property
    def parent(self):
        """Call parent of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK.

        Calls parent of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK,
        returning a GroupWrapper.

        The original returned netCDF4.Group is simply replaced with its
        respective GroupWrapper, ensuring that downstream calls are
        also performed within _GLOBAL_NETCDF4_LOCK.

        """
        with _GLOBAL_NETCDF4_LOCK:
            parent_ = self._contained_instance.parent
        return GroupWrapper.from_existing(parent_)

    def createGroup(self, *args, **kwargs):
        """Call createGroup() from netCDF4.Group/Dataset.

        Call createGroup() from netCDF4.Group/Dataset within
        _GLOBAL_NETCDF4_LOCK, returning GroupWrapper.  The original returned
        netCDF4.Group is simply replaced with its respective GroupWrapper,
        ensuring that downstream calls are also performed within
        _GLOBAL_NETCDF4_LOCK.
        """
        with _GLOBAL_NETCDF4_LOCK:
            new_group = self._contained_instance.createGroup(*args, **kwargs)
        return GroupWrapper.from_existing(new_group)


class DatasetWrapper(GroupWrapper):
    """Accessor for a netCDF4.Dataset, always acquiring _GLOBAL_NETCDF4_LOCK.

    All API calls should be identical to those for netCDF4.Dataset.
    """

    CONTAINED_CLASS = netCDF4.Dataset
    # Note: 'close' exists on Dataset but not Group (though a rather weak distinction).
    _DUCKTYPE_CHECK_PROPERTIES = ["createVariable", "close"]

    @classmethod
    def fromcdl(cls, *args, **kwargs):
        """Call netCDF4.Dataset.fromcdl() within _GLOBAL_NETCDF4_LOCK.

        Call netCDF4.Dataset.fromcdl() within _GLOBAL_NETCDF4_LOCK,
        returning a DatasetWrapper.  The original returned netCDF4.Dataset is
        simply replaced with its respective DatasetWrapper, ensuring that
        downstream calls are also performed within _GLOBAL_NETCDF4_LOCK.
        """
        with _GLOBAL_NETCDF4_LOCK:
            instance = cls.CONTAINED_CLASS.fromcdl(*args, **kwargs)
        return cls.from_existing(instance)


[docs] class NetCDFDataProxy: """A reference to the data payload of a single NetCDF file variable.""" __slots__ = ("shape", "dtype", "path", "variable_name", "fill_value") def __init__(self, shape, dtype, path, variable_name, fill_value): self.shape = shape self.dtype = dtype self.path = path self.variable_name = variable_name self.fill_value = fill_value @property def ndim(self): # noqa: D102 return len(self.shape) @property def dask_meta(self): return np.ma.array(np.empty((0,) * self.ndim, dtype=self.dtype), mask=True) def __getitem__(self, keys): # Using a DatasetWrapper causes problems with invalid ID's and the # netCDF4 library, presumably because __getitem__ gets called so many # times by Dask. Use _GLOBAL_NETCDF4_LOCK directly instead. with _GLOBAL_NETCDF4_LOCK: dataset = netCDF4.Dataset(self.path) try: variable = dataset.variables[self.variable_name] # Get the NetCDF variable data and slice. var = variable[keys] finally: dataset.close() return np.asanyarray(var) def __repr__(self): fmt = ( "<{self.__class__.__name__} shape={self.shape}" " dtype={self.dtype!r} path={self.path!r}" " variable_name={self.variable_name!r}>" ) return fmt.format(self=self) def __getstate__(self): return {attr: getattr(self, attr) for attr in self.__slots__} def __setstate__(self, state): for key, value in state.items(): setattr(self, key, value)
class NetCDFWriteProxy: """An object mimicking the data access of a netCDF4.Variable. The "opposite" of a NetCDFDataProxy : An object mimicking the data access of a netCDF4.Variable, but where the data is to be ***written to***. It encapsulates the netcdf file and variable which are actually to be written to. This opens the file each time, to enable writing the data chunk, then closes it. TODO: could be improved with a caching scheme, but this just about works. """ def __init__(self, filepath, cf_var, file_write_lock): self.path = filepath self.varname = cf_var.name self.lock = file_write_lock def __setitem__(self, keys, array_data): # Write to the variable. # First acquire a file-specific lock for all workers writing to this file. self.lock.acquire() # Open the file for writing + write to the specific file variable. # Exactly as above, in NetCDFDataProxy : a DatasetWrapper causes problems with # invalid ID's and the netCDF4 library, for so-far unknown reasons. # Instead, use _GLOBAL_NETCDF4_LOCK, and netCDF4 _directly_. with _GLOBAL_NETCDF4_LOCK: dataset = None try: # Even when fully serialised - no parallelism - HDF still # occasionally fails to acquire the file. This is despite all # Python locks being available at expected moments, and the # file reporting as closed. During testing, 2nd retry always # succeeded. This is likely caused by HDF-level locking # running on a different timescale to Python-level locking - # i.e. sometimes Python has released its locks but HDF still # has not. Thought to be filesystem-dependent; further # investigation needed. for attempt in range(5): try: dataset = netCDF4.Dataset(self.path, "r+") break except OSError: if attempt < 4: sleep(0.1) else: raise var = dataset.variables[self.varname] var[keys] = array_data finally: try: if dataset: dataset.close() finally: # *ALWAYS* let go ! self.lock.release() def __repr__(self): return f"<{self.__class__.__name__} path={self.path!r} var={self.varname!r}>"