Source code for dpdata.data_type

from enum import Enum, unique
from typing import TYPE_CHECKING, Tuple

import numpy as np

from dpdata.plugin import Plugin

if TYPE_CHECKING:
    from dpdata.system import System


[docs] @unique class Axis(Enum): """Data axis.""" NFRAMES = "nframes" NATOMS = "natoms" NTYPES = "ntypes" NBONDS = "nbonds"
[docs] class AnyInt(int): """AnyInt equals to any other integer.""" def __eq__(self, other): return True
[docs] class DataError(Exception): """Data is not correct."""
[docs] class DataType: """DataType represents a type of data, like coordinates, energies, etc. Parameters ---------- name : str name of data dtype : type or tuple[type] data type, e.g. np.ndarray shape : tuple[int], optional shape of data. Used when data is list or np.ndarray. Use Axis to represents numbers required : bool, default=True whether this data is required """ def __init__( self, name: str, dtype: type, shape: Tuple[int, Axis] = None, required: bool = True, ) -> None: self.name = name self.dtype = dtype self.shape = shape self.required = required
[docs] def real_shape(self, system: "System") -> Tuple[int]: """Returns expected real shape of a system.""" shape = [] for ii in self.shape: if ii is Axis.NFRAMES: shape.append(system.get_nframes()) elif ii is Axis.NTYPES: shape.append(system.get_ntypes()) elif ii is Axis.NATOMS: shape.append(system.get_natoms()) elif ii is Axis.NBONDS: # BondOrderSystem shape.append(system.get_nbonds()) elif ii == -1: shape.append(AnyInt(-1)) elif isinstance(ii, int): shape.append(ii) else: raise RuntimeError("Shape is not an int!") return tuple(shape)
[docs] def check(self, system: "System"): """Check if a system has correct data of this type. Parameters ---------- system : System checked system Raises ------ DataError type or shape of data is not correct """ # check if exists if self.name in system.data: data = system.data[self.name] # check dtype # allow list for empty np.ndarray if isinstance(data, list) and not len(data): pass elif not isinstance(data, self.dtype): raise DataError( f"Type of {self.name} is {type(data).__name__}, but expected {self.dtype.__name__}" ) # check shape if self.shape is not None: shape = self.real_shape(system) # skip checking empty list of np.ndarray if isinstance(data, np.ndarray): if data.size and shape != data.shape: raise DataError( f"Shape of {self.name} is {data.shape}, but expected {shape}" ) elif isinstance(data, list): if len(shape) and shape[0] != len(data): raise DataError( "Length of %s is %d, but expected %d" % (self.name, len(data), shape[0]) ) else: raise RuntimeError("Unsupported type to check shape") elif self.required: raise DataError("%s not found in data" % self.name)
__system_data_type_plugin = Plugin() __labeled_system_data_type_plugin = Plugin()
[docs] def register_data_type(data_type: DataType, labeled: bool): """Register a data type. Parameters ---------- data_type : DataType data type to be registered labeled : bool whether this data type is for LabeledSystem """ plugin = __labeled_system_data_type_plugin if labeled else __system_data_type_plugin plugin.register(data_type.name)(data_type)
[docs] def get_data_types(labeled: bool): """Get all registered data types. Parameters ---------- labeled : bool whether this data type is for LabeledSystem """ plugin = __labeled_system_data_type_plugin if labeled else __system_data_type_plugin return tuple(plugin.plugins.values())