Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix array comparison in core.Structure.merge_sites, also allow int property to be merged instead of float alone, mode only allow full name #4198

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions src/pymatgen/core/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4717,43 +4717,56 @@ def scale_lattice(self, volume: float) -> Self:
return self

def merge_sites(self, tol: float = 0.01, mode: Literal["sum", "delete", "average"] = "sum") -> Self:
"""Merges sites (adding occupancies) within tol of each other.
Removes site properties.
"""Merges sites (by adding occupancies) within tolerance and optionally removes
site properties in "sum/delete" modes.

Args:
tol (float): Tolerance for distance to merge sites.
mode ("sum" | "delete" | "average"): "delete" means duplicate sites are
deleted. "sum" means the occupancies are summed for the sites.
"average" means that the site is deleted but the properties are averaged
Only first letter is considered.
mode ("sum" | "delete" | "average"): Only first letter is considered at this moment.
- "delete" means duplicate sites are deleted.
- "sum" means the occupancies are summed for the sites.
- "average" means that the site is deleted but the properties are averaged.

Returns:
Structure: self with merged sites.
Structure: Structure with merged sites.
"""
dist_mat = self.distance_matrix
# TODO: change the code the allow full name after 2025-12-01
# TODO2: add a test for mode value, currently it only checks if first letter is "s/a"
if mode.lower() not in {"sum", "delete", "average"} and mode.lower()[0] in {"s", "d", "a"}:
Copy link
Contributor Author

@DanielYang59 DanielYang59 Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want to hear more opinions on this, I guess it's beneficial to allow only full name ("sum", "delete", "average") instead of checking first letter only to make use of the IDE auto-complete feature and facilitate typing

Also using the full name would be more readable: mode="sum" instead of mode="s"

warnings.warn(
"mode would only allow full name sum/delete/average after 2025-12-01", DeprecationWarning, stacklevel=2
)

if mode.lower()[0] not in {"s", "d", "a"}:
raise ValueError(f"Illegal {mode=}, should start with a/d/s.")

dist_mat: NDArray = self.distance_matrix
np.fill_diagonal(dist_mat, 0)
clusters = fcluster(linkage(squareform((dist_mat + dist_mat.T) / 2)), tol, "distance")
sites = []

sites: list[PeriodicSite] = []
for cluster in np.unique(clusters):
inds = np.where(clusters == cluster)[0]
species = self[inds[0]].species
coords = self[inds[0]].frac_coords
props = self[inds[0]].properties

for n, i in enumerate(inds[1:]):
sp = self[i].species
if mode.lower()[0] == "s":
species += sp
offset = self[i].frac_coords - coords
coords += ((offset - np.round(offset)) / (n + 2)).astype(coords.dtype)
for key in props:
if props[key] is not None and self[i].properties[key] != props[key]:
Copy link
Contributor Author

@DanielYang59 DanielYang59 Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO tag: Using array_equal may not be a good idea either as the property could be (sequence of) floats

TODO2: test failure

# Test that we can average the site properties that are floats
lattice = Lattice.hexagonal(3.587776, 19.622793)
species = ["Na", "V", "S", "S"]
coords = [
[0.333333, 0.666667, 0.165000],
[0, 0, 0.998333],
[0.333333, 0.666667, 0.399394],
[0.666667, 0.333333, 0.597273],
]
site_props = {"prop1": [3.0, 5.0, 7.0, 11.0]}
navs2 = Structure.from_spacegroup(160, lattice, species, coords, site_properties=site_props)
navs2.insert(0, "Na", coords[0], properties={"prop1": 100.0})
navs2.merge_sites(mode="a")
assert len(navs2) == 12
assert 51.5 in [itr.properties["prop1"] for itr in navs2]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Property can actually be anything, including value supplied by user.

Copy link
Contributor Author

@DanielYang59 DanielYang59 Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for stepping in, I'm here adding a note for myself to work on later.

Yes in this case using == for comparison may not be ideal when the value could be float or np.array, this is very similar to #4092 where we want to compare the equality of two dict whose value could be np.array

But I assume we would not be able to use np.allclose as the value could be non-numerical, so as far as I'm aware, using array_equal would be the best approach (though handling of float would be sub-optimal)

if props[key] is not None and not np.array_equal(self[i].properties[key], props[key]):
if mode.lower()[0] == "a" and isinstance(props[key], float):
Copy link
Contributor Author

@DanielYang59 DanielYang59 Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess only allowing float to be averaged may not be a good idea, we should include int as well, otherwise properties={"prop1": 100} would lead to the site property being reset to None while properties={"prop1": 100.0} doesn't, which is super confusing

navs2 = Structure.from_spacegroup(160, lattice, species, coords, site_properties=site_props)
navs2.insert(0, "Na", coords[0], properties={"prop1": 100.0})
navs2.merge_sites(mode="a")

Also the docstring would be clarified to explain this behaviour.

# update a running total
props[key] = props[key] * (n + 1) / (n + 2) + self[i].properties[key] / (n + 2)
else:
props[key] = None
warnings.warn(
f"Sites with different site property {key} are merged. So property is set to none"
f"Sites with different site property {key} are merged. So property is set to none",
stacklevel=2,
)
sites.append(PeriodicSite(species, coords, self.lattice, properties=props))

Expand Down
12 changes: 8 additions & 4 deletions tests/core/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,11 +1633,15 @@ def test_merge_sites(self):
[0.5, 0.5, 1.501],
]
struct = Structure(Lattice.cubic(1), species, coords)
struct.merge_sites(mode="s")
struct.merge_sites(mode="sum")
assert struct[0].specie.symbol == "Ag"
assert struct[1].species == Composition({"Cl": 0.35, "F": 0.25})
assert_allclose(struct[1].frac_coords, [0.5, 0.5, 0.5005])

# Test illegal mode
with pytest.raises(ValueError, match="Illegal mode='illegal', should start with a/d/s"):
struct.merge_sites(mode="illegal")

# Test for TaS2 with spacegroup 166 in 160 setting.
lattice = Lattice.hexagonal(3.374351, 20.308941)
species = ["Ta", "S", "S"]
Expand All @@ -1648,7 +1652,7 @@ def test_merge_sites(self):
]
tas2 = Structure.from_spacegroup(160, lattice, species, coords)
assert len(tas2) == 13
tas2.merge_sites(mode="d")
tas2.merge_sites(mode="delete")
assert len(tas2) == 9

lattice = Lattice.hexagonal(3.587776, 19.622793)
Expand All @@ -1661,7 +1665,7 @@ def test_merge_sites(self):
]
navs2 = Structure.from_spacegroup(160, lattice, species, coords)
assert len(navs2) == 18
navs2.merge_sites(mode="d")
navs2.merge_sites(mode="delete")
assert len(navs2) == 12

# Test that we can average the site properties that are floats
Expand All @@ -1676,7 +1680,7 @@ def test_merge_sites(self):
site_props = {"prop1": [3.0, 5.0, 7.0, 11.0]}
navs2 = Structure.from_spacegroup(160, lattice, species, coords, site_properties=site_props)
navs2.insert(0, "Na", coords[0], properties={"prop1": 100.0})
navs2.merge_sites(mode="a")
navs2.merge_sites(mode="average")
assert len(navs2) == 12
assert 51.5 in [itr.properties["prop1"] for itr in navs2]

Expand Down
24 changes: 18 additions & 6 deletions tests/util/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,6 @@ def test_nested_arrays(self):

def test_diff_dtype(self):
"""Make sure it also works for other data types as value."""

@dataclass
class CustomClass:
name: str
value: int

# Test with bool values
dict1 = {"a": True}
dict2 = {"a": True}
Expand All @@ -69,13 +63,31 @@ class CustomClass:
assert not is_np_dict_equal(dict4, dict6)

# Test with a custom data class
@dataclass
class CustomClass:
name: str
value: int

dict7 = {"a": CustomClass(name="test", value=1)}
dict8 = {"a": CustomClass(name="test", value=1)}
assert is_np_dict_equal(dict7, dict8)

dict9 = {"a": CustomClass(name="test", value=2)}
assert not is_np_dict_equal(dict7, dict9)

# Test __eq__ method being used
@dataclass
class NewCustomClass:
name: str
value: int

def __eq__(self, other):
return True

dict7_1 = {"a": NewCustomClass(name="test", value=1)}
dict8_1 = {"a": NewCustomClass(name="hello", value=2)}
assert is_np_dict_equal(dict7_1, dict8_1)

# Test with None
dict10 = {"a": None}
dict11 = {"a": None}
Expand Down
Loading