fixing up pydantic things (#1885)

This commit is contained in:
James Hodgkinson 2023-07-24 00:09:43 -07:00 committed by GitHub
parent e17dcc0ddb
commit 8981478d76
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 62 additions and 86 deletions

View file

@ -77,7 +77,7 @@ class KanidmClient:
if not isinstance(config_file, Path): if not isinstance(config_file, Path):
config_file = Path(config_file) config_file = Path(config_file)
config_data = load_config(config_file.expanduser().resolve()) config_data = load_config(config_file.expanduser().resolve())
self.config = self.config.parse_obj(config_data) self.config = self.config.model_validate(config_data)
if self.config.uri is None: if self.config.uri is None:
raise ValueError("Please initialize this with a server URI") raise ValueError("Please initialize this with a server URI")
@ -108,7 +108,7 @@ class KanidmClient:
) -> None: ) -> None:
"""hand it a config dict and it'll configure the client""" """hand it a config dict and it'll configure the client"""
try: try:
self.config.parse_obj(config_data) self.config.model_validate(config_data)
except ValidationError as validation_error: except ValidationError as validation_error:
# pylint: disable=raise-missing-from # pylint: disable=raise-missing-from
raise ValueError(f"Failed to validate configuration: {validation_error}") raise ValueError(f"Failed to validate configuration: {validation_error}")
@ -191,7 +191,7 @@ class KanidmClient:
"status_code": request.status, "status_code": request.status,
} }
logging.debug(json_lib.dumps(response_input, default=str, indent=4)) logging.debug(json_lib.dumps(response_input, default=str, indent=4))
response = ClientResponse.parse_obj(response_input) response = ClientResponse.model_validate(response_input)
return response return response
async def call_get( async def call_get(
@ -239,7 +239,7 @@ class KanidmClient:
raise ValueError( raise ValueError(
f"Missing x-kanidm-auth-session-id header in init auth response: {response.headers}" f"Missing x-kanidm-auth-session-id header in init auth response: {response.headers}"
) )
retval = AuthInitResponse.parse_obj(response.data) retval = AuthInitResponse.model_validate(response.data)
retval.response = response retval.response = response
return retval return retval
@ -262,7 +262,7 @@ class KanidmClient:
# TODO: mock test for auth_begin raises AuthBeginFailed # TODO: mock test for auth_begin raises AuthBeginFailed
raise AuthBeginFailed(response.content) raise AuthBeginFailed(response.content)
retobject = AuthBeginResponse.parse_obj(response.data) retobject = AuthBeginResponse.model_validate(response.data)
retobject.response = response retobject.response = response
return response return response
@ -296,7 +296,7 @@ class KanidmClient:
raise AuthMechUnknown(f"No auth mechanisms for {username}") raise AuthMechUnknown(f"No auth mechanisms for {username}")
auth_begin = await self.auth_begin(method="password", sessionid=sessionid) auth_begin = await self.auth_begin(method="password", sessionid=sessionid)
# does a little bit of validation # does a little bit of validation
auth_begin_object = AuthBeginResponse.parse_obj(auth_begin.data) auth_begin_object = AuthBeginResponse.model_validate(auth_begin.data)
auth_begin_object.response = auth_begin auth_begin_object.response = auth_begin
return await self.auth_step_password(password=password, sessionid=sessionid) return await self.auth_step_password(password=password, sessionid=sessionid)
@ -324,7 +324,7 @@ class KanidmClient:
logging.debug("Failed to authenticate, response: %s", response.content) logging.debug("Failed to authenticate, response: %s", response.content)
raise AuthCredFailed("Failed password authentication!") raise AuthCredFailed("Failed password authentication!")
result = AuthStepPasswordResponse.parse_obj(response.data) result = AuthStepPasswordResponse.model_validate(response.data)
result.response = response result.response = response
# pull the token out and set it # pull the token out and set it

View file

@ -107,7 +107,7 @@ def authorize(
tok = None tok = None
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
tok = RadiusTokenResponse.parse_obj( tok = RadiusTokenResponse.model_validate(
loop.run_until_complete(_get_radius_token(username=user_id)) loop.run_until_complete(_get_radius_token(username=user_id))
) )
logging.debug("radius information token: %s", tok) logging.debug("radius information token: %s", tok)

View file

@ -10,7 +10,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from authlib.jose import JsonWebSignature # type: ignore from authlib.jose import JsonWebSignature # type: ignore
from pydantic import BaseModel from pydantic import ConfigDict, BaseModel, RootModel
from . import TOKEN_PATH from . import TOKEN_PATH
@ -31,11 +31,7 @@ class JWSHeader(BaseModel):
alg: str alg: str
typ: str typ: str
jwk: JWSHeaderJWK jwk: JWSHeaderJWK
model_config = ConfigDict(arbitrary_types_allowed=True)
class Config:
"""Configure the pydantic class"""
arbitrary_types_allowed = True
class JWSPayload(BaseModel): class JWSPayload(BaseModel):
@ -50,7 +46,7 @@ class JWSPayload(BaseModel):
name: str name: str
displayname: str displayname: str
spn: str spn: str
mail_primary: Optional[str] mail_primary: Optional[str] = None
lim_uidx: bool lim_uidx: bool
lim_rmax: int lim_rmax: int
lim_pmax: int lim_pmax: int
@ -93,13 +89,13 @@ class JWS:
padded_header = raw_header + "=" * divmod(len(raw_header), 4)[0] padded_header = raw_header + "=" * divmod(len(raw_header), 4)[0]
decoded_header = base64.urlsafe_b64decode(padded_header) decoded_header = base64.urlsafe_b64decode(padded_header)
logging.debug("decoded_header=%s", decoded_header) logging.debug("decoded_header=%s", decoded_header)
header = JWSHeader.parse_obj(json.loads(decoded_header.decode("utf-8"))) header = JWSHeader.model_validate(json.loads(decoded_header.decode("utf-8")))
logging.debug("header: %s", header) logging.debug("header: %s", header)
raw_payload = split_raw[1] raw_payload = split_raw[1]
logging.debug("Parsing payload: %s", raw_payload) logging.debug("Parsing payload: %s", raw_payload)
padded_payload = raw_payload + "=" * divmod(len(raw_payload), 4)[1] padded_payload = raw_payload + "=" * divmod(len(raw_payload), 4)[1]
payload = JWSPayload.parse_raw(base64.urlsafe_b64decode(padded_payload)) payload = JWSPayload.model_validate_json(base64.urlsafe_b64decode(padded_payload))
raw_signature = split_raw[2] raw_signature = split_raw[2]
logging.debug("Parsing signature: %s", raw_signature) logging.debug("Parsing signature: %s", raw_signature)
@ -109,32 +105,31 @@ class JWS:
return header, payload, signature return header, payload, signature
class TokenStore(BaseModel): class TokenStore(RootModel[Dict[str, str]]):
"""Represents the user auth tokens, can load them from the user store""" """Represents the user auth tokens, so we can load them from the user store"""
root: Dict[str, str]
__root__: Dict[str, str] = {}
# TODO: one day work out how to type the __iter__ on TokenStore properly. It's some kind of iter() that makes mypy unhappy. # TODO: one day work out how to type the __iter__ on TokenStore properly. It's some kind of iter() that makes mypy unhappy.
def __iter__(self) -> Any: def __iter__(self) -> Any:
"""overloading the default function""" """overloading the default function"""
for key in self.__root__.keys(): for key in self.root.keys():
yield key yield key
def __getitem__(self, item: str) -> str: def __getitem__(self, item: str) -> str:
"""overloading the default function""" """overloading the default function"""
return self.__root__[item] return self.root[item]
def __delitem__(self, item: str) -> None: def __delitem__(self, item: str) -> None:
"""overloading the default function""" """overloading the default function"""
del self.__root__[item] del self.root[item]
def __setitem__(self, key: str, value: str) -> None: def __setitem__(self, key: str, value: str) -> None:
"""overloading the default function""" """overloading the default function"""
self.__root__[key] = value self.root[key] = value
def save(self, filepath: Path = TOKEN_PATH) -> None: def save(self, filepath: Path = TOKEN_PATH) -> None:
"""saves the cached tokens to disk""" """saves the cached tokens to disk"""
data = json.dumps(self.__root__, indent=2) data = json.dumps(self.root, indent=2)
with filepath.expanduser().resolve().open( with filepath.expanduser().resolve().open(
mode="w", encoding="utf-8" mode="w", encoding="utf-8"
) as file_handle: ) as file_handle:
@ -157,19 +152,19 @@ class TokenStore(BaseModel):
tokens = json.load(file_handle) tokens = json.load(file_handle)
if overwrite: if overwrite:
self.__root__ = tokens self.root = tokens
else: else:
for user in tokens: for user in tokens:
self.__root__[user] = tokens[user] self.root[user] = tokens[user]
self.validate_tokens() self.validate_tokens()
logging.debug(json.dumps(tokens, indent=4)) logging.debug(json.dumps(tokens, indent=4))
return self.__root__ return self.root
def validate_tokens(self) -> None: def validate_tokens(self) -> None:
"""validates the JWS tokens for format, not their signature - PRs welcome""" """validates the JWS tokens for format, not their signature - PRs welcome"""
for username in self.__root__: for username in self.root:
logging.debug("Parsing %s", username) logging.debug("Parsing %s", username)
# TODO: Work out how to get the validation working. We probably shouldn't be worried about this since we're using it for auth... # TODO: Work out how to get the validation working. We probably shouldn't be worried about this since we're using it for auth...
logging.debug( logging.debug(
@ -184,4 +179,4 @@ class TokenStore(BaseModel):
s=self[username], key=None s=self[username], key=None
) )
logging.debug(parsed_object) logging.debug(parsed_object)
return JWSPayload.parse_raw(parsed_object.payload) return JWSPayload.model_validate_json(parsed_object.payload)

View file

@ -7,7 +7,7 @@ import socket
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from pydantic import BaseModel, Field, validator from pydantic import field_validator, ConfigDict, BaseModel, Field
import toml import toml
@ -19,15 +19,11 @@ class ClientResponse(BaseModel):
status_code: int status_code: int
""" """
content: Optional[str] content: Optional[str] = None
data: Optional[Dict[str, Any]] data: Optional[Dict[str, Any]] = None
headers: Dict[str, Any] headers: Dict[str, Any]
status_code: int status_code: int
model_config = ConfigDict(arbitrary_types_allowed=True)
class Config:
"""Configuration"""
arbitrary_types_allowed = True
class AuthInitResponse(BaseModel): class AuthInitResponse(BaseModel):
@ -35,18 +31,13 @@ class AuthInitResponse(BaseModel):
class _AuthInitState(BaseModel): class _AuthInitState(BaseModel):
"""sub-class for the AuthInitResponse model""" """sub-class for the AuthInitResponse model"""
# TODO: can we add validation for AuthInitResponse.state.choose? # TODO: can we add validation for AuthInitResponse.state.choose?
choose: List[str] choose: List[str]
sessionid: str sessionid: str
state: _AuthInitState state: _AuthInitState
response: Optional[ClientResponse] response: Optional[ClientResponse] = None
# model_config = ConfigDict(arbitrary_types_allowed=True)
class Config:
"""config class"""
arbitrary_types_allowed = True
class AuthBeginResponse(BaseModel): class AuthBeginResponse(BaseModel):
@ -65,11 +56,7 @@ class AuthBeginResponse(BaseModel):
sessionid: str sessionid: str
state: _AuthBeginState state: _AuthBeginState
response: Optional[ClientResponse] response: Optional[ClientResponse]
model_config = ConfigDict(arbitrary_types_allowed=True)
class Config:
"""config class"""
arbitrary_types_allowed = True
class AuthStepPasswordResponse(BaseModel): class AuthStepPasswordResponse(BaseModel):
@ -78,16 +65,12 @@ class AuthStepPasswordResponse(BaseModel):
class _AuthStepPasswordState(BaseModel): class _AuthStepPasswordState(BaseModel):
"""subclass to help parse the response from the auth 'step password' stage""" """subclass to help parse the response from the auth 'step password' stage"""
success: Optional[str] success: Optional[str] = None
sessionid: str sessionid: str
state: _AuthStepPasswordState state: _AuthStepPasswordState
response: Optional[ClientResponse] response: Optional[ClientResponse]
model_config = ConfigDict(arbitrary_types_allowed=True)
class Config:
"""config class"""
arbitrary_types_allowed = True
class RadiusGroup(BaseModel): class RadiusGroup(BaseModel):
@ -96,7 +79,8 @@ class RadiusGroup(BaseModel):
spn: str spn: str
vlan: int vlan: int
@validator("vlan") @field_validator("vlan")
@classmethod
def validate_vlan(cls, value: int) -> int: def validate_vlan(cls, value: int) -> int:
"""validate the vlan option is above 0""" """validate the vlan option is above 0"""
if not value > 0: if not value > 0:
@ -120,11 +104,7 @@ class RadiusTokenResponse(BaseModel):
uuid: str uuid: str
groups: List[RadiusTokenGroup] groups: List[RadiusTokenGroup]
model_config = ConfigDict(arbitrary_types_allowed=True)
class Config:
"""config for RadiusTokenGroupList"""
arbitrary_types_allowed = True
class RadiusClient(BaseModel): class RadiusClient(BaseModel):
@ -146,7 +126,8 @@ class RadiusClient(BaseModel):
ipaddr: str ipaddr: str
secret: str # TODO: this should probably be renamed to token secret: str # TODO: this should probably be renamed to token
@validator("ipaddr") @field_validator("ipaddr")
@classmethod
def validate_ipaddr(cls, value: str) -> str: def validate_ipaddr(cls, value: str) -> str:
"""validates the ipaddr field is an IP address, CIDR or valid hostname""" """validates the ipaddr field is an IP address, CIDR or valid hostname"""
for typedef in (IPv6Network, IPv6Address, IPv4Address, IPv4Network): for typedef in (IPv6Network, IPv6Address, IPv4Address, IPv4Network):
@ -198,9 +179,10 @@ class KanidmClientConfig(BaseModel):
@classmethod @classmethod
def parse_toml(cls, input_string: str) -> Any: def parse_toml(cls, input_string: str) -> Any:
"""loads from a string""" """loads from a string"""
return super().parse_obj(toml.loads(input_string)) return super().model_validate(toml.loads(input_string))
@validator("uri") @field_validator("uri")
@classmethod
def validate_uri(cls, value: Optional[str]) -> Optional[str]: def validate_uri(cls, value: Optional[str]) -> Optional[str]:
"""validator for the uri field""" """validator for the uri field"""
if value is not None: if value is not None:

View file

@ -64,7 +64,7 @@ async def test_auth_begin(client_configfile: KanidmClient) -> None:
retval["response"] = begin_result retval["response"] = begin_result
assert AuthBeginResponse.parse_obj(retval) assert AuthBeginResponse.model_validate(retval)
@pytest.mark.network @pytest.mark.network
@ -170,7 +170,7 @@ async def test_authenticate_with_token(client_configfile: KanidmClient) -> None:
f"Using username {test_username} by default - set KANIDM_TEST_USERNAME env var if you want to change this." f"Using username {test_username} by default - set KANIDM_TEST_USERNAME env var if you want to change this."
) )
tokens = TokenStore() tokens = TokenStore.model_validate({})
tokens.load() tokens.load()
if test_username not in tokens: if test_username not in tokens:

View file

@ -80,7 +80,7 @@ def test_config_invalid_uri() -> None:
"uri": "asdfsadfasd", "uri": "asdfsadfasd",
} }
with pytest.raises(pydantic.ValidationError): with pytest.raises(pydantic.ValidationError):
KanidmClientConfig.parse_obj(test_input) KanidmClientConfig.model_validate(test_input)
def test_config_none_uri() -> None: def test_config_none_uri() -> None:

View file

@ -7,7 +7,7 @@ import pytest
from kanidm.tokens import JWS, TokenStore from kanidm.tokens import JWS, TokenStore
TEST_TOKEN = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJrdHkiOiJFQyIsImNydiI6IlAtMjU2IiwieCI6Im1KQTgtTURfeFRxQXBmSU9nbFptNXJ6RWhoQ3hDdjRxZFNpeGxjV1Q3ZmsiLCJ5IjoiNy0yVkNuY0h3NEF1WVJpYVpYT2FoVXRGMUE2SDd3eUxrUW1FekduS0pKcyIsImFsZyI6IkVTMjU2IiwidXNlIjoic2lnIn0sInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX2lkIjoiZjExOTg2NzMtNGI5MC00NjE4LWJkZTctMTBiY2M2YzhjOGE0IiwiYXV0aF90eXBlIjoiZ2VuZXJhdGVkcGFzc3dvcmQiLCJleHBpcnkiOlsyMDIyLDI2NSwyODM2Niw4MDI1MjUwMDBdLCJ1dWlkIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDE4IiwibmFtZSI6ImlkbV9hZG1pbiIsImRpc3BsYXluYW1lIjoiSURNIEFkbWluaXN0cmF0b3IiLCJzcG4iOiJpZG1fYWRtaW5AbG9jYWxob3N0IiwibWFpbF9wcmltYXJ5IjpudWxsLCJsaW1fdWlkeCI6ZmFsc2UsImxpbV9ybWF4IjoxMjgsImxpbV9wbWF4IjoyNTYsImxpbV9mbWF4IjozMn0.cln3gRV3NdgbGqYeD26mBSHFGOaFXak2UA5umvj_Xw30dMS8ECTnJU7lvLyepRTW_VzqUJHbRatPkQ1TEuK99Q" # noqa: E501 pylint: disable=line-too-long TEST_TOKEN = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJrdHkiOiJFQyIsImNydiI6IlAtMjU2IiwieCI6Im1KQTgtTURfeFRxQXBmSU9nbFptNXJ6RWhoQ3hDdjRxZFNpeGxjV1Q3ZmsiLCJ5IjoiNy0yVkNuY0h3NEF1WVJpYVpYT2FoVXRGMUE2SDd3eUxrUW1FekduS0pKcyIsImFsZyI6IkVTMjU2IiwidXNlIjoic2lnIn0sInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX2lkIjoiZjExOTg2NzMtNGI5MC00NjE4LWJkZTctMTBiY2M2YzhjOGE0IiwiYXV0aF90eXBlIjoiZ2VuZXJhdGVkcGFzc3dvcmQiLCJleHBpcnkiOlsyMDIyLDI2NSwyODM2Niw4MDI1MjUwMDBdLCJ1dWlkIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDE4IiwibmFtZSI6ImlkbV9hZG1pbiIsImRpc3BsYXluYW1lIjoiSURNIEFkbWluaXN0cmF0b3IiLCJzcG4iOiJpZG1fYWRtaW5AbG9jYWxob3N0IiwibWFpbF9wcmltYXJ5IjpudWxsLCJsaW1fdWlkeCI6ZmFsc2UsImxpbV9ybWF4IjoxMjgsImxpbV9wbWF4IjoyNTYsImxpbV9mbWF4IjozMn0.cln3gRV3NdgbGqYeD26mBSHFGOaFXak2UA5umvj_Xw30dMS8ECTnJU7lvLyepRTW_VzqUJHbRatPkQ1TEuK99Q" # noqa: E501 pylint: disable=line-too-long
def test_jws_parser() -> None: def test_jws_parser() -> None:
@ -43,16 +43,16 @@ def test_jws_parser() -> None:
test_jws = JWS(TEST_TOKEN) test_jws = JWS(TEST_TOKEN)
assert test_jws.header.dict() == expected_header assert test_jws.header.model_dump() == expected_header
assert test_jws.payload.dict() == expected_payload assert test_jws.payload.model_dump() == expected_payload
def test_tokenstuff() -> None: def test_tokenstuff() -> None:
"""tests stuff""" """tests stuff"""
token_store = TokenStore() token_store = TokenStore.model_validate({})
token_store[ token_store[
"idm_admin" "idm_admin"
] = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJrdHkiOiJFQyIsImNydiI6IlAtMjU2IiwieCI6Im1KQTgtTURfeFRxQXBmSU9nbFptNXJ6RWhoQ3hDdjRxZFNpeGxjV1Q3ZmsiLCJ5IjoiNy0yVkNuY0h3NEF1WVJpYVpYT2FoVXRGMUE2SDd3eUxrUW1FekduS0pKcyIsImFsZyI6IkVTMjU2IiwidXNlIjoic2lnIn0sInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX2lkIjoiMTBmZDJjYzMtM2UxZS00MjM1LTk4NjEtNWQyNjQ3NTAyMmVkIiwiYXV0aF90eXBlIjoiZ2VuZXJhdGVkcGFzc3dvcmQiLCJleHBpcnkiOlsyMDIyLDI2NSwzMzkyMywyOTQyNTQwMDBdLCJ1dWlkIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDE4IiwibmFtZSI6ImlkbV9hZG1pbiIsImRpc3BsYXluYW1lIjoiSURNIEFkbWluaXN0cmF0b3IiLCJzcG4iOiJpZG1fYWRtaW5AbG9jYWxob3N0IiwibWFpbF9wcmltYXJ5IjpudWxsLCJsaW1fdWlkeCI6ZmFsc2UsImxpbV9ybWF4IjoxMjgsImxpbV9wbWF4IjoyNTYsImxpbV9mbWF4IjozMn0.rq1y7YNS9iCBWMmAu-FSa4-o4jrSSnMO_18zafgvLRtZFlB7j-Q68CzxceNN9C_1EWnc9uf4fOyeaSNUwGyaIQ" # noqa: E501 pylint: disable=line-too-long ] = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJrdHkiOiJFQyIsImNydiI6IlAtMjU2IiwieCI6Im1KQTgtTURfeFRxQXBmSU9nbFptNXJ6RWhoQ3hDdjRxZFNpeGxjV1Q3ZmsiLCJ5IjoiNy0yVkNuY0h3NEF1WVJpYVpYT2FoVXRGMUE2SDd3eUxrUW1FekduS0pKcyIsImFsZyI6IkVTMjU2IiwidXNlIjoic2lnIn0sInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX2lkIjoiMTBmZDJjYzMtM2UxZS00MjM1LTk4NjEtNWQyNjQ3NTAyMmVkIiwiYXV0aF90eXBlIjoiZ2VuZXJhdGVkcGFzc3dvcmQiLCJleHBpcnkiOlsyMDIyLDI2NSwzMzkyMywyOTQyNTQwMDBdLCJ1dWlkIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDE4IiwibmFtZSI6ImlkbV9hZG1pbiIsImRpc3BsYXluYW1lIjoiSURNIEFkbWluaXN0cmF0b3IiLCJzcG4iOiJpZG1fYWRtaW5AbG9jYWxob3N0IiwibWFpbF9wcmltYXJ5IjpudWxsLCJsaW1fdWlkeCI6ZmFsc2UsImxpbV9ybWF4IjoxMjgsImxpbV9wbWF4IjoyNTYsImxpbV9mbWF4IjozMn0.rq1y7YNS9iCBWMmAu-FSa4-o4jrSSnMO_18zafgvLRtZFlB7j-Q68CzxceNN9C_1EWnc9uf4fOyeaSNUwGyaIQ" # noqa: E501 pylint: disable=line-too-long
info = token_store.token_info("idm_admin") info = token_store.token_info("idm_admin")
print(f"Parsed token: {info}") print(f"Parsed token: {info}")

View file

@ -19,7 +19,7 @@ def test_load_config_file() -> None:
print("Can't find client config file", file=sys.stderr) print("Can't find client config file", file=sys.stderr)
pytest.skip() pytest.skip()
config = load_config(EXAMPLE_CONFIG_FILE) config = load_config(EXAMPLE_CONFIG_FILE)
kanidm_config = KanidmClientConfig.parse_obj(config) kanidm_config = KanidmClientConfig.model_validate(config)
assert kanidm_config.uri == "https://idm.example.com/" assert kanidm_config.uri == "https://idm.example.com/"
print(f"{kanidm_config.uri=}") print(f"{kanidm_config.uri=}")
print(kanidm_config) print(kanidm_config)
@ -36,7 +36,7 @@ radius_groups = [
""" """
config_parsed = toml.loads(config_toml) config_parsed = toml.loads(config_toml)
print(config_parsed) print(config_parsed)
kanidm_config = KanidmClientConfig.parse_obj(config_parsed) kanidm_config = KanidmClientConfig.model_validate(config_parsed)
for group in kanidm_config.radius_groups: for group in kanidm_config.radius_groups:
print(group.spn) print(group.spn)
assert group.spn == "hello world" assert group.spn == "hello world"
@ -52,7 +52,7 @@ radius_clients = [ { name = "hello world", ipaddr = "10.0.0.5", secret = "cr4bj0
""" """
config_parsed = toml.loads(config_toml) config_parsed = toml.loads(config_toml)
print(config_parsed) print(config_parsed)
kanidm_config = KanidmClientConfig.parse_obj(config_parsed) kanidm_config = KanidmClientConfig.model_validate(config_parsed)
client = kanidm_config.radius_clients[0] client = kanidm_config.radius_clients[0]
print(client.name) print(client.name)
assert client.name == "hello world" assert client.name == "hello world"

View file

@ -1,7 +1,7 @@
""" tests types """ """ tests types """
import pytest import pytest
import pydantic.error_wrappers from pydantic import ValidationError
from kanidm.types import AuthInitResponse, KanidmClientConfig, RadiusGroup, RadiusClient from kanidm.types import AuthInitResponse, KanidmClientConfig, RadiusGroup, RadiusClient
@ -15,19 +15,19 @@ def test_auth_init_response() -> None:
}, },
} }
testval = AuthInitResponse.parse_obj(testobj) testval = AuthInitResponse.model_validate(testobj)
assert testval.sessionid == "crabzrool" assert testval.sessionid == "crabzrool"
def test_radiusgroup_vlan_negative() -> None: def test_radiusgroup_vlan_negative() -> None:
"""tests RadiusGroup's vlan validator""" """tests RadiusGroup's vlan validator"""
with pytest.raises(pydantic.error_wrappers.ValidationError): with pytest.raises(ValidationError):
RadiusGroup(vlan=-1, spn="crabzrool@foo") RadiusGroup(vlan=-1, spn="crabzrool@foo")
def test_radiusgroup_vlan_zero() -> None: def test_radiusgroup_vlan_zero() -> None:
"""tests RadiusGroup's vlan validator""" """tests RadiusGroup's vlan validator"""
with pytest.raises(pydantic.error_wrappers.ValidationError): with pytest.raises(ValidationError):
RadiusGroup(vlan=0, spn="crabzrool@foo") RadiusGroup(vlan=0, spn="crabzrool@foo")
@ -38,10 +38,9 @@ def test_radiusgroup_vlan_4096() -> None:
def test_radiusgroup_vlan_no_name() -> None: def test_radiusgroup_vlan_no_name() -> None:
"""tests RadiusGroup's vlan validator""" """tests RadiusGroup's vlan validator"""
with pytest.raises( with pytest.raises(ValidationError, match="(?i)spn\n.*Field required"):
pydantic.error_wrappers.ValidationError, match="spn\n.*field required" RadiusGroup(vlan=4096) # type: ignore[call-arg]
):
RadiusGroup(vlan=4096) # type: ignore[call-arg]
def test_kanidmconfig_parse_toml() -> None: def test_kanidmconfig_parse_toml() -> None:
"""tests KanidmClientConfig.parse_toml()""" """tests KanidmClientConfig.parse_toml()"""
@ -53,7 +52,7 @@ def test_kanidmconfig_parse_toml() -> None:
@pytest.mark.network @pytest.mark.network
def test_radius_client_bad_hostname() -> None: def test_radius_client_bad_hostname() -> None:
"""tests with a bad hostname""" """tests with a bad hostname"""
with pytest.raises(pydantic.error_wrappers.ValidationError): with pytest.raises(ValidationError):
RadiusClient( RadiusClient(
name="test", name="test",
ipaddr="thiscannotpossiblywork.kanidm.example.com", ipaddr="thiscannotpossiblywork.kanidm.example.com",