diff --git a/pykanidm/kanidm/__init__.py b/pykanidm/kanidm/__init__.py index 94e270d2e..c4f9530cc 100644 --- a/pykanidm/kanidm/__init__.py +++ b/pykanidm/kanidm/__init__.py @@ -77,7 +77,7 @@ class KanidmClient: if not isinstance(config_file, Path): config_file = Path(config_file) config_data = load_config(config_file.expanduser().resolve()) - self.config = self.config.parse_obj(config_data) + self.config = self.config.model_validate(config_data) if self.config.uri is None: raise ValueError("Please initialize this with a server URI") @@ -108,7 +108,7 @@ class KanidmClient: ) -> None: """hand it a config dict and it'll configure the client""" try: - self.config.parse_obj(config_data) + self.config.model_validate(config_data) except ValidationError as validation_error: # pylint: disable=raise-missing-from raise ValueError(f"Failed to validate configuration: {validation_error}") @@ -191,7 +191,7 @@ class KanidmClient: "status_code": request.status, } logging.debug(json_lib.dumps(response_input, default=str, indent=4)) - response = ClientResponse.parse_obj(response_input) + response = ClientResponse.model_validate(response_input) return response async def call_get( @@ -239,7 +239,7 @@ class KanidmClient: raise ValueError( f"Missing x-kanidm-auth-session-id header in init auth response: {response.headers}" ) - retval = AuthInitResponse.parse_obj(response.data) + retval = AuthInitResponse.model_validate(response.data) retval.response = response return retval @@ -262,7 +262,7 @@ class KanidmClient: # TODO: mock test for auth_begin raises AuthBeginFailed raise AuthBeginFailed(response.content) - retobject = AuthBeginResponse.parse_obj(response.data) + retobject = AuthBeginResponse.model_validate(response.data) retobject.response = response return response @@ -296,7 +296,7 @@ class KanidmClient: raise AuthMechUnknown(f"No auth mechanisms for {username}") auth_begin = await self.auth_begin(method="password", sessionid=sessionid) # does a little bit of validation - auth_begin_object = AuthBeginResponse.parse_obj(auth_begin.data) + auth_begin_object = AuthBeginResponse.model_validate(auth_begin.data) auth_begin_object.response = auth_begin return await self.auth_step_password(password=password, sessionid=sessionid) @@ -324,7 +324,7 @@ class KanidmClient: logging.debug("Failed to authenticate, response: %s", response.content) raise AuthCredFailed("Failed password authentication!") - result = AuthStepPasswordResponse.parse_obj(response.data) + result = AuthStepPasswordResponse.model_validate(response.data) result.response = response # pull the token out and set it diff --git a/pykanidm/kanidm/radius/__init__.py b/pykanidm/kanidm/radius/__init__.py index 57972ce02..ba66b5cfd 100644 --- a/pykanidm/kanidm/radius/__init__.py +++ b/pykanidm/kanidm/radius/__init__.py @@ -107,7 +107,7 @@ def authorize( tok = None try: loop = asyncio.get_event_loop() - tok = RadiusTokenResponse.parse_obj( + tok = RadiusTokenResponse.model_validate( loop.run_until_complete(_get_radius_token(username=user_id)) ) logging.debug("radius information token: %s", tok) diff --git a/pykanidm/kanidm/tokens.py b/pykanidm/kanidm/tokens.py index 20006639d..57435bed0 100644 --- a/pykanidm/kanidm/tokens.py +++ b/pykanidm/kanidm/tokens.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from authlib.jose import JsonWebSignature # type: ignore -from pydantic import BaseModel +from pydantic import ConfigDict, BaseModel, RootModel from . import TOKEN_PATH @@ -31,11 +31,7 @@ class JWSHeader(BaseModel): alg: str typ: str jwk: JWSHeaderJWK - - class Config: - """Configure the pydantic class""" - - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) class JWSPayload(BaseModel): @@ -50,7 +46,7 @@ class JWSPayload(BaseModel): name: str displayname: str spn: str - mail_primary: Optional[str] + mail_primary: Optional[str] = None lim_uidx: bool lim_rmax: int lim_pmax: int @@ -93,13 +89,13 @@ class JWS: padded_header = raw_header + "=" * divmod(len(raw_header), 4)[0] decoded_header = base64.urlsafe_b64decode(padded_header) logging.debug("decoded_header=%s", decoded_header) - header = JWSHeader.parse_obj(json.loads(decoded_header.decode("utf-8"))) + header = JWSHeader.model_validate(json.loads(decoded_header.decode("utf-8"))) logging.debug("header: %s", header) raw_payload = split_raw[1] logging.debug("Parsing payload: %s", raw_payload) padded_payload = raw_payload + "=" * divmod(len(raw_payload), 4)[1] - payload = JWSPayload.parse_raw(base64.urlsafe_b64decode(padded_payload)) + payload = JWSPayload.model_validate_json(base64.urlsafe_b64decode(padded_payload)) raw_signature = split_raw[2] logging.debug("Parsing signature: %s", raw_signature) @@ -109,32 +105,31 @@ class JWS: return header, payload, signature -class TokenStore(BaseModel): - """Represents the user auth tokens, can load them from the user store""" - - __root__: Dict[str, str] = {} +class TokenStore(RootModel[Dict[str, str]]): + """Represents the user auth tokens, so we can load them from the user store""" + root: Dict[str, str] # TODO: one day work out how to type the __iter__ on TokenStore properly. It's some kind of iter() that makes mypy unhappy. def __iter__(self) -> Any: """overloading the default function""" - for key in self.__root__.keys(): + for key in self.root.keys(): yield key def __getitem__(self, item: str) -> str: """overloading the default function""" - return self.__root__[item] + return self.root[item] def __delitem__(self, item: str) -> None: """overloading the default function""" - del self.__root__[item] + del self.root[item] def __setitem__(self, key: str, value: str) -> None: """overloading the default function""" - self.__root__[key] = value + self.root[key] = value def save(self, filepath: Path = TOKEN_PATH) -> None: """saves the cached tokens to disk""" - data = json.dumps(self.__root__, indent=2) + data = json.dumps(self.root, indent=2) with filepath.expanduser().resolve().open( mode="w", encoding="utf-8" ) as file_handle: @@ -157,19 +152,19 @@ class TokenStore(BaseModel): tokens = json.load(file_handle) if overwrite: - self.__root__ = tokens + self.root = tokens else: for user in tokens: - self.__root__[user] = tokens[user] + self.root[user] = tokens[user] self.validate_tokens() logging.debug(json.dumps(tokens, indent=4)) - return self.__root__ + return self.root def validate_tokens(self) -> None: """validates the JWS tokens for format, not their signature - PRs welcome""" - for username in self.__root__: + for username in self.root: logging.debug("Parsing %s", username) # TODO: Work out how to get the validation working. We probably shouldn't be worried about this since we're using it for auth... logging.debug( @@ -184,4 +179,4 @@ class TokenStore(BaseModel): s=self[username], key=None ) logging.debug(parsed_object) - return JWSPayload.parse_raw(parsed_object.payload) + return JWSPayload.model_validate_json(parsed_object.payload) diff --git a/pykanidm/kanidm/types.py b/pykanidm/kanidm/types.py index 362709c24..57afbe0a2 100644 --- a/pykanidm/kanidm/types.py +++ b/pykanidm/kanidm/types.py @@ -7,7 +7,7 @@ import socket from typing import Any, Dict, List, Optional from urllib.parse import urlparse -from pydantic import BaseModel, Field, validator +from pydantic import field_validator, ConfigDict, BaseModel, Field import toml @@ -19,15 +19,11 @@ class ClientResponse(BaseModel): status_code: int """ - content: Optional[str] - data: Optional[Dict[str, Any]] + content: Optional[str] = None + data: Optional[Dict[str, Any]] = None headers: Dict[str, Any] status_code: int - - class Config: - """Configuration""" - - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) class AuthInitResponse(BaseModel): @@ -35,18 +31,13 @@ class AuthInitResponse(BaseModel): class _AuthInitState(BaseModel): """sub-class for the AuthInitResponse model""" - # TODO: can we add validation for AuthInitResponse.state.choose? choose: List[str] sessionid: str state: _AuthInitState - response: Optional[ClientResponse] - - class Config: - """config class""" - - arbitrary_types_allowed = True + response: Optional[ClientResponse] = None + # model_config = ConfigDict(arbitrary_types_allowed=True) class AuthBeginResponse(BaseModel): @@ -65,11 +56,7 @@ class AuthBeginResponse(BaseModel): sessionid: str state: _AuthBeginState response: Optional[ClientResponse] - - class Config: - """config class""" - - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) class AuthStepPasswordResponse(BaseModel): @@ -78,16 +65,12 @@ class AuthStepPasswordResponse(BaseModel): class _AuthStepPasswordState(BaseModel): """subclass to help parse the response from the auth 'step password' stage""" - success: Optional[str] + success: Optional[str] = None sessionid: str state: _AuthStepPasswordState response: Optional[ClientResponse] - - class Config: - """config class""" - - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) class RadiusGroup(BaseModel): @@ -96,7 +79,8 @@ class RadiusGroup(BaseModel): spn: str vlan: int - @validator("vlan") + @field_validator("vlan") + @classmethod def validate_vlan(cls, value: int) -> int: """validate the vlan option is above 0""" if not value > 0: @@ -120,11 +104,7 @@ class RadiusTokenResponse(BaseModel): uuid: str groups: List[RadiusTokenGroup] - - class Config: - """config for RadiusTokenGroupList""" - - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) class RadiusClient(BaseModel): @@ -146,7 +126,8 @@ class RadiusClient(BaseModel): ipaddr: str secret: str # TODO: this should probably be renamed to token - @validator("ipaddr") + @field_validator("ipaddr") + @classmethod def validate_ipaddr(cls, value: str) -> str: """validates the ipaddr field is an IP address, CIDR or valid hostname""" for typedef in (IPv6Network, IPv6Address, IPv4Address, IPv4Network): @@ -198,9 +179,10 @@ class KanidmClientConfig(BaseModel): @classmethod def parse_toml(cls, input_string: str) -> Any: """loads from a string""" - return super().parse_obj(toml.loads(input_string)) + return super().model_validate(toml.loads(input_string)) - @validator("uri") + @field_validator("uri") + @classmethod def validate_uri(cls, value: Optional[str]) -> Optional[str]: """validator for the uri field""" if value is not None: diff --git a/pykanidm/tests/test_authenticate.py b/pykanidm/tests/test_authenticate.py index d3a5140cf..7b02d1883 100644 --- a/pykanidm/tests/test_authenticate.py +++ b/pykanidm/tests/test_authenticate.py @@ -64,7 +64,7 @@ async def test_auth_begin(client_configfile: KanidmClient) -> None: retval["response"] = begin_result - assert AuthBeginResponse.parse_obj(retval) + assert AuthBeginResponse.model_validate(retval) @pytest.mark.network @@ -170,7 +170,7 @@ async def test_authenticate_with_token(client_configfile: KanidmClient) -> None: f"Using username {test_username} by default - set KANIDM_TEST_USERNAME env var if you want to change this." ) - tokens = TokenStore() + tokens = TokenStore.model_validate({}) tokens.load() if test_username not in tokens: diff --git a/pykanidm/tests/test_config_loader.py b/pykanidm/tests/test_config_loader.py index 3d1407337..40560ee55 100644 --- a/pykanidm/tests/test_config_loader.py +++ b/pykanidm/tests/test_config_loader.py @@ -80,7 +80,7 @@ def test_config_invalid_uri() -> None: "uri": "asdfsadfasd", } with pytest.raises(pydantic.ValidationError): - KanidmClientConfig.parse_obj(test_input) + KanidmClientConfig.model_validate(test_input) def test_config_none_uri() -> None: diff --git a/pykanidm/tests/test_jwt.py b/pykanidm/tests/test_jwt.py index 8c85d500d..0fe2ea81a 100644 --- a/pykanidm/tests/test_jwt.py +++ b/pykanidm/tests/test_jwt.py @@ -7,7 +7,7 @@ import pytest from kanidm.tokens import JWS, TokenStore -TEST_TOKEN = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJrdHkiOiJFQyIsImNydiI6IlAtMjU2IiwieCI6Im1KQTgtTURfeFRxQXBmSU9nbFptNXJ6RWhoQ3hDdjRxZFNpeGxjV1Q3ZmsiLCJ5IjoiNy0yVkNuY0h3NEF1WVJpYVpYT2FoVXRGMUE2SDd3eUxrUW1FekduS0pKcyIsImFsZyI6IkVTMjU2IiwidXNlIjoic2lnIn0sInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX2lkIjoiZjExOTg2NzMtNGI5MC00NjE4LWJkZTctMTBiY2M2YzhjOGE0IiwiYXV0aF90eXBlIjoiZ2VuZXJhdGVkcGFzc3dvcmQiLCJleHBpcnkiOlsyMDIyLDI2NSwyODM2Niw4MDI1MjUwMDBdLCJ1dWlkIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDE4IiwibmFtZSI6ImlkbV9hZG1pbiIsImRpc3BsYXluYW1lIjoiSURNIEFkbWluaXN0cmF0b3IiLCJzcG4iOiJpZG1fYWRtaW5AbG9jYWxob3N0IiwibWFpbF9wcmltYXJ5IjpudWxsLCJsaW1fdWlkeCI6ZmFsc2UsImxpbV9ybWF4IjoxMjgsImxpbV9wbWF4IjoyNTYsImxpbV9mbWF4IjozMn0.cln3gRV3NdgbGqYeD26mBSHFGOaFXak2UA5umvj_Xw30dMS8ECTnJU7lvLyepRTW_VzqUJHbRatPkQ1TEuK99Q" # noqa: E501 pylint: disable=line-too-long +TEST_TOKEN = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJrdHkiOiJFQyIsImNydiI6IlAtMjU2IiwieCI6Im1KQTgtTURfeFRxQXBmSU9nbFptNXJ6RWhoQ3hDdjRxZFNpeGxjV1Q3ZmsiLCJ5IjoiNy0yVkNuY0h3NEF1WVJpYVpYT2FoVXRGMUE2SDd3eUxrUW1FekduS0pKcyIsImFsZyI6IkVTMjU2IiwidXNlIjoic2lnIn0sInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX2lkIjoiZjExOTg2NzMtNGI5MC00NjE4LWJkZTctMTBiY2M2YzhjOGE0IiwiYXV0aF90eXBlIjoiZ2VuZXJhdGVkcGFzc3dvcmQiLCJleHBpcnkiOlsyMDIyLDI2NSwyODM2Niw4MDI1MjUwMDBdLCJ1dWlkIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDE4IiwibmFtZSI6ImlkbV9hZG1pbiIsImRpc3BsYXluYW1lIjoiSURNIEFkbWluaXN0cmF0b3IiLCJzcG4iOiJpZG1fYWRtaW5AbG9jYWxob3N0IiwibWFpbF9wcmltYXJ5IjpudWxsLCJsaW1fdWlkeCI6ZmFsc2UsImxpbV9ybWF4IjoxMjgsImxpbV9wbWF4IjoyNTYsImxpbV9mbWF4IjozMn0.cln3gRV3NdgbGqYeD26mBSHFGOaFXak2UA5umvj_Xw30dMS8ECTnJU7lvLyepRTW_VzqUJHbRatPkQ1TEuK99Q" # noqa: E501 pylint: disable=line-too-long def test_jws_parser() -> None: @@ -43,16 +43,16 @@ def test_jws_parser() -> None: test_jws = JWS(TEST_TOKEN) - assert test_jws.header.dict() == expected_header - assert test_jws.payload.dict() == expected_payload + assert test_jws.header.model_dump() == expected_header + assert test_jws.payload.model_dump() == expected_payload def test_tokenstuff() -> None: """tests stuff""" - token_store = TokenStore() + token_store = TokenStore.model_validate({}) token_store[ "idm_admin" - ] = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJrdHkiOiJFQyIsImNydiI6IlAtMjU2IiwieCI6Im1KQTgtTURfeFRxQXBmSU9nbFptNXJ6RWhoQ3hDdjRxZFNpeGxjV1Q3ZmsiLCJ5IjoiNy0yVkNuY0h3NEF1WVJpYVpYT2FoVXRGMUE2SDd3eUxrUW1FekduS0pKcyIsImFsZyI6IkVTMjU2IiwidXNlIjoic2lnIn0sInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX2lkIjoiMTBmZDJjYzMtM2UxZS00MjM1LTk4NjEtNWQyNjQ3NTAyMmVkIiwiYXV0aF90eXBlIjoiZ2VuZXJhdGVkcGFzc3dvcmQiLCJleHBpcnkiOlsyMDIyLDI2NSwzMzkyMywyOTQyNTQwMDBdLCJ1dWlkIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDE4IiwibmFtZSI6ImlkbV9hZG1pbiIsImRpc3BsYXluYW1lIjoiSURNIEFkbWluaXN0cmF0b3IiLCJzcG4iOiJpZG1fYWRtaW5AbG9jYWxob3N0IiwibWFpbF9wcmltYXJ5IjpudWxsLCJsaW1fdWlkeCI6ZmFsc2UsImxpbV9ybWF4IjoxMjgsImxpbV9wbWF4IjoyNTYsImxpbV9mbWF4IjozMn0.rq1y7YNS9iCBWMmAu-FSa4-o4jrSSnMO_18zafgvLRtZFlB7j-Q68CzxceNN9C_1EWnc9uf4fOyeaSNUwGyaIQ" # noqa: E501 pylint: disable=line-too-long + ] = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJrdHkiOiJFQyIsImNydiI6IlAtMjU2IiwieCI6Im1KQTgtTURfeFRxQXBmSU9nbFptNXJ6RWhoQ3hDdjRxZFNpeGxjV1Q3ZmsiLCJ5IjoiNy0yVkNuY0h3NEF1WVJpYVpYT2FoVXRGMUE2SDd3eUxrUW1FekduS0pKcyIsImFsZyI6IkVTMjU2IiwidXNlIjoic2lnIn0sInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX2lkIjoiMTBmZDJjYzMtM2UxZS00MjM1LTk4NjEtNWQyNjQ3NTAyMmVkIiwiYXV0aF90eXBlIjoiZ2VuZXJhdGVkcGFzc3dvcmQiLCJleHBpcnkiOlsyMDIyLDI2NSwzMzkyMywyOTQyNTQwMDBdLCJ1dWlkIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDE4IiwibmFtZSI6ImlkbV9hZG1pbiIsImRpc3BsYXluYW1lIjoiSURNIEFkbWluaXN0cmF0b3IiLCJzcG4iOiJpZG1fYWRtaW5AbG9jYWxob3N0IiwibWFpbF9wcmltYXJ5IjpudWxsLCJsaW1fdWlkeCI6ZmFsc2UsImxpbV9ybWF4IjoxMjgsImxpbV9wbWF4IjoyNTYsImxpbV9mbWF4IjozMn0.rq1y7YNS9iCBWMmAu-FSa4-o4jrSSnMO_18zafgvLRtZFlB7j-Q68CzxceNN9C_1EWnc9uf4fOyeaSNUwGyaIQ" # noqa: E501 pylint: disable=line-too-long info = token_store.token_info("idm_admin") print(f"Parsed token: {info}") diff --git a/pykanidm/tests/test_radius_config.py b/pykanidm/tests/test_radius_config.py index 2a2497be2..4ef260c69 100644 --- a/pykanidm/tests/test_radius_config.py +++ b/pykanidm/tests/test_radius_config.py @@ -19,7 +19,7 @@ def test_load_config_file() -> None: print("Can't find client config file", file=sys.stderr) pytest.skip() config = load_config(EXAMPLE_CONFIG_FILE) - kanidm_config = KanidmClientConfig.parse_obj(config) + kanidm_config = KanidmClientConfig.model_validate(config) assert kanidm_config.uri == "https://idm.example.com/" print(f"{kanidm_config.uri=}") print(kanidm_config) @@ -36,7 +36,7 @@ radius_groups = [ """ config_parsed = toml.loads(config_toml) print(config_parsed) - kanidm_config = KanidmClientConfig.parse_obj(config_parsed) + kanidm_config = KanidmClientConfig.model_validate(config_parsed) for group in kanidm_config.radius_groups: print(group.spn) assert group.spn == "hello world" @@ -52,7 +52,7 @@ radius_clients = [ { name = "hello world", ipaddr = "10.0.0.5", secret = "cr4bj0 """ config_parsed = toml.loads(config_toml) print(config_parsed) - kanidm_config = KanidmClientConfig.parse_obj(config_parsed) + kanidm_config = KanidmClientConfig.model_validate(config_parsed) client = kanidm_config.radius_clients[0] print(client.name) assert client.name == "hello world" diff --git a/pykanidm/tests/test_types.py b/pykanidm/tests/test_types.py index 582209164..8555e2c41 100644 --- a/pykanidm/tests/test_types.py +++ b/pykanidm/tests/test_types.py @@ -1,7 +1,7 @@ """ tests types """ import pytest -import pydantic.error_wrappers +from pydantic import ValidationError from kanidm.types import AuthInitResponse, KanidmClientConfig, RadiusGroup, RadiusClient @@ -15,19 +15,19 @@ def test_auth_init_response() -> None: }, } - testval = AuthInitResponse.parse_obj(testobj) + testval = AuthInitResponse.model_validate(testobj) assert testval.sessionid == "crabzrool" def test_radiusgroup_vlan_negative() -> None: """tests RadiusGroup's vlan validator""" - with pytest.raises(pydantic.error_wrappers.ValidationError): + with pytest.raises(ValidationError): RadiusGroup(vlan=-1, spn="crabzrool@foo") def test_radiusgroup_vlan_zero() -> None: """tests RadiusGroup's vlan validator""" - with pytest.raises(pydantic.error_wrappers.ValidationError): + with pytest.raises(ValidationError): RadiusGroup(vlan=0, spn="crabzrool@foo") @@ -38,10 +38,9 @@ def test_radiusgroup_vlan_4096() -> None: def test_radiusgroup_vlan_no_name() -> None: """tests RadiusGroup's vlan validator""" - with pytest.raises( - pydantic.error_wrappers.ValidationError, match="spn\n.*field required" - ): - RadiusGroup(vlan=4096) # type: ignore[call-arg] + with pytest.raises(ValidationError, match="(?i)spn\n.*Field required"): + RadiusGroup(vlan=4096) # type: ignore[call-arg] + def test_kanidmconfig_parse_toml() -> None: """tests KanidmClientConfig.parse_toml()""" @@ -53,7 +52,7 @@ def test_kanidmconfig_parse_toml() -> None: @pytest.mark.network def test_radius_client_bad_hostname() -> None: """tests with a bad hostname""" - with pytest.raises(pydantic.error_wrappers.ValidationError): + with pytest.raises(ValidationError): RadiusClient( name="test", ipaddr="thiscannotpossiblywork.kanidm.example.com",