Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[9.0] Fix pydantic and pkg_resources deprecations #7662

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ install_requires =
psutil
pyasn1
pyasn1-modules
pydantic
pydantic >=2.4
pyparsing
python-dateutil
pytz
Expand Down
2 changes: 1 addition & 1 deletion src/DIRAC/Core/Utilities/test/Test_JDL.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_jdlToBaseJobDescriptionModel_valid(jdl_monkey_business):
res = jdlToBaseJobDescriptionModel(ClassAd(jdl))
assert res["OK"], res["Message"]

data = res["Value"].dict()
data = res["Value"].model_dump()
assert JobDescriptionModel(owner="owner", ownerGroup="ownerGroup", vo="lhcb", **data)


Expand Down
2 changes: 1 addition & 1 deletion src/DIRAC/Resources/Computing/BatchSystems/SLURM.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _generateSrunWrapper(self, executableFile):
content = f.read()

# Need to escape environment variables of the executable file
content = re.sub("\$", "\\$", content)
content = re.sub(r"\$", r"\\$", content)

# Build the script to run the executable in parallel multiple times
# - Embed the content of executableFile inside the parallel library wrapper script
Expand Down
151 changes: 72 additions & 79 deletions src/DIRAC/WorkloadManagementSystem/Utilities/JobModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
# pylint: disable=no-self-argument, no-self-use, invalid-name, missing-function-docstring

from collections.abc import Iterable
from typing import Any, Annotated
from typing import Any, Annotated, TypeAlias, Self

import pydantic
from packaging.version import Version
from pydantic import BaseModel, root_validator, validator
from pydantic import BaseModel, BeforeValidator, model_validator, field_validator, ConfigDict

from DIRAC import gLogger
from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations
Expand All @@ -16,71 +14,71 @@

# HACK: Convert appropriate iterables into sets
def default_set_validator(value):
if not isinstance(value, Iterable):
if value is None:
return set()
elif not isinstance(value, Iterable):
return value
elif isinstance(value, (str, bytes, bytearray)):
return value
else:
return set(value)


if Version(pydantic.__version__) > Version("2.0.0a0"):
CoercibleSetStr = Annotated[
set[str] | None, pydantic.BeforeValidator(default_set_validator) # pylint: disable=no-member
]
else:
CoercibleSetStr = set[str]
CoercibleSetStr: TypeAlias = Annotated[set[str], BeforeValidator(default_set_validator)]


class BaseJobDescriptionModel(BaseModel):
"""Base model for the job description (not parametric)"""

class Config:
validate_assignment = True
model_config = ConfigDict(validate_assignment=True)

arguments: str = None
bannedSites: CoercibleSetStr = None
arguments: str = ""
bannedSites: CoercibleSetStr = set()
# TODO: This should use a field factory
cpuTime: int = Operations().getValue("JobDescription/DefaultCPUTime", 86400)
executable: str
executionEnvironment: dict = None
gridCE: str = None
inputSandbox: CoercibleSetStr = None
inputData: CoercibleSetStr = None
inputDataPolicy: str = None
jobConfigArgs: str = None
jobGroup: str = None
gridCE: str = ""
inputSandbox: CoercibleSetStr = set()
inputData: CoercibleSetStr = set()
inputDataPolicy: str = ""
jobConfigArgs: str = ""
jobGroup: str = ""
jobType: str = "User"
jobName: str = "Name"
# TODO: This should be an StrEnum
logLevel: str = "INFO"
# TODO: This can't be None with this type hint
maxNumberOfProcessors: int = None
minNumberOfProcessors: int = 1
outputData: CoercibleSetStr = None
outputPath: str = None
outputSandbox: CoercibleSetStr = None
outputSE: str = None
platform: str = None
outputData: CoercibleSetStr = set()
outputPath: str = ""
outputSandbox: CoercibleSetStr = set()
outputSE: str = ""
platform: str = ""
# TODO: This should use a field factory
priority: int = Operations().getValue("JobDescription/DefaultPriority", 1)
sites: CoercibleSetStr = None
sites: CoercibleSetStr = set()
stderr: str = "std.err"
stdout: str = "std.out"
tags: CoercibleSetStr = None
extraFields: dict[str, Any] = None
tags: CoercibleSetStr = set()
extraFields: dict[str, Any] = {}

@validator("cpuTime")
@field_validator("cpuTime")
def checkCPUTimeBounds(cls, v):
minCPUTime = Operations().getValue("JobDescription/MinCPUTime", 100)
maxCPUTime = Operations().getValue("JobDescription/MaxCPUTime", 500000)
if not minCPUTime <= v <= maxCPUTime:
raise ValueError(f"cpuTime out of bounds (must be between {minCPUTime} and {maxCPUTime})")
return v

@validator("executable")
@field_validator("executable")
def checkExecutableIsNotAnEmptyString(cls, v: str):
if not v:
raise ValueError("executable must not be an empty string")
return v

@validator("jobType")
@field_validator("jobType")
def checkJobTypeIsAllowed(cls, v: str):
jobTypes = Operations().getValue("JobDescription/AllowedJobTypes", ["User", "Test", "Hospital"])
transformationTypes = Operations().getValue("Transformations/DataProcessing", [])
Expand All @@ -89,15 +87,15 @@ def checkJobTypeIsAllowed(cls, v: str):
raise ValueError(f"jobType '{v}' is not allowed for this kind of user (must be in {allowedTypes})")
return v

@validator("inputData")
@field_validator("inputData")
def checkInputDataDoesntContainDoubleSlashes(cls, v):
if v:
for lfn in v:
if lfn.find("//") > -1:
raise ValueError("Input data contains //")
return v

@validator("inputData")
@field_validator("inputData")
def addLFNPrefixIfStringStartsWithASlash(cls, v: set[str]):
if v:
v = {lfn.strip() for lfn in v if lfn.strip()}
Expand All @@ -108,30 +106,30 @@ def addLFNPrefixIfStringStartsWithASlash(cls, v: set[str]):
raise ValueError("Input data files must start with LFN:/")
return v

@root_validator(skip_on_failure=True)
def checkNumberOfInputDataFiles(cls, values):
if "inputData" in values and values["inputData"]:
@model_validator(mode="after")
def checkNumberOfInputDataFiles(self) -> Self:
if self.inputData:
maxInputDataFiles = Operations().getValue("JobDescription/MaxInputData", 500)
if values["jobType"] == "User" and len(values["inputData"]) >= maxInputDataFiles:
if self.jobType == "User" and len(self.inputData) >= maxInputDataFiles:
raise ValueError(f"inputData contains too many files (must contain at most {maxInputDataFiles})")
return values
return self

@validator("inputSandbox")
@field_validator("inputSandbox")
def checkLFNSandboxesAreWellFormated(cls, v: set[str]):
for inputSandbox in v:
if inputSandbox.startswith("LFN:") and not inputSandbox.startswith("LFN:/"):
raise ValueError("LFN files must start by LFN:/")
return v

@validator("logLevel")
@field_validator("logLevel")
def checkLogLevelIsValid(cls, v: str):
v = v.upper()
possibleLogLevels = gLogger.getAllPossibleLevels()
if v not in possibleLogLevels:
raise ValueError(f"Log level {v} not in {possibleLogLevels}")
return v

@validator("minNumberOfProcessors")
@field_validator("minNumberOfProcessors")
def checkMinNumberOfProcessorsBounds(cls, v):
minNumberOfProcessors = Operations().getValue("JobDescription/MinNumberOfProcessors", 1)
maxNumberOfProcessors = Operations().getValue("JobDescription/MaxNumberOfProcessors", 1024)
Expand All @@ -141,7 +139,7 @@ def checkMinNumberOfProcessorsBounds(cls, v):
)
return v

@validator("maxNumberOfProcessors")
@field_validator("maxNumberOfProcessors")
def checkMaxNumberOfProcessorsBounds(cls, v):
minNumberOfProcessors = Operations().getValue("JobDescription/MinNumberOfProcessors", 1)
maxNumberOfProcessors = Operations().getValue("JobDescription/MaxNumberOfProcessors", 1024)
Expand All @@ -151,27 +149,22 @@ def checkMaxNumberOfProcessorsBounds(cls, v):
)
return v

@root_validator(skip_on_failure=True)
def checkThatMaxNumberOfProcessorsIsGreaterThanMinNumberOfProcessors(cls, values):
if "maxNumberOfProcessors" in values and values["maxNumberOfProcessors"]:
if values["maxNumberOfProcessors"] < values["minNumberOfProcessors"]:
@model_validator(mode="after")
def checkThatMaxNumberOfProcessorsIsGreaterThanMinNumberOfProcessors(self) -> Self:
if self.maxNumberOfProcessors:
if self.maxNumberOfProcessors < self.minNumberOfProcessors:
raise ValueError("maxNumberOfProcessors must be greater than minNumberOfProcessors")
return values

@root_validator(skip_on_failure=True)
def addTagsDependingOnNumberOfProcessors(cls, values):
if "maxNumberOfProcessors" in values and values["minNumberOfProcessors"] == values["maxNumberOfProcessors"]:
if values["tags"] is None:
values["tags"] = set()
values["tags"].add(f"{values['minNumberOfProcessors']}Processors")
if values["minNumberOfProcessors"] > 1:
if values["tags"] is None:
values["tags"] = set()
values["tags"].add("MultiProcessor")

return values

@validator("sites")
return self

@model_validator(mode="after")
def addTagsDependingOnNumberOfProcessors(self) -> Self:
if self.minNumberOfProcessors == self.maxNumberOfProcessors:
self.tags.add(f"{self.minNumberOfProcessors}Processors")
if self.minNumberOfProcessors > 1:
self.tags.add("MultiProcessor")
return self

@field_validator("sites")
def checkSites(cls, v: set[str]):
if v:
res = getSites()
Expand All @@ -182,16 +175,16 @@ def checkSites(cls, v: set[str]):
raise ValueError(f"Invalid sites: {' '.join(invalidSites)}")
return v

@root_validator(skip_on_failure=True)
def checkThatSitesAndBannedSitesAreNotMutuallyExclusive(cls, values):
if "sites" in values and values["sites"] and "bannedSites" in values and values["bannedSites"]:
values["sites"] -= values["bannedSites"]
values["bannedSites"] = None
if not values["sites"]:
@model_validator(mode="after")
def checkThatSitesAndBannedSitesAreNotMutuallyExclusive(self) -> Self:
if self.sites and self.bannedSites:
while self.bannedSites:
self.sites.discard(self.bannedSites.pop())
if not self.sites:
raise ValueError("sites and bannedSites are mutually exclusive")
return values
return self

@validator("platform")
@field_validator("platform")
def checkPlatform(cls, v: str):
if v:
res = getDIRACPlatforms()
Expand All @@ -201,7 +194,7 @@ def checkPlatform(cls, v: str):
raise ValueError("Invalid platform")
return v

@validator("priority")
@field_validator("priority")
def checkPriorityBounds(cls, v):
minPriority = Operations().getValue("JobDescription/MinPriority", 0)
maxPriority = Operations().getValue("JobDescription/MaxPriority", 10)
Expand All @@ -217,10 +210,10 @@ class JobDescriptionModel(BaseJobDescriptionModel):
ownerGroup: str
vo: str

@root_validator(skip_on_failure=True)
def checkLFNMatchesREGEX(cls, values):
if "inputData" in values and values["inputData"]:
for lfn in values["inputData"]:
if not lfn.startswith(f"LFN:/{values['vo']}/"):
raise ValueError(f"Input data not correctly specified (must start with LFN:/{values['vo']}/)")
return values
@model_validator(mode="after")
def checkLFNMatchesREGEX(self) -> Self:
if self.inputData:
for lfn in self.inputData:
if not lfn.startswith(f"LFN:/{self.vo}/"):
raise ValueError(f"Input data not correctly specified (must start with LFN:/{self.vo}/)")
return self
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,9 @@ def test_sitesValidator_invalid(validSites, selectedSites):
@pytest.mark.parametrize(
"sites, bannedSites, parsedSites, parsedBannedSites",
[
({"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, None, {"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, None),
(None, {"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, None, {"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}),
({"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, {"LCG.PIC.es", "LCG.CNAF.it"}, {"LCG.IN2P3.fr"}, None),
({"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, None, {"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, set()),
(None, {"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, set(), {"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}),
({"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, {"LCG.PIC.es", "LCG.CNAF.it"}, {"LCG.IN2P3.fr"}, set()),
],
)
def test_checkThatSitesAndBannedSitesAreNotMutuallyExclusive_valid(
Expand Down
6 changes: 3 additions & 3 deletions src/DIRAC/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@

"""
import os
import importlib.metadata
import re
import sys
import warnings
from pkgutil import extend_path
from typing import Any, Optional, Union
from pkg_resources import get_distribution, DistributionNotFound


__path__ = extend_path(__path__, __name__)
Expand All @@ -81,9 +81,9 @@

# Define Version
try:
__version__ = get_distribution(__name__).version
__version__ = importlib.metadata.version(__name__)
version = __version__
except DistributionNotFound:
except importlib.metadata.PackageNotFoundError:
# package is not installed
version = "Unknown"

Expand Down
Loading