Skip to content

Commit

Permalink
Turn all numedtuples into attrs classes
Browse files Browse the repository at this point in the history
Fixes theochem#201

Related to theochem#138 and theochem#157 (which were earlier attempts)

This PR includes:
- Attribute validation (to large extent, not every detail)
- attrutil module to facilitate validation of array attributes
- Documentation of how attrs is used in IOData
- Bug fix in CP2K loader
- Minor fixes elsewhere
  • Loading branch information
tovrstra committed Sep 3, 2020
1 parent ced6e7e commit e0d7f6a
Show file tree
Hide file tree
Showing 20 changed files with 516 additions and 105 deletions.
70 changes: 70 additions & 0 deletions CONTRIBUTING.rst
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,77 @@ to avoid duplicate efforts.
results in minor corrections at worst. We'll do our best to avoid larger
problems in step 1.


Notes on attrs
--------------

IOData uses the `attrs`_ library, not to be confused with the `attr`_ library,
for classes representing data loaded from files: ``IOData``, ``MolecularBasis``,
``Shell``, ``MolecularOrbitals`` and ``Cube``. This enables basic attribute
validation, which eliminates potentially silly bugs.
(See ``iodata/attrutils.py`` and the usage of ``validate_shape`` in all those
classes.)

The following two tricks might be convenient with working with these classes:

- The data can be turned into plain Python data types with the ``attr.asdict``
function. Make sure you add the ``retain_collection_types=True`` option, to
avoid the following issue: https://github.com/python-attrs/attrs/issues/646
For example.

.. code-block:: python
from iodata import load_one
import attr
iodata = load_one("example.xyz")
fields = attr.asdict(iodata, retain_collection_types=True)
A similar ``astuple`` function works as you would expect.

- A `shallow copy`_ with a few modified attributes can be created with the
evolve method, which is a wrapper for ``attr.evolve``:

.. code-block:: python
from iodata import load_one
import attr
iodata1 = load_one("example.xyz")
iodata2 = attr.evolve(iodata1, title="another title")
The usage of evolve becomes mandatory when you want to change two or more
attributes whose shape need to be consistent. For example, the following
would fail:

.. code-block:: python
from iodata import IOData
iodata = IOData(atnums=[7, 7], atcoords=[[0, 0, 0], [2, 0, 0]])
# The next line will fail because the size of atnums and atcoords
# becomes inconsistent.
iodata.atnums = [8, 8, 8]
iodata.atcoords = [[0, 0, 0], [2, 0, 1], [4, 0, 0]]
The following code, which has the same intent, does work:

.. code-block:: python
from iodata import IOData
import attr
iodata1 = IOData(atnums=[7, 7], atcoords=[[0, 0, 0], [2, 0, 0]])
iodata2 = attr.evolve(
iodata1,
atnums=[8, 8, 8],
atcoords=[[0, 0, 0], [2, 0, 1], [4, 0, 0]],
)
For brevity, lists (of lists) were used in these examples. These are always
converted to arrays by the constructor or when assigning them to attributes.
.. _Bash: https://en.wikipedia.org/wiki/Bash_(Unix_shell)
.. _Python: https://en.wikipedia.org/wiki/Python_(programming_language)
.. _type hinting: https://docs.python.org/3/library/typing.html
.. _PEP 0563: https://www.python.org/dev/peps/pep-0563/
.. _attrs: https://www.attrs.org/en/stable/
.. _attr: https://github.com/denis-ryzhkov/attr
.. _shallow copy: https://docs.python.org/3/library/copy.html?highlight=shallow
2 changes: 1 addition & 1 deletion doc/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ IOData has the following dependencies:

- numpy >= 1.0: https://numpy.org/
- scipy: https://scipy.org/
- attrs >= 19.1.0: https://www.attrs.org/en/stable/index.html
- attrs >= 20.1.0: https://www.attrs.org/en/stable/index.html
- importlib_resources [only for Python 3.6]: https://gitlab.com/python-devs/importlib_resources

You only need to install these dependencies manually when installing IOData from
Expand Down
129 changes: 129 additions & 0 deletions iodata/attrutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# IODATA is an input and output module for quantum chemistry.
# Copyright (C) 2011-2019 The IODATA Development Team
#
# This file is part of IODATA.
#
# IODATA is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# IODATA is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>
# --
"""Utilities for building attr classes."""


import numpy as np


__all__ = ["convert_array_to", "validate_shape"]


def convert_array_to(dtype):
"""Return a function to convert arrays to the given type."""
def converter(array):
if array is None:
return None
return np.array(array, copy=False, dtype=dtype)
return converter


# pylint: disable=too-many-branches
def validate_shape(*shape_requirements: tuple):
"""Return a validator for the shape of an array or the length of an iterable.
Parameters
----------
shape_requirements
Specifications for the required shape. Every item of the tuple describes
the required size of the corresponding axis of an array. Also the
number of items should match the dimensionality of the array. When the
validator is used for general iterables, this tuple should contain just
one element. Possible values for each item are explained in the "Notes"
section below.
Returns
-------
validator
A validator function for the attr library.
Notes
-----
Every element of ``shape_requirements`` defines the expected size of an
array along the corresponding axis. An item in this tuple at position (or
index) ``i`` can be one of the following:
1. An integer, which is taken as the expected size along axis ``i``.
2. None. In this case, the size of the array along axis ``i`` is not
checked.
3. A string, which should be the name of another integer attribute with
the expected size along axis ``i``. The other attribute is always an
attribute of the same object as the attribute being checked.
4. A 2-tuple containing a name and an integer. In this case, the name refers
to another attribute which is an array or an iterable. When the integer
is 0, just the length of the other attribute is used. When the integer is
non-zero, the other attribute must be an array and the integer selects an
axis. The size of the other array along the selected axis is then used as
the expected size of the array being checked along axis ``i``.
"""
def validator(obj, attribute, value):
# Build the expected shape, with the rules from the docstring.
expected_shape = []
for item in shape_requirements:
if isinstance(item, int) or item is None:
expected_shape.append(item)
elif isinstance(item, str):
expected_shape.append(getattr(obj, item))
elif isinstance(item, tuple) and len(item) == 2:
other_name, other_axis = item
other = getattr(obj, other_name)
if other is None:
raise TypeError(
"Other attribute '{}' is not set.".format(other_name)
)
if other_axis == 0:
expected_shape.append(len(other))
else:
if other_axis >= other.ndim or other_axis < 0:
raise TypeError(
"Cannot get length along axis "
"{} of attribute {} with ndim {}.".format(
other_axis, other_name, other.ndim
)
)
expected_shape.append(other.shape[other_axis])
else:
raise ValueError(f"Cannot interpret item in shape_requirements: {item}")
expected_shape = tuple(expected_shape)
# Get the actual shape
if isinstance(value, np.ndarray):
observed_shape = value.shape
else:
observed_shape = (len(value),)
# Compare
match = True
if len(expected_shape) != len(observed_shape):
match = False
if match:
for es, os in zip(expected_shape, observed_shape):
if es is None:
continue
if es != os:
match = False
break
# Raise TypeError if needed.
if not match:
raise TypeError(
"Expecting shape {} for attribute {}, got {}".format(
expected_shape, attribute.name, observed_shape
)
)

return validator
26 changes: 17 additions & 9 deletions iodata/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@

from functools import wraps
from numbers import Integral
from typing import List, Dict, NamedTuple, Tuple, Union
from typing import List, Dict, Tuple, Union

import attr
import numpy as np

from .attrutils import validate_shape


__all__ = ['angmom_sti', 'angmom_its', 'Shell', 'MolecularBasis',
'convert_convention_shell', 'convert_conventions',
'iter_cart_alphabet', 'HORTON2_CONVENTIONS', 'PSI4_CONVENTIONS']
Expand Down Expand Up @@ -81,7 +85,9 @@ def angmom_its(angmom: Union[int, List[int]]) -> Union[str, List[str]]:
return ANGMOM_CHARS[angmom]


class Shell(NamedTuple):
@attr.s(auto_attribs=True, slots=True,
on_setattr=[attr.setters.validate, attr.setters.convert])
class Shell:
"""A shell in a molecular basis representing (generalized) contractions with the same exponents.
Attributes
Expand All @@ -107,10 +113,10 @@ class Shell(NamedTuple):
"""

icenter: int
angmoms: List[int]
kinds: List[str]
exponents: np.ndarray
coeffs: np.ndarray
angmoms: List[int] = attr.ib(validator=validate_shape(("coeffs", 1)))
kinds: List[str] = attr.ib(validator=validate_shape(("coeffs", 1)))
exponents: np.ndarray = attr.ib(validator=validate_shape(("coeffs", 0)))
coeffs: np.ndarray = attr.ib(validator=validate_shape(("exponents", 0), ("kinds", 0)))

@property
def nbasis(self) -> int: # noqa: D401
Expand All @@ -136,7 +142,9 @@ def ncon(self) -> int: # noqa: D401
return len(self.angmoms)


class MolecularBasis(NamedTuple):
@attr.s(auto_attribs=True, slots=True,
on_setattr=[attr.setters.validate, attr.setters.convert])
class MolecularBasis:
"""A complete molecular orbital or density basis set.
Attributes
Expand Down Expand Up @@ -184,7 +192,7 @@ class MolecularBasis(NamedTuple):
"""

shells: tuple
shells: List[Shell]
conventions: Dict[str, str]
primitive_normalization: str

Expand All @@ -201,7 +209,7 @@ def get_segmented(self):
shells.append(Shell(shell.icenter, [angmom], [kind],
shell.exponents, coeffs.reshape(-1, 1)))
# pylint: disable=no-member
return self._replace(shells=shells)
return attr.evolve(self, shells=shells)


def convert_convention_shell(conv1: List[str], conv2: List[str], reverse=False) \
Expand Down
2 changes: 1 addition & 1 deletion iodata/formats/chgcar.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _load_vasp_grid(lit: LineIterator) -> dict:
cube_data[i0, i1, i2] = float(words.pop(0))

cube = Cube(origin=np.zeros(3), axes=cellvecs / shape.reshape(-1, 1),
shape=shape, data=cube_data)
data=cube_data)

return {
'title': title,
Expand Down
2 changes: 1 addition & 1 deletion iodata/formats/cp2klog.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _read_cp2k_uncontracted_obasis(lit: LineIterator) -> MolecularBasis:
# read the exponent
exponent = float(words[-1])
exponents.append(exponent)
coeffs.append([1.0 / _get_cp2k_norm_corrections(angmom, exponent)])
coeffs.append(1.0 / _get_cp2k_norm_corrections(angmom, exponent))
line = next(lit)
# Build the shell
kind = 'c' if angmom < 2 else 'p'
Expand Down
1 change: 1 addition & 0 deletions iodata/formats/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def load_one(lit: LineIterator) -> dict:
"""Do not edit this docstring. It will be overwritten."""
title, atcoords, atnums, cellvecs, cube, atcorenums = _read_cube_header(lit)
_read_cube_data(lit, cube)
del cube["shape"]
return {
'title': title,
'atcoords': atcoords,
Expand Down
3 changes: 2 additions & 1 deletion iodata/formats/molden.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import Tuple, Union, TextIO
import copy

import attr
import numpy as np

from ..basis import (angmom_its, angmom_sti, MolecularBasis, Shell,
Expand Down Expand Up @@ -502,7 +503,7 @@ def _fix_obasis_normalize_contractions(obasis: MolecularBasis) -> MolecularBasis
fixed_shells = []
for shell in obasis.shells:
shell_obasis = MolecularBasis(
[shell._replace(icenter=0)],
[attr.evolve(shell, icenter=0)],
obasis.conventions,
obasis.primitive_normalization
)
Expand Down
Loading

0 comments on commit e0d7f6a

Please sign in to comment.