fixing up pydantic things (#1885)

This commit is contained in:
James Hodgkinson 2023-07-24 00:09:43 -07:00 committed by GitHub
parent e17dcc0ddb
commit 8981478d76
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 62 additions and 86 deletions

View file

@ -77,7 +77,7 @@ class KanidmClient:
if not isinstance(config_file, Path):
config_file = Path(config_file)
config_data = load_config(config_file.expanduser().resolve())
self.config = self.config.parse_obj(config_data)
self.config = self.config.model_validate(config_data)
if self.config.uri is None:
raise ValueError("Please initialize this with a server URI")
@ -108,7 +108,7 @@ class KanidmClient:
) -> None:
"""hand it a config dict and it'll configure the client"""
try:
self.config.parse_obj(config_data)
self.config.model_validate(config_data)
except ValidationError as validation_error:
# pylint: disable=raise-missing-from
raise ValueError(f"Failed to validate configuration: {validation_error}")
@ -191,7 +191,7 @@ class KanidmClient:
"status_code": request.status,
}
logging.debug(json_lib.dumps(response_input, default=str, indent=4))
response = ClientResponse.parse_obj(response_input)
response = ClientResponse.model_validate(response_input)
return response
async def call_get(
@ -239,7 +239,7 @@ class KanidmClient:
raise ValueError(
f"Missing x-kanidm-auth-session-id header in init auth response: {response.headers}"
)
retval = AuthInitResponse.parse_obj(response.data)
retval = AuthInitResponse.model_validate(response.data)
retval.response = response
return retval
@ -262,7 +262,7 @@ class KanidmClient:
# TODO: mock test for auth_begin raises AuthBeginFailed
raise AuthBeginFailed(response.content)
retobject = AuthBeginResponse.parse_obj(response.data)
retobject = AuthBeginResponse.model_validate(response.data)
retobject.response = response
return response
@ -296,7 +296,7 @@ class KanidmClient:
raise AuthMechUnknown(f"No auth mechanisms for {username}")
auth_begin = await self.auth_begin(method="password", sessionid=sessionid)
# does a little bit of validation
auth_begin_object = AuthBeginResponse.parse_obj(auth_begin.data)
auth_begin_object = AuthBeginResponse.model_validate(auth_begin.data)
auth_begin_object.response = auth_begin
return await self.auth_step_password(password=password, sessionid=sessionid)
@ -324,7 +324,7 @@ class KanidmClient:
logging.debug("Failed to authenticate, response: %s", response.content)
raise AuthCredFailed("Failed password authentication!")
result = AuthStepPasswordResponse.parse_obj(response.data)
result = AuthStepPasswordResponse.model_validate(response.data)
result.response = response
# pull the token out and set it

View file

@ -107,7 +107,7 @@ def authorize(
tok = None
try:
loop = asyncio.get_event_loop()
tok = RadiusTokenResponse.parse_obj(
tok = RadiusTokenResponse.model_validate(
loop.run_until_complete(_get_radius_token(username=user_id))
)
logging.debug("radius information token: %s", tok)

View file

@ -10,7 +10,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from authlib.jose import JsonWebSignature # type: ignore
from pydantic import BaseModel
from pydantic import ConfigDict, BaseModel, RootModel
from . import TOKEN_PATH
@ -31,11 +31,7 @@ class JWSHeader(BaseModel):
alg: str
typ: str
jwk: JWSHeaderJWK
class Config:
"""Configure the pydantic class"""
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
class JWSPayload(BaseModel):
@ -50,7 +46,7 @@ class JWSPayload(BaseModel):
name: str
displayname: str
spn: str
mail_primary: Optional[str]
mail_primary: Optional[str] = None
lim_uidx: bool
lim_rmax: int
lim_pmax: int
@ -93,13 +89,13 @@ class JWS:
padded_header = raw_header + "=" * divmod(len(raw_header), 4)[0]
decoded_header = base64.urlsafe_b64decode(padded_header)
logging.debug("decoded_header=%s", decoded_header)
header = JWSHeader.parse_obj(json.loads(decoded_header.decode("utf-8")))
header = JWSHeader.model_validate(json.loads(decoded_header.decode("utf-8")))
logging.debug("header: %s", header)
raw_payload = split_raw[1]
logging.debug("Parsing payload: %s", raw_payload)
padded_payload = raw_payload + "=" * divmod(len(raw_payload), 4)[1]
payload = JWSPayload.parse_raw(base64.urlsafe_b64decode(padded_payload))
payload = JWSPayload.model_validate_json(base64.urlsafe_b64decode(padded_payload))
raw_signature = split_raw[2]
logging.debug("Parsing signature: %s", raw_signature)
@ -109,32 +105,31 @@ class JWS:
return header, payload, signature
class TokenStore(BaseModel):
"""Represents the user auth tokens, can load them from the user store"""
__root__: Dict[str, str] = {}
class TokenStore(RootModel[Dict[str, str]]):
"""Represents the user auth tokens, so we can load them from the user store"""
root: Dict[str, str]
# TODO: one day work out how to type the __iter__ on TokenStore properly. It's some kind of iter() that makes mypy unhappy.
def __iter__(self) -> Any:
"""overloading the default function"""
for key in self.__root__.keys():
for key in self.root.keys():
yield key
def __getitem__(self, item: str) -> str:
"""overloading the default function"""
return self.__root__[item]
return self.root[item]
def __delitem__(self, item: str) -> None:
"""overloading the default function"""
del self.__root__[item]
del self.root[item]
def __setitem__(self, key: str, value: str) -> None:
"""overloading the default function"""
self.__root__[key] = value
self.root[key] = value
def save(self, filepath: Path = TOKEN_PATH) -> None:
"""saves the cached tokens to disk"""
data = json.dumps(self.__root__, indent=2)
data = json.dumps(self.root, indent=2)
with filepath.expanduser().resolve().open(
mode="w", encoding="utf-8"
) as file_handle:
@ -157,19 +152,19 @@ class TokenStore(BaseModel):
tokens = json.load(file_handle)
if overwrite:
self.__root__ = tokens
self.root = tokens
else:
for user in tokens:
self.__root__[user] = tokens[user]
self.root[user] = tokens[user]
self.validate_tokens()
logging.debug(json.dumps(tokens, indent=4))
return self.__root__
return self.root
def validate_tokens(self) -> None:
"""validates the JWS tokens for format, not their signature - PRs welcome"""
for username in self.__root__:
for username in self.root:
logging.debug("Parsing %s", username)
# TODO: Work out how to get the validation working. We probably shouldn't be worried about this since we're using it for auth...
logging.debug(
@ -184,4 +179,4 @@ class TokenStore(BaseModel):
s=self[username], key=None
)
logging.debug(parsed_object)
return JWSPayload.parse_raw(parsed_object.payload)
return JWSPayload.model_validate_json(parsed_object.payload)

View file

@ -7,7 +7,7 @@ import socket
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from pydantic import BaseModel, Field, validator
from pydantic import field_validator, ConfigDict, BaseModel, Field
import toml
@ -19,15 +19,11 @@ class ClientResponse(BaseModel):
status_code: int
"""
content: Optional[str]
data: Optional[Dict[str, Any]]
content: Optional[str] = None
data: Optional[Dict[str, Any]] = None
headers: Dict[str, Any]
status_code: int
class Config:
"""Configuration"""
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
class AuthInitResponse(BaseModel):
@ -35,18 +31,13 @@ class AuthInitResponse(BaseModel):
class _AuthInitState(BaseModel):
"""sub-class for the AuthInitResponse model"""
# TODO: can we add validation for AuthInitResponse.state.choose?
choose: List[str]
sessionid: str
state: _AuthInitState
response: Optional[ClientResponse]
class Config:
"""config class"""
arbitrary_types_allowed = True
response: Optional[ClientResponse] = None
# model_config = ConfigDict(arbitrary_types_allowed=True)
class AuthBeginResponse(BaseModel):
@ -65,11 +56,7 @@ class AuthBeginResponse(BaseModel):
sessionid: str
state: _AuthBeginState
response: Optional[ClientResponse]
class Config:
"""config class"""
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
class AuthStepPasswordResponse(BaseModel):
@ -78,16 +65,12 @@ class AuthStepPasswordResponse(BaseModel):
class _AuthStepPasswordState(BaseModel):
"""subclass to help parse the response from the auth 'step password' stage"""
success: Optional[str]
success: Optional[str] = None
sessionid: str
state: _AuthStepPasswordState
response: Optional[ClientResponse]
class Config:
"""config class"""
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
class RadiusGroup(BaseModel):
@ -96,7 +79,8 @@ class RadiusGroup(BaseModel):
spn: str
vlan: int
@validator("vlan")
@field_validator("vlan")
@classmethod
def validate_vlan(cls, value: int) -> int:
"""validate the vlan option is above 0"""
if not value > 0:
@ -120,11 +104,7 @@ class RadiusTokenResponse(BaseModel):
uuid: str
groups: List[RadiusTokenGroup]
class Config:
"""config for RadiusTokenGroupList"""
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
class RadiusClient(BaseModel):
@ -146,7 +126,8 @@ class RadiusClient(BaseModel):
ipaddr: str
secret: str # TODO: this should probably be renamed to token
@validator("ipaddr")
@field_validator("ipaddr")
@classmethod
def validate_ipaddr(cls, value: str) -> str:
"""validates the ipaddr field is an IP address, CIDR or valid hostname"""
for typedef in (IPv6Network, IPv6Address, IPv4Address, IPv4Network):
@ -198,9 +179,10 @@ class KanidmClientConfig(BaseModel):
@classmethod
def parse_toml(cls, input_string: str) -> Any:
"""loads from a string"""
return super().parse_obj(toml.loads(input_string))
return super().model_validate(toml.loads(input_string))
@validator("uri")
@field_validator("uri")
@classmethod
def validate_uri(cls, value: Optional[str]) -> Optional[str]:
"""validator for the uri field"""
if value is not None:

View file

@ -64,7 +64,7 @@ async def test_auth_begin(client_configfile: KanidmClient) -> None:
retval["response"] = begin_result
assert AuthBeginResponse.parse_obj(retval)
assert AuthBeginResponse.model_validate(retval)
@pytest.mark.network
@ -170,7 +170,7 @@ async def test_authenticate_with_token(client_configfile: KanidmClient) -> None:
f"Using username {test_username} by default - set KANIDM_TEST_USERNAME env var if you want to change this."
)
tokens = TokenStore()
tokens = TokenStore.model_validate({})
tokens.load()
if test_username not in tokens:

View file

@ -80,7 +80,7 @@ def test_config_invalid_uri() -> None:
"uri": "asdfsadfasd",
}
with pytest.raises(pydantic.ValidationError):
KanidmClientConfig.parse_obj(test_input)
KanidmClientConfig.model_validate(test_input)
def test_config_none_uri() -> None:

View file

@ -7,7 +7,7 @@ import pytest
from kanidm.tokens import JWS, TokenStore
TEST_TOKEN = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJrdHkiOiJFQyIsImNydiI6IlAtMjU2IiwieCI6Im1KQTgtTURfeFRxQXBmSU9nbFptNXJ6RWhoQ3hDdjRxZFNpeGxjV1Q3ZmsiLCJ5IjoiNy0yVkNuY0h3NEF1WVJpYVpYT2FoVXRGMUE2SDd3eUxrUW1FekduS0pKcyIsImFsZyI6IkVTMjU2IiwidXNlIjoic2lnIn0sInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX2lkIjoiZjExOTg2NzMtNGI5MC00NjE4LWJkZTctMTBiY2M2YzhjOGE0IiwiYXV0aF90eXBlIjoiZ2VuZXJhdGVkcGFzc3dvcmQiLCJleHBpcnkiOlsyMDIyLDI2NSwyODM2Niw4MDI1MjUwMDBdLCJ1dWlkIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDE4IiwibmFtZSI6ImlkbV9hZG1pbiIsImRpc3BsYXluYW1lIjoiSURNIEFkbWluaXN0cmF0b3IiLCJzcG4iOiJpZG1fYWRtaW5AbG9jYWxob3N0IiwibWFpbF9wcmltYXJ5IjpudWxsLCJsaW1fdWlkeCI6ZmFsc2UsImxpbV9ybWF4IjoxMjgsImxpbV9wbWF4IjoyNTYsImxpbV9mbWF4IjozMn0.cln3gRV3NdgbGqYeD26mBSHFGOaFXak2UA5umvj_Xw30dMS8ECTnJU7lvLyepRTW_VzqUJHbRatPkQ1TEuK99Q" # noqa: E501 pylint: disable=line-too-long
TEST_TOKEN = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJrdHkiOiJFQyIsImNydiI6IlAtMjU2IiwieCI6Im1KQTgtTURfeFRxQXBmSU9nbFptNXJ6RWhoQ3hDdjRxZFNpeGxjV1Q3ZmsiLCJ5IjoiNy0yVkNuY0h3NEF1WVJpYVpYT2FoVXRGMUE2SDd3eUxrUW1FekduS0pKcyIsImFsZyI6IkVTMjU2IiwidXNlIjoic2lnIn0sInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX2lkIjoiZjExOTg2NzMtNGI5MC00NjE4LWJkZTctMTBiY2M2YzhjOGE0IiwiYXV0aF90eXBlIjoiZ2VuZXJhdGVkcGFzc3dvcmQiLCJleHBpcnkiOlsyMDIyLDI2NSwyODM2Niw4MDI1MjUwMDBdLCJ1dWlkIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDE4IiwibmFtZSI6ImlkbV9hZG1pbiIsImRpc3BsYXluYW1lIjoiSURNIEFkbWluaXN0cmF0b3IiLCJzcG4iOiJpZG1fYWRtaW5AbG9jYWxob3N0IiwibWFpbF9wcmltYXJ5IjpudWxsLCJsaW1fdWlkeCI6ZmFsc2UsImxpbV9ybWF4IjoxMjgsImxpbV9wbWF4IjoyNTYsImxpbV9mbWF4IjozMn0.cln3gRV3NdgbGqYeD26mBSHFGOaFXak2UA5umvj_Xw30dMS8ECTnJU7lvLyepRTW_VzqUJHbRatPkQ1TEuK99Q" # noqa: E501 pylint: disable=line-too-long
def test_jws_parser() -> None:
@ -43,16 +43,16 @@ def test_jws_parser() -> None:
test_jws = JWS(TEST_TOKEN)
assert test_jws.header.dict() == expected_header
assert test_jws.payload.dict() == expected_payload
assert test_jws.header.model_dump() == expected_header
assert test_jws.payload.model_dump() == expected_payload
def test_tokenstuff() -> None:
"""tests stuff"""
token_store = TokenStore()
token_store = TokenStore.model_validate({})
token_store[
"idm_admin"
] = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJrdHkiOiJFQyIsImNydiI6IlAtMjU2IiwieCI6Im1KQTgtTURfeFRxQXBmSU9nbFptNXJ6RWhoQ3hDdjRxZFNpeGxjV1Q3ZmsiLCJ5IjoiNy0yVkNuY0h3NEF1WVJpYVpYT2FoVXRGMUE2SDd3eUxrUW1FekduS0pKcyIsImFsZyI6IkVTMjU2IiwidXNlIjoic2lnIn0sInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX2lkIjoiMTBmZDJjYzMtM2UxZS00MjM1LTk4NjEtNWQyNjQ3NTAyMmVkIiwiYXV0aF90eXBlIjoiZ2VuZXJhdGVkcGFzc3dvcmQiLCJleHBpcnkiOlsyMDIyLDI2NSwzMzkyMywyOTQyNTQwMDBdLCJ1dWlkIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDE4IiwibmFtZSI6ImlkbV9hZG1pbiIsImRpc3BsYXluYW1lIjoiSURNIEFkbWluaXN0cmF0b3IiLCJzcG4iOiJpZG1fYWRtaW5AbG9jYWxob3N0IiwibWFpbF9wcmltYXJ5IjpudWxsLCJsaW1fdWlkeCI6ZmFsc2UsImxpbV9ybWF4IjoxMjgsImxpbV9wbWF4IjoyNTYsImxpbV9mbWF4IjozMn0.rq1y7YNS9iCBWMmAu-FSa4-o4jrSSnMO_18zafgvLRtZFlB7j-Q68CzxceNN9C_1EWnc9uf4fOyeaSNUwGyaIQ" # noqa: E501 pylint: disable=line-too-long
] = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJrdHkiOiJFQyIsImNydiI6IlAtMjU2IiwieCI6Im1KQTgtTURfeFRxQXBmSU9nbFptNXJ6RWhoQ3hDdjRxZFNpeGxjV1Q3ZmsiLCJ5IjoiNy0yVkNuY0h3NEF1WVJpYVpYT2FoVXRGMUE2SDd3eUxrUW1FekduS0pKcyIsImFsZyI6IkVTMjU2IiwidXNlIjoic2lnIn0sInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX2lkIjoiMTBmZDJjYzMtM2UxZS00MjM1LTk4NjEtNWQyNjQ3NTAyMmVkIiwiYXV0aF90eXBlIjoiZ2VuZXJhdGVkcGFzc3dvcmQiLCJleHBpcnkiOlsyMDIyLDI2NSwzMzkyMywyOTQyNTQwMDBdLCJ1dWlkIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDE4IiwibmFtZSI6ImlkbV9hZG1pbiIsImRpc3BsYXluYW1lIjoiSURNIEFkbWluaXN0cmF0b3IiLCJzcG4iOiJpZG1fYWRtaW5AbG9jYWxob3N0IiwibWFpbF9wcmltYXJ5IjpudWxsLCJsaW1fdWlkeCI6ZmFsc2UsImxpbV9ybWF4IjoxMjgsImxpbV9wbWF4IjoyNTYsImxpbV9mbWF4IjozMn0.rq1y7YNS9iCBWMmAu-FSa4-o4jrSSnMO_18zafgvLRtZFlB7j-Q68CzxceNN9C_1EWnc9uf4fOyeaSNUwGyaIQ" # noqa: E501 pylint: disable=line-too-long
info = token_store.token_info("idm_admin")
print(f"Parsed token: {info}")

View file

@ -19,7 +19,7 @@ def test_load_config_file() -> None:
print("Can't find client config file", file=sys.stderr)
pytest.skip()
config = load_config(EXAMPLE_CONFIG_FILE)
kanidm_config = KanidmClientConfig.parse_obj(config)
kanidm_config = KanidmClientConfig.model_validate(config)
assert kanidm_config.uri == "https://idm.example.com/"
print(f"{kanidm_config.uri=}")
print(kanidm_config)
@ -36,7 +36,7 @@ radius_groups = [
"""
config_parsed = toml.loads(config_toml)
print(config_parsed)
kanidm_config = KanidmClientConfig.parse_obj(config_parsed)
kanidm_config = KanidmClientConfig.model_validate(config_parsed)
for group in kanidm_config.radius_groups:
print(group.spn)
assert group.spn == "hello world"
@ -52,7 +52,7 @@ radius_clients = [ { name = "hello world", ipaddr = "10.0.0.5", secret = "cr4bj0
"""
config_parsed = toml.loads(config_toml)
print(config_parsed)
kanidm_config = KanidmClientConfig.parse_obj(config_parsed)
kanidm_config = KanidmClientConfig.model_validate(config_parsed)
client = kanidm_config.radius_clients[0]
print(client.name)
assert client.name == "hello world"

View file

@ -1,7 +1,7 @@
""" tests types """
import pytest
import pydantic.error_wrappers
from pydantic import ValidationError
from kanidm.types import AuthInitResponse, KanidmClientConfig, RadiusGroup, RadiusClient
@ -15,19 +15,19 @@ def test_auth_init_response() -> None:
},
}
testval = AuthInitResponse.parse_obj(testobj)
testval = AuthInitResponse.model_validate(testobj)
assert testval.sessionid == "crabzrool"
def test_radiusgroup_vlan_negative() -> None:
"""tests RadiusGroup's vlan validator"""
with pytest.raises(pydantic.error_wrappers.ValidationError):
with pytest.raises(ValidationError):
RadiusGroup(vlan=-1, spn="crabzrool@foo")
def test_radiusgroup_vlan_zero() -> None:
"""tests RadiusGroup's vlan validator"""
with pytest.raises(pydantic.error_wrappers.ValidationError):
with pytest.raises(ValidationError):
RadiusGroup(vlan=0, spn="crabzrool@foo")
@ -38,10 +38,9 @@ def test_radiusgroup_vlan_4096() -> None:
def test_radiusgroup_vlan_no_name() -> None:
"""tests RadiusGroup's vlan validator"""
with pytest.raises(
pydantic.error_wrappers.ValidationError, match="spn\n.*field required"
):
RadiusGroup(vlan=4096) # type: ignore[call-arg]
with pytest.raises(ValidationError, match="(?i)spn\n.*Field required"):
RadiusGroup(vlan=4096) # type: ignore[call-arg]
def test_kanidmconfig_parse_toml() -> None:
"""tests KanidmClientConfig.parse_toml()"""
@ -53,7 +52,7 @@ def test_kanidmconfig_parse_toml() -> None:
@pytest.mark.network
def test_radius_client_bad_hostname() -> None:
"""tests with a bad hostname"""
with pytest.raises(pydantic.error_wrappers.ValidationError):
with pytest.raises(ValidationError):
RadiusClient(
name="test",
ipaddr="thiscannotpossiblywork.kanidm.example.com",