From 9141639149b2f8346b1b4d89ebe1a147f743dad6 Mon Sep 17 00:00:00 2001 From: Antonis Angelakis Date: Wed, 21 Oct 2020 17:47:06 +0300 Subject: [PATCH 001/150] Add more tests for PKCE The tests added reflect current bugs that were not covered until now and were missed. --- tests/test_33_pkce.py | 51 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/test_33_pkce.py b/tests/test_33_pkce.py index b407520..d11dc65 100644 --- a/tests/test_33_pkce.py +++ b/tests/test_33_pkce.py @@ -371,3 +371,54 @@ def test_no_code_verifier(self): assert isinstance(resp, TokenErrorResponse) assert resp["error"] == "invalid_grant" assert resp["error_description"] == "Missing code_verifier" + + def test_no_authorization_endpoint(self, conf, caplog): + """ + Test that PKCE configuration does not crash when there is no authorization + endpoint and a warning is logged. + """ + del conf["endpoint"]["authorization"] + create_endpoint(conf) + assert "WARNING" in caplog.text + assert ( + "No authorization endpoint found, skipping PKCE configuration" + in caplog.text + ) + + def test_no_token_endpoint(self, conf, caplog): + """ + Test that PKCE configuration does not crash when there is no token endpoint + and a warning is logged. + """ + del conf["endpoint"]["token"] + create_endpoint(conf) + assert "WARNING" in caplog.text + assert "No token endpoint found, skipping PKCE configuration" in caplog.text + + def test_plain_challenge_method_not_supported_and_PKCE_not_essential(self, conf): + """ + Test that an authentication request without PKCE parameters does not fail when + "plain" code_challenge_method is not supported and PKCE is not essential. + """ + conf["add_on"]["pkce"]["kwargs"]["code_challenge_methods"] = ["S256"] + conf["add_on"]["pkce"]["kwargs"]["essential"] = False + endpoint_context = create_endpoint(conf) + authn_endpoint = endpoint_context.endpoint["authorization"] + token_endpoint = endpoint_context.endpoint["token"] + + authentication_request = AUTH_REQ.copy() + + parsed_request = authn_endpoint.parse_request(authentication_request.to_dict()) + + assert not isinstance(parsed_request, AuthorizationErrorResponse) + assert isinstance(parsed_request, AuthorizationRequest) + + response = authn_endpoint.process_request(parsed_request) + + assert isinstance(response["response_args"], AuthorizationResponse) + + token_request = TOKEN_REQ.copy() + token_request["code"] = response["response_args"]["code"] + parsed_token_request = token_endpoint.parse_request(token_request) + + assert isinstance(parsed_token_request, AccessTokenRequest) From 82c87305f9447bc8885ae7aa9d11063e10924a39 Mon Sep 17 00:00:00 2001 From: roland Date: Thu, 22 Oct 2020 09:10:26 +0200 Subject: [PATCH 002/150] Handle different types of input. There might not be any client information at this time. This is an effect of automatic client registration as defined in the OIDC federation specification. --- src/oidcendpoint/common/authorization.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/oidcendpoint/common/authorization.py b/src/oidcendpoint/common/authorization.py index 06f0158..2529b72 100755 --- a/src/oidcendpoint/common/authorization.py +++ b/src/oidcendpoint/common/authorization.py @@ -5,6 +5,7 @@ from oidcmsg.exception import ParameterError from oidcmsg.exception import URIError +from oidcmsg.message import Message from oidcmsg.oauth2 import AuthorizationErrorResponse from oidcmsg.oidc import AuthorizationResponse from oidcmsg.oidc import verified_claim_name @@ -172,20 +173,33 @@ def get_uri(endpoint_context, request, uri_type): def authn_args_gather(request, authn_class_ref, cinfo, **kwargs): """ Gather information to be used by the authentication method + + :param request: The request either as a dictionary or as a Message instance + :param authn_class_ref: Authentication class reference + :param cinfo: Client information + :param kwargs: Extra keyword arguments + :return: Authentication arguments """ authn_args = { "authn_class_ref": authn_class_ref, - "query": request.to_urlencoded(), "return_uri": request["redirect_uri"], } + if isinstance(request, Message): + authn_args["query"] = request.to_urlencoded() + elif isinstance(request, dict): + authn_args["query"] = urlencode(request) + else: + ValueError("Wrong request format") + if "req_user" in kwargs: authn_args["as_user"] = (kwargs["req_user"],) # Below are OIDC specific. Just ignore if OAuth2 - for attr in ["policy_uri", "logo_uri", "tos_uri"]: - if cinfo.get(attr): - authn_args[attr] = cinfo[attr] + if cinfo: + for attr in ["policy_uri", "logo_uri", "tos_uri"]: + if cinfo.get(attr): + authn_args[attr] = cinfo[attr] for attr in ["ui_locales", "acr_values", "login_hint"]: if request.get(attr): From 5d52f0bf0187a2255017fc3f5b660135310e1b02 Mon Sep 17 00:00:00 2001 From: roland Date: Thu, 22 Oct 2020 09:12:22 +0200 Subject: [PATCH 003/150] Allow for adding extra response arguments. Not used here but used in classes that are built on this. --- src/oidcendpoint/oidc/authorization.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/oidcendpoint/oidc/authorization.py b/src/oidcendpoint/oidc/authorization.py index 8408084..0955534 100755 --- a/src/oidcendpoint/oidc/authorization.py +++ b/src/oidcendpoint/oidc/authorization.py @@ -9,6 +9,7 @@ from cryptojwt.utils import as_unicode from cryptojwt.utils import b64d from cryptojwt.utils import b64e + from oidcendpoint.token_handler import UnknownToken from oidcmsg import oidc from oidcmsg.exception import ParameterError @@ -393,6 +394,9 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): return {"authn_event": authn_event, "identity": identity, "user": user} + def extra_response_args(self, aresp): + return aresp + def create_authn_response(self, request, sid): """ @@ -474,6 +478,8 @@ def create_authn_response(self, request, sid): ) return {"response_args": resp, "fragment_enc": fragment_enc} + aresp = self.extra_response_args(aresp) + return {"response_args": aresp, "fragment_enc": fragment_enc} def aresp_check(self, aresp, request): @@ -679,7 +685,7 @@ def authz_part2(self, user, authn_event, request, **kwargs): def process_request(self, request_info=None, **kwargs): """ The AuthorizationRequest endpoint - :param request_info: The authorization request as a dictionary + :param request_info: The authorization request as a Message instance :return: dictionary """ @@ -691,7 +697,8 @@ def process_request(self, request_info=None, **kwargs): logger.debug("client {}: {}".format(_cid, cinfo)) # this apply the default optionally deny_unknown_scopes policy - check_unknown_scopes_policy(request_info, cinfo, self.endpoint_context) + if cinfo: + check_unknown_scopes_policy(request_info, cinfo, self.endpoint_context) cookie = kwargs.get("cookie", "") if cookie: From dc6f366def220e15579654b5420c912353bd39b4 Mon Sep 17 00:00:00 2001 From: roland Date: Thu, 22 Oct 2020 09:12:47 +0200 Subject: [PATCH 004/150] Bumped version. --- src/oidcendpoint/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/oidcendpoint/__init__.py b/src/oidcendpoint/__init__.py index 653aef2..ef93b0a 100755 --- a/src/oidcendpoint/__init__.py +++ b/src/oidcendpoint/__init__.py @@ -1,7 +1,7 @@ import string from secrets import choice -__version__ = "1.1.1" +__version__ = "1.1.2" DEF_SIGN_ALG = { "id_token": "RS256", From 6ddedb20770a74ad5dfad9cf925f8dbfaed308f0 Mon Sep 17 00:00:00 2001 From: Antonis Angelakis Date: Wed, 21 Oct 2020 17:01:56 +0300 Subject: [PATCH 005/150] Check for token and authorization when adding PKCE A warning is logged, as there can't be PKCE functionality without authorization and token endpoints. --- src/oidcendpoint/oidc/add_on/pkce.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/oidcendpoint/oidc/add_on/pkce.py b/src/oidcendpoint/oidc/add_on/pkce.py index 07bc9f7..a7d850c 100644 --- a/src/oidcendpoint/oidc/add_on/pkce.py +++ b/src/oidcendpoint/oidc/add_on/pkce.py @@ -119,7 +119,21 @@ def post_token_parse(request, client_id, endpoint_context, **kwargs): def add_pkce_support(endpoint, **kwargs): - endpoint["authorization"].post_parse_request.append(post_authn_parse) + authn_endpoint = endpoint.get("authorization") + if authn_endpoint is None: + LOGGER.warning( + "No authorization endpoint found, skipping PKCE configuration" + ) + return + + token_endpoint = endpoint.get("token") + if token_endpoint is None: + LOGGER.warning( + "No token endpoint found, skipping PKCE configuration" + ) + return + + authn_endpoint.post_parse_request.append(post_authn_parse) if "essential" not in kwargs: kwargs["essential"] = False @@ -134,6 +148,6 @@ def add_pkce_support(endpoint, **kwargs): raise ValueError("Unsupported method: {}".format(method)) kwargs["code_challenge_methods"][method] = CC_METHOD[method] - endpoint["authorization"].endpoint_context.args["pkce"] = kwargs + authn_endpoint.endpoint_context.args["pkce"] = kwargs - endpoint["token"].post_parse_request.append(post_token_parse) + token_endpoint.post_parse_request.append(post_token_parse) From 58f5c5556832927da6a9f2810dcbade753906713 Mon Sep 17 00:00:00 2001 From: Antonis Angelakis Date: Wed, 21 Oct 2020 17:06:41 +0300 Subject: [PATCH 006/150] Check for code_challenge before checking method There was a bug when plain code challenge method is not supported based on the configuration and PKCE is not essential, where all flows authentication requests without PKCE would fail because plain is not supported and is the default method. It's fixed now, but maybe there is an underlying issue here, concerning this use case; the PKCE RFC states that plain is the default method and we follow it, however we provide the option to not support plain. As a result each authentication request which has a code_challenge and omits code_challenge_method will fail. Is that behaviour expected or we should default in a code challenge method we support. Or should we always support plain? --- src/oidcendpoint/oidc/add_on/pkce.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/oidcendpoint/oidc/add_on/pkce.py b/src/oidcendpoint/oidc/add_on/pkce.py index a7d850c..f51ffb4 100644 --- a/src/oidcendpoint/oidc/add_on/pkce.py +++ b/src/oidcendpoint/oidc/add_on/pkce.py @@ -47,8 +47,11 @@ def post_authn_parse(request, client_id, endpoint_context, **kwargs): request["code_challenge_method"] = "plain" if ( - request["code_challenge_method"] - not in endpoint_context.args["pkce"]["code_challenge_methods"] + "code_challenge" in request + and ( + request["code_challenge_method"] + not in endpoint_context.args["pkce"]["code_challenge_methods"] + ) ): return AuthorizationErrorResponse( error="invalid_request", From d5bc366cdb66bf8e06ca27adcb4f74c59b321f68 Mon Sep 17 00:00:00 2001 From: Nikos Sklikas Date: Thu, 29 Oct 2020 14:02:07 +0200 Subject: [PATCH 007/150] Add add_scope option for JWT access token Add add_scope option JWT access tokens, enabling this will add a list with the allowed scopes that were requested in the returned JWT. --- src/oidcendpoint/jwt_token.py | 10 +++++++++- src/oidcendpoint/scopes.py | 4 ++++ src/oidcendpoint/userinfo.py | 5 ++--- tests/test_27_jwt_token.py | 15 +++++++++++++++ 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/oidcendpoint/jwt_token.py b/src/oidcendpoint/jwt_token.py index d211eea..e3f4ec8 100644 --- a/src/oidcendpoint/jwt_token.py +++ b/src/oidcendpoint/jwt_token.py @@ -16,6 +16,7 @@ class JWTToken(Token): init_args = { "add_claims_by_scope": False, "enable_claims_per_client": False, + "add_scope": False, "add_claims": {}, } @@ -49,6 +50,7 @@ def __init__( self.add_claims = self.init_args["add_claims"] self.add_claims_by_scope = self.init_args["add_claims_by_scope"] + self.add_scope = self.init_args["add_scope"] self.enable_claims_per_client = self.init_args["enable_claims_per_client"] for param, default in self.init_args.items(): @@ -83,6 +85,7 @@ def __call__( :return: """ payload = {"sid": sid, "ttype": self.type, "sub": sinfo["sub"]} + scopes = sinfo["authn_req"]["scope"] if self.add_claims: self.do_add_claims(payload, uinfo, self.add_claims) @@ -94,11 +97,16 @@ def __call__( payload, uinfo, convert_scopes2claims( - sinfo["authn_req"]["scope"], + scopes, _allowed_claims, map=self.scope_claims_map, ).keys(), ) + if self.add_scope: + payload["scope"] = self.cntx.scopes_handler.filter_scopes( + client_id, self.cntx, scopes + ) + # Add claims if is access token if self.type == "T" and self.enable_claims_per_client: client = self.cdb.get(client_id, {}) diff --git a/src/oidcendpoint/scopes.py b/src/oidcendpoint/scopes.py index b3413ae..8a04544 100644 --- a/src/oidcendpoint/scopes.py +++ b/src/oidcendpoint/scopes.py @@ -82,6 +82,10 @@ def allowed_scopes(self, client_id, endpoint_context): return available_scopes(endpoint_context) return [] + def filter_scopes(self, client_id, endpoint_context, scopes): + allowed_scopes = self.allowed_scopes(client_id, endpoint_context) + return [s for s in scopes if s in allowed_scopes] + class Claims: def __init__(self): diff --git a/src/oidcendpoint/userinfo.py b/src/oidcendpoint/userinfo.py index 56684fd..590c2ad 100755 --- a/src/oidcendpoint/userinfo.py +++ b/src/oidcendpoint/userinfo.py @@ -127,10 +127,9 @@ def collect_user_info( if scope_to_claims is None: scope_to_claims = endpoint_context.scope2claims - _allowed = endpoint_context.scopes_handler.allowed_scopes( - authn_req["client_id"], endpoint_context + supported_scopes = endpoint_context.scopes_handler.filter_scopes( + authn_req["client_id"], endpoint_context, authn_req["scope"] ) - supported_scopes = [s for s in authn_req["scope"] if s in _allowed] if userinfo_claims is None: _allowed_claims = endpoint_context.claims_handler.allowed_claims( authn_req["client_id"], endpoint_context diff --git a/tests/test_27_jwt_token.py b/tests/test_27_jwt_token.py index cd37c27..da0336d 100644 --- a/tests/test_27_jwt_token.py +++ b/tests/test_27_jwt_token.py @@ -205,6 +205,21 @@ def test_client_claims(self, enable_claims_per_client): res = _jwt.unpack(token) assert enable_claims_per_client is ("address" in res) + @pytest.mark.parametrize("add_scope", [True, False]) + def test_add_scopes(self, add_scope): + ec = self.endpoint.endpoint_context + handler = ec.sdb.handler.handler["access_token"] + auth_req = dict(AUTH_REQ) + auth_req["scope"] = ["openid", "profile", "aba"] + session_id = setup_session(ec, auth_req, uid="diana") + handler.add_scope = add_scope + _dic = ec.sdb.upgrade_to_token(key=session_id) + + token = _dic["access_token"] + _jwt = JWT(key_jar=KEYJAR, iss="client_1") + res = _jwt.unpack(token) + assert add_scope is (res.get("scope") == ["openid", "profile"]) + def test_is_expired(self): session_id = setup_session( self.endpoint.endpoint_context, AUTH_REQ, uid="diana" From 7f383d702e76b56ecf6d8df9dac1d79f84211745 Mon Sep 17 00:00:00 2001 From: roland Date: Tue, 3 Nov 2020 10:50:15 +0100 Subject: [PATCH 008/150] clean up --- src/oidcendpoint/endpoint.py | 5 +++-- src/oidcendpoint/oauth2/authorization.py | 2 +- src/oidcendpoint/oauth2/introspection.py | 2 +- src/oidcendpoint/oidc/authorization.py | 3 +-- src/oidcendpoint/oidc/userinfo.py | 2 +- src/oidcendpoint/token_handler.py | 2 +- src/oidcendpoint/userinfo.py | 1 + tests/test_08_session.py | 4 ++-- tests/test_23_oidc_registration_endpoint.py | 2 +- tests/test_24_oauth2_authorization_endpoint.py | 4 ++-- tests/test_24_oidc_authorization_endpoint.py | 2 +- tests/test_30_oidc_end_session.py | 2 +- tests/test_31_introspection.py | 3 --- 13 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/oidcendpoint/endpoint.py b/src/oidcendpoint/endpoint.py index f60ab1d..ba91b6f 100755 --- a/src/oidcendpoint/endpoint.py +++ b/src/oidcendpoint/endpoint.py @@ -4,17 +4,18 @@ from cryptojwt import jwe from cryptojwt.jws.jws import SIGNER_ALGS -from oidcendpoint.token_handler import UnknownToken from oidcmsg.exception import MissingRequiredAttribute from oidcmsg.exception import MissingRequiredValue from oidcmsg.message import Message -from oidcmsg.oauth2 import ResponseMessage, AuthorizationErrorResponse +from oidcmsg.oauth2 import AuthorizationErrorResponse +from oidcmsg.oauth2 import ResponseMessage from oidcendpoint import sanitize from oidcendpoint.client_authn import UnknownOrNoAuthnMethod from oidcendpoint.client_authn import client_auth_setup from oidcendpoint.client_authn import verify_client from oidcendpoint.exception import UnAuthorizedClient +from oidcendpoint.token_handler import UnknownToken from oidcendpoint.util import OAUTH2_NOCACHE_HEADERS __author__ = "Roland Hedberg" diff --git a/src/oidcendpoint/oauth2/authorization.py b/src/oidcendpoint/oauth2/authorization.py index 6e72b17..c4f6c23 100755 --- a/src/oidcendpoint/oauth2/authorization.py +++ b/src/oidcendpoint/oauth2/authorization.py @@ -7,7 +7,6 @@ from cryptojwt.utils import as_unicode from cryptojwt.utils import b64d from cryptojwt.utils import b64e -from oidcendpoint.token_handler import UnknownToken from oidcmsg import oauth2 from oidcmsg.exception import ParameterError from oidcmsg.oidc import AuthorizationResponse @@ -35,6 +34,7 @@ from oidcendpoint.exception import UnAuthorizedClientScope from oidcendpoint.exception import UnknownClient from oidcendpoint.session import setup_session +from oidcendpoint.token_handler import UnknownToken from oidcendpoint.user_authn.authn_context import pick_auth logger = logging.getLogger(__name__) diff --git a/src/oidcendpoint/oauth2/introspection.py b/src/oidcendpoint/oauth2/introspection.py index 4211bd8..8e83a4c 100644 --- a/src/oidcendpoint/oauth2/introspection.py +++ b/src/oidcendpoint/oauth2/introspection.py @@ -1,11 +1,11 @@ """Implements RFC7662""" import logging -from oidcendpoint.token_handler import UnknownToken from oidcmsg import oauth2 from oidcmsg.time_util import utc_time_sans_frac from oidcendpoint.endpoint import Endpoint +from oidcendpoint.token_handler import UnknownToken LOGGER = logging.getLogger(__name__) diff --git a/src/oidcendpoint/oidc/authorization.py b/src/oidcendpoint/oidc/authorization.py index 0955534..050c56e 100755 --- a/src/oidcendpoint/oidc/authorization.py +++ b/src/oidcendpoint/oidc/authorization.py @@ -9,8 +9,6 @@ from cryptojwt.utils import as_unicode from cryptojwt.utils import b64d from cryptojwt.utils import b64e - -from oidcendpoint.token_handler import UnknownToken from oidcmsg import oidc from oidcmsg.exception import ParameterError from oidcmsg.oidc import Claims @@ -38,6 +36,7 @@ from oidcendpoint.exception import UnknownClient from oidcendpoint.oauth2.authorization import check_unknown_scopes_policy from oidcendpoint.session import setup_session +from oidcendpoint.token_handler import UnknownToken from oidcendpoint.user_authn.authn_context import pick_auth logger = logging.getLogger(__name__) diff --git a/src/oidcendpoint/oidc/userinfo.py b/src/oidcendpoint/oidc/userinfo.py index ffe57e7..a63cdc5 100755 --- a/src/oidcendpoint/oidc/userinfo.py +++ b/src/oidcendpoint/oidc/userinfo.py @@ -4,12 +4,12 @@ from cryptojwt.exception import MissingValue from cryptojwt.jwt import JWT from cryptojwt.jwt import utc_time_sans_frac -from oidcendpoint.token_handler import UnknownToken from oidcmsg import oidc from oidcmsg.message import Message from oidcmsg.oauth2 import ResponseMessage from oidcendpoint.endpoint import Endpoint +from oidcendpoint.token_handler import UnknownToken from oidcendpoint.userinfo import collect_user_info from oidcendpoint.util import OAUTH2_NOCACHE_HEADERS diff --git a/src/oidcendpoint/token_handler.py b/src/oidcendpoint/token_handler.py index 692d7d5..3f04a1e 100755 --- a/src/oidcendpoint/token_handler.py +++ b/src/oidcendpoint/token_handler.py @@ -304,7 +304,7 @@ def factory(ec, code=None, token=None, refresh=None, jwks_def=None, **kwargs): _add_passwd(kj, token, "token") args["access_token_handler"] = init_token_handler(ec, token, TTYPE["token"]) - if refresh: + if refresh is not None: _add_passwd(kj, refresh, "refresh") args["refresh_token_handler"] = init_token_handler( ec, refresh, TTYPE["refresh"] diff --git a/src/oidcendpoint/userinfo.py b/src/oidcendpoint/userinfo.py index 56684fd..bd2fa01 100755 --- a/src/oidcendpoint/userinfo.py +++ b/src/oidcendpoint/userinfo.py @@ -3,6 +3,7 @@ from oidcmsg.oidc import Claims from oidcendpoint import sanitize +from oidcendpoint.exception import FailedAuthentication from oidcendpoint.exception import ImproperlyConfigured from oidcendpoint.scopes import convert_scopes2claims diff --git a/tests/test_08_session.py b/tests/test_08_session.py index 3952c55..15ca52c 100644 --- a/tests/test_08_session.py +++ b/tests/test_08_session.py @@ -2,11 +2,10 @@ import shutil import time -from oidcendpoint.token_handler import UnknownToken +import pytest from oidcmsg.oidc import AuthorizationRequest from oidcmsg.oidc import OpenIDRequest from oidcmsg.storage.init import storage_factory -import pytest from oidcendpoint import rndstr from oidcendpoint import token_handler @@ -17,6 +16,7 @@ from oidcendpoint.session import SessionDB from oidcendpoint.session import setup_session from oidcendpoint.sso_db import SSODb +from oidcendpoint.token_handler import UnknownToken from oidcendpoint.token_handler import WrongTokenType from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo diff --git a/tests/test_23_oidc_registration_endpoint.py b/tests/test_23_oidc_registration_endpoint.py index 0d59e45..90d7969 100755 --- a/tests/test_23_oidc_registration_endpoint.py +++ b/tests/test_23_oidc_registration_endpoint.py @@ -1,9 +1,9 @@ # -*- coding: latin-1 -*- import json -from cryptojwt.key_jar import init_key_jar import pytest import responses +from cryptojwt.key_jar import init_key_jar from oidcmsg.oidc import RegistrationRequest from oidcmsg.oidc import RegistrationResponse diff --git a/tests/test_24_oauth2_authorization_endpoint.py b/tests/test_24_oauth2_authorization_endpoint.py index 599269d..8d4cbcd 100755 --- a/tests/test_24_oauth2_authorization_endpoint.py +++ b/tests/test_24_oauth2_authorization_endpoint.py @@ -29,9 +29,9 @@ from oidcendpoint.exception import NoSuchAuthentication from oidcendpoint.exception import RedirectURIError from oidcendpoint.exception import ToOld -from oidcendpoint.exception import UnknownClient -from oidcendpoint.exception import UnAuthorizedClientScope from oidcendpoint.exception import UnAuthorizedClient +from oidcendpoint.exception import UnAuthorizedClientScope +from oidcendpoint.exception import UnknownClient from oidcendpoint.id_token import IDToken from oidcendpoint.oauth2.authorization import Authorization from oidcendpoint.session import SessionInfo diff --git a/tests/test_24_oidc_authorization_endpoint.py b/tests/test_24_oidc_authorization_endpoint.py index 9de55a9..2130b5d 100755 --- a/tests/test_24_oidc_authorization_endpoint.py +++ b/tests/test_24_oidc_authorization_endpoint.py @@ -32,8 +32,8 @@ from oidcendpoint.exception import NoSuchAuthentication from oidcendpoint.exception import RedirectURIError from oidcendpoint.exception import ToOld -from oidcendpoint.exception import UnknownClient from oidcendpoint.exception import UnAuthorizedClient +from oidcendpoint.exception import UnknownClient from oidcendpoint.id_token import IDToken from oidcendpoint.login_hint import LoginHint2Acrs from oidcendpoint.oidc import userinfo diff --git a/tests/test_30_oidc_end_session.py b/tests/test_30_oidc_end_session.py index 57125de..088e591 100644 --- a/tests/test_30_oidc_end_session.py +++ b/tests/test_30_oidc_end_session.py @@ -4,7 +4,6 @@ from urllib.parse import parse_qs from urllib.parse import urlparse -from oidcendpoint.token_handler import UnknownToken import pytest import responses from cryptojwt.key_jar import build_keyjar @@ -26,6 +25,7 @@ from oidcendpoint.oidc.session import Session from oidcendpoint.oidc.session import do_front_channel_logout_iframe from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.token_handler import UnknownToken from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo diff --git a/tests/test_31_introspection.py b/tests/test_31_introspection.py index 4327eb8..2a31542 100644 --- a/tests/test_31_introspection.py +++ b/tests/test_31_introspection.py @@ -13,9 +13,6 @@ from oidcmsg.oidc import AuthorizationRequest from oidcmsg.time_util import utc_time_sans_frac -from oidcendpoint.client_authn import ClientSecretPost -from oidcendpoint.client_authn import UnknownOrNoAuthnMethod -from oidcendpoint.client_authn import WrongAuthnMethod from oidcendpoint.client_authn import verify_client from oidcendpoint.endpoint_context import EndpointContext from oidcendpoint.exception import UnAuthorizedClient From 2f2e388499eb196322de6e88b9f9d4f586c71e9d Mon Sep 17 00:00:00 2001 From: roland Date: Tue, 3 Nov 2020 10:51:24 +0100 Subject: [PATCH 009/150] Change so the method definition is the same. --- src/oidcendpoint/jwt_token.py | 70 +++++++++++++++-------------------- 1 file changed, 29 insertions(+), 41 deletions(-) diff --git a/src/oidcendpoint/jwt_token.py b/src/oidcendpoint/jwt_token.py index d211eea..28b39b1 100644 --- a/src/oidcendpoint/jwt_token.py +++ b/src/oidcendpoint/jwt_token.py @@ -1,7 +1,3 @@ -from typing import Any -from typing import Dict -from typing import Optional - from cryptojwt import JWT from cryptojwt.jws.exception import JWSException @@ -20,16 +16,16 @@ class JWTToken(Token): } def __init__( - self, - typ, - keyjar=None, - issuer=None, - aud=None, - alg="ES256", - lifetime=300, - ec=None, - token_type="Bearer", - **kwargs + self, + typ, + keyjar=None, + issuer=None, + aud=None, + alg="ES256", + lifetime=300, + ec=None, + token_type="Bearer", + **kwargs ): Token.__init__(self, typ, **kwargs) self.token_type = token_type @@ -64,47 +60,38 @@ def do_add_claims(self, payload, uinfo, claims): pass def __call__( - self, - sid: str, - uinfo: Dict, - sinfo: Dict, - aud: Optional[Any], - client_id: Optional[str], - **kwargs + self, + sid: str, + **kwargs ): """ Return a token. :param sid: Session id - :param uinfo: User information - :param sinfo: Session information - :param aud: audience - :param client_id: client_id - :return: + :return: Signed JSON Web Token """ - payload = {"sid": sid, "ttype": self.type, "sub": sinfo["sub"]} + + payload = {"sid": sid, "ttype": self.type, "sub": kwargs["sinfo"]["sub"]} + + _user_claims = kwargs.get('user_claims') + _client_id = kwargs.get('client_id') + _scopes = kwargs.get('scope') if self.add_claims: - self.do_add_claims(payload, uinfo, self.add_claims) + self.do_add_claims(payload, _user_claims, self.add_claims) if self.add_claims_by_scope: - _allowed_claims = self.cntx.claims_handler.allowed_claims( - client_id, self.cntx - ) + _allowed_claims = self.cntx.claims_handler.allowed_claims(_client_id, self.cntx) self.do_add_claims( payload, - uinfo, - convert_scopes2claims( - sinfo["authn_req"]["scope"], - _allowed_claims, - map=self.scope_claims_map, - ).keys(), + _user_claims, + convert_scopes2claims(_scopes, _allowed_claims, map=self.scope_claims_map).keys(), ) # Add claims if is access token if self.type == "T" and self.enable_claims_per_client: - client = self.cdb.get(client_id, {}) + client = self.cdb.get(_client_id, {}) client_claims = client.get("access_token_claims") if client_claims: - self.do_add_claims(payload, uinfo, client_claims) + self.do_add_claims(payload, _user_claims, client_claims) payload.update(kwargs) signer = JWT( @@ -114,10 +101,11 @@ def __call__( sign_alg=self.alg, ) - if aud is None: + _aud = kwargs.get('aud') + if _aud is None: _aud = self.def_aud else: - _aud = aud if isinstance(aud, list) else [aud] + _aud = _aud if isinstance(_aud, list) else [_aud] _aud.extend(self.def_aud) return signer.pack(payload, aud=_aud) From 5cca27bd1bd223ad004d993b24fd5cad8f2af11e Mon Sep 17 00:00:00 2001 From: roland Date: Tue, 3 Nov 2020 11:00:08 +0100 Subject: [PATCH 010/150] Basic functionality. --- src/oidcendpoint/grant.py | 231 ++++++++++++ src/oidcendpoint/session_management.py | 230 ++++++++++++ tests/test_70_grant.py | 61 ++++ tests/test_71_identity_db.py | 67 ++++ tests/test_72_session_life.py | 464 +++++++++++++++++++++++++ 5 files changed, 1053 insertions(+) create mode 100644 src/oidcendpoint/grant.py create mode 100644 src/oidcendpoint/session_management.py create mode 100644 tests/test_70_grant.py create mode 100644 tests/test_71_identity_db.py create mode 100644 tests/test_72_session_life.py diff --git a/src/oidcendpoint/grant.py b/src/oidcendpoint/grant.py new file mode 100644 index 0000000..83e2f0a --- /dev/null +++ b/src/oidcendpoint/grant.py @@ -0,0 +1,231 @@ +import json +import time +from typing import Optional +from uuid import uuid1 + +from oidcmsg.message import OPTIONAL_LIST_OF_SP_SEP_STRINGS +from oidcmsg.message import OPTIONAL_LIST_OF_STRINGS +from oidcmsg.message import SINGLE_OPTIONAL_JSON +from oidcmsg.message import Message +from oidcmsg.time_util import utc_time_sans_frac + + +class GrantMessage(Message): + c_param = { + "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, # As defined in RFC6749 + "authorization_details": SINGLE_OPTIONAL_JSON, # As defined in draft-lodderstedt-oauth-rar + "claims": SINGLE_OPTIONAL_JSON, # As defined in OIDC core + "resources": OPTIONAL_LIST_OF_STRINGS, # As defined in RFC8707 + } + + +GRANT_TYPE_MAP = { + "authorization_code": "code", + "access_token": "access_token", + "refresh_token": "refresh_token" +} + + +def find_token(issued, id): + for iss in issued: + if iss.id == id: + return iss + return None + + +class Item: + def __init__(self, + usage_rules: Optional[dict] = None, + issued_at: int = 0, + expires_at: int = 0, + not_before: int = 0 + ): + self.issued_at = issued_at or utc_time_sans_frac() + self.not_before = not_before + self.expires_at = expires_at + self.revoked = False + self.used = 0 + self.usage_rules = usage_rules or {} + + def max_usage_reached(self): + if "max_usage" in self.usage_rules: + return self.used >= self.usage_rules['max_usage'] + else: + return False + + def is_active(self): + if self.max_usage_reached(): + return False + + if self.revoked: + return False + + if self.not_before: + if time.time() < self.not_before: + return False + + if self.expires_at: + if time.time() > self.expires_at: + return False + + return True + + +class Token(Item): + attributes = ["type", "issued_at", "not_before", "expires_at", "revoked", "value", + "usage_rules", "used", "based_on", "id"] + + def __init__(self, + typ: str = '', + based_on: Optional[str] = None, + usage_rules: Optional[dict] = None, + value: Optional[str] = '', + issued_at: int = 0, + expires_at: int = 0, + not_before: int = 0 + ): + Item.__init__(self, usage_rules=usage_rules, issued_at=issued_at, expires_at=expires_at, + not_before=not_before) + + self.type = typ + self.value = value + self.based_on = based_on + self.id = uuid1().hex + + self.set_defaults() + + def set_defaults(self): + pass + + def register_usage(self): + self.used += 1 + + def has_been_used(self): + return self.used != 0 + + def to_json(self): + d = { + "type": self.type, + "issued_at": self.issued_at, + "not_before": self.not_before, + "expires_at": self.expires_at, + "revoked": self.revoked, + "value": self.value, + "usage_rules": self.usage_rules, + "used": self.used, + "based_on": self.based_on, + "id": self.id + } + return json.dumps(d) + + def from_json(self, json_str): + d = json.loads(json_str) + for attr in self.attributes: + if attr in d: + setattr(self, attr, d[attr]) + return self + + def supports_minting(self, token_type): + return token_type in self.usage_rules['supports_minting'] + + +class AuthorizationCode(Token): + def set_defaults(self): + if "supports_minting" not in self.usage_rules: + self.usage_rules['supports_minting'] = ["access_token", "refresh_token"] + + self.usage_rules['max_usage'] = 1 + + +class RefreshToken(Token): + def set_defaults(self): + if "supports_minting" not in self.usage_rules: + self.usage_rules['supports_minting'] = ["access_token", "refresh_token"] + + +TOKEN_MAP = { + "authorization_code": AuthorizationCode, + "access_token": Token, + "refresh_token": RefreshToken +} + + +class Grant(Item): + def __init__(self, + scopes: Optional[list] = None, + claims: Optional[dict] = None, + resources: Optional[list] = None, + authorization_details: Optional[dict] = None, + token: Optional[list] = None, + usage_rules: Optional[dict] = None, + issued_at: int = 0, + expires_at: int = 0): + Item.__init__(self, usage_rules=usage_rules, issued_at=issued_at, expires_at=expires_at) + self.scope = scopes or [] + self.authorization_details = authorization_details or None + self.claims = claims or None + self.resources = resources or [] + self.issued_token = token or [] + self.id = uuid1().hex + + def update(self, item: dict): + for attr in ['scope', 'authorization_details', 'claims', 'resources']: + val = item.get(attr) + if val: + setattr(self, attr, val) + + def replace(self, item: dict): + for attr in ['scope', 'authorization_details', 'claims', 'resources']: + setattr(self, attr, item.get(attr)) + + def revoke(self): + self.revoked = True + for t in self.issued_token: + t.revoked = True + + def get(self): + return GrantMessage(scope=self.scope, claims=self.claims, + authorization_details=self.authorization_details, + resources=self.resources) + + def to_json(self): + d = { + "scope": self.scope, + "authorization_details": self.authorization_details, + "claims": self.claims, + "resources": self.resources, + "issued_at": self.issued_at, + "not_before": self.not_before, + "expires_at": self.expires_at, + "revoked": self.revoked, + "issued_token": [t.to_json for t in self.issued_token], + "id": self.id + } + return json.dumps(d) + + def from_json(self, json_str): + d = json.loads(json_str) + for attr in ["scope", "authorization_details", "claims", "resources", "issued_at", + "not_before", + "expires_at", "revoked", "id"]: + if attr in d: + setattr(self, attr, d[attr]) + if "issued_token" in d: + setattr(self, "issued_token", [Token(**t) for t in d['issued_token']]) + + def mint_token(self, token_type, **kwargs): + item = TOKEN_MAP[token_type](typ=token_type, **kwargs) + self.issued_token.append(item) + return item + + def revoke_all_based_on(self, id): + for t in self.issued_token: + if t.based_on == id: + t.revoked = True + self.revoke_all_based_on(t.id) + + def get_token(self, val): + for t in self.issued_token: + if t.value == val: + return t + return None diff --git a/src/oidcendpoint/session_management.py b/src/oidcendpoint/session_management.py new file mode 100644 index 0000000..44979d6 --- /dev/null +++ b/src/oidcendpoint/session_management.py @@ -0,0 +1,230 @@ +import hashlib +import logging + +logger = logging.getLogger(__name__) + + +def db_key(*args): + return ':'.join(args) + + +def unpack_db_key(key): + return key.split(':') + + +def pairwise_id(uid, sector_identifier, salt, **kwargs): + return hashlib.sha256(("%s%s%s" % (uid, sector_identifier, salt)).encode("utf-8")).hexdigest() + + +def public_id(uid, salt="", **kwargs): + return hashlib.sha256("{}{}".format(uid, salt).encode("utf-8")).hexdigest() + + +class Info(object): + def __init__(self, **kwargs): + self._db = kwargs or {} + if "subordinate" not in self._db: + self._db["subordinate"] = [] + + def set(self, key, value): + self._db[key] = value + + def get(self, key): + return self._db[key] + + def update(self, ava): + self._db.update(ava) + return self + + def add_subordinate(self, value): + self._db["subordinate"].append(value) + return self + + def remove_subordinate(self, value): + self._db["subordinate"].remove(value) + return self + + def __setitem__(self, key, value): + self._db[key] = value + + def __getitem__(self, key): + return self._db[key] + + def keys(self): + return self._db.keys() + + def values(self): + return self._db.values() + + def items(self): + return self._db.items() + + +class UserInfo(Info): + pass + + +class ClientInfo(Info): + def find_grant(self, val): + for grant in self._db["subordinate"]: + token = grant.get_token(val) + if token: + return grant, token + + +class Database(object): + def __init__(self, storage=None): + self._db = storage or {} + + def eval_path(self, path): + uid = path[0] + client_id = None + grant_id = None + if len(path) > 1: + client_id = path[1] + if len(path) > 2: + grant_id = path[2] + + return uid, client_id, grant_id + + def set(self, path: list, value: object): + """ + + :param path: a list of identifiers + :param value: Class instance to be stored + """ + # Try loading the key, that's a good place to put a debugger to + # import pdb; pdb.set_trace() + uid, client_id, grant_id = self.eval_path(path) + + _userinfo = self._db.get(uid) + if _userinfo: + if client_id: + if client_id in _userinfo['subordinate']: + _cid_key = db_key(uid, client_id) + _cid_info = self._db[db_key(uid, client_id)] + if _cid_info: + if grant_id: + _gid_key = db_key(uid, client_id, grant_id) + if grant_id in _cid_info['subordinate']: + _gid_info = self._db[_gid_key] + if not _gid_info: + self._db[_cid_key] = _cid_info.add_subordinate(grant_id) + self._db[_gid_key] = value + else: + self._db[_cid_key] = _cid_info.add_subordinate(grant_id) + self._db[_gid_key] = value + else: + self._db.set[_cid_key] = value + else: + _userinfo.add_subordinate(client_id) + if grant_id: + _cid_info = ClientInfo() + _cid_info.add_subordinate(grant_id) + self._db[_cid_key] = _cid_info + self._db[db_key(uid, client_id, grant_id)] = value + else: + _cid_info = ClientInfo() + self._db[_cid_key] = _cid_info + self._db[uid] = _userinfo + else: + _userinfo.add_subordinate(client_id) + self._db[uid] = _userinfo + if grant_id: + _cid_info = ClientInfo() + _cid_info.add_subordinate(grant_id) + self._db[db_key(uid, client_id, grant_id)] = value + else: + _cid_info = value + + _cid_key = db_key(uid, client_id) + self._db[_cid_key] = _cid_info + else: + self._db[uid] = value + else: + if client_id: + _user_info = UserInfo() + _user_info.add_subordinate(client_id) + if grant_id: + _cid_info = ClientInfo() + _cid_info.add_subordinate(grant_id) + self._db[db_key(uid, client_id, grant_id)] = value + else: + _cid_info = value + self._db[db_key(uid, client_id)] = _cid_info + else: + _user_info = value + + self._db[uid] = _user_info + + def get(self, path: list): + uid, client_id, grant_id = self.eval_path(path) + try: + user_info = self._db[uid] + except KeyError: + raise KeyError('No such UserID') + + if client_id is None: + return user_info + else: + if client_id not in user_info['subordinate']: + raise ValueError('No session from that client for that user') + else: + try: + client_session_info = self._db[db_key(uid, client_id)] + except KeyError: + return {} + else: + if grant_id is None: + return client_session_info + + if grant_id not in client_session_info['subordinate']: + raise ValueError('No such grant for that user and client') + else: + try: + return self._db[db_key(uid, client_id, grant_id)] + except KeyError: + return {} + + def delete(self, path): + uid, client_id, grant_id = self.eval_path(path) + try: + _dic = self._db[uid] + except KeyError: + pass + else: + if client_id: + if client_id in _dic['client_id']: + try: + _cinfo = self._db[db_key(uid, client_id)] + except KeyError: + pass + else: + if grant_id: + if grant_id in _cinfo['grant_id']: + self._db.__delitem__(db_key(uid, client_id, grant_id)) + else: + self._db.__delitem__(db_key(uid, client_id)) + else: + pass + else: + self._db.__delitem__(uid) + + +class SessionManager(Database): + def __init__(self, handler, storage=None): + Database.__init__(self, storage) + self.token_handler = handler + + def get_user(self, uid): + user = self.get(uid) + + def find_grant(self, user_id, client_id, token_value): + client_info = self.get([user_id, client_id]) + for grant_id in client_info["subordinate"]: + grant = self.get([user_id, client_id, grant_id]) + for token in grant.issued_token: + if token.value == token_value: + return grant, token + + return None diff --git a/tests/test_70_grant.py b/tests/test_70_grant.py new file mode 100644 index 0000000..4635541 --- /dev/null +++ b/tests/test_70_grant.py @@ -0,0 +1,61 @@ +from oidcendpoint.grant import Grant +from oidcendpoint.grant import Token +from oidcendpoint.grant import find_token + + +def test_access_code(): + token = Token('access_code', value="ABCD") + assert token.issued_at + assert token.type == "access_code" + assert token.value == "ABCD" + + token.register_usage() + # max_usage == 1 + assert token.max_usage_reached() is True + + +def test_access_token(): + code = Token('access_code', value="ABCD") + token = Token('access_token', value="1234", based_on=code.id) + assert token.issued_at + assert token.type == "access_token" + assert token.value == "1234" + + token.register_usage() + # max_usage - undefined + assert token.max_usage_reached() is False + + token.max_usage = 2 + token.register_usage() + assert token.max_usage_reached() is True + + t = find_token([code, token], token.based_on) + assert t.value == "ABCD" + + token.revoked = True + assert token.revoked is True + + +def test_grant(): + grant = Grant() + code = grant.mint_token("authorization_code", value="ABCD") + access_token = grant.mint_token("access_token", value="1234", based_on=code.id) + refresh_token = grant.mint_token("refresh_token", value="1234", based_on=code.id) + grant.revoke() + assert code.revoked is True + assert access_token.revoked is True + assert refresh_token.revoked is True + +def test_grant_revoked_based_on(): + grant = Grant() + code = grant.mint_token("authorization_code", value="ABCD") + access_token = grant.mint_token("access_token", value="1234", based_on=code.id) + refresh_token = grant.mint_token("refresh_token", value="1234", based_on=code.id) + + code.register_usage() + if code.max_usage_reached(): + grant.revoke_all_based_on(code.id) + + assert code.is_active() is False + assert access_token.is_active() is False + assert refresh_token.is_active() is False diff --git a/tests/test_71_identity_db.py b/tests/test_71_identity_db.py new file mode 100644 index 0000000..4f1a15f --- /dev/null +++ b/tests/test_71_identity_db.py @@ -0,0 +1,67 @@ +# Database is organized in 3 layers. User-session-grant. +import pytest + +from oidcendpoint.authn_event import create_authn_event +from oidcendpoint.grant import Grant +from oidcendpoint.grant import Token +from oidcendpoint.session_management import ClientInfo +from oidcendpoint.session_management import Database +from oidcendpoint.session_management import UserInfo +from oidcendpoint.session_management import public_id + + +class TestDB: + @pytest.fixture(autouse=True) + def setup_environment(self): + self.db = Database() + + def test_user_info(self): + with pytest.raises(KeyError): + self.db.get(['diana']) + + user_info = UserInfo(foo="bar") + self.db.set(['diana'], user_info) + user_info = self.db.get(['diana']) + assert user_info["foo"] == "bar" + + def test_client_info(self): + user_info = UserInfo(foo="bar") + self.db.set(['diana'], user_info) + client_info = ClientInfo(sid= "abcdef") + self.db.set(['diana', "client_1"], client_info) + + user_info = self.db.get(['diana']) + assert user_info['client_id'] == ['client_1'] + client_info = self.db.get(['diana', "client_1"]) + assert client_info['sid'] == "abcdef" + + def test_jump_ahead(self): + grant = Grant() + access_code = Token('access_code', value='1234567890') + grant.issued_token.append(access_code) + + self.db.set(['diana', "client_1", "G1"], grant) + + user_info = self.db.get(['diana']) + assert user_info['client_id'] == ['client_1'] + client_info = self.db.get(['diana', "client_1"]) + assert client_info['grant_id'] == ["G1"] + grant_info = self.db.get(['diana', 'client_1', 'G1']) + assert grant_info.issued_at + assert len(grant_info.issued_token) == 1 + token = grant_info.issued_token[0] + assert token.value == '1234567890' + assert token.type == "access_code" + + def test_step_wise(self): + salt = "natriumklorid" + # store user info + self.db.set(['diana'], UserInfo(authn_event = create_authn_event('diana', salt))) + # Client specific information + self.db.set(['diana', 'client_1'], ClientInfo(sub= public_id('diana', salt))) + # Grant + grant = Grant() + access_code = Token('access_code', value='1234567890') + grant.issued_token.append(access_code) + + self.db.set(['diana', 'client_1', 'G1'], grant) diff --git a/tests/test_72_session_life.py b/tests/test_72_session_life.py new file mode 100644 index 0000000..dcfa961 --- /dev/null +++ b/tests/test_72_session_life.py @@ -0,0 +1,464 @@ +import os + +import pytest +from cryptojwt.key_jar import init_key_jar +from oidcmsg.oidc import AccessTokenRequest +from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.oidc import RefreshAccessTokenRequest +from oidcmsg.time_util import time_sans_frac + +from oidcendpoint import user_info +from oidcendpoint.authn_event import create_authn_event +from oidcendpoint.client_authn import verify_client +from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.grant import Grant +from oidcendpoint.id_token import IDToken +from oidcendpoint.session_management import ClientInfo +from oidcendpoint.session_management import SessionManager +from oidcendpoint.session_management import UserInfo +from oidcendpoint.session_management import db_key +from oidcendpoint.session_management import public_id +from oidcendpoint.session_management import unpack_db_key +from oidcendpoint.oidc.authorization import Authorization +from oidcendpoint.oidc.provider_config import ProviderConfiguration +from oidcendpoint.oidc.registration import Registration +from oidcendpoint.oidc.session import Session +from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.token_handler import DefaultToken +from oidcendpoint.token_handler import TokenHandler +from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD + + +class TestSession(): + @pytest.fixture(autouse=True) + def setup_token_handler(self): + password = "The longer the better. Is this close to enough ?" + grant_expires_in = 600 + token_expires_in = 900 + refresh_token_expires_in = 86400 + + code_handler = DefaultToken(password, typ="A", lifetime=grant_expires_in) + access_token_handler = DefaultToken( + password, typ="T", lifetime=token_expires_in + ) + refresh_token_handler = DefaultToken( + password, typ="R", lifetime=refresh_token_expires_in + ) + + handler = TokenHandler( + code_handler=code_handler, + access_token_handler=access_token_handler, + refresh_token_handler=refresh_token_handler, + ) + + self.session_manager = SessionManager(handler) + + def auth(self): + # Start with an authentication request + # The client ID appears in the request + AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid", "mail", "address", "offline_access"], + state="STATE", + response_type="code", + ) + + # The authentication returns a user ID + user_id = "diana" + + # User info is stored in the Session DB + + user_info = UserInfo() + self.session_manager.set([user_id], user_info) + + # Now for client session information + salt = "natriumklorid" + authn_event = create_authn_event( + user_id, + salt, + authn_info=INTERNETPROTOCOLPASSWORD, + authn_time=time_sans_frac(), + ) + + client_info = ClientInfo( + authorization_request=AUTH_REQ, + authenticationEvent=authn_event, + sub=public_id(user_id, salt) + ) + self.session_manager.set([user_id, AUTH_REQ['client_id']], client_info) + + # The user consent module produces a Grant instance + + grant = Grant(scopes=AUTH_REQ['scope'], resources=[AUTH_REQ['client_id']]) + + # the grant is assigned to a session (user_id, client_id) + + self.session_manager.set([user_id, AUTH_REQ['client_id'], grant.id], grant) + + # Constructing an authorization code is now done by + + code = grant.mint_token( + 'authorization_code', + value=self.session_manager.token_handler["code"](user_id), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + + return code + + def test_code_flow(self): + # code is a Token instance + code = self.auth() + + # next step is access token request + + TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", + code=code.value + ) + + # parse the token + user_id = self.session_manager.token_handler.sid(TOKEN_REQ['code']) + + # Now given I have the client_id from the request and the user_id from the + # token I can easily find the grant + + # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']]) + grant, tok = self.session_manager.find_grant(user_id, TOKEN_REQ['client_id'], + TOKEN_REQ['code']) + + # Verify that it's of the correct type and can be used + assert tok.type == "authorization_code" + assert tok.is_active() + + # Mint an access token and a refresh token and mark the code as used + + assert tok.supports_minting("access_token") + + access_token = grant.mint_token( + 'access_token', + value=self.session_manager.token_handler["access_token"](user_id), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=tok.id # Means the token (tok) was used to mint this token + ) + + assert tok.supports_minting("refresh_token") + + refresh_token = grant.mint_token( + 'refresh_token', + value=self.session_manager.token_handler["refresh_token"](user_id), + based_on=tok.id + ) + + tok.register_usage() + + assert tok.max_usage_reached() is True + + # A bit later a refresh token is used to mint a new access token + + REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( + grant_type="refresh_token", + client_id="client_1", + client_secret="hemligt", + refresh_token=refresh_token.value, + scope=["openid", "mail", "offline_access"] + ) + + grant, reftok = self.session_manager.find_grant(user_id, + REFRESH_TOKEN_REQ['client_id'], + REFRESH_TOKEN_REQ['refresh_token']) + + assert reftok.supports_minting("access_token") + + access_token_2 = grant.mint_token( + 'access_token', + value=self.session_manager.token_handler["access_token"](user_id), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=reftok.id # Means the token (tok) was used to mint this token + ) + + assert access_token_2.is_active() + + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +ISSUER = "https://example.com/" + +KEYJAR = init_key_jar(key_defs=KEYDEFS, issuer_id=ISSUER) +KEYJAR.import_jwks(KEYJAR.export_jwks(True, ISSUER), "") +RESPONSE_TYPES_SUPPORTED = [ + ["code"], + ["token"], + ["id_token"], + ["code", "token"], + ["code", "id_token"], + ["id_token", "token"], + ["code", "token", "id_token"], + ["none"], +] +CAPABILITIES = { + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ], + "response_modes_supported": ["query", "fragment", "form_post"], + "subject_types_supported": ["public", "pairwise"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + ], + "claim_types_supported": ["normal", "aggregated", "distributed"], + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, +} +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +class TestSessionJWTToken(): + @pytest.fixture(autouse=True) + def setup_session_manager(self): + conf = { + "issuer": ISSUER, + "password": "mycket hemligt", + "token_expires_in": 600, + "grant_expires_in": 300, + "refresh_token_expires_in": 86400, + "verify_ssl": False, + "capabilities": CAPABILITIES, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "token_handler_args": { + "jwks_def": { + "private_path": "private/token_jwks.json", + "read_only": False, + "key_defs": [ + {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}, + {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "refresh"} + ], + }, + "code": {"lifetime": 600}, + "token": { + "class": "oidcendpoint.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims": [ + "email", + "email_verified", + "phone_number", + "phone_number_verified", + ], + "add_claim_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": {}, + }, + "endpoint": { + "provider_config": { + "path": "{}/.well-known/openid-configuration", + "class": ProviderConfiguration, + "kwargs": {}, + }, + "registration": { + "path": "{}/registration", + "class": Registration, + "kwargs": {}, + }, + "authorization": { + "path": "{}/authorization", + "class": Authorization, + "kwargs": {}, + }, + "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}}, + "session": {"path": "{}/end_session", "class": Session}, + }, + "client_authn": verify_client, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "oidcendpoint.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "template_dir": "template", + "userinfo": { + "class": user_info.UserInfo, + "kwargs": {"db_file": full_path("users.json")}, + }, + "id_token": {"class": IDToken}, + } + + self.endpoint_context = EndpointContext(conf, keyjar=KEYJAR) + self.session_manager = SessionManager(self.endpoint_context.sdb.handler) + + def auth(self): + # Start with an authentication request + # The client ID appears in the request + AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid", "mail", "address", "offline_access"], + state="STATE", + response_type="code", + ) + + # The authentication returns a user ID + user_id = "diana" + + # User info is stored in the Session DB + + user_info = UserInfo() + self.session_manager.set([user_id], user_info) + + # Now for client session information + salt = "natriumklorid" + authn_event = create_authn_event( + user_id, + salt, + authn_info=INTERNETPROTOCOLPASSWORD, + authn_time=time_sans_frac(), + ) + + client_info = ClientInfo( + authorization_request=AUTH_REQ, + authenticationEvent=authn_event, + sub=public_id(user_id, salt) + ) + self.session_manager.set([user_id, AUTH_REQ['client_id']], client_info) + + # The user consent module produces a Grant instance + + grant = Grant(scopes=AUTH_REQ['scope'], resources=[AUTH_REQ['client_id']]) + + # the grant is assigned to a session (user_id, client_id) + + self.session_manager.set([user_id, AUTH_REQ['client_id'], grant.id], grant) + + # Constructing an authorization code is now done by + + code = grant.mint_token( + 'authorization_code', + value=self.session_manager.token_handler["code"]( + db_key(user_id, AUTH_REQ['client_id'])), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + + return code + + def test_code_flow(self): + # code is a Token instance + code = self.auth() + + # next step is access token request + + TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", + code=code.value + ) + + # parse the token + session_id = self.session_manager.token_handler.sid(TOKEN_REQ['code']) + user_id, client_id = unpack_db_key(session_id) + + # Now given I have the client_id from the request and the user_id from the + # token I can easily find the grant + + # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']]) + grant, tok = self.session_manager.find_grant(user_id, TOKEN_REQ['client_id'], + TOKEN_REQ['code']) + + # Verify that it's of the correct type and can be used + assert tok.type == "authorization_code" + assert tok.is_active() + + # Mint an access token and a refresh token and mark the code as used + + assert tok.supports_minting("access_token") + + client_info = self.session_manager.get([user_id, TOKEN_REQ["client_id"]]) + + assert tok.supports_minting("access_token") + + user_claims = self.endpoint_context.userinfo(user_id, client_id=TOKEN_REQ["client_id"], + user_info_claims=grant.claims) + + access_token = grant.mint_token( + 'access_token', + value=self.session_manager.token_handler["access_token"]( + db_key(user_id, client_id), + sinfo=client_info, + client_id=TOKEN_REQ['client_id'], + aud=grant.resources, + uinfo=user_claims + ), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=tok.id # Means the token (tok) was used to mint this token + ) + + assert tok.supports_minting("refresh_token") + + refresh_token = grant.mint_token( + 'refresh_token', + value=self.session_manager.token_handler["refresh_token"](db_key(user_id, client_id)), + based_on=tok.id + ) + + tok.register_usage() + + assert tok.max_usage_reached() is True + + # A bit later a refresh token is used to mint a new access token + + REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( + grant_type="refresh_token", + client_id="client_1", + client_secret="hemligt", + refresh_token=refresh_token.value, + scope=["openid", "mail", "offline_access"] + ) + + grant, reftok = self.session_manager.find_grant(user_id, + REFRESH_TOKEN_REQ['client_id'], + REFRESH_TOKEN_REQ['refresh_token']) + + # Can I use this token to mint another token ? + assert reftok.supports_minting("access_token") + assert reftok.is_active() + assert grant.is_active() + + user_claims = self.endpoint_context.userinfo(user_id, client_id=TOKEN_REQ["client_id"], + user_info_claims=grant.claims) + + access_token_2 = grant.mint_token( + 'access_token', + value=self.session_manager.token_handler["access_token"]( + db_key(user_id, client_id), + sinfo=client_info, + client_id=TOKEN_REQ['client_id'], + aud=grant.resources, + uinfo=user_claims + ), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=reftok.id # Means the refresh token (reftok) was used to mint this token + ) + + assert access_token_2.is_active() From acbe78004c3951ce9bf3284c2a8008d38ca3a8f5 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Tue, 3 Nov 2020 11:30:00 +0100 Subject: [PATCH 011/150] Fixed tests. --- src/oidcendpoint/__init__.py | 2 +- src/oidcendpoint/jwt_token.py | 4 +- tests/test_70_grant.py | 13 ++- ...identity_db.py => test_71_sess_mngm_db.py} | 6 +- tests/test_72_session_life.py | 103 +++++++++--------- 5 files changed, 65 insertions(+), 63 deletions(-) rename tests/{test_71_identity_db.py => test_71_sess_mngm_db.py} (93%) diff --git a/src/oidcendpoint/__init__.py b/src/oidcendpoint/__init__.py index ef93b0a..2bb7941 100755 --- a/src/oidcendpoint/__init__.py +++ b/src/oidcendpoint/__init__.py @@ -1,7 +1,7 @@ import string from secrets import choice -__version__ = "1.1.2" +__version__ = '2.0.0' DEF_SIGN_ALG = { "id_token": "RS256", diff --git a/src/oidcendpoint/jwt_token.py b/src/oidcendpoint/jwt_token.py index 28b39b1..a33d765 100644 --- a/src/oidcendpoint/jwt_token.py +++ b/src/oidcendpoint/jwt_token.py @@ -71,7 +71,7 @@ def __call__( :return: Signed JSON Web Token """ - payload = {"sid": sid, "ttype": self.type, "sub": kwargs["sinfo"]["sub"]} + payload = {"sid": sid, "ttype": self.type, "sub": kwargs['sub']} _user_claims = kwargs.get('user_claims') _client_id = kwargs.get('client_id') @@ -93,7 +93,7 @@ def __call__( if client_claims: self.do_add_claims(payload, _user_claims, client_claims) - payload.update(kwargs) + # payload.update(kwargs) signer = JWT( key_jar=self.key_jar, iss=self.issuer, diff --git a/tests/test_70_grant.py b/tests/test_70_grant.py index 4635541..e3bce23 100644 --- a/tests/test_70_grant.py +++ b/tests/test_70_grant.py @@ -1,12 +1,13 @@ +from oidcendpoint.grant import AuthorizationCode +from oidcendpoint.grant import find_token from oidcendpoint.grant import Grant from oidcendpoint.grant import Token -from oidcendpoint.grant import find_token def test_access_code(): - token = Token('access_code', value="ABCD") + token = AuthorizationCode('authorization_code', value="ABCD") assert token.issued_at - assert token.type == "access_code" + assert token.type == "authorization_code" assert token.value == "ABCD" token.register_usage() @@ -15,8 +16,8 @@ def test_access_code(): def test_access_token(): - code = Token('access_code', value="ABCD") - token = Token('access_token', value="1234", based_on=code.id) + code = AuthorizationCode('authorization_code', value="ABCD") + token = Token('access_token', value="1234", based_on=code.id, usage_rules={"max_usage": 2}) assert token.issued_at assert token.type == "access_token" assert token.value == "1234" @@ -25,7 +26,6 @@ def test_access_token(): # max_usage - undefined assert token.max_usage_reached() is False - token.max_usage = 2 token.register_usage() assert token.max_usage_reached() is True @@ -46,6 +46,7 @@ def test_grant(): assert access_token.revoked is True assert refresh_token.revoked is True + def test_grant_revoked_based_on(): grant = Grant() code = grant.mint_token("authorization_code", value="ABCD") diff --git a/tests/test_71_identity_db.py b/tests/test_71_sess_mngm_db.py similarity index 93% rename from tests/test_71_identity_db.py rename to tests/test_71_sess_mngm_db.py index 4f1a15f..be93cb0 100644 --- a/tests/test_71_identity_db.py +++ b/tests/test_71_sess_mngm_db.py @@ -31,7 +31,7 @@ def test_client_info(self): self.db.set(['diana', "client_1"], client_info) user_info = self.db.get(['diana']) - assert user_info['client_id'] == ['client_1'] + assert user_info['subordinate'] == ['client_1'] client_info = self.db.get(['diana', "client_1"]) assert client_info['sid'] == "abcdef" @@ -43,9 +43,9 @@ def test_jump_ahead(self): self.db.set(['diana', "client_1", "G1"], grant) user_info = self.db.get(['diana']) - assert user_info['client_id'] == ['client_1'] + assert user_info['subordinate'] == ['client_1'] client_info = self.db.get(['diana', "client_1"]) - assert client_info['grant_id'] == ["G1"] + assert client_info['subordinate'] == ["G1"] grant_info = self.db.get(['diana', 'client_1', 'G1']) assert grant_info.issued_at assert len(grant_info.issued_token) == 1 diff --git a/tests/test_72_session_life.py b/tests/test_72_session_life.py index dcfa961..42f6ebf 100644 --- a/tests/test_72_session_life.py +++ b/tests/test_72_session_life.py @@ -13,17 +13,17 @@ from oidcendpoint.endpoint_context import EndpointContext from oidcendpoint.grant import Grant from oidcendpoint.id_token import IDToken -from oidcendpoint.session_management import ClientInfo -from oidcendpoint.session_management import SessionManager -from oidcendpoint.session_management import UserInfo -from oidcendpoint.session_management import db_key -from oidcendpoint.session_management import public_id -from oidcendpoint.session_management import unpack_db_key from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration from oidcendpoint.oidc.session import Session from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.session_management import ClientInfo +from oidcendpoint.session_management import db_key +from oidcendpoint.session_management import public_id +from oidcendpoint.session_management import SessionManager +from oidcendpoint.session_management import unpack_db_key +from oidcendpoint.session_management import UserInfo from oidcendpoint.token_handler import DefaultToken from oidcendpoint.token_handler import TokenHandler from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD @@ -40,16 +40,16 @@ def setup_token_handler(self): code_handler = DefaultToken(password, typ="A", lifetime=grant_expires_in) access_token_handler = DefaultToken( password, typ="T", lifetime=token_expires_in - ) + ) refresh_token_handler = DefaultToken( password, typ="R", lifetime=refresh_token_expires_in - ) + ) handler = TokenHandler( code_handler=code_handler, access_token_handler=access_token_handler, refresh_token_handler=refresh_token_handler, - ) + ) self.session_manager = SessionManager(handler) @@ -62,7 +62,7 @@ def auth(self): scope=["openid", "mail", "address", "offline_access"], state="STATE", response_type="code", - ) + ) # The authentication returns a user ID user_id = "diana" @@ -79,13 +79,13 @@ def auth(self): salt, authn_info=INTERNETPROTOCOLPASSWORD, authn_time=time_sans_frac(), - ) + ) client_info = ClientInfo( authorization_request=AUTH_REQ, authenticationEvent=authn_event, sub=public_id(user_id, salt) - ) + ) self.session_manager.set([user_id, AUTH_REQ['client_id']], client_info) # The user consent module produces a Grant instance @@ -102,7 +102,7 @@ def auth(self): 'authorization_code', value=self.session_manager.token_handler["code"](user_id), expires_at=time_sans_frac() + 300 # 5 minutes from now - ) + ) return code @@ -119,7 +119,7 @@ def test_code_flow(self): grant_type="authorization_code", client_secret="hemligt", code=code.value - ) + ) # parse the token user_id = self.session_manager.token_handler.sid(TOKEN_REQ['code']) @@ -144,7 +144,7 @@ def test_code_flow(self): value=self.session_manager.token_handler["access_token"](user_id), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=tok.id # Means the token (tok) was used to mint this token - ) + ) assert tok.supports_minting("refresh_token") @@ -152,7 +152,7 @@ def test_code_flow(self): 'refresh_token', value=self.session_manager.token_handler["refresh_token"](user_id), based_on=tok.id - ) + ) tok.register_usage() @@ -166,7 +166,7 @@ def test_code_flow(self): client_secret="hemligt", refresh_token=refresh_token.value, scope=["openid", "mail", "offline_access"] - ) + ) grant, reftok = self.session_manager.find_grant(user_id, REFRESH_TOKEN_REQ['client_id'], @@ -179,7 +179,7 @@ def test_code_flow(self): value=self.session_manager.token_handler["access_token"](user_id), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=reftok.id # Means the token (tok) was used to mint this token - ) + ) assert access_token_2.is_active() @@ -187,7 +187,7 @@ def test_code_flow(self): KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, -] + ] ISSUER = "https://example.com/" @@ -202,7 +202,7 @@ def test_code_flow(self): ["id_token", "token"], ["code", "token", "id_token"], ["none"], -] + ] CAPABILITIES = { "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], "token_endpoint_auth_methods_supported": [ @@ -210,19 +210,19 @@ def test_code_flow(self): "client_secret_basic", "client_secret_jwt", "private_key_jwt", - ], + ], "response_modes_supported": ["query", "fragment", "form_post"], "subject_types_supported": ["public", "pairwise"], "grant_types_supported": [ "authorization_code", "implicit", "urn:ietf:params:oauth:grant-type:jwt-bearer", - ], + ], "claim_types_supported": ["normal", "aggregated", "distributed"], "claims_parameter_supported": True, "request_parameter_supported": True, "request_uri_parameter_supported": True, -} + } BASEDIR = os.path.abspath(os.path.dirname(__file__)) @@ -249,8 +249,8 @@ def setup_session_manager(self): "key_defs": [ {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}, {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "refresh"} - ], - }, + ], + }, "code": {"lifetime": 600}, "token": { "class": "oidcendpoint.jwt_token.JWTToken", @@ -261,47 +261,47 @@ def setup_session_manager(self): "email_verified", "phone_number", "phone_number_verified", - ], + ], "add_claim_by_scope": True, "aud": ["https://example.org/appl"], + }, }, - }, "refresh": {}, - }, + }, "endpoint": { "provider_config": { "path": "{}/.well-known/openid-configuration", "class": ProviderConfiguration, "kwargs": {}, - }, + }, "registration": { "path": "{}/registration", "class": Registration, "kwargs": {}, - }, + }, "authorization": { "path": "{}/authorization", "class": Authorization, "kwargs": {}, - }, + }, "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}}, "session": {"path": "{}/end_session", "class": Session}, - }, + }, "client_authn": verify_client, "authentication": { "anon": { "acr": INTERNETPROTOCOLPASSWORD, "class": "oidcendpoint.user_authn.user.NoAuthn", "kwargs": {"user": "diana"}, - } - }, + } + }, "template_dir": "template", "userinfo": { "class": user_info.UserInfo, "kwargs": {"db_file": full_path("users.json")}, - }, + }, "id_token": {"class": IDToken}, - } + } self.endpoint_context = EndpointContext(conf, keyjar=KEYJAR) self.session_manager = SessionManager(self.endpoint_context.sdb.handler) @@ -315,7 +315,7 @@ def auth(self): scope=["openid", "mail", "address", "offline_access"], state="STATE", response_type="code", - ) + ) # The authentication returns a user ID user_id = "diana" @@ -332,13 +332,13 @@ def auth(self): salt, authn_info=INTERNETPROTOCOLPASSWORD, authn_time=time_sans_frac(), - ) + ) client_info = ClientInfo( authorization_request=AUTH_REQ, authenticationEvent=authn_event, sub=public_id(user_id, salt) - ) + ) self.session_manager.set([user_id, AUTH_REQ['client_id']], client_info) # The user consent module produces a Grant instance @@ -356,7 +356,7 @@ def auth(self): value=self.session_manager.token_handler["code"]( db_key(user_id, AUTH_REQ['client_id'])), expires_at=time_sans_frac() + 300 # 5 minutes from now - ) + ) return code @@ -373,7 +373,7 @@ def test_code_flow(self): grant_type="authorization_code", client_secret="hemligt", code=code.value - ) + ) # parse the token session_id = self.session_manager.token_handler.sid(TOKEN_REQ['code']) @@ -405,14 +405,15 @@ def test_code_flow(self): 'access_token', value=self.session_manager.token_handler["access_token"]( db_key(user_id, client_id), - sinfo=client_info, client_id=TOKEN_REQ['client_id'], aud=grant.resources, - uinfo=user_claims - ), + user_claims=user_claims, + scope=grant.scope, + sub=client_info['sub'] + ), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=tok.id # Means the token (tok) was used to mint this token - ) + ) assert tok.supports_minting("refresh_token") @@ -420,7 +421,7 @@ def test_code_flow(self): 'refresh_token', value=self.session_manager.token_handler["refresh_token"](db_key(user_id, client_id)), based_on=tok.id - ) + ) tok.register_usage() @@ -434,7 +435,7 @@ def test_code_flow(self): client_secret="hemligt", refresh_token=refresh_token.value, scope=["openid", "mail", "offline_access"] - ) + ) grant, reftok = self.session_manager.find_grant(user_id, REFRESH_TOKEN_REQ['client_id'], @@ -452,13 +453,13 @@ def test_code_flow(self): 'access_token', value=self.session_manager.token_handler["access_token"]( db_key(user_id, client_id), - sinfo=client_info, + sub=client_info['sub'], client_id=TOKEN_REQ['client_id'], aud=grant.resources, - uinfo=user_claims - ), + user_claims=user_claims + ), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=reftok.id # Means the refresh token (reftok) was used to mint this token - ) + ) assert access_token_2.is_active() From 35e667e299502446563239f6e23eb940cc8aa85f Mon Sep 17 00:00:00 2001 From: roland Date: Fri, 6 Nov 2020 14:20:34 +0100 Subject: [PATCH 012/150] Refactoring and adding more usecases. --- src/oidcendpoint/grant.py | 43 +++++-- src/oidcendpoint/session_management.py | 158 ++++++++++++++++++++++--- src/oidcendpoint/user_authn/user.py | 3 +- src/oidcendpoint/userinfo.py | 2 +- tests/test_70_grant.py | 12 +- tests/test_71_sess_mngm_db.py | 26 ++-- tests/test_72_session_life.py | 65 +++++----- 7 files changed, 233 insertions(+), 76 deletions(-) diff --git a/src/oidcendpoint/grant.py b/src/oidcendpoint/grant.py index 83e2f0a..9a20c83 100644 --- a/src/oidcendpoint/grant.py +++ b/src/oidcendpoint/grant.py @@ -3,13 +3,17 @@ from typing import Optional from uuid import uuid1 +from oidcmsg.message import Message from oidcmsg.message import OPTIONAL_LIST_OF_SP_SEP_STRINGS from oidcmsg.message import OPTIONAL_LIST_OF_STRINGS from oidcmsg.message import SINGLE_OPTIONAL_JSON -from oidcmsg.message import Message from oidcmsg.time_util import utc_time_sans_frac +class MintingNotAllowed(Exception): + pass + + class GrantMessage(Message): c_param = { "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, # As defined in RFC6749 @@ -126,7 +130,11 @@ def from_json(self, json_str): return self def supports_minting(self, token_type): - return token_type in self.usage_rules['supports_minting'] + _supports_minting = self.usage_rules.get("supports_minting") + if _supports_minting is None: + return False + else: + return token_type in _supports_minting class AuthorizationCode(Token): @@ -152,8 +160,8 @@ def set_defaults(self): class Grant(Item): def __init__(self, - scopes: Optional[list] = None, - claims: Optional[dict] = None, + scope: Optional[list] = None, + claim: Optional[dict] = None, resources: Optional[list] = None, authorization_details: Optional[dict] = None, token: Optional[list] = None, @@ -161,9 +169,9 @@ def __init__(self, issued_at: int = 0, expires_at: int = 0): Item.__init__(self, usage_rules=usage_rules, issued_at=issued_at, expires_at=expires_at) - self.scope = scopes or [] + self.scope = scope or [] self.authorization_details = authorization_details or None - self.claims = claims or None + self.claims = claim or None self.resources = resources or [] self.issued_token = token or [] self.id = uuid1().hex @@ -178,7 +186,7 @@ def replace(self, item: dict): for attr in ['scope', 'authorization_details', 'claims', 'resources']: setattr(self, attr, item.get(attr)) - def revoke(self): + def revoke_all(self): self.revoked = True for t in self.issued_token: t.revoked = True @@ -206,16 +214,24 @@ def to_json(self): def from_json(self, json_str): d = json.loads(json_str) for attr in ["scope", "authorization_details", "claims", "resources", "issued_at", - "not_before", - "expires_at", "revoked", "id"]: + "not_before", "expires_at", "revoked", "id"]: if attr in d: setattr(self, attr, d[attr]) if "issued_token" in d: setattr(self, "issued_token", [Token(**t) for t in d['issued_token']]) - def mint_token(self, token_type, **kwargs): - item = TOKEN_MAP[token_type](typ=token_type, **kwargs) + def mint_token(self, token_type: str, based_on: Optional[Token] = None, **kwargs) -> Token: + if based_on: + if based_on.supports_minting(token_type) and based_on.is_active(): + _base_on_ref = based_on.value + else: + raise MintingNotAllowed() + else: + _base_on_ref = None + + item = TOKEN_MAP[token_type](typ=token_type, based_on=_base_on_ref, **kwargs) self.issued_token.append(item) + return item def revoke_all_based_on(self, id): @@ -229,3 +245,8 @@ def get_token(self, val): if t.value == val: return t return None + + def revoke_token(self, val): + for t in self.issued_token: + if t.value == val: + t.revoked = True diff --git a/src/oidcendpoint/session_management.py b/src/oidcendpoint/session_management.py index 44979d6..cc65f2f 100644 --- a/src/oidcendpoint/session_management.py +++ b/src/oidcendpoint/session_management.py @@ -1,9 +1,15 @@ import hashlib import logging +from oidcendpoint import rndstr + logger = logging.getLogger(__name__) +class Revoked(Exception): + pass + + def db_key(*args): return ':'.join(args) @@ -12,7 +18,7 @@ def unpack_db_key(key): return key.split(':') -def pairwise_id(uid, sector_identifier, salt, **kwargs): +def pairwise_id(uid, sector_identifier, salt="", **kwargs): return hashlib.sha256(("%s%s%s" % (uid, sector_identifier, salt)).encode("utf-8")).hexdigest() @@ -20,9 +26,10 @@ def public_id(uid, salt="", **kwargs): return hashlib.sha256("{}{}".format(uid, salt).encode("utf-8")).hexdigest() -class Info(object): +class SessionInfo(object): def __init__(self, **kwargs): self._db = kwargs or {} + self._revoked = False if "subordinate" not in self._db: self._db["subordinate"] = [] @@ -59,12 +66,21 @@ def values(self): def items(self): return self._db.items() + def __contains__(self, item): + return item in self._db + + def revoke(self): + self._revoked = True -class UserInfo(Info): + def is_revoked(self): + return self._revoked + + +class UserSessionInfo(SessionInfo): pass -class ClientInfo(Info): +class ClientSessionInfo(SessionInfo): def find_grant(self, val): for grant in self._db["subordinate"]: token = grant.get_token(val) @@ -103,6 +119,8 @@ def set(self, path: list, value: object): if client_id in _userinfo['subordinate']: _cid_key = db_key(uid, client_id) _cid_info = self._db[db_key(uid, client_id)] + if _cid_info.is_revoked(): + raise Revoked("Session is revoked") if _cid_info: if grant_id: _gid_key = db_key(uid, client_id, grant_id) @@ -115,23 +133,23 @@ def set(self, path: list, value: object): self._db[_cid_key] = _cid_info.add_subordinate(grant_id) self._db[_gid_key] = value else: - self._db.set[_cid_key] = value + self._db[_cid_key] = value else: _userinfo.add_subordinate(client_id) if grant_id: - _cid_info = ClientInfo() + _cid_info = ClientSessionInfo() _cid_info.add_subordinate(grant_id) self._db[_cid_key] = _cid_info self._db[db_key(uid, client_id, grant_id)] = value else: - _cid_info = ClientInfo() + _cid_info = ClientSessionInfo() self._db[_cid_key] = _cid_info self._db[uid] = _userinfo else: _userinfo.add_subordinate(client_id) self._db[uid] = _userinfo if grant_id: - _cid_info = ClientInfo() + _cid_info = ClientSessionInfo() _cid_info.add_subordinate(grant_id) self._db[db_key(uid, client_id, grant_id)] = value else: @@ -143,10 +161,10 @@ def set(self, path: list, value: object): self._db[uid] = value else: if client_id: - _user_info = UserInfo() + _user_info = UserSessionInfo() _user_info.add_subordinate(client_id) if grant_id: - _cid_info = ClientInfo() + _cid_info = ClientSessionInfo() _cid_info.add_subordinate(grant_id) self._db[db_key(uid, client_id, grant_id)] = value else: @@ -175,6 +193,9 @@ def get(self, path: list): except KeyError: return {} else: + if client_session_info.is_revoked(): + raise Revoked("Session is revoked") + if grant_id is None: return client_session_info @@ -194,32 +215,56 @@ def delete(self, path): pass else: if client_id: - if client_id in _dic['client_id']: + if client_id in _dic['subordinate']: try: _cinfo = self._db[db_key(uid, client_id)] except KeyError: pass else: if grant_id: - if grant_id in _cinfo['grant_id']: + if grant_id in _cinfo['subordinate']: self._db.__delitem__(db_key(uid, client_id, grant_id)) else: + for grant_id in _cinfo['subordinate']: + self._db.__delitem__(db_key(uid, client_id, grant_id)) self._db.__delitem__(db_key(uid, client_id)) + + _dic["subordinate"].remove(client_id) + self._db[uid] = _dic else: pass else: self._db.__delitem__(uid) + def update(self, path, new_info): + _info = self.get(path) + _info.update(new_info) + self.set(path, _info) + class SessionManager(Database): - def __init__(self, handler, storage=None): - Database.__init__(self, storage) + def __init__(self, db, handler, userinfo=None, sub_func=None): + Database.__init__(self, db) self.token_handler = handler + self.userinfo = userinfo + self.salt = rndstr(32) + + # this allows the subject identifier minters to be defined by someone + # else then me. + if sub_func is None: + self.sub_func = {"public": public_id, "pairwise": pairwise_id} + else: + self.sub_func = sub_func + if "public" not in sub_func: + self.sub_func["public"] = public_id + if "pairwise" not in sub_func: + self.sub_func["pairwise"] = pairwise_id - def get_user(self, uid): - user = self.get(uid) + def get_user_info(self, uid): + return self.get(uid) - def find_grant(self, user_id, client_id, token_value): + def find_grant(self, session_id, token_value): + user_id, client_id = unpack_db_key(session_id) client_info = self.get([user_id, client_id]) for grant_id in client_info["subordinate"]: grant = self.get([user_id, client_id, grant_id]) @@ -228,3 +273,82 @@ def find_grant(self, user_id, client_id, token_value): return grant, token return None + + def create_session(self, authn_event, auth_req, user_id, client_id="", + sub_type="public", sector_identifier='', **kwargs): + """ + + :param authn_event: + :param auth_req: Authorization Request + :param client_id: Client ID + :param user_id: User ID + :param kwargs: extra keyword arguments + :return: + """ + + try: + _ = self.get([user_id]) + except KeyError: + user_info = UserSessionInfo(authentication_event=authn_event) + self.set([user_id], user_info) + + client_info = ClientSessionInfo( + authorization_request=auth_req, + sub=self.sub_func[sub_type](user_id, salt=self.salt, + sector_identifier=sector_identifier) + ) + + if not client_id: + client_id = auth_req['client_id'] + + self.set([user_id, client_id], client_info) + + def _update_client_info(self, session_id, new_information): + _path = unpack_db_key(session_id) + _client_info = self.get(_path) + _client_info.update(new_information) + self.set(_path, _client_info) + + def do_sub(self, session_id, sector_id="", subject_type="public"): + """ + Create and store a subject identifier + + :param session_id: Session ID + :param sector_id: For pairwise identifiers, an Identifier for the RP group + :param subject_type: 'pairwise'/'public' + :return: + """ + _path = unpack_db_key(session_id) + sub = self.sub_func[subject_type](_path[0], salt=self.salt, sector_identifier=sector_id) + self._update_client_info(session_id, {'sub': sub}) + return sub + + def __getitem__(self, item): + return self.get(unpack_db_key(item)) + + def revoke_token(self, session_id, token_value): + grant, token = self.find_grant(session_id, token_value) + token.revoked = True + + def get_sids_by_user_id(self, user_id): + user_info = self.get([user_id]) + return [db_key(user_id, c) for c in user_info['subordinate']] + + def get_authentication_event(self, user_id): + try: + user_info = self.get([user_id]) + except KeyError: + return None + + return user_info["authentication_event"] + + def revoke_session(self, session_id): + _path = unpack_db_key(session_id) + _info = self.get(_path) + _info.revoke() + self.set(_path, _info) + + def grants(self, session_id): + uid, cid = unpack_db_key(session_id) + _csi = self.get([uid, cid]) + return [self.get([uid, cid, gid]) for gid in _csi['subordinate']] diff --git a/src/oidcendpoint/user_authn/user.py b/src/oidcendpoint/user_authn/user.py index e197ef6..dce3116 100755 --- a/src/oidcendpoint/user_authn/user.py +++ b/src/oidcendpoint/user_authn/user.py @@ -4,13 +4,14 @@ import logging import sys import time -import warnings from urllib.parse import unquote +import warnings from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptojwt.jwt import JWT from oidcendpoint import sanitize +from oidcendpoint.authn_event import create_authn_event from oidcendpoint.exception import FailedAuthentication from oidcendpoint.exception import ImproperlyConfigured from oidcendpoint.exception import InvalidCookieSign diff --git a/src/oidcendpoint/userinfo.py b/src/oidcendpoint/userinfo.py index bd2fa01..33b7160 100755 --- a/src/oidcendpoint/userinfo.py +++ b/src/oidcendpoint/userinfo.py @@ -124,7 +124,7 @@ def collect_user_info( :param userinfo_claims: user info claims :return: User info """ - authn_req = session["authn_req"] + authn_req = session["authorization_request"] if scope_to_claims is None: scope_to_claims = endpoint_context.scope2claims diff --git a/tests/test_70_grant.py b/tests/test_70_grant.py index e3bce23..1d75043 100644 --- a/tests/test_70_grant.py +++ b/tests/test_70_grant.py @@ -39,9 +39,9 @@ def test_access_token(): def test_grant(): grant = Grant() code = grant.mint_token("authorization_code", value="ABCD") - access_token = grant.mint_token("access_token", value="1234", based_on=code.id) - refresh_token = grant.mint_token("refresh_token", value="1234", based_on=code.id) - grant.revoke() + access_token = grant.mint_token("access_token", value="1234", based_on=code) + refresh_token = grant.mint_token("refresh_token", value="1234", based_on=code) + grant.revoke_all() assert code.revoked is True assert access_token.revoked is True assert refresh_token.revoked is True @@ -50,12 +50,12 @@ def test_grant(): def test_grant_revoked_based_on(): grant = Grant() code = grant.mint_token("authorization_code", value="ABCD") - access_token = grant.mint_token("access_token", value="1234", based_on=code.id) - refresh_token = grant.mint_token("refresh_token", value="1234", based_on=code.id) + access_token = grant.mint_token("access_token", value="1234", based_on=code) + refresh_token = grant.mint_token("refresh_token", value="1234", based_on=code) code.register_usage() if code.max_usage_reached(): - grant.revoke_all_based_on(code.id) + grant.revoke_all_based_on(code.value) assert code.is_active() is False assert access_token.is_active() is False diff --git a/tests/test_71_sess_mngm_db.py b/tests/test_71_sess_mngm_db.py index be93cb0..8f88d06 100644 --- a/tests/test_71_sess_mngm_db.py +++ b/tests/test_71_sess_mngm_db.py @@ -4,9 +4,9 @@ from oidcendpoint.authn_event import create_authn_event from oidcendpoint.grant import Grant from oidcendpoint.grant import Token -from oidcendpoint.session_management import ClientInfo +from oidcendpoint.session_management import ClientSessionInfo from oidcendpoint.session_management import Database -from oidcendpoint.session_management import UserInfo +from oidcendpoint.session_management import UserSessionInfo from oidcendpoint.session_management import public_id @@ -19,15 +19,15 @@ def test_user_info(self): with pytest.raises(KeyError): self.db.get(['diana']) - user_info = UserInfo(foo="bar") + user_info = UserSessionInfo(foo="bar") self.db.set(['diana'], user_info) user_info = self.db.get(['diana']) assert user_info["foo"] == "bar" def test_client_info(self): - user_info = UserInfo(foo="bar") + user_info = UserSessionInfo(foo="bar") self.db.set(['diana'], user_info) - client_info = ClientInfo(sid= "abcdef") + client_info = ClientSessionInfo(sid="abcdef") self.db.set(['diana', "client_1"], client_info) user_info = self.db.get(['diana']) @@ -56,12 +56,24 @@ def test_jump_ahead(self): def test_step_wise(self): salt = "natriumklorid" # store user info - self.db.set(['diana'], UserInfo(authn_event = create_authn_event('diana', salt))) + self.db.set(['diana'], + UserSessionInfo(authentication_event=create_authn_event('diana', salt))) # Client specific information - self.db.set(['diana', 'client_1'], ClientInfo(sub= public_id('diana', salt))) + self.db.set(['diana', 'client_1'], ClientSessionInfo(sub=public_id( + 'diana', salt))) # Grant grant = Grant() access_code = Token('access_code', value='1234567890') grant.issued_token.append(access_code) self.db.set(['diana', 'client_1', 'G1'], grant) + + def test_removed(self): + grant = Grant() + access_code = Token('access_code', value='1234567890') + grant.issued_token.append(access_code) + + self.db.set(['diana', "client_1", "G1"], grant) + self.db.delete(['diana', 'client_1']) + with pytest.raises(ValueError): + self.db.get(['diana', "client_1", "G1"]) diff --git a/tests/test_72_session_life.py b/tests/test_72_session_life.py index 42f6ebf..d93cf4e 100644 --- a/tests/test_72_session_life.py +++ b/tests/test_72_session_life.py @@ -18,12 +18,12 @@ from oidcendpoint.oidc.registration import Registration from oidcendpoint.oidc.session import Session from oidcendpoint.oidc.token import AccessToken -from oidcendpoint.session_management import ClientInfo +from oidcendpoint.session_management import ClientSessionInfo from oidcendpoint.session_management import db_key from oidcendpoint.session_management import public_id from oidcendpoint.session_management import SessionManager from oidcendpoint.session_management import unpack_db_key -from oidcendpoint.session_management import UserInfo +from oidcendpoint.session_management import UserSessionInfo from oidcendpoint.token_handler import DefaultToken from oidcendpoint.token_handler import TokenHandler from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD @@ -51,7 +51,7 @@ def setup_token_handler(self): refresh_token_handler=refresh_token_handler, ) - self.session_manager = SessionManager(handler) + self.session_manager = SessionManager({}, handler=handler) def auth(self): # Start with an authentication request @@ -68,29 +68,27 @@ def auth(self): user_id = "diana" # User info is stored in the Session DB - - user_info = UserInfo() - self.session_manager.set([user_id], user_info) - - # Now for client session information - salt = "natriumklorid" authn_event = create_authn_event( user_id, - salt, + self.session_manager.salt, authn_info=INTERNETPROTOCOLPASSWORD, authn_time=time_sans_frac(), ) - client_info = ClientInfo( + user_info = UserSessionInfo(authenticationEvent=authn_event) + self.session_manager.set([user_id], user_info) + + # Now for client session information + + client_info = ClientSessionInfo( authorization_request=AUTH_REQ, - authenticationEvent=authn_event, - sub=public_id(user_id, salt) + sub=public_id(user_id, self.session_manager.salt) ) self.session_manager.set([user_id, AUTH_REQ['client_id']], client_info) # The user consent module produces a Grant instance - grant = Grant(scopes=AUTH_REQ['scope'], resources=[AUTH_REQ['client_id']]) + grant = Grant(scope=AUTH_REQ['scope'], resources=[AUTH_REQ['client_id']]) # the grant is assigned to a session (user_id, client_id) @@ -128,7 +126,8 @@ def test_code_flow(self): # token I can easily find the grant # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']]) - grant, tok = self.session_manager.find_grant(user_id, TOKEN_REQ['client_id'], + session_id = db_key(user_id, TOKEN_REQ['client_id']) + grant, tok = self.session_manager.find_grant(session_id, TOKEN_REQ['code']) # Verify that it's of the correct type and can be used @@ -143,7 +142,7 @@ def test_code_flow(self): 'access_token', value=self.session_manager.token_handler["access_token"](user_id), expires_at=time_sans_frac() + 900, # 15 minutes from now - based_on=tok.id # Means the token (tok) was used to mint this token + based_on=tok # Means the token (tok) was used to mint this token ) assert tok.supports_minting("refresh_token") @@ -151,7 +150,7 @@ def test_code_flow(self): refresh_token = grant.mint_token( 'refresh_token', value=self.session_manager.token_handler["refresh_token"](user_id), - based_on=tok.id + based_on=tok ) tok.register_usage() @@ -168,8 +167,8 @@ def test_code_flow(self): scope=["openid", "mail", "offline_access"] ) - grant, reftok = self.session_manager.find_grant(user_id, - REFRESH_TOKEN_REQ['client_id'], + session_id = db_key(user_id,REFRESH_TOKEN_REQ['client_id']) + grant, reftok = self.session_manager.find_grant(session_id, REFRESH_TOKEN_REQ['refresh_token']) assert reftok.supports_minting("access_token") @@ -178,7 +177,7 @@ def test_code_flow(self): 'access_token', value=self.session_manager.token_handler["access_token"](user_id), expires_at=time_sans_frac() + 900, # 15 minutes from now - based_on=reftok.id # Means the token (tok) was used to mint this token + based_on=reftok # Means the token (tok) was used to mint this token ) assert access_token_2.is_active() @@ -304,7 +303,7 @@ def setup_session_manager(self): } self.endpoint_context = EndpointContext(conf, keyjar=KEYJAR) - self.session_manager = SessionManager(self.endpoint_context.sdb.handler) + self.session_manager = SessionManager({}, handler=self.endpoint_context.sdb.handler) def auth(self): # Start with an authentication request @@ -322,7 +321,7 @@ def auth(self): # User info is stored in the Session DB - user_info = UserInfo() + user_info = UserSessionInfo() self.session_manager.set([user_id], user_info) # Now for client session information @@ -334,7 +333,7 @@ def auth(self): authn_time=time_sans_frac(), ) - client_info = ClientInfo( + client_info = ClientSessionInfo( authorization_request=AUTH_REQ, authenticationEvent=authn_event, sub=public_id(user_id, salt) @@ -343,7 +342,7 @@ def auth(self): # The user consent module produces a Grant instance - grant = Grant(scopes=AUTH_REQ['scope'], resources=[AUTH_REQ['client_id']]) + grant = Grant(scope=AUTH_REQ['scope'], resources=[AUTH_REQ['client_id']]) # the grant is assigned to a session (user_id, client_id) @@ -383,7 +382,8 @@ def test_code_flow(self): # token I can easily find the grant # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']]) - grant, tok = self.session_manager.find_grant(user_id, TOKEN_REQ['client_id'], + session_id = db_key(user_id, TOKEN_REQ['client_id']) + grant, tok = self.session_manager.find_grant(session_id, TOKEN_REQ['code']) # Verify that it's of the correct type and can be used @@ -412,15 +412,16 @@ def test_code_flow(self): sub=client_info['sub'] ), expires_at=time_sans_frac() + 900, # 15 minutes from now - based_on=tok.id # Means the token (tok) was used to mint this token + based_on=tok # Means the token (tok) was used to mint this token ) - assert tok.supports_minting("refresh_token") + # this test is include in the mint_token methods + # assert tok.supports_minting("refresh_token") refresh_token = grant.mint_token( 'refresh_token', value=self.session_manager.token_handler["refresh_token"](db_key(user_id, client_id)), - based_on=tok.id + based_on=tok ) tok.register_usage() @@ -437,13 +438,11 @@ def test_code_flow(self): scope=["openid", "mail", "offline_access"] ) - grant, reftok = self.session_manager.find_grant(user_id, - REFRESH_TOKEN_REQ['client_id'], + session_id = db_key(user_id, REFRESH_TOKEN_REQ['client_id']) + grant, reftok = self.session_manager.find_grant(session_id, REFRESH_TOKEN_REQ['refresh_token']) # Can I use this token to mint another token ? - assert reftok.supports_minting("access_token") - assert reftok.is_active() assert grant.is_active() user_claims = self.endpoint_context.userinfo(user_id, client_id=TOKEN_REQ["client_id"], @@ -459,7 +458,7 @@ def test_code_flow(self): user_claims=user_claims ), expires_at=time_sans_frac() + 900, # 15 minutes from now - based_on=reftok.id # Means the refresh token (reftok) was used to mint this token + based_on=reftok # Means the refresh token (reftok) was used to mint this token ) assert access_token_2.is_active() From cbd4d32c6e83e933f9af3ef2fc638529a37685ec Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Fri, 6 Nov 2020 14:23:24 +0100 Subject: [PATCH 013/150] Deal with commonality. --- src/oidcendpoint/jwt_token.py | 124 ++++++++++++++++++++-------------- 1 file changed, 75 insertions(+), 49 deletions(-) diff --git a/src/oidcendpoint/jwt_token.py b/src/oidcendpoint/jwt_token.py index a33d765..4cc494c 100644 --- a/src/oidcendpoint/jwt_token.py +++ b/src/oidcendpoint/jwt_token.py @@ -1,38 +1,91 @@ +from typing import Optional + from cryptojwt import JWT from cryptojwt.jws.exception import JWSException from oidcendpoint.exception import ToOld from oidcendpoint.scopes import convert_scopes2claims +from oidcendpoint.token_handler import is_expired from oidcendpoint.token_handler import Token from oidcendpoint.token_handler import UnknownToken -from oidcendpoint.token_handler import is_expired -class JWTToken(Token): +class ClaimsInterface: init_args = { "add_claims_by_scope": False, - "enable_claims_per_client": False, - "add_claims": {}, - } + "enable_claims_per_client": False + } + def __init__(self, endpoint_context, **kwargs): + self.endpoint_context = endpoint_context + self.scope_claims_map = kwargs.get("scope_claims_map", endpoint_context.scope2claims) + self.add_claims_by_scope = kwargs.get("add_claims_by_scope", + self.init_args["add_claims_by_scope"]) + self.enable_claims_per_client = kwargs.get("enable_claims_per_client", + self.init_args["enable_claims_per_client"]) + + def _get_client_claims(self, client_id): + if self.enable_claims_per_client: + client_info = self.endpoint_context.cdb.get(client_id, {}) + return client_info.get("introspection_claims") + else: + return [] + + def _get_user_info(self, token_info): + user_id = self.endpoint_context.sdb.sso_db.get_uid_by_sid(token_info["sid"]) + return self.endpoint_context.userinfo(user_id, client_id=None) + + def add_claims(self, client_id, user_id, payload, scopes, claims_restriction): + if claims_restriction is None: + user_info = self.endpoint_context.userinfo(user_id, client_id=None) + payload.update(user_info) + elif claims_restriction == {}: # Nothing is allowed + pass + else: + possible_claims = self._get_client_claims(client_id) + if self.add_claims_by_scope: + _claims = convert_scopes2claims(scopes, map=self.scope_claims_map).keys() + possible_claims = list(set(possible_claims).union(_claims)) + + if possible_claims: + _claims = {c: None for c in + set(possible_claims).intersection(set(claims_restriction.key()))} + _claims.update(claims_restriction) + else: + _claims = claims_restriction + + if _claims: + user_info = self.endpoint_context.userinfo(user_id, client_id=None, + user_info_claims=_claims) + for attr in _claims: + try: + payload[attr] = user_info[attr] + except KeyError: + pass + + +class JWTToken(Token): def __init__( self, typ, keyjar=None, - issuer=None, - aud=None, - alg="ES256", - lifetime=300, + issuer: str = None, + aud: Optional[list] = None, + alg: str = "ES256", + lifetime: int = 300, ec=None, - token_type="Bearer", + token_type: str = "Bearer", + add_claims: bool = False, **kwargs - ): + ): Token.__init__(self, typ, **kwargs) self.token_type = token_type self.lifetime = lifetime + self.claims_interface = ClaimsInterface(ec, **kwargs) + self.args = { - (k, v) for k, v in kwargs.items() if k not in self.init_args.keys() - } + (k, v) for k, v in kwargs.items() if k not in self.claims_interface.init_args.keys() + } self.key_jar = keyjar or ec.keyjar self.issuer = issuer or ec.issuer @@ -41,29 +94,13 @@ def __init__( self.def_aud = aud or [] self.alg = alg - self.scope_claims_map = kwargs.get("scope_claims_map", ec.scope2claims) - - self.add_claims = self.init_args["add_claims"] - self.add_claims_by_scope = self.init_args["add_claims_by_scope"] - self.enable_claims_per_client = self.init_args["enable_claims_per_client"] - - for param, default in self.init_args.items(): - setattr(self, param, kwargs.get(param, default)) - - def do_add_claims(self, payload, uinfo, claims): - for attr in claims: - if attr == "sub": - continue - try: - payload[attr] = uinfo[attr] - except KeyError: - pass + self.add_claims = add_claims def __call__( self, sid: str, **kwargs - ): + ): """ Return a token. @@ -73,25 +110,14 @@ def __call__( payload = {"sid": sid, "ttype": self.type, "sub": kwargs['sub']} - _user_claims = kwargs.get('user_claims') - _client_id = kwargs.get('client_id') _scopes = kwargs.get('scope') + _client_id = kwargs.get('client_id') + _user_id = kwargs.get('user_id') + _claims = kwargs.get('claims') if self.add_claims: - self.do_add_claims(payload, _user_claims, self.add_claims) - if self.add_claims_by_scope: - _allowed_claims = self.cntx.claims_handler.allowed_claims(_client_id, self.cntx) - self.do_add_claims( - payload, - _user_claims, - convert_scopes2claims(_scopes, _allowed_claims, map=self.scope_claims_map).keys(), - ) - # Add claims if is access token - if self.type == "T" and self.enable_claims_per_client: - client = self.cdb.get(_client_id, {}) - client_claims = client.get("access_token_claims") - if client_claims: - self.do_add_claims(payload, _user_claims, client_claims) + self.claims_interface.add_claims(_client_id, _user_id, payload, claims=_claims, + scopes=_scopes) # payload.update(kwargs) signer = JWT( @@ -99,7 +125,7 @@ def __call__( iss=self.issuer, lifetime=self.lifetime, sign_alg=self.alg, - ) + ) _aud = kwargs.get('aud') if _aud is None: @@ -132,7 +158,7 @@ def info(self, token): "type": _payload["ttype"], "exp": _payload["exp"], "handler": self, - } + } return _res def is_expired(self, token, when=0): From 4c09c3a90182359a038e55136a59e382ef11e61b Mon Sep 17 00:00:00 2001 From: roland Date: Tue, 10 Nov 2020 15:25:38 +0100 Subject: [PATCH 014/150] Getting closer --- src/oidcendpoint/authz/__init__.py | 46 +- src/oidcendpoint/authz/old.init.py | 78 ++ src/oidcendpoint/common/authorization.py | 52 +- src/oidcendpoint/endpoint_context.py | 22 +- src/oidcendpoint/grant.py | 2 +- src/oidcendpoint/id_token.py | 149 ++- src/oidcendpoint/jwt_token.py | 77 +- src/oidcendpoint/oauth2/authorization.py | 115 ++- src/oidcendpoint/oauth2/introspection.py | 23 +- src/oidcendpoint/oauth2/old_introspection.py | 117 +++ src/oidcendpoint/oidc/add_on/pkce.py | 9 +- src/oidcendpoint/oidc/authorization.py | 227 +++-- src/oidcendpoint/oidc/old_authorization.py | 740 ++++++++++++++ src/oidcendpoint/oidc/refresh_token.py | 37 +- src/oidcendpoint/oidc/session.py | 75 +- src/oidcendpoint/oidc/token.py | 130 ++- src/oidcendpoint/oidc/userinfo.py | 34 +- src/oidcendpoint/old_id_token.py | 288 ++++++ .../{session.py => old_session.py} | 0 src/oidcendpoint/scopes.py | 20 +- src/oidcendpoint/session_management.py | 106 +- src/oidcendpoint/user_authn/user.py | 1 - src/oidcendpoint/userinfo.py | 219 ++-- src/oidcendpoint/util.py | 3 +- tests/{test_70_grant.py => test_01_grant.py} | 0 ...ess_mngm_db.py => test_01_sess_mngm_db.py} | 0 ...ession_life.py => test_01_session_life.py} | 137 +-- tests/test_03_id_token.py | 356 ++++--- tests/test_05_sso_db.py | 135 --- tests/test_07_userinfo.py | 251 ++--- tests/test_08_session.py | 520 ---------- ...oidc_authz.py => test_10_oidc_authz.py.no} | 0 .../test_24_oauth2_authorization_endpoint.py | 113 ++- ...st_24_oauth2_authorization_endpoint_jar.py | 2 + tests/test_24_oidc_authorization_endpoint.py | 37 +- .../test_24_oidc_authorization_endpoint.py.no | 933 ++++++++++++++++++ tests/test_25_oidc_token_endpoint.py | 118 ++- tests/test_25_oidc_token_endpoint.py.no | 255 +++++ tests/test_26_oidc_userinfo_endpoint.py | 285 +++--- tests/test_26_oidc_userinfo_endpoint.py.no | 353 +++++++ tests/test_30_oidc_end_session.py | 60 +- tests/test_30_oidc_end_session.py.no | 573 +++++++++++ tests/users.json | 3 - 43 files changed, 4736 insertions(+), 1965 deletions(-) create mode 100755 src/oidcendpoint/authz/old.init.py create mode 100644 src/oidcendpoint/oauth2/old_introspection.py create mode 100755 src/oidcendpoint/oidc/old_authorization.py create mode 100755 src/oidcendpoint/old_id_token.py rename src/oidcendpoint/{session.py => old_session.py} (100%) rename tests/{test_70_grant.py => test_01_grant.py} (100%) rename tests/{test_71_sess_mngm_db.py => test_01_sess_mngm_db.py} (100%) rename tests/{test_72_session_life.py => test_01_session_life.py} (89%) delete mode 100644 tests/test_05_sso_db.py delete mode 100644 tests/test_08_session.py rename tests/{test_10_oidc_authz.py => test_10_oidc_authz.py.no} (100%) create mode 100755 tests/test_24_oidc_authorization_endpoint.py.no create mode 100755 tests/test_25_oidc_token_endpoint.py.no create mode 100755 tests/test_26_oidc_userinfo_endpoint.py.no create mode 100644 tests/test_30_oidc_end_session.py.no diff --git a/src/oidcendpoint/authz/__init__.py b/src/oidcendpoint/authz/__init__.py index 3869cc2..e25eb25 100755 --- a/src/oidcendpoint/authz/__init__.py +++ b/src/oidcendpoint/authz/__init__.py @@ -2,8 +2,7 @@ import logging import sys -from oidcendpoint import sanitize -from oidcendpoint.cookie import cookie_value +from oidcendpoint.grant import Grant logger = logging.getLogger(__name__) @@ -14,50 +13,19 @@ class AuthzHandling(object): def __init__(self, endpoint_context, **kwargs): self.endpoint_context = endpoint_context self.cookie_dealer = endpoint_context.cookie_dealer - self.permdb = {} self.kwargs = kwargs - def __call__(self, *args, **kwargs): - return "" - - def set(self, uid, client_id, permission): - try: - self.permdb[uid][client_id] = permission - except KeyError: - self.permdb[uid] = {client_id: permission} - - def permissions(self, cookie=None, **kwargs): - if cookie is None: - return None - else: - logger.debug("kwargs: %s" % sanitize(kwargs)) - - val = self.cookie_dealer.get_cookie_value(cookie) - if val is None: - return None - else: - b64, _ts, typ = val - - info = cookie_value(b64) - return self.get(info["sub"], info["client_id"]) - - def get(self, uid, client_id): - try: - return self.permdb[uid][client_id] - except KeyError: - return None + def __call__(self, user_id, client_id, request): + permission = {k: v for k, v in request.items() if k in ["scope", "claims"]} + return Grant(**permission) class Implicit(AuthzHandling): - def __init__(self, endpoint_context, permission="implicit"): + def __init__(self, endpoint_context): AuthzHandling.__init__(self, endpoint_context) - self.permission = permission - - def permissions(self, cookie=None, **kwargs): - return self.permission - def get(self, uid, client_id): - return self.permission + def __call__(self, user_id, client_id, request): + return Grant() def factory(msgtype, endpoint_context, **kwargs): diff --git a/src/oidcendpoint/authz/old.init.py b/src/oidcendpoint/authz/old.init.py new file mode 100755 index 0000000..3869cc2 --- /dev/null +++ b/src/oidcendpoint/authz/old.init.py @@ -0,0 +1,78 @@ +import inspect +import logging +import sys + +from oidcendpoint import sanitize +from oidcendpoint.cookie import cookie_value + +logger = logging.getLogger(__name__) + + +class AuthzHandling(object): + """ Class that allow an entity to manage authorization """ + + def __init__(self, endpoint_context, **kwargs): + self.endpoint_context = endpoint_context + self.cookie_dealer = endpoint_context.cookie_dealer + self.permdb = {} + self.kwargs = kwargs + + def __call__(self, *args, **kwargs): + return "" + + def set(self, uid, client_id, permission): + try: + self.permdb[uid][client_id] = permission + except KeyError: + self.permdb[uid] = {client_id: permission} + + def permissions(self, cookie=None, **kwargs): + if cookie is None: + return None + else: + logger.debug("kwargs: %s" % sanitize(kwargs)) + + val = self.cookie_dealer.get_cookie_value(cookie) + if val is None: + return None + else: + b64, _ts, typ = val + + info = cookie_value(b64) + return self.get(info["sub"], info["client_id"]) + + def get(self, uid, client_id): + try: + return self.permdb[uid][client_id] + except KeyError: + return None + + +class Implicit(AuthzHandling): + def __init__(self, endpoint_context, permission="implicit"): + AuthzHandling.__init__(self, endpoint_context) + self.permission = permission + + def permissions(self, cookie=None, **kwargs): + return self.permission + + def get(self, uid, client_id): + return self.permission + + +def factory(msgtype, endpoint_context, **kwargs): + """ + Factory method that can be used to easily instantiate a class instance + + :param msgtype: The name of the class + :param kwargs: Keyword arguments + :return: An instance of the class or None if the name doesn't match any + known class. + """ + for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isclass(obj) and issubclass(obj, AuthzHandling): + try: + if obj.__name__ == msgtype: + return obj(endpoint_context, **kwargs) + except AttributeError: + pass diff --git a/src/oidcendpoint/common/authorization.py b/src/oidcendpoint/common/authorization.py index 2529b72..6564869 100755 --- a/src/oidcendpoint/common/authorization.py +++ b/src/oidcendpoint/common/authorization.py @@ -3,6 +3,9 @@ from urllib.parse import urlencode from urllib.parse import urlparse +from oidcmsg.time_util import time_sans_frac + +from oidcendpoint.session_management import unpack_db_key from oidcmsg.exception import ParameterError from oidcmsg.exception import URIError from oidcmsg.message import Message @@ -225,7 +228,8 @@ def create_authn_response(endpoint, request, sid): fragment_enc = False else: _context = endpoint.endpoint_context - _sinfo = _context.sdb[sid] + _mngr = endpoint.endpoint_context.session_manager + _session_info = _mngr[sid] if request.get("scope"): aresp["scope"] = request["scope"] @@ -234,27 +238,41 @@ def create_authn_response(endpoint, request, sid): handled_response_type = [] fragment_enc = True + if len(rtype) == 1 and "code" in rtype: fragment_enc = False - if "code" in request["response_type"]: - _code = aresp["code"] = _context.sdb[sid]["code"] - handled_response_type.append("code") - else: - _context.sdb.update(sid, code=None) - _code = None - - if "token" in rtype: - _dic = _context.sdb.upgrade_to_token(issue_refresh=False, key=sid) + grant = _mngr.grants(sid)[0] + user_id, client_id = unpack_db_key(sid) - logger.debug("_dic: %s" % sanitize(_dic)) - for key, val in _dic.items(): - if key in aresp.parameters() and val is not None: - aresp[key] = val - - handled_response_type.append("token") + if "code" in request["response_type"]: + _code = grant.mint_token( + 'authorization_code', + value=_mngr.token_handler["code"](user_id), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + aresp["code"] = _code.value + handled_response_type.append("code") + else: + _code = None + + if "token" in rtype: + _access_token = grant.mint_token( + "access_token", + value=_mngr.token_handler["access_token"]( + sid, + client_id=client_id, + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_session_info['sub'], + based_on=_code + ), + expires_at=time_sans_frac() + 900 + ) - _access_token = aresp.get("access_token", None) + aresp['token'] = _access_token + handled_response_type.append("token") not_handled = rtype.difference(handled_response_type) if not_handled: diff --git a/src/oidcendpoint/endpoint_context.py b/src/oidcendpoint/endpoint_context.py index 10b6430..591bffd 100755 --- a/src/oidcendpoint/endpoint_context.py +++ b/src/oidcendpoint/endpoint_context.py @@ -14,8 +14,7 @@ from oidcendpoint.scopes import STANDARD_CLAIMS from oidcendpoint.scopes import Claims from oidcendpoint.scopes import Scopes -from oidcendpoint.session import create_session_db -from oidcendpoint.sso_db import SSODb +from oidcendpoint.session_management import create_session_manager from oidcendpoint.template_handler import Jinja2TemplateHandler from oidcendpoint.user_authn.authn_context import populate_authn_broker from oidcendpoint.util import allow_refresh_token @@ -100,21 +99,15 @@ def __init__( self.conf = conf # For my Dev environment - self.sso_db = None - self.session_db = None - self.state_db = None self.cdb = None self.jti_db = None self.registration_access_token = None self.add_boxes( { - "state": "state_db", "client": "cdb", "jti": "jti_db", "registration_access_token": "registration_access_token", - "sso": "sso_db", - "session": "session_db", }, self.db_conf, ) @@ -279,10 +272,9 @@ def set_claims_handler(self): self.claims_handler = Claims() def set_session_db(self): - self.do_session_db(SSODb(db=self.sso_db), self.session_db) + self.do_session_manager() # append userinfo db to the session db self.do_userinfo() - logger.debug("Session DB: {}".format(self.sdb.__dict__)) def do_add_on(self): if self.conf.get("add_on"): @@ -311,9 +303,9 @@ def do_login_hint_lookup(self): def do_userinfo(self): _conf = self.conf.get("userinfo") if _conf: - if self.sdb: + if self.session_manager: self.userinfo = init_user_info(_conf, self.cwd) - self.sdb.userinfo = self.userinfo + self.session_manager.userinfo = self.userinfo else: logger.warning("Cannot init_user_info if no session_db was provided.") @@ -357,9 +349,9 @@ def do_sub_func(self): else: self._sub_func[key] = args["function"] - def do_session_db(self, sso_db, db=None): - self.sdb = create_session_db( - self, self.th_args, db=db, sso_db=sso_db, sub_func=self._sub_func + def do_session_manager(self, db=None): + self.session_manager = create_session_manager( + self, self.th_args, db=db, sub_func=self._sub_func ) def do_endpoints(self): diff --git a/src/oidcendpoint/grant.py b/src/oidcendpoint/grant.py index 9a20c83..8f85ef5 100644 --- a/src/oidcendpoint/grant.py +++ b/src/oidcendpoint/grant.py @@ -171,7 +171,7 @@ def __init__(self, Item.__init__(self, usage_rules=usage_rules, issued_at=issued_at, expires_at=expires_at) self.scope = scope or [] self.authorization_details = authorization_details or None - self.claims = claim or None + self.claims = claim or {} # default is to not release any user information self.resources = resources or [] self.issued_token = token or [] self.id = uuid1().hex diff --git a/src/oidcendpoint/id_token.py b/src/oidcendpoint/id_token.py index a9747bb..984a19d 100755 --- a/src/oidcendpoint/id_token.py +++ b/src/oidcendpoint/id_token.py @@ -1,11 +1,17 @@ import logging +import uuid from cryptojwt.jws.utils import left_hash from cryptojwt.jwt import JWT +from oidcendpoint.session_management import unpack_db_key +from oidcendpoint.session_management import SessionInfo + +from oidcendpoint import rndstr from oidcendpoint.endpoint import construct_endpoint_info -from oidcendpoint.userinfo import collect_user_info -from oidcendpoint.userinfo import userinfo_in_id_token_claims +from oidcendpoint.grant import Item +from oidcendpoint.session_management import db_key +from oidcendpoint.userinfo import ClaimsInterface logger = logging.getLogger(__name__) @@ -50,7 +56,7 @@ def include_session_id(endpoint_context, client_id, where): def get_sign_and_encrypt_algorithms( - endpoint_context, client_info, payload_type, sign=False, encrypt=False + endpoint_context, client_info, payload_type, sign=False, encrypt=False ): args = {"sign": sign, "encrypt": encrypt} if sign: @@ -112,23 +118,19 @@ class IDToken(object): def __init__(self, endpoint_context, **kwargs): self.endpoint_context = endpoint_context self.kwargs = kwargs - self.enable_claims_per_client = kwargs.get("enable_claims_per_client", False) self.scope_to_claims = None self.provider_info = construct_endpoint_info( self.default_capabilities, **kwargs ) + self.claims_interface = ClaimsInterface(endpoint_context, "id_token", **kwargs) def payload( - self, - session, - acr="", - alg="RS256", - code=None, - access_token=None, - user_info=None, - auth_time=0, - lifetime=None, - extra_claims=None, + self, + session_id, + alg="RS256", + code=None, + access_token=None, + extra_claims=None, ): """ @@ -144,15 +146,18 @@ def payload( :return: IDToken instance """ - _args = {"sub": session["sub"]} - - if lifetime is None: - lifetime = DEF_LIFETIME + _mngr = self.endpoint_context.session_manager + session_information = _mngr.get_session_info(session_id) + _args = {"sub": session_information["client_session_info"]["sub"]} + for claim, attr in {"authn_time": "auth_time", "acr": "acr"}.items(): + _val = session_information["user_session_info"]["authentication_event"].get(claim) + if _val: + _args[attr] = _val - if auth_time: - _args["auth_time"] = auth_time - if acr: - _args["acr"] = acr + grant = _mngr.grants(session_id)[0] + _claims_restriction = grant.claims.get(self.claims_interface.usage) + user_info = self.claims_interface.get_user_claims(user_id=session_information["user_id"], + claims_restriction=_claims_restriction) if user_info: try: @@ -179,35 +184,34 @@ def payload( if access_token: _args["at_hash"] = left_hash(access_token.encode("utf-8"), halg) - authn_req = session["authn_req"] + authn_req = session_information["client_session_info"]["authorization_request"] if authn_req: try: _args["nonce"] = authn_req["nonce"] except KeyError: pass - return {"payload": _args, "lifetime": lifetime} + return _args def sign_encrypt( - self, - session_info, - client_id, - code=None, - access_token=None, - user_info=None, - sign=True, - encrypt=False, - lifetime=None, - extra_claims=None, + self, + session_id, + client_id, + code=None, + access_token=None, + sign=True, + encrypt=False, + lifetime=None, + extra_claims=None, ): """ Signed and or encrypt a IDToken + :param lifetime: How long the ID Token should be valid :param session_info: Session information :param client_id: Client ID :param code: Access grant :param access_token: Access Token - :param user_info: User information :param sign: If the JWT should be signed :param encrypt: If the JWT should be encrypted :param extra_claims: Extra claims to be added to the ID Token @@ -221,69 +225,52 @@ def sign_encrypt( _cntx, client_info, "id_token", sign=sign, encrypt=encrypt ) - _authn_event = session_info["authn_event"] - - _idt_info = self.payload( - session_info, - acr=_authn_event["authn_info"], + _payload = self.payload( + session_id=session_id, alg=alg_dict["sign_alg"], code=code, access_token=access_token, - user_info=user_info, - auth_time=_authn_event["authn_time"], - lifetime=lifetime, - extra_claims=extra_claims, + extra_claims=extra_claims ) + if lifetime is None: + lifetime = DEF_LIFETIME + _jwt = JWT( - _cntx.keyjar, iss=_cntx.issuer, lifetime=_idt_info["lifetime"], **alg_dict + _cntx.keyjar, iss=_cntx.issuer, lifetime=lifetime, **alg_dict ) - return _jwt.pack(_idt_info["payload"], recv=client_id) + return _jwt.pack(_payload, recv=client_id) - def make(self, req, sess_info, authn_req=None, user_claims=False, **kwargs): + def make(self, session_id, **kwargs): _context = self.endpoint_context - if authn_req: - _client_id = authn_req["client_id"] + user_id, client_id, grant_id = unpack_db_key(session_id) + + # Should I add session ID. This is about Single Logout. + if include_session_id(_context, client_id, "back") or include_session_id( + _context, client_id, "front"): + + # Note that this session ID is not the session ID the session manager is using. + # It must be possible to map from one to the other. + logout_session_id = uuid.uuid4().get_hex() + _item = SessionInfo() + _item.set("user_id", user_id) + _item.set("client_id", client_id) + # Store the map + _mngr = self.endpoint_context.session_manager + _mngr.set([logout_session_id], _item) + # add identifier to extra arguments + xargs = {"sid": logout_session_id} else: - _client_id = req["client_id"] - - _cinfo = _context.cdb[_client_id] + xargs = {} - idtoken_claims = dict(self.kwargs.get("available_claims", {})) - if self.enable_claims_per_client: - idtoken_claims.update(_cinfo.get("id_token_claims", {})) lifetime = self.kwargs.get("lifetime") - userinfo = userinfo_in_id_token_claims(_context, sess_info, idtoken_claims) - - if user_claims: - info = collect_user_info(_context, sess_info) - if userinfo is None: - userinfo = info - else: - userinfo.update(info) - - # Should I add session ID - req_sid = include_session_id( - _context, _client_id, "back" - ) or include_session_id(_context, _client_id, "front") - - if req_sid: - xargs = { - "sid": _context.sdb.get_sid_by_sub_and_client_id( - sess_info["sub"], _client_id - ) - } - else: - xargs = {} - return self.sign_encrypt( - sess_info, - _client_id, + session_id, + client_id, sign=True, - user_info=userinfo, lifetime=lifetime, extra_claims=xargs, **kwargs diff --git a/src/oidcendpoint/jwt_token.py b/src/oidcendpoint/jwt_token.py index 4cc494c..25464b3 100644 --- a/src/oidcendpoint/jwt_token.py +++ b/src/oidcendpoint/jwt_token.py @@ -4,64 +4,11 @@ from cryptojwt.jws.exception import JWSException from oidcendpoint.exception import ToOld -from oidcendpoint.scopes import convert_scopes2claims -from oidcendpoint.token_handler import is_expired +from oidcendpoint.session_management import db_key from oidcendpoint.token_handler import Token from oidcendpoint.token_handler import UnknownToken - - -class ClaimsInterface: - init_args = { - "add_claims_by_scope": False, - "enable_claims_per_client": False - } - - def __init__(self, endpoint_context, **kwargs): - self.endpoint_context = endpoint_context - self.scope_claims_map = kwargs.get("scope_claims_map", endpoint_context.scope2claims) - self.add_claims_by_scope = kwargs.get("add_claims_by_scope", - self.init_args["add_claims_by_scope"]) - self.enable_claims_per_client = kwargs.get("enable_claims_per_client", - self.init_args["enable_claims_per_client"]) - - def _get_client_claims(self, client_id): - if self.enable_claims_per_client: - client_info = self.endpoint_context.cdb.get(client_id, {}) - return client_info.get("introspection_claims") - else: - return [] - - def _get_user_info(self, token_info): - user_id = self.endpoint_context.sdb.sso_db.get_uid_by_sid(token_info["sid"]) - return self.endpoint_context.userinfo(user_id, client_id=None) - - def add_claims(self, client_id, user_id, payload, scopes, claims_restriction): - if claims_restriction is None: - user_info = self.endpoint_context.userinfo(user_id, client_id=None) - payload.update(user_info) - elif claims_restriction == {}: # Nothing is allowed - pass - else: - possible_claims = self._get_client_claims(client_id) - if self.add_claims_by_scope: - _claims = convert_scopes2claims(scopes, map=self.scope_claims_map).keys() - possible_claims = list(set(possible_claims).union(_claims)) - - if possible_claims: - _claims = {c: None for c in - set(possible_claims).intersection(set(claims_restriction.key()))} - _claims.update(claims_restriction) - else: - _claims = claims_restriction - - if _claims: - user_info = self.endpoint_context.userinfo(user_id, client_id=None, - user_info_claims=_claims) - for attr in _claims: - try: - payload[attr] = user_info[attr] - except KeyError: - pass +from oidcendpoint.token_handler import is_expired +from oidcendpoint.userinfo import ClaimsInterface class JWTToken(Token): @@ -77,15 +24,15 @@ def __init__( token_type: str = "Bearer", add_claims: bool = False, **kwargs - ): + ): Token.__init__(self, typ, **kwargs) self.token_type = token_type self.lifetime = lifetime - self.claims_interface = ClaimsInterface(ec, **kwargs) + self.claims_interface = ClaimsInterface(ec, "jwt_token", **kwargs) self.args = { (k, v) for k, v in kwargs.items() if k not in self.claims_interface.init_args.keys() - } + } self.key_jar = keyjar or ec.keyjar self.issuer = issuer or ec.issuer @@ -100,7 +47,7 @@ def __call__( self, sid: str, **kwargs - ): + ): """ Return a token. @@ -116,8 +63,10 @@ def __call__( _claims = kwargs.get('claims') if self.add_claims: - self.claims_interface.add_claims(_client_id, _user_id, payload, claims=_claims, - scopes=_scopes) + grant = self.claims_interface.endpoint_context.session_manager.grants(sid)[0] + user_info = self.claims_interface.get_user_claims( + _user_id, grant.claims.get("id_token", {})) + payload.update(user_info) # payload.update(kwargs) signer = JWT( @@ -125,7 +74,7 @@ def __call__( iss=self.issuer, lifetime=self.lifetime, sign_alg=self.alg, - ) + ) _aud = kwargs.get('aud') if _aud is None: @@ -158,7 +107,7 @@ def info(self, token): "type": _payload["ttype"], "exp": _payload["exp"], "handler": self, - } + } return _res def is_expired(self, token, when=0): diff --git a/src/oidcendpoint/oauth2/authorization.py b/src/oidcendpoint/oauth2/authorization.py index c4f6c23..162ad03 100755 --- a/src/oidcendpoint/oauth2/authorization.py +++ b/src/oidcendpoint/oauth2/authorization.py @@ -7,6 +7,14 @@ from cryptojwt.utils import as_unicode from cryptojwt.utils import b64d from cryptojwt.utils import b64e +from oidcendpoint.session_management import Revoked + +from oidcendpoint.session_management import db_key + +from oidcendpoint.session_management import ClientSessionInfo +from oidcmsg.time_util import time_sans_frac + +from oidcendpoint.session_management import unpack_db_key from oidcmsg import oauth2 from oidcmsg.exception import ParameterError from oidcmsg.oidc import AuthorizationResponse @@ -33,7 +41,6 @@ from oidcendpoint.exception import ToOld from oidcendpoint.exception import UnAuthorizedClientScope from oidcendpoint.exception import UnknownClient -from oidcendpoint.session import setup_session from oidcendpoint.token_handler import UnknownToken from oidcendpoint.user_authn.authn_context import pick_auth @@ -292,9 +299,13 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): else: identity = json.loads(as_unicode(_id)) - session = self.endpoint_context.sdb[identity.get("sid")] - if not session or "revoked" in session: + try: + _csi = self.endpoint_context.session_manager[identity.get("sid")] + except Revoked: identity = None + else: + if _csi.is_active() is False: + identity = None authn_args = authn_args_gather(request, authn_class_ref, cinfo, **kwargs) @@ -316,16 +327,14 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): # demand re-authentication return {"function": authn, "args": authn_args} else: + _mngr = self.endpoint_context.session_manager # I get back a dictionary user = identity["uid"] if "req_user" in kwargs: - sids = self.endpoint_context.sdb.get_sids_by_sub(kwargs["req_user"]) + sids = _mngr.get_sids_by_user_id(kwargs["req_user"]) if ( - sids - and user - != self.endpoint_context.sdb.get_authentication_event( - sids[-1] - ).uid + sids + and user != _mngr.get_authentication_event(sids[-1]).uid ): logger.debug("Wanted to be someone else!") if "prompt" in request and "none" in request["prompt"]: @@ -365,8 +374,9 @@ def create_authn_response(self, request, sid): if "response_type" in request and request["response_type"] == ["none"]: fragment_enc = False else: + _mngr = self.endpoint_context.session_manager _context = self.endpoint_context - _sinfo = _context.sdb[sid] + _session_info = _mngr.get_session_info(sid) if request.get("scope"): aresp["scope"] = request["scope"] @@ -378,25 +388,37 @@ def create_authn_response(self, request, sid): if len(rtype) == 1 and "code" in rtype: fragment_enc = False + grant = _mngr[sid] + if "code" in request["response_type"]: - _code = aresp["code"] = _context.sdb[sid]["code"] + _code = grant.mint_token( + 'authorization_code', + value=_mngr.token_handler["code"](_session_info["user_id"]), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + aresp["code"] = _code.value handled_response_type.append("code") else: - _context.sdb.update(sid, code=None) _code = None if "token" in rtype: - _dic = _context.sdb.upgrade_to_token(issue_refresh=False, key=sid) - - logger.debug("_dic: %s" % sanitize(_dic)) - for key, val in _dic.items(): - if key in aresp.parameters() and val is not None: - aresp[key] = val + _access_token = grant.mint_token( + "access_token", + value=_mngr.token_handler["access_token"]( + sid, + client_id=_session_info["client_id"], + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_session_info["client_session_info"]['sub'], + based_on=_code + ), + expires_at=time_sans_frac() + 900 + ) + aresp['token'] = _access_token handled_response_type.append("token") - _access_token = aresp.get("access_token", None) - not_handled = rtype.difference(handled_response_type) if not_handled: resp = self.error_cls( @@ -446,13 +468,13 @@ def error_response(self, response_info, error, error_description): response_info["response_args"] = resp return response_info - def post_authentication(self, user, request, sid, **kwargs): + def post_authentication(self, user, request, pre_sid, **kwargs): """ Things that are done after a successful authentication. :param user: :param request: - :param sid: + :param pre_sid: :param kwargs: :return: A dictionary with 'response_args' """ @@ -461,8 +483,8 @@ def post_authentication(self, user, request, sid, **kwargs): # Do the authorization try: - permission = self.endpoint_context.authz( - user, client_id=request["client_id"] + grant = self.endpoint_context.authz( + user, client_id=request["client_id"], request=request ) except ToOld as err: return self.error_response( @@ -475,8 +497,10 @@ def post_authentication(self, user, request, sid, **kwargs): response_info, "access_denied", "{}".format(err.args) ) else: + session_id = db_key(user, request["client_id"], grant.id) try: - self.endpoint_context.sdb.update(sid, permission=permission) + self.endpoint_context.session_manager.set([user, request["client_id"], + grant.id], grant) except Exception as err: return self.error_response( response_info, "server_error", "{}".format(err.args) @@ -484,12 +508,7 @@ def post_authentication(self, user, request, sid, **kwargs): logger.debug("response type: %s" % request["response_type"]) - if self.endpoint_context.sdb.is_session_revoked(sid): - return self.error_response( - response_info, "access_denied", "Session is revoked" - ) - - response_info = self.create_authn_response(request, sid) + response_info = self.create_authn_response(request, session_id) try: redirect_uri = get_uri(self.endpoint_context, request, "redirect_uri") @@ -507,10 +526,8 @@ def post_authentication(self, user, request, sid, **kwargs): _cookie = new_cookie( self.endpoint_context, - sub=user, - sid=sid, + sid=session_id, state=request["state"], - client_id=request["client_id"], cookie_name=self.endpoint_context.cookie_name["session"], ) @@ -528,7 +545,22 @@ def post_authentication(self, user, request, sid, **kwargs): response_info["cookie"] = [_cookie] - return response_info + return response_info, session_id + + def setup_client_session(self, user_id: str, request: dict) -> str: + _mngr = self.endpoint_context.session_manager + client_id = request['client_id'] + + _client_info = self.endpoint_context.cdb[client_id] + sub_type = _client_info.get("subject_type") + + client_info = ClientSessionInfo( + authorization_request=request, + sub=_mngr.sub_func['public'](user_id, salt=_mngr.salt) + ) + + _mngr.set([user_id, client_id], client_info) + return db_key(user_id, client_id) def authz_part2(self, user, authn_event, request, **kwargs): """ @@ -540,22 +572,19 @@ def authz_part2(self, user, authn_event, request, **kwargs): :param kwargs: possible other parameters :return: A redirect to the redirect_uri of the client """ - sid = setup_session( - self.endpoint_context, request, user, authn_event=authn_event - ) + pre_sid = self.setup_client_session(user, request) try: - resp_info = self.post_authentication(user, request, sid, **kwargs) + resp_info, session_id = self.post_authentication(user, request, pre_sid, **kwargs) except Exception as err: return self.error_response({}, "server_error", err) if "check_session_iframe" in self.endpoint_context.provider_info: ec = self.endpoint_context salt = rndstr() - if not ec.sdb.is_session_revoked(sid): - authn_event = ec.sdb.get_authentication_event( - sid - ) # use the last session + grant = ec.session_manager[session_id] + if grant.is_active() is False: + authn_event = ec.session_manager.get_authentication_event(session_id) _state = b64e( as_bytes(json.dumps({"authn_time": authn_event["authn_time"]})) ) diff --git a/src/oidcendpoint/oauth2/introspection.py b/src/oidcendpoint/oauth2/introspection.py index 8e83a4c..124c9b0 100644 --- a/src/oidcendpoint/oauth2/introspection.py +++ b/src/oidcendpoint/oauth2/introspection.py @@ -2,10 +2,9 @@ import logging from oidcmsg import oauth2 -from oidcmsg.time_util import utc_time_sans_frac from oidcendpoint.endpoint import Endpoint -from oidcendpoint.token_handler import UnknownToken +from oidcendpoint.session_management import unpack_db_key LOGGER = logging.getLogger(__name__) @@ -54,18 +53,12 @@ def _add_claims(self, token_info, claims, payload): except KeyError: pass - def _introspect(self, token): - try: - info = self.endpoint_context.sdb[token] - except (KeyError, UnknownToken): - return None - + def _introspect(self, token, grant): # Make sure that the token is an access_token or a refresh_token - if token != info.get("access_token") and token != info.get("refresh_token"): + if token.type not in ["access_token", "refresh_token"]: return None - eat = info.get("expires_at") - if eat and eat < utc_time_sans_frac(): + if not token.is_active(): return None if info: # Now what can be returned ? @@ -88,10 +81,14 @@ def process_request(self, request=None, **kwargs): if "error" in _introspect_request: return _introspect_request - _token = _introspect_request["token"] + request_token = _introspect_request["token"] + session_id = self.endpoint_context.session_manager.token_handler.sid(request_token) + grant, token = self.endpoint_context.session_manager.find_grant(session_id, + request_token) + _resp = self.response_cls(active=False) - _info = self._introspect(_token) + _info = self._introspect(token) if _info is None: return {"response_args": _resp} diff --git a/src/oidcendpoint/oauth2/old_introspection.py b/src/oidcendpoint/oauth2/old_introspection.py new file mode 100644 index 0000000..8e83a4c --- /dev/null +++ b/src/oidcendpoint/oauth2/old_introspection.py @@ -0,0 +1,117 @@ +"""Implements RFC7662""" +import logging + +from oidcmsg import oauth2 +from oidcmsg.time_util import utc_time_sans_frac + +from oidcendpoint.endpoint import Endpoint +from oidcendpoint.token_handler import UnknownToken + +LOGGER = logging.getLogger(__name__) + + +class Introspection(Endpoint): + """Implements RFC 7662""" + + request_cls = oauth2.TokenIntrospectionRequest + response_cls = oauth2.TokenIntrospectionResponse + request_format = "urlencoded" + response_format = "json" + endpoint_name = "introspection_endpoint" + name = "introspection" + + def __init__(self, **kwargs): + Endpoint.__init__(self, **kwargs) + self.offset = kwargs.get("offset", 0) + self.enable_claims_per_client = kwargs.get("enable_claims_per_client", False) + + def get_client_id_from_token(self, endpoint_context, token, request=None): + """ + Will try to match tokens against information in the session DB. + + :param endpoint_context: + :param token: + :param request: + :return: client_id if there was a match + """ + sinfo = endpoint_context.sdb[token] + return sinfo["authn_req"]["client_id"] + + def _get_client_claims(self, token): + client_id = self.get_client_id_from_token(self.endpoint_context, token) + client = self.endpoint_context.cdb.get(client_id, {}) + return client.get("introspection_claims") + + def _get_user_info(self, token_info): + user_id = self.endpoint_context.sdb.sso_db.get_uid_by_sid(token_info["sid"]) + return self.endpoint_context.userinfo(user_id, client_id=None) + + def _add_claims(self, token_info, claims, payload): + user_info = self._get_user_info(token_info) + for attr in claims: + try: + payload[attr] = user_info[attr] + except KeyError: + pass + + def _introspect(self, token): + try: + info = self.endpoint_context.sdb[token] + except (KeyError, UnknownToken): + return None + + # Make sure that the token is an access_token or a refresh_token + if token != info.get("access_token") and token != info.get("refresh_token"): + return None + + eat = info.get("expires_at") + if eat and eat < utc_time_sans_frac(): + return None + + if info: # Now what can be returned ? + ret = info.to_dict() + ret["iss"] = self.endpoint_context.issuer + + if "scope" not in ret: + ret["scope"] = " ".join(info["authn_req"]["scope"]) + + return ret + + def process_request(self, request=None, **kwargs): + """ + + :param request: The authorization request as a dictionary + :param kwargs: + :return: + """ + _introspect_request = self.request_cls(**request) + if "error" in _introspect_request: + return _introspect_request + + _token = _introspect_request["token"] + _resp = self.response_cls(active=False) + + _info = self._introspect(_token) + if _info is None: + return {"response_args": _resp} + + if "release" in self.kwargs: + if "username" in self.kwargs["release"]: + try: + _info["username"] = self.endpoint_context.userinfo.search( + sub=_info["sub"] + ) + except KeyError: + pass + + _resp.update(_info) + _resp.weed() + + if self.enable_claims_per_client: + client_claims = self._get_client_claims(_token) + if client_claims: + self._add_claims(_info, client_claims, _resp) + + _resp["active"] = True + + return {"response_args": _resp} diff --git a/src/oidcendpoint/oidc/add_on/pkce.py b/src/oidcendpoint/oidc/add_on/pkce.py index 07bc9f7..f53b7e4 100644 --- a/src/oidcendpoint/oidc/add_on/pkce.py +++ b/src/oidcendpoint/oidc/add_on/pkce.py @@ -90,12 +90,13 @@ def post_token_parse(request, client_id, endpoint_context, **kwargs): return request try: - _info = endpoint_context.sdb[request["code"]] + _session_info = endpoint_context.session_manager.get_session_info_by_token(request["code"]) except KeyError: return TokenErrorResponse( error="invalid_grant", error_description="Unknown access grant" ) - _authn_req = _info["authn_req"] + + _authn_req = _session_info["client_session_info"]["authorization_request"] if "code_challenge" in _authn_req: if "code_verifier" not in request: @@ -104,11 +105,11 @@ def post_token_parse(request, client_id, endpoint_context, **kwargs): error_description="Missing code_verifier", ) - _method = _info["authn_req"]["code_challenge_method"] + _method = _authn_req["code_challenge_method"] if not verify_code_challenge( request["code_verifier"], - _info["authn_req"]["code_challenge"], + _authn_req["code_challenge"], _method, ): return TokenErrorResponse( diff --git a/src/oidcendpoint/oidc/authorization.py b/src/oidcendpoint/oidc/authorization.py index 050c56e..9f1a5c5 100755 --- a/src/oidcendpoint/oidc/authorization.py +++ b/src/oidcendpoint/oidc/authorization.py @@ -1,5 +1,6 @@ import json import logging +from urllib.parse import urlsplit from cryptojwt import BadSyntax from cryptojwt.jwe.exception import JWEException @@ -13,12 +14,12 @@ from oidcmsg.exception import ParameterError from oidcmsg.oidc import Claims from oidcmsg.oidc import verified_claim_name +from oidcmsg.time_util import time_sans_frac from oidcendpoint import rndstr -from oidcendpoint import sanitize from oidcendpoint.authn_event import create_authn_event -from oidcendpoint.common.authorization import FORM_POST from oidcendpoint.common.authorization import AllowedAlgorithms +from oidcendpoint.common.authorization import FORM_POST from oidcendpoint.common.authorization import authn_args_gather from oidcendpoint.common.authorization import get_uri from oidcendpoint.common.authorization import inputs @@ -35,7 +36,10 @@ from oidcendpoint.exception import ToOld from oidcendpoint.exception import UnknownClient from oidcendpoint.oauth2.authorization import check_unknown_scopes_policy -from oidcendpoint.session import setup_session +from oidcendpoint.session_management import ClientSessionInfo +from oidcendpoint.session_management import UserSessionInfo +from oidcendpoint.session_management import db_key +from oidcendpoint.session_management import unpack_db_key from oidcendpoint.token_handler import UnknownToken from oidcendpoint.user_authn.authn_context import pick_auth @@ -68,6 +72,11 @@ def acr_claims(request): return acrdef["values"] +def host_component(url): + res = urlsplit(url) + return "{}://{}".format(res.scheme, res.netloc) + + ALG_PARAMS = { "sign": [ "request_object_signing_alg", @@ -290,7 +299,6 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): authn = res["method"] authn_class_ref = res["acr"] - session = None try: _auth_info = kwargs.get("authn", "") @@ -322,15 +330,8 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): else: identity = json.loads(as_unicode(_id)) - try: - session = self.endpoint_context.sdb[identity.get("sid")] - except UnknownToken: - identity= None - else: - if not session or "revoked" in session: - identity = None - authn_args = authn_args_gather(request, authn_class_ref, cinfo, **kwargs) + _mngr = self.endpoint_context.session_manager # To authenticate or Not if identity is None: # No! @@ -357,13 +358,10 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): # I get back a dictionary user = identity["uid"] if "req_user" in kwargs: - sids = self.endpoint_context.sdb.get_sids_by_sub(kwargs["req_user"]) + sids = _mngr.get_sids_by_user_id(kwargs["req_user"]) if ( sids - and user - != self.endpoint_context.sdb.get_authentication_event( - sids[-1] - ).uid + and user != _mngr.get_authentication_event(sids[-1]).uid ): logger.debug("Wanted to be someone else!") if "prompt" in request and "none" in request["prompt"]: @@ -375,17 +373,16 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): else: return {"function": authn, "args": authn_args} - authn_event = None - if session: - authn_event = session.get('authn_event') + authn_event = _mngr.get_authentication_event(user) if authn_event is None: authn_event = create_authn_event( identity["uid"], - identity.get("salt", ""), + _mngr.salt, authn_info=authn_class_ref, time_stamp=_ts, ) + _mngr.set([identity["uid"]], UserSessionInfo(authentication_event=authn_event)) _exp_in = authn.kwargs.get("expires_in") if _exp_in and "valid_until" in authn_event: @@ -413,7 +410,8 @@ def create_authn_response(self, request, sid): fragment_enc = False else: _context = self.endpoint_context - _sinfo = _context.sdb[sid] + _mngr = self.endpoint_context.session_manager + _sinfo = _mngr[sid] if request.get("scope"): aresp["scope"] = request["scope"] @@ -425,24 +423,39 @@ def create_authn_response(self, request, sid): if len(rtype) == 1 and "code" in rtype: fragment_enc = False + grant = _mngr.grants(sid)[0] + user_id, client_id = unpack_db_key(sid) + if "code" in request["response_type"]: - _code = aresp["code"] = _context.sdb[sid]["code"] + _code = grant.mint_token( + 'authorization_code', + value=_mngr.token_handler["code"](user_id), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + aresp["code"] = _code.value handled_response_type.append("code") else: - _context.sdb.update(sid, code=None) _code = None if "token" in rtype: - _dic = _context.sdb.upgrade_to_token(issue_refresh=False, key=sid) - - logger.debug("_dic: %s" % sanitize(_dic)) - for key, val in _dic.items(): - if key in aresp.parameters() and val is not None: - aresp[key] = val + _access_token = grant.mint_token( + "access_token", + value=_mngr.token_handler["access_token"]( + sid, + client_id=client_id, + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_sinfo['sub'], + based_on=_code + ), + expires_at=time_sans_frac() + 900 + ) + aresp['token'] = _access_token handled_response_type.append("token") - - _access_token = aresp.get("access_token", None) + else: + _access_token = None if "id_token" in request["response_type"]: kwargs = {} @@ -453,11 +466,11 @@ def create_authn_response(self, request, sid): elif {"id_token", "token"}.issubset(rtype): kwargs = {"access_token": _access_token} - if request["response_type"] == ["id_token"]: - kwargs["user_claims"] = True + # if request["response_type"] == ["id_token"]: + # kwargs["user_claims"] = True try: - id_token = _context.idtoken.make(request, _sinfo, **kwargs) + id_token = _context.idtoken.make(user_id=user_id, client_id=client_id, **kwargs) except (JWEException, NoSuitableSigningKeys) as err: logger.warning(str(err)) resp = self.error_cls( @@ -467,7 +480,7 @@ def create_authn_response(self, request, sid): return {"response_args": resp, "fragment_enc": fragment_enc} aresp["id_token"] = id_token - _sinfo["id_token"] = id_token + _mngr.update([user_id, client_id], {"id_token": id_token}) handled_response_type.append("id_token") not_handled = rtype.difference(handled_response_type) @@ -521,24 +534,24 @@ def error_response(self, response_info, error, error_description): response_info["response_args"] = resp return response_info - def post_authentication(self, user, request, sid, **kwargs): + def post_authentication(self, user, request, pre_sid, **kwargs): """ Things that are done after a successful authentication. :param user: - :param request: + :param request: The authorization request :param sid: :param kwargs: :return: A dictionary with 'response_args' """ response_info = {} + _mngr = self.endpoint_context.session_manager + user_id, client_id = unpack_db_key(pre_sid) # Do the authorization try: - permission = self.endpoint_context.authz( - user, client_id=request["client_id"] - ) + grant = self.endpoint_context.authz(user_id, client_id, request=request) except ToOld as err: return self.error_response( response_info, @@ -551,20 +564,17 @@ def post_authentication(self, user, request, sid, **kwargs): ) else: try: - self.endpoint_context.sdb.update(sid, permission=permission) + _mngr.set([user_id, client_id, grant.id], grant) except Exception as err: return self.error_response( response_info, "server_error", "{}".format(err.args) ) + else: + session_id = db_key(user_id, client_id, grant.id) logger.debug("response type: %s" % request["response_type"]) - if self.endpoint_context.sdb.is_session_revoked(sid): - return self.error_response( - response_info, "access_denied", "Session is revoked" - ) - - response_info = self.create_authn_response(request, sid) + response_info = self.create_authn_response(request, session_id) logger.debug("Known clients: {}".format(list(self.endpoint_context.cdb.keys()))) @@ -585,7 +595,6 @@ def post_authentication(self, user, request, sid, **kwargs): _cookie = new_cookie( self.endpoint_context, uid=user, - sid=sid, state=request["state"], client_id=request["client_id"], cookie_name=self.endpoint_context.cookie_name["session"], @@ -605,9 +614,34 @@ def post_authentication(self, user, request, sid, **kwargs): response_info["cookie"] = [_cookie] - return response_info + return response_info, session_id + + def setup_client_session(self, user_id: str, request: dict) -> str: + _mngr = self.endpoint_context.session_manager + client_id = request['client_id'] + + _client_info = self.endpoint_context.cdb[client_id] + sub_type = _client_info.get("subject_type") + if sub_type and sub_type == "pairwise": + sector_identifier_uri = _client_info.get("sector_identifier_uri") + if sector_identifier_uri is None: + sector_identifier_uri = host_component(_client_info["redirect_uris"][0]) + + client_info = ClientSessionInfo( + authorization_request=request, + sub=_mngr.sub_func[sub_type](user_id, salt=_mngr.salt, + sector_identifier=sector_identifier_uri) + ) + else: + client_info = ClientSessionInfo( + authorization_request=request, + sub=_mngr.sub_func['public'](user_id, salt=_mngr.salt) + ) + + _mngr.set([user_id, client_id], client_info) + return db_key(user_id, client_id) - def authz_part2(self, user, authn_event, request, **kwargs): + def authz_part2(self, user, request, **kwargs): """ After the authentication this is where you should end up @@ -617,63 +651,67 @@ def authz_part2(self, user, authn_event, request, **kwargs): :param kwargs: possible other parameters :return: A redirect to the redirect_uri of the client """ - sid = setup_session( - self.endpoint_context, request, user, authn_event=authn_event - ) + + pre_sid = self.setup_client_session(user, request) try: - resp_info = self.post_authentication(user, request, sid, **kwargs) + resp_info, session_id = self.post_authentication(user, request, pre_sid, **kwargs) except Exception as err: return self.error_response({}, "server_error", err) if "check_session_iframe" in self.endpoint_context.provider_info: ec = self.endpoint_context salt = rndstr() - if not ec.sdb.is_session_revoked(sid): - authn_event = ec.sdb.get_authentication_event( - sid - ) # use the last session - _state = b64e( - as_bytes(json.dumps({"authn_time": authn_event["authn_time"]})) - ) + try: + authn_event = ec.session_manager.get_authentication_event(session_id) + except KeyError: + return self.error_response({}, "server_error", "No such session") + else: + if authn_event.is_active() is False: + return self.error_response({}, "server_error", "Authentication has timed out") - opbs_value = '' - if hasattr(ec.cookie_dealer, 'create_cookie'): - session_cookie = ec.cookie_dealer.create_cookie( - as_unicode(_state), - typ="session", - cookie_name=ec.cookie_name["session_management"], - same_site="None", - http_only=False, - ) + _state = b64e( + as_bytes(json.dumps({"authn_time": authn_event["authn_time"]})) + ) - opbs = session_cookie[ec.cookie_name["session_management"]] - opbs_value = opbs.value - else: - logger.debug("Failed to set Cookie, that's not configured in main configuration.") + opbs_value = '' + if hasattr(ec.cookie_dealer, 'create_cookie'): + session_cookie = ec.cookie_dealer.create_cookie( + as_unicode(_state), + typ="session", + cookie_name=ec.cookie_name["session_management"], + same_site="None", + http_only=False, + ) + opbs = session_cookie[ec.cookie_name["session_management"]] + opbs_value = opbs.value + else: logger.debug( - "compute_session_state: client_id=%s, origin=%s, opbs=%s, salt=%s", - request["client_id"], - resp_info["return_uri"], - opbs_value, - salt, - ) + "Failed to set Cookie, that's not configured in main configuration.") - _session_state = compute_session_state( - opbs_value, salt, request["client_id"], resp_info["return_uri"] - ) + logger.debug( + "compute_session_state: client_id=%s, origin=%s, opbs=%s, salt=%s", + request["client_id"], + resp_info["return_uri"], + opbs_value, + salt, + ) - if opbs_value: - if "cookie" in resp_info: - if isinstance(resp_info["cookie"], list): - resp_info["cookie"].append(session_cookie) - else: - append_cookie(resp_info["cookie"], session_cookie) + _session_state = compute_session_state( + opbs_value, salt, request["client_id"], resp_info["return_uri"] + ) + + if opbs_value: + if "cookie" in resp_info: + if isinstance(resp_info["cookie"], list): + resp_info["cookie"].append(session_cookie) else: - resp_info["cookie"] = session_cookie + append_cookie(resp_info["cookie"], session_cookie) + else: + resp_info["cookie"] = session_cookie - resp_info["response_args"]["session_state"] = _session_state + resp_info["response_args"]["session_state"] = _session_state # Mix-Up mitigation resp_info["response_args"]["iss"] = self.endpoint_context.issuer @@ -724,10 +762,7 @@ def process_request(self, request_info=None, **kwargs): if not _function: logger.debug("- authenticated -") logger.debug("AREQ keys: %s" % request_info.keys()) - res = self.authz_part2( - info["user"], info["authn_event"], request_info, cookie=cookie - ) - return res + return self.authz_part2(user=info["user"], request=request_info, cookie=cookie) try: # Run the authentication function diff --git a/src/oidcendpoint/oidc/old_authorization.py b/src/oidcendpoint/oidc/old_authorization.py new file mode 100755 index 0000000..050c56e --- /dev/null +++ b/src/oidcendpoint/oidc/old_authorization.py @@ -0,0 +1,740 @@ +import json +import logging + +from cryptojwt import BadSyntax +from cryptojwt.jwe.exception import JWEException +from cryptojwt.jws.exception import NoSuitableSigningKeys +from cryptojwt.jwt import utc_time_sans_frac +from cryptojwt.utils import as_bytes +from cryptojwt.utils import as_unicode +from cryptojwt.utils import b64d +from cryptojwt.utils import b64e +from oidcmsg import oidc +from oidcmsg.exception import ParameterError +from oidcmsg.oidc import Claims +from oidcmsg.oidc import verified_claim_name + +from oidcendpoint import rndstr +from oidcendpoint import sanitize +from oidcendpoint.authn_event import create_authn_event +from oidcendpoint.common.authorization import FORM_POST +from oidcendpoint.common.authorization import AllowedAlgorithms +from oidcendpoint.common.authorization import authn_args_gather +from oidcendpoint.common.authorization import get_uri +from oidcendpoint.common.authorization import inputs +from oidcendpoint.common.authorization import max_age +from oidcendpoint.cookie import append_cookie +from oidcendpoint.cookie import compute_session_state +from oidcendpoint.cookie import new_cookie +from oidcendpoint.endpoint import Endpoint +from oidcendpoint.exception import InvalidRequest +from oidcendpoint.exception import NoSuchAuthentication +from oidcendpoint.exception import RedirectURIError +from oidcendpoint.exception import ServiceError +from oidcendpoint.exception import TamperAllert +from oidcendpoint.exception import ToOld +from oidcendpoint.exception import UnknownClient +from oidcendpoint.oauth2.authorization import check_unknown_scopes_policy +from oidcendpoint.session import setup_session +from oidcendpoint.token_handler import UnknownToken +from oidcendpoint.user_authn.authn_context import pick_auth + +logger = logging.getLogger(__name__) + + +def proposed_user(request): + cn = verified_claim_name("it_token_hint") + if request.get(cn): + return request[cn].get("sub", "") + return "" + + +def acr_claims(request): + acrdef = None + + _claims = request.get("claims") + if isinstance(_claims, str): + _claims = Claims().from_json(_claims) + + if _claims: + _id_token_claim = _claims.get("id_token") + if _id_token_claim: + acrdef = _id_token_claim.get("acr") + + if isinstance(acrdef, dict): + if acrdef.get("value"): + return [acrdef["value"]] + elif acrdef.get("values"): + return acrdef["values"] + + +ALG_PARAMS = { + "sign": [ + "request_object_signing_alg", + "request_object_signing_alg_values_supported", + ], + "enc_alg": [ + "request_object_encryption_alg", + "request_object_encryption_alg_values_supported", + ], + "enc_enc": [ + "request_object_encryption_enc", + "request_object_encryption_enc_values_supported", + ], +} + + +def re_authenticate(request, authn): + if "prompt" in request and "login" in request["prompt"]: + if authn.done(request): + return True + + return False + + +class Authorization(Endpoint): + request_cls = oidc.AuthorizationRequest + response_cls = oidc.AuthorizationResponse + error_cls = oidc.AuthorizationErrorResponse + request_format = "urlencoded" + response_format = "urlencoded" + response_placement = "url" + endpoint_name = "authorization_endpoint" + name = "authorization" + default_capabilities = { + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, + "response_types_supported": [ + "code", + "token", + "id_token", + "code token", + "code id_token", + "id_token token", + "code id_token token", + ], + "response_modes_supported": ["query", "fragment", "form_post"], + "request_object_signing_alg_values_supported": None, + "request_object_encryption_alg_values_supported": None, + "request_object_encryption_enc_values_supported": None, + "grant_types_supported": ["authorization_code", "implicit"], + "claim_types_supported": ["normal", "aggregated", "distributed"], + } + + def __init__(self, endpoint_context, **kwargs): + Endpoint.__init__(self, endpoint_context, **kwargs) + # self.pre_construct.append(self._pre_construct) + self.post_parse_request.append(self._do_request_uri) + self.post_parse_request.append(self._post_parse_request) + self.allowed_request_algorithms = AllowedAlgorithms(ALG_PARAMS) + + def filter_request(self, endpoint_context, req): + return req + + def verify_response_type(self, request, cinfo): + # Checking response types + _registered = [set(rt.split(" ")) for rt in cinfo.get("response_types", [])] + if not _registered: + # If no response_type is registered by the client then we'll + # code which it the default according to the OIDC spec. + _registered = [{"code"}] + + # Is the asked for response_type among those that are permitted + return set(request["response_type"]) in _registered + + def _do_request_uri(self, request, client_id, endpoint_context, **kwargs): + _request_uri = request.get("request_uri") + if _request_uri: + # Do I do pushed authorization requests ? + if "pushed_authorization" in endpoint_context.endpoint: + # Is it a UUID urn + if _request_uri.startswith("urn:uuid:"): + _req = endpoint_context.par_db.get(_request_uri) + if _req: + del endpoint_context.par_db[_request_uri] # One time + # usage + return _req + else: + raise ValueError("Got a request_uri I can not resolve") + + # Do I support request_uri ? + _supported = endpoint_context.provider_info.get( + "request_uri_parameter_supported", True + ) + _registered = endpoint_context.cdb[client_id].get("request_uris") + # Not registered should be handled else where + if _registered: + # Before matching remove a possible fragment + _p = _request_uri.split("#") + # ignore registered fragments for now. + if _p[0] not in [l[0] for l in _registered]: + raise ValueError("A request_uri outside the registered") + + # Fetch the request + _resp = endpoint_context.httpc.get( + _request_uri, **endpoint_context.httpc_params + ) + if _resp.status_code == 200: + args = {"keyjar": endpoint_context.keyjar, "issuer": client_id} + _ver_request = self.request_cls().from_jwt(_resp.text, **args) + self.allowed_request_algorithms( + client_id, + endpoint_context, + _ver_request.jws_header.get("alg", "RS256"), + "sign", + ) + if _ver_request.jwe_header is not None: + self.allowed_request_algorithms( + client_id, + endpoint_context, + _ver_request.jws_header.get("alg"), + "enc_alg", + ) + self.allowed_request_algorithms( + client_id, + endpoint_context, + _ver_request.jws_header.get("enc"), + "enc_enc", + ) + # The protected info overwrites the non-protected + for k, v in _ver_request.items(): + request[k] = v + + request[verified_claim_name("request")] = _ver_request + else: + raise ServiceError("Got a %s response", _resp.status) + + return request + + def _post_parse_request(self, request, client_id, endpoint_context, **kwargs): + """ + Verify the authorization request. + + :param endpoint_context: + :param request: + :param client_id: + :param kwargs: + :return: + """ + if not request: + logger.debug("No AuthzRequest") + return self.error_cls( + error="invalid_request", error_description="Can not parse AuthzRequest" + ) + + request = self.filter_request(endpoint_context, request) + + _cinfo = endpoint_context.cdb.get(client_id) + if not _cinfo: + logger.error( + "Client ID ({}) not in client database".format(request["client_id"]) + ) + return self.error_cls( + error="unauthorized_client", error_description="unknown client" + ) + + # Is the asked for response_type among those that are permitted + if not self.verify_response_type(request, _cinfo): + return self.error_cls( + error="invalid_request", + error_description="Trying to use unregistered response_type", + ) + + # Get a verified redirect URI + try: + redirect_uri = get_uri(endpoint_context, request, "redirect_uri") + except (RedirectURIError, ParameterError, UnknownClient) as err: + return self.error_cls( + error="invalid_request", + error_description="{}:{}".format(err.__class__.__name__, err), + ) + else: + request["redirect_uri"] = redirect_uri + + return request + + def pick_authn_method(self, request, redirect_uri, acr=None, **kwargs): + auth_id = kwargs.get("auth_method_id") + if auth_id: + return self.endpoint_context.authn_broker[auth_id] + + if acr: + res = self.endpoint_context.authn_broker.pick(acr) + else: + res = pick_auth(self.endpoint_context, request) + + if res: + return res + else: + return { + "error": "access_denied", + "error_description": "ACR I do not support", + "return_uri": redirect_uri, + "return_type": request["response_type"], + } + + def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): + """ + + :param request: The authorization/authentication request + :param redirect_uri: + :param cinfo: client info + :param cookie: + :param acr: Default ACR, if nothing else is specified + :param kwargs: + :return: + """ + + res = self.pick_authn_method(request, redirect_uri, acr, **kwargs) + + authn = res["method"] + authn_class_ref = res["acr"] + session = None + + try: + _auth_info = kwargs.get("authn", "") + if "upm_answer" in request and request["upm_answer"] == "true": + _max_age = 0 + else: + _max_age = max_age(request) + + identity, _ts = authn.authenticated_as( + cookie, authorization=_auth_info, max_age=_max_age + ) + except (NoSuchAuthentication, TamperAllert): + identity = None + _ts = 0 + except ToOld: + logger.info("Too old authentication") + identity = None + _ts = 0 + except UnknownToken: + logger.info("Unknown Token") + identity = None + _ts = 0 + else: + if identity: + try: # If identity['uid'] is in fact a base64 encoded JSON string + _id = b64d(as_bytes(identity["uid"])) + except BadSyntax: + pass + else: + identity = json.loads(as_unicode(_id)) + + try: + session = self.endpoint_context.sdb[identity.get("sid")] + except UnknownToken: + identity= None + else: + if not session or "revoked" in session: + identity = None + + authn_args = authn_args_gather(request, authn_class_ref, cinfo, **kwargs) + + # To authenticate or Not + if identity is None: # No! + logger.info("No active authentication") + logger.debug( + "Known clients: {}".format(list(self.endpoint_context.cdb.keys())) + ) + + if "prompt" in request and "none" in request["prompt"]: + # Need to authenticate but not allowed + return { + "error": "login_required", + "return_uri": redirect_uri, + "return_type": request["response_type"], + } + else: + return {"function": authn, "args": authn_args} + else: + logger.info("Active authentication") + if re_authenticate(request, authn): + # demand re-authentication + return {"function": authn, "args": authn_args} + else: + # I get back a dictionary + user = identity["uid"] + if "req_user" in kwargs: + sids = self.endpoint_context.sdb.get_sids_by_sub(kwargs["req_user"]) + if ( + sids + and user + != self.endpoint_context.sdb.get_authentication_event( + sids[-1] + ).uid + ): + logger.debug("Wanted to be someone else!") + if "prompt" in request and "none" in request["prompt"]: + # Need to authenticate but not allowed + return { + "error": "login_required", + "return_uri": redirect_uri, + } + else: + return {"function": authn, "args": authn_args} + + authn_event = None + if session: + authn_event = session.get('authn_event') + + if authn_event is None: + authn_event = create_authn_event( + identity["uid"], + identity.get("salt", ""), + authn_info=authn_class_ref, + time_stamp=_ts, + ) + + _exp_in = authn.kwargs.get("expires_in") + if _exp_in and "valid_until" in authn_event: + authn_event["valid_until"] = utc_time_sans_frac() + _exp_in + + return {"authn_event": authn_event, "identity": identity, "user": user} + + def extra_response_args(self, aresp): + return aresp + + def create_authn_response(self, request, sid): + """ + + :param self: + :param request: + :param sid: + :return: + """ + # create the response + aresp = self.response_cls() + if request.get("state"): + aresp["state"] = request["state"] + + if "response_type" in request and request["response_type"] == ["none"]: + fragment_enc = False + else: + _context = self.endpoint_context + _sinfo = _context.sdb[sid] + + if request.get("scope"): + aresp["scope"] = request["scope"] + + rtype = set(request["response_type"][:]) + handled_response_type = [] + + fragment_enc = True + if len(rtype) == 1 and "code" in rtype: + fragment_enc = False + + if "code" in request["response_type"]: + _code = aresp["code"] = _context.sdb[sid]["code"] + handled_response_type.append("code") + else: + _context.sdb.update(sid, code=None) + _code = None + + if "token" in rtype: + _dic = _context.sdb.upgrade_to_token(issue_refresh=False, key=sid) + + logger.debug("_dic: %s" % sanitize(_dic)) + for key, val in _dic.items(): + if key in aresp.parameters() and val is not None: + aresp[key] = val + + handled_response_type.append("token") + + _access_token = aresp.get("access_token", None) + + if "id_token" in request["response_type"]: + kwargs = {} + if {"code", "id_token", "token"}.issubset(rtype): + kwargs = {"code": _code, "access_token": _access_token} + elif {"code", "id_token"}.issubset(rtype): + kwargs = {"code": _code} + elif {"id_token", "token"}.issubset(rtype): + kwargs = {"access_token": _access_token} + + if request["response_type"] == ["id_token"]: + kwargs["user_claims"] = True + + try: + id_token = _context.idtoken.make(request, _sinfo, **kwargs) + except (JWEException, NoSuitableSigningKeys) as err: + logger.warning(str(err)) + resp = self.error_cls( + error="invalid_request", + error_description="Could not sign/encrypt id_token", + ) + return {"response_args": resp, "fragment_enc": fragment_enc} + + aresp["id_token"] = id_token + _sinfo["id_token"] = id_token + handled_response_type.append("id_token") + + not_handled = rtype.difference(handled_response_type) + if not_handled: + resp = self.error_cls( + error="invalid_request", error_description="unsupported_response_type" + ) + return {"response_args": resp, "fragment_enc": fragment_enc} + + aresp = self.extra_response_args(aresp) + + return {"response_args": aresp, "fragment_enc": fragment_enc} + + def aresp_check(self, aresp, request): + return "" + + def response_mode(self, request, **kwargs): + resp_mode = request["response_mode"] + if resp_mode == "form_post": + msg = FORM_POST.format( + inputs=inputs(kwargs["response_args"].to_dict()), + action=kwargs["return_uri"], + ) + kwargs.update( + { + "response_msg": msg, + "content_type": "text/html", + "response_placement": "body", + } + ) + elif resp_mode == "fragment": + if "fragment_enc" in kwargs: + if not kwargs["fragment_enc"]: + # Can't be done + raise InvalidRequest("wrong response_mode") + else: + kwargs["fragment_enc"] = True + elif resp_mode == "query": + if "fragment_enc" in kwargs: + if kwargs["fragment_enc"]: + # Can't be done + raise InvalidRequest("wrong response_mode") + else: + raise InvalidRequest("Unknown response_mode") + return kwargs + + def error_response(self, response_info, error, error_description): + resp = self.error_cls( + error=error, error_description=str(error_description) + ) + response_info["response_args"] = resp + return response_info + + def post_authentication(self, user, request, sid, **kwargs): + """ + Things that are done after a successful authentication. + + :param user: + :param request: + :param sid: + :param kwargs: + :return: A dictionary with 'response_args' + """ + + response_info = {} + + # Do the authorization + try: + permission = self.endpoint_context.authz( + user, client_id=request["client_id"] + ) + except ToOld as err: + return self.error_response( + response_info, + "access_denied", + "Authentication to old {}".format(err.args), + ) + except Exception as err: + return self.error_response( + response_info, "access_denied", "{}".format(err.args) + ) + else: + try: + self.endpoint_context.sdb.update(sid, permission=permission) + except Exception as err: + return self.error_response( + response_info, "server_error", "{}".format(err.args) + ) + + logger.debug("response type: %s" % request["response_type"]) + + if self.endpoint_context.sdb.is_session_revoked(sid): + return self.error_response( + response_info, "access_denied", "Session is revoked" + ) + + response_info = self.create_authn_response(request, sid) + + logger.debug("Known clients: {}".format(list(self.endpoint_context.cdb.keys()))) + + try: + redirect_uri = get_uri(self.endpoint_context, request, "redirect_uri") + except (RedirectURIError, ParameterError) as err: + return self.error_response( + response_info, "invalid_request", "{}".format(err.args) + ) + else: + response_info["return_uri"] = redirect_uri + + # Must not use HTTP unless implicit grant type and native application + # info = self.aresp_check(response_info['response_args'], request) + # if isinstance(info, ResponseMessage): + # return info + + _cookie = new_cookie( + self.endpoint_context, + uid=user, + sid=sid, + state=request["state"], + client_id=request["client_id"], + cookie_name=self.endpoint_context.cookie_name["session"], + ) + + # Now about the response_mode. Should not be set if it's obvious + # from the response_type. Knows about 'query', 'fragment' and + # 'form_post'. + + if "response_mode" in request: + try: + response_info = self.response_mode(request, **response_info) + except InvalidRequest as err: + return self.error_response( + response_info, "invalid_request", "{}".format(err.args) + ) + + response_info["cookie"] = [_cookie] + + return response_info + + def authz_part2(self, user, authn_event, request, **kwargs): + """ + After the authentication this is where you should end up + + :param user: + :param request: The Authorization Request + :param sid: Session key + :param kwargs: possible other parameters + :return: A redirect to the redirect_uri of the client + """ + sid = setup_session( + self.endpoint_context, request, user, authn_event=authn_event + ) + + try: + resp_info = self.post_authentication(user, request, sid, **kwargs) + except Exception as err: + return self.error_response({}, "server_error", err) + + if "check_session_iframe" in self.endpoint_context.provider_info: + ec = self.endpoint_context + salt = rndstr() + if not ec.sdb.is_session_revoked(sid): + authn_event = ec.sdb.get_authentication_event( + sid + ) # use the last session + _state = b64e( + as_bytes(json.dumps({"authn_time": authn_event["authn_time"]})) + ) + + opbs_value = '' + if hasattr(ec.cookie_dealer, 'create_cookie'): + session_cookie = ec.cookie_dealer.create_cookie( + as_unicode(_state), + typ="session", + cookie_name=ec.cookie_name["session_management"], + same_site="None", + http_only=False, + ) + + opbs = session_cookie[ec.cookie_name["session_management"]] + opbs_value = opbs.value + else: + logger.debug("Failed to set Cookie, that's not configured in main configuration.") + + logger.debug( + "compute_session_state: client_id=%s, origin=%s, opbs=%s, salt=%s", + request["client_id"], + resp_info["return_uri"], + opbs_value, + salt, + ) + + _session_state = compute_session_state( + opbs_value, salt, request["client_id"], resp_info["return_uri"] + ) + + if opbs_value: + if "cookie" in resp_info: + if isinstance(resp_info["cookie"], list): + resp_info["cookie"].append(session_cookie) + else: + append_cookie(resp_info["cookie"], session_cookie) + else: + resp_info["cookie"] = session_cookie + + resp_info["response_args"]["session_state"] = _session_state + + # Mix-Up mitigation + resp_info["response_args"]["iss"] = self.endpoint_context.issuer + resp_info["response_args"]["client_id"] = request["client_id"] + + return resp_info + + def process_request(self, request_info=None, **kwargs): + """ The AuthorizationRequest endpoint + + :param request_info: The authorization request as a Message instance + :return: dictionary + """ + + if isinstance(request_info, self.error_cls): + return request_info + + _cid = request_info["client_id"] + cinfo = self.endpoint_context.cdb[_cid] + logger.debug("client {}: {}".format(_cid, cinfo)) + + # this apply the default optionally deny_unknown_scopes policy + if cinfo: + check_unknown_scopes_policy(request_info, cinfo, self.endpoint_context) + + cookie = kwargs.get("cookie", "") + if cookie: + del kwargs["cookie"] + + if proposed_user(request_info): + kwargs["req_user"] = proposed_user(request_info) + else: + if request_info.get("login_hint"): + _login_hint = request_info["login_hint"] + if self.endpoint_context.login_hint_lookup: + kwargs["req_user"] = self.endpoint_context.login_hint_lookup[ + _login_hint + ] + + info = self.setup_auth( + request_info, request_info["redirect_uri"], cinfo, cookie, **kwargs + ) + + if "error" in info: + return info + + _function = info.get("function") + if not _function: + logger.debug("- authenticated -") + logger.debug("AREQ keys: %s" % request_info.keys()) + res = self.authz_part2( + info["user"], info["authn_event"], request_info, cookie=cookie + ) + return res + + try: + # Run the authentication function + return { + "http_response": _function(**info["args"]), + "return_uri": request_info["redirect_uri"], + } + except Exception as err: + logger.exception(err) + return {"http_response": "Internal error: {}".format(err)} diff --git a/src/oidcendpoint/oidc/refresh_token.py b/src/oidcendpoint/oidc/refresh_token.py index a2ffea1..759a120 100755 --- a/src/oidcendpoint/oidc/refresh_token.py +++ b/src/oidcendpoint/oidc/refresh_token.py @@ -2,16 +2,14 @@ from oidcmsg import oidc from oidcmsg.oauth2 import ResponseMessage -from oidcmsg.oidc import AccessTokenResponse from oidcmsg.oidc import RefreshAccessTokenRequest from oidcmsg.oidc import TokenErrorResponse +from oidcmsg.time_util import time_sans_frac from oidcendpoint import sanitize from oidcendpoint.client_authn import verify_client from oidcendpoint.cookie import new_cookie from oidcendpoint.endpoint import Endpoint -from oidcendpoint.token_handler import ExpiredToken -from oidcendpoint.userinfo import by_schema logger = logging.getLogger(__name__) @@ -32,9 +30,7 @@ def __init__(self, endpoint_context, **kwargs): self.post_parse_request.append(self._post_parse_request) def _refresh_access_token(self, req, **kwargs): - _sdb = self.endpoint_context.sdb - - # client_id = str(req["client_id"]) + _mngr = self.endpoint_context.session_manager if req["grant_type"] != "refresh_token": return self.error_cls( @@ -42,14 +38,29 @@ def _refresh_access_token(self, req, **kwargs): ) rtoken = req["refresh_token"] - try: - _info = _sdb.refresh_token(rtoken) - except ExpiredToken: + _session_info = _mngr.get_session_info_by_token(rtoken) + grant, token = _mngr.find_grant(_session_info["session_id"], rtoken) + if token.is_active is False: return self.error_cls( - error="invalid_request", error_description="Refresh token is expired" + error="invalid_request", error_description="Refresh token inactive" ) - return by_schema(AccessTokenResponse, **_info) + access_token = grant.mint_token( + 'access_token', + value=_mngr.token_handler["access_token"]( + _session_info["session_id"], + client_id=_session_info["client_id"], + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_session_info["client_session_info"]['sub'] + ), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=rtoken # Means the token (tok) was used to mint this token + ) + + return {"access_token": access_token, "token_type": "Bearer", + "expires_in": 900, "scope": grant.scope} def client_authentication(self, request, auth=None, **kwargs): """ @@ -114,10 +125,12 @@ def process_request(self, request=None, **kwargs): if isinstance(response_args, ResponseMessage): return response_args + _token = request["refresh_token"].replace(" ", "+") + _session_info = self.endpoint_context.session_manager.get_session_info_by_token(_token) _cookie = new_cookie( self.endpoint_context, - sub=self.endpoint_context.sdb[_token]["sub"], + sub=_session_info["client_session_info"]["sub"], cookie_name=self.endpoint_context.cookie_name["session"], ) _headers = [("Content-type", "application/json")] diff --git a/src/oidcendpoint/oidc/session.py b/src/oidcendpoint/oidc/session.py index e49a648..16b7c8a 100644 --- a/src/oidcendpoint/oidc/session.py +++ b/src/oidcendpoint/oidc/session.py @@ -11,6 +11,9 @@ from cryptojwt.jws.utils import alg2keytype from cryptojwt.jwt import JWT from cryptojwt.utils import as_bytes +from oidcendpoint.session_management import db_key + +from oidcendpoint.session_management import unpack_db_key from oidcmsg.exception import InvalidRequest from oidcmsg.message import Message from oidcmsg.oauth2 import ResponseMessage @@ -120,53 +123,40 @@ def do_back_channel_logout(self, cinfo, sub, sid): def clean_sessions(self, usids): # Clean out all sessions - _sdb = self.endpoint_context.sdb - _sso_db = self.endpoint_context.sdb.sso_db + _mngr = self.endpoint_context.session_manager + for sid in usids: - _state = _sdb[sid]["authn_req"]["state"] - # remove session information - del _sdb[sid] - # remove all states connected to this session id - _sdb.delete_kv2sid(_state, "state") - _sso_db.remove_session_id(sid) + _mngr.revoke_session(sid) def logout_all_clients(self, sid, client_id): - _sdb = self.endpoint_context.sdb - _sso_db = self.endpoint_context.sdb.sso_db - + _mngr = self.endpoint_context.session_manager + _user_id, _client_id = unpack_db_key(sid) # Find all RPs this user has logged it from - uid = _sso_db.get_uid_by_sid(sid) - if uid is None: - logger.debug("Can not translate sid:%s into a user id", sid) - return {} - - _client_sid = {} - usids = _sso_db.get_sids_by_uid(uid) - if usids is None: - logger.debug("No sessions found for uid: %s", uid) - return {} - - for usid in usids: - _client_sid[_sdb[usid]["authn_req"]["client_id"]] = usid + _user_session_info = _mngr.get([_user_id]) # Front-/Backchannel logout ? _cdb = self.endpoint_context.cdb _iss = self.endpoint_context.issuer bc_logouts = {} fc_iframes = {} - for _cid, _csid in _client_sid.items(): - if "backchannel_logout_uri" in _cdb[_cid]: - _sub = _sso_db.get_sub_by_sid(_csid) - _spec = self.do_back_channel_logout(_cdb[_cid], _sub, _csid) + sids = [] + for _client_id in _user_session_info["subordinate"]: + if "backchannel_logout_uri" in _cdb[_client_id]: + _sid = db_key(_user_id, _client_id) + _sub = _mngr.get([_user_id, _client_id])["sub"] + sids.append(_sid) + _spec = self.do_back_channel_logout(_cdb[_client_id], _sub, _sid) if _spec: - bc_logouts[_cid] = _spec - elif "frontchannel_logout_uri" in _cdb[_cid]: + bc_logouts[_client_id] = _spec + elif "frontchannel_logout_uri" in _cdb[_client_id]: # Construct an IFrame - _spec = do_front_channel_logout_iframe(_cdb[_cid], _iss, _csid) + _sid = db_key(_user_id, _client_id) + sids.append(_sid) + _spec = do_front_channel_logout_iframe(_cdb[_client_id], _iss, _sid) if _spec: - fc_iframes[_cid] = _spec + fc_iframes[_client_id] = _spec - self.clean_sessions(usids) + self.clean_sessions(sids) res = {} if bc_logouts: @@ -191,15 +181,14 @@ def unpack_signed_jwt(self, sjwt, sig_alg=""): def logout_from_client(self, sid, client_id): _cdb = self.endpoint_context.cdb - _sso_db = self.endpoint_context.sdb.sso_db + _mngr = self.endpoint_context.session_manager # Kill the session - _sdb = self.endpoint_context.sdb - _sdb.revoke_session(sid=sid) + _mngr.revoke_session(sid) res = {} if "backchannel_logout_uri" in _cdb[client_id]: - _sub = _sso_db.get_sub_by_sid(sid) + _sub = _mngr[sid]["sub"] _spec = self.do_back_channel_logout(_cdb[client_id], _sub, sid) if _spec: res["blu"] = {client_id: _spec} @@ -224,7 +213,7 @@ def process_request(self, request=None, cookie=None, **kwargs): :return: """ _cntx = self.endpoint_context - _sdb = _cntx.sdb + _mngr = _cntx.session_manager if "post_logout_redirect_uri" in request: if "id_token_hint" not in request: @@ -246,6 +235,7 @@ def process_request(self, request=None, cookie=None, **kwargs): _cookie_info = json.loads(as_unicode(b64d(as_bytes(part[0])))) logger.debug("Cookie info: {}".format(_cookie_info)) _sid = _cookie_info["sid"] + _user_id, _client_id = unpack_db_key(_sid) else: logger.debug("No relevant cookie") _sid = "" @@ -283,12 +273,15 @@ def process_request(self, request=None, cookie=None, **kwargs): else: auds = [] + if not _sid: + raise KeyError("Unknown session") + try: - session = _sdb[_sid] + session = _mngr[_sid] except KeyError: raise ValueError("Can't find any corresponding session") - client_id = session["authn_req"]["client_id"] + client_id = session["authorization_request"]["client_id"] # Does this match what's in the cookie ? if _cookie_info: if client_id != _cookie_info["client_id"]: @@ -321,7 +314,7 @@ def process_request(self, request=None, cookie=None, **kwargs): payload = { "sid": _sid, "client_id": client_id, - "user": session["authn_event"]["uid"], + "user": _user_id } # redirect user to OP logout verification page diff --git a/src/oidcendpoint/oidc/token.py b/src/oidcendpoint/oidc/token.py index 74eb7f9..a53dad6 100755 --- a/src/oidcendpoint/oidc/token.py +++ b/src/oidcendpoint/oidc/token.py @@ -4,15 +4,13 @@ from cryptojwt.jws.exception import NoSuitableSigningKeys from oidcmsg import oidc from oidcmsg.oauth2 import ResponseMessage -from oidcmsg.oidc import AccessTokenResponse from oidcmsg.oidc import TokenErrorResponse +from oidcmsg.time_util import time_sans_frac from oidcendpoint import sanitize from oidcendpoint.cookie import new_cookie from oidcendpoint.endpoint import Endpoint -from oidcendpoint.exception import MultipleCodeUsage -from oidcendpoint.token_handler import AccessCodeUsed -from oidcendpoint.userinfo import by_schema +from oidcendpoint.session_management import unpack_db_key logger = logging.getLogger(__name__) @@ -39,7 +37,7 @@ def __init__(self, endpoint_context, **kwargs): def _access_token(self, req, **kwargs): _context = self.endpoint_context - _sdb = _context.sdb + _mngr = _context.session_manager _log_debug = logger.debug if req["grant_type"] != "authorization_code": @@ -54,22 +52,17 @@ def _access_token(self, req, **kwargs): error="invalid_request", error_description="Missing code" ) - # Session might not exist or _access_code malformed - try: - _info = _sdb[_access_code] - except KeyError: - return self.error_cls( - error="invalid_grant", error_description="Code is invalid" - ) - - _authn_req = _info["authn_req"] + _session_info = _mngr.get_session_info_by_token(_access_code) + grant, code = _mngr.find_grant(_session_info["session_id"], _access_code) # assert that the code is valid - if _context.sdb.is_session_revoked(_access_code): + if code.is_active() is False: return self.error_cls( - error="invalid_grant", error_description="Session is revoked" + error="invalid_grant", error_description="Code is invalid" ) + _authn_req = _session_info["client_session_info"]["authorization_request"] + # If redirect_uri was in the initial authorization request # verify that the one given here is the correct one. if "redirect_uri" in _authn_req: @@ -83,26 +76,53 @@ def _access_token(self, req, **kwargs): issue_refresh = False if "issue_refresh" in kwargs: issue_refresh = kwargs["issue_refresh"] + else: + if "offline_access" in grant.scope: + issue_refresh = True + + token = grant.mint_token( + "access_token", + value=_mngr.token_handler["access_token"]( + _session_info["session_id"], + client_id=_session_info["client_id"], + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_session_info["client_session_info"]['sub'] + ), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=code + ) - # offline_access the default if nothing is specified - permissions = _info.get("permission", ["offline_access"]) - - if "offline_access" in _authn_req["scope"] and "offline_access" in permissions: - issue_refresh = True - - try: - _info = _sdb.upgrade_to_token(_access_code, issue_refresh=issue_refresh) - except AccessCodeUsed as err: - logger.error("%s" % err) - # Should revoke the token issued to this access code - _sdb.revoke_all_tokens(_access_code) - return self.error_cls( - error="access_denied", error_description="Access Code already used" + _response = { + "access_token": token.value, + "token_type": "Bearer", + "expires_in": 900, + "scope": grant.scope, + "state": _authn_req["state"] + } + + if issue_refresh: + refresh_token = grant.mint_token( + "refresh_token", + value=_mngr.token_handler["refresh_token"]( + _session_info["session_id"], + client_id=_session_info["client_id"], + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_session_info["client_session_info"]['sub'] + ), + based_on=code ) + _response["refresh_token"] = refresh_token.value + + code.register_usage() if "openid" in _authn_req["scope"]: try: - _idtoken = _context.idtoken.make(req, _info, _authn_req) + _idtoken = _context.idtoken.make(_session_info["user_id"], + _session_info["client_id"]) except (JWEException, NoSuitableSigningKeys) as err: logger.warning(str(err)) resp = self.error_cls( @@ -111,10 +131,9 @@ def _access_token(self, req, **kwargs): ) return resp - _sdb.update_by_token(_access_code, id_token=_idtoken) - _info = _sdb[_info["sid"]] + _response["id_token"] = _idtoken - return by_schema(AccessTokenResponse, **_info) + return _response def get_client_id_from_token(self, endpoint_context, token, request=None): sinfo = endpoint_context.sdb[token] @@ -129,26 +148,34 @@ def _post_parse_request(self, request, client_id="", **kwargs): :returns: """ - if "state" in request: - try: - sinfo = self.endpoint_context.sdb[request["code"]] - except KeyError: - logger.error("Code not present in SessionDB") - return self.error_cls(error="access_denied") - except MultipleCodeUsage: - logger.error("Access Code reused") - # Remove any access tokens issued - self.endpoint_context.sdb.revoke_all_tokens(request["code"]) - return self.error_cls(error="invalid_grant") - else: - state = sinfo["authn_req"]["state"] + _mngr = self.endpoint_context.session_manager + try: + _session_info = _mngr.get_session_info_by_token(request["code"]) + except KeyError: + logger.error("Access Code invalid") + return self.error_cls(error="invalid_grant") + + grant, code = _mngr.find_grant(_session_info["session_id"], request["code"]) + _auth_req = _session_info["client_session_info"]["authorization_request"] + if code.is_active(): + state = _auth_req["state"] + else: + logger.error("Access Code inactive") + # Remove any access tokens issued + if code.max_usage_reached(): + _mngr.revoke_token(_session_info["session_id"], + code.value, recursive=True) + return self.error_cls(error="invalid_grant") + if "state" in request: + # verify that state in this request is the same as the one in the + # authorization request if state != request["state"]: logger.error("State value mismatch") return self.error_cls(error="invalid_request") if "client_id" not in request: # Optional for access token request - request["client_id"] = client_id + request["client_id"] = _auth_req["client_id"] logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) @@ -171,10 +198,13 @@ def process_request(self, request=None, **kwargs): if isinstance(response_args, ResponseMessage): return response_args - _access_token = response_args["access_token"] + _mngr = self.endpoint_context.session_manager + _tinfo = _mngr.token_handler.info(request["code"]) + _cs_info = _mngr[_tinfo["sid"]] + _cookie = new_cookie( self.endpoint_context, - sub=self.endpoint_context.sdb[_access_token]["sub"], + sub=_cs_info["sub"], cookie_name=self.endpoint_context.cookie_name["session"], ) _headers = [("Content-type", "application/json")] diff --git a/src/oidcendpoint/oidc/userinfo.py b/src/oidcendpoint/oidc/userinfo.py index a63cdc5..6b23495 100755 --- a/src/oidcendpoint/oidc/userinfo.py +++ b/src/oidcendpoint/oidc/userinfo.py @@ -9,8 +9,10 @@ from oidcmsg.oauth2 import ResponseMessage from oidcendpoint.endpoint import Endpoint +from oidcendpoint.session_management import unpack_db_key from oidcendpoint.token_handler import UnknownToken -from oidcendpoint.userinfo import collect_user_info +# from oidcendpoint.userinfo import collect_user_info +from oidcendpoint.userinfo import ClaimsInterface from oidcendpoint.util import OAUTH2_NOCACHE_HEADERS logger = logging.getLogger(__name__) @@ -34,13 +36,14 @@ class UserInfo(Endpoint): def __init__(self, endpoint_context, **kwargs): Endpoint.__init__(self, endpoint_context, **kwargs) - self.scope_to_claims = None # Add the issuer ID as an allowed JWT target self.allowed_targets.append("") + self.claims_interface = ClaimsInterface(endpoint_context, "userinfo", **kwargs) def get_client_id_from_token(self, endpoint_context, token, request=None): - sinfo = self.endpoint_context.sdb[token] - return sinfo["authn_req"]["client_id"] + _info = endpoint_context.session_manager.token_handler.info(token) + sinfo = self.endpoint_context.session_manager[_info["sid"]] + return sinfo["authorization_request"]["client_id"] def do_response(self, response_args=None, request=None, client_id="", **kwargs): @@ -96,23 +99,24 @@ def do_response(self, response_args=None, request=None, client_id="", **kwargs): return {"response": resp, "http_headers": http_headers} def process_request(self, request=None, **kwargs): - _sdb = self.endpoint_context.sdb - + _mngr = self.endpoint_context.session_manager + _info = _mngr.token_handler.info(request["access_token"]) + grant, token = _mngr.find_grant(_info['sid'], request["access_token"]) # should be an access token - if not _sdb.is_token_valid(request["access_token"]): + if token.is_active() is False: return self.error_cls( error="invalid_token", error_description="Invalid Token" ) - - session = _sdb.read(request["access_token"]) - + _user_id, _client_id = unpack_db_key(_info['sid']) + _cs_info = _mngr.get([_user_id, _client_id]) + _us_info = _mngr.get([_user_id]) allowed = True # if the authenticate is still active or offline_access is granted. - if session["authn_event"]["valid_until"] > utc_time_sans_frac(): + if _us_info["authentication_event"]["valid_until"] > utc_time_sans_frac(): pass else: logger.debug("authentication not valid: {} > {}".format( - session["authn_event"]["valid_until"], utc_time_sans_frac() + _us_info["authentication_event"]["valid_until"], utc_time_sans_frac() )) allowed = False @@ -122,14 +126,16 @@ def process_request(self, request=None, **kwargs): if allowed: # Scope can translate to userinfo_claims - info = collect_user_info(self.endpoint_context, session) + _restrictions = grant.claims.get("userinfo") + info = self.claims_interface.get_user_claims( + user_id=_user_id, claims_restriction=_restrictions) else: info = { "error": "invalid_request", "error_description": "Access not granted", } - return {"response_args": info, "client_id": session["authn_req"]["client_id"]} + return {"response_args": info, "client_id": _client_id} def parse_request(self, request, auth=None, **kwargs): """ diff --git a/src/oidcendpoint/old_id_token.py b/src/oidcendpoint/old_id_token.py new file mode 100755 index 0000000..293a5fe --- /dev/null +++ b/src/oidcendpoint/old_id_token.py @@ -0,0 +1,288 @@ +import logging + +from cryptojwt.jws.utils import left_hash +from cryptojwt.jwt import JWT + +from oidcendpoint.endpoint import construct_endpoint_info + +logger = logging.getLogger(__name__) + +DEF_SIGN_ALG = { + "id_token": "RS256", + "userinfo": "RS256", + "request_object": "RS256", + "client_secret_jwt": "HS256", + "private_key_jwt": "RS256", +} +DEF_LIFETIME = 300 + + +def include_session_id(endpoint_context, client_id, where): + """ + + :param endpoint_context: + :param client_id: + :param dir: front or back + :return: + """ + _pinfo = endpoint_context.provider_info + + # Am the OP supposed to support {dir}-channel log out and if so can + # it pass sid in logout token and ID Token + for param in ["{}channel_logout_supported", "{}channel_logout_session_supported"]: + try: + _supported = _pinfo[param.format(where)] + except KeyError: + return False + else: + if not _supported: + return False + + # Does the client support back-channel logout ? + try: + _val = endpoint_context.cdb[client_id]["{}channel_logout_uri".format(where)] + except KeyError: + return False + + return True + + +def get_sign_and_encrypt_algorithms( + endpoint_context, client_info, payload_type, sign=False, encrypt=False +): + args = {"sign": sign, "encrypt": encrypt} + if sign: + try: + args["sign_alg"] = client_info[ + "{}_signed_response_alg".format(payload_type) + ] + except KeyError: # Fall back to default + try: + args["sign_alg"] = endpoint_context.jwx_def["signing_alg"][payload_type] + except KeyError: + _def_sign_alg = DEF_SIGN_ALG[payload_type] + _supported = endpoint_context.provider_info[ + "{}_signing_alg_values_supported".format(payload_type) + ] + + if _def_sign_alg in _supported: + args["sign_alg"] = _def_sign_alg + else: + args["sign_alg"] = _supported[0] + + if encrypt: + try: + args["enc_alg"] = client_info["%s_encrypted_response_alg" % payload_type] + except KeyError: + try: + args["enc_alg"] = endpoint_context.jwx_def["encryption_alg"][ + payload_type + ] + except KeyError: + _supported = endpoint_context.provider_info[ + "{}_encryption_alg_values_supported".format(payload_type) + ] + args["enc_alg"] = _supported[0] + + try: + args["enc_enc"] = client_info["%s_encrypted_response_enc" % payload_type] + except KeyError: + try: + args["enc_enc"] = endpoint_context.jwx_def["encryption_enc"][ + payload_type + ] + except KeyError: + _supported = endpoint_context.provider_info[ + "{}_encryption_enc_values_supported".format(payload_type) + ] + args["enc_enc"] = _supported[0] + + return args + + +class IDToken(object): + default_capabilities = { + "id_token_signing_alg_values_supported": None, + "id_token_encryption_alg_values_supported": None, + "id_token_encryption_enc_values_supported": None, + } + + def __init__(self, endpoint_context, **kwargs): + self.endpoint_context = endpoint_context + self.kwargs = kwargs + self.enable_claims_per_client = kwargs.get("enable_claims_per_client", False) + self.scope_to_claims = None + self.provider_info = construct_endpoint_info( + self.default_capabilities, **kwargs + ) + + def payload( + self, + session, + acr="", + alg="RS256", + code=None, + access_token=None, + user_info=None, + auth_time=0, + lifetime=None, + extra_claims=None, + ): + """ + + :param session: Session information + :param acr: Default Assurance/Authentication context class reference + :param alg: Which signing algorithm to use for the IdToken + :param code: Access grant + :param access_token: Access Token + :param user_info: If user info are to be part of the IdToken + :param auth_time: + :param lifetime: Life time of the ID Token + :param extra_claims: extra claims to be added to the ID Token + :return: IDToken instance + """ + + _args = {"sub": session["sub"]} + + if lifetime is None: + lifetime = DEF_LIFETIME + + if auth_time: + _args["auth_time"] = auth_time + if acr: + _args["acr"] = acr + + if user_info: + try: + user_info = user_info.to_dict() + except AttributeError: + pass + + # Make sure that there are no name clashes + for key in ["iss", "sub", "aud", "exp", "acr", "nonce", "auth_time"]: + try: + del user_info[key] + except KeyError: + pass + + _args.update(user_info) + + if extra_claims is not None: + _args.update(extra_claims) + + # Left hashes of code and/or access_token + halg = "HS%s" % alg[-3:] + if code: + _args["c_hash"] = left_hash(code.encode("utf-8"), halg) + if access_token: + _args["at_hash"] = left_hash(access_token.encode("utf-8"), halg) + + authn_req = session["authn_req"] + if authn_req: + try: + _args["nonce"] = authn_req["nonce"] + except KeyError: + pass + + return {"payload": _args, "lifetime": lifetime} + + def sign_encrypt( + self, + session_info, + client_id, + code=None, + access_token=None, + user_info=None, + sign=True, + encrypt=False, + lifetime=None, + extra_claims=None, + ): + """ + Signed and or encrypt a IDToken + + :param session_info: Session information + :param client_id: Client ID + :param code: Access grant + :param access_token: Access Token + :param user_info: User information + :param sign: If the JWT should be signed + :param encrypt: If the JWT should be encrypted + :param extra_claims: Extra claims to be added to the ID Token + :return: IDToken as a signed and/or encrypted JWT + """ + + _cntx = self.endpoint_context + + client_info = _cntx.cdb[client_id] + alg_dict = get_sign_and_encrypt_algorithms( + _cntx, client_info, "id_token", sign=sign, encrypt=encrypt + ) + + _authn_event = session_info["authn_event"] + + _idt_info = self.payload( + session_info, + acr=_authn_event["authn_info"], + alg=alg_dict["sign_alg"], + code=code, + access_token=access_token, + user_info=user_info, + auth_time=_authn_event["authn_time"], + lifetime=lifetime, + extra_claims=extra_claims, + ) + + _jwt = JWT( + _cntx.keyjar, iss=_cntx.issuer, lifetime=_idt_info["lifetime"], **alg_dict + ) + + return _jwt.pack(_idt_info["payload"], recv=client_id) + + def make(self, req, sess_info, authn_req=None, user_claims=False, **kwargs): + _context = self.endpoint_context + + if authn_req: + _client_id = authn_req["client_id"] + else: + _client_id = req["client_id"] + + _cinfo = _context.cdb[_client_id] + + idtoken_claims = dict(self.kwargs.get("available_claims", {})) + if self.enable_claims_per_client: + idtoken_claims.update(_cinfo.get("id_token_claims", {})) + lifetime = self.kwargs.get("lifetime") + + userinfo = userinfo_in_id_token_claims(_context, sess_info, idtoken_claims) + + if user_claims: + info = collect_user_info(_context, sess_info) + if userinfo is None: + userinfo = info + else: + userinfo.update(info) + + # Should I add session ID + req_sid = include_session_id( + _context, _client_id, "back" + ) or include_session_id(_context, _client_id, "front") + + if req_sid: + xargs = { + "sid": _context.sdb.get_sid_by_sub_and_client_id( + sess_info["sub"], _client_id + ) + } + else: + xargs = {} + + return self.sign_encrypt( + sess_info, + _client_id, + sign=True, + user_info=userinfo, + lifetime=lifetime, + extra_claims=xargs, + **kwargs + ) diff --git a/src/oidcendpoint/session.py b/src/oidcendpoint/old_session.py similarity index 100% rename from src/oidcendpoint/session.py rename to src/oidcendpoint/old_session.py diff --git a/src/oidcendpoint/scopes.py b/src/oidcendpoint/scopes.py index b3413ae..e23dfdf 100644 --- a/src/oidcendpoint/scopes.py +++ b/src/oidcendpoint/scopes.py @@ -45,19 +45,23 @@ def available_claims(endpoint_context): return STANDARD_CLAIMS -def convert_scopes2claims(scopes, allowed_claims, map=None): +def convert_scopes2claims(scopes, allowed_claims=None, map=None): if map is None: map = SCOPE2CLAIMS res = {} - for scope in scopes: - try: - claims = dict( - [(name, None) for name in map[scope] if name in allowed_claims] - ) + if allowed_claims is None: + for scope in scopes: + claims = {name: None for name in map[scope]} res.update(claims) - except KeyError: - continue + else: + for scope in scopes: + try: + claims = {name: None for name in map[scope] if name in allowed_claims} + res.update(claims) + except KeyError: + continue + return res diff --git a/src/oidcendpoint/session_management.py b/src/oidcendpoint/session_management.py index cc65f2f..67ef2c1 100644 --- a/src/oidcendpoint/session_management.py +++ b/src/oidcendpoint/session_management.py @@ -2,6 +2,8 @@ import logging from oidcendpoint import rndstr +from oidcendpoint import token_handler +from oidcendpoint.token_handler import UnknownToken logger = logging.getLogger(__name__) @@ -77,7 +79,10 @@ def is_revoked(self): class UserSessionInfo(SessionInfo): - pass + def __init__(self, **kwargs): + SessionInfo.__init__(self, **kwargs) + if "logout_sid" not in self._db: + self._db["logout_sid"] = {} class ClientSessionInfo(SessionInfo): @@ -181,6 +186,9 @@ def get(self, path: list): user_info = self._db[uid] except KeyError: raise KeyError('No such UserID') + else: + if user_info is None: + raise KeyError('No such UserID') if client_id is None: return user_info @@ -243,10 +251,9 @@ def update(self, path, new_info): class SessionManager(Database): - def __init__(self, db, handler, userinfo=None, sub_func=None): + def __init__(self, db, handler, sub_func=None): Database.__init__(self, db) self.token_handler = handler - self.userinfo = userinfo self.salt = rndstr(32) # this allows the subject identifier minters to be defined by someone @@ -263,14 +270,18 @@ def __init__(self, db, handler, userinfo=None, sub_func=None): def get_user_info(self, uid): return self.get(uid) - def find_grant(self, session_id, token_value): - user_id, client_id = unpack_db_key(session_id) - client_info = self.get([user_id, client_id]) - for grant_id in client_info["subordinate"]: - grant = self.get([user_id, client_id, grant_id]) - for token in grant.issued_token: - if token.value == token_value: - return grant, token + def find_token(self, session_id, token_value): + """ + + :param session_id: Based on 3-tuple, user_id, client_id and grant_id + :param token_value: + :return: + """ + user_id, client_id, grant_id = unpack_db_key(session_id) + grant = self.get([user_id, client_id, grant_id]) + for token in grant.issued_token: + if token.value == token_value: + return token return None @@ -282,6 +293,8 @@ def create_session(self, authn_event, auth_req, user_id, client_id="", :param auth_req: Authorization Request :param client_id: Client ID :param user_id: User ID + :param sector_identifier: + :param sub_type: :param kwargs: extra keyword arguments :return: """ @@ -304,10 +317,16 @@ def create_session(self, authn_event, auth_req, user_id, client_id="", self.set([user_id, client_id], client_info) def _update_client_info(self, session_id, new_information): - _path = unpack_db_key(session_id) - _client_info = self.get(_path) + """ + + :param session_id: + :param new_information: + :return: + """ + _user_id, _client_id, _grant_id = unpack_db_key(session_id) + _client_info = self.get([_user_id, _client_id]) _client_info.update(new_information) - self.set(_path, _client_info) + self.set([_user_id, _client_id], _client_info) def do_sub(self, session_id, sector_id="", subject_type="public"): """ @@ -318,37 +337,80 @@ def do_sub(self, session_id, sector_id="", subject_type="public"): :param subject_type: 'pairwise'/'public' :return: """ - _path = unpack_db_key(session_id) - sub = self.sub_func[subject_type](_path[0], salt=self.salt, sector_identifier=sector_id) + _user_id, _client_id, _grant_id = unpack_db_key(session_id) + sub = self.sub_func[subject_type](_user_id, salt=self.salt, sector_identifier=sector_id) self._update_client_info(session_id, {'sub': sub}) return sub def __getitem__(self, item): return self.get(unpack_db_key(item)) - def revoke_token(self, session_id, token_value): - grant, token = self.find_grant(session_id, token_value) + def get_client_session_info(self, session_id): + _user_id, _client_id, _grant_id = unpack_db_key(session_id) + self.get([_user_id, _client_id]) + + def _revoke_dependent(self, grant, token): + for t in grant.issued_token: + if t.based_on == token.value: + t.revoked = True + self._revoke_dependent(grant, t) + + def revoke_token(self, session_id, token_value, recursive=False): + token = self.find_token(session_id, token_value) + if token is None: + raise UnknownToken() + token.revoked = True + if recursive: + grant = self[session_id] + self._revoke_dependent(grant, token) def get_sids_by_user_id(self, user_id): user_info = self.get([user_id]) return [db_key(user_id, c) for c in user_info['subordinate']] - def get_authentication_event(self, user_id): + def get_authentication_event(self, session_id): + _user_id = unpack_db_key(session_id)[0] try: - user_info = self.get([user_id]) + user_info = self.get([_user_id]) except KeyError: return None return user_info["authentication_event"] - def revoke_session(self, session_id): + def revoke_client_session(self, session_id): + _user_id, _client_id, _ = unpack_db_key(session_id) + _info = self.get([_user_id, _client_id]) + _info.revoke() + self.set([_user_id, _client_id], _info) + + def revoke_grant(self, session_id): _path = unpack_db_key(session_id) _info = self.get(_path) _info.revoke() self.set(_path, _info) def grants(self, session_id): - uid, cid = unpack_db_key(session_id) + uid, cid, _gid = unpack_db_key(session_id) _csi = self.get([uid, cid]) return [self.get([uid, cid, gid]) for gid in _csi['subordinate']] + + def get_session_info(self, session_id): + _user_id, _client_id, _grant_id = unpack_db_key(session_id) + return { + "session_id": session_id, + "user_id": _user_id, + "client_id": _client_id, + "user_session_info": self.get([_user_id]), + "client_session_info": self.get([_user_id, _client_id]), + "grant": self.get([_user_id, _client_id, _grant_id]) + } + + def get_session_info_by_token(self, token_value): + _token_info = self.token_handler.info(token_value) + return self.get_session_info(_token_info["sid"]) + + +def create_session_manager(endpoint_context, token_handler_args, db=None, sub_func=None): + _token_handler = token_handler.factory(endpoint_context, **token_handler_args) + return SessionManager(db, _token_handler, sub_func=sub_func) diff --git a/src/oidcendpoint/user_authn/user.py b/src/oidcendpoint/user_authn/user.py index dce3116..0f1a628 100755 --- a/src/oidcendpoint/user_authn/user.py +++ b/src/oidcendpoint/user_authn/user.py @@ -11,7 +11,6 @@ from cryptojwt.jwt import JWT from oidcendpoint import sanitize -from oidcendpoint.authn_event import create_authn_event from oidcendpoint.exception import FailedAuthentication from oidcendpoint.exception import ImproperlyConfigured from oidcendpoint.exception import InvalidCookieSign diff --git a/src/oidcendpoint/userinfo.py b/src/oidcendpoint/userinfo.py index 33b7160..e896adf 100755 --- a/src/oidcendpoint/userinfo.py +++ b/src/oidcendpoint/userinfo.py @@ -1,67 +1,71 @@ import logging -from oidcmsg.oidc import Claims - -from oidcendpoint import sanitize -from oidcendpoint.exception import FailedAuthentication -from oidcendpoint.exception import ImproperlyConfigured from oidcendpoint.scopes import convert_scopes2claims logger = logging.getLogger(__name__) -def id_token_claims(session, provider_info): - """ - Pick the IdToken claims from the request - - :param session: Session information - :return: The IdToken claims - """ - itc = update_claims(session, "id_token", provider_info=provider_info, old_claims={}) - return itc - - -def update_claims(session, about, provider_info, old_claims=None): - """ - - :param session: - :param about: userinfo or id_token - :param old_claims: - :return: claims or None - """ - - if old_claims is None: - old_claims = {} - - req = None - try: - req = session["authn_req"] - except KeyError: - pass - - if req: - try: - _claims = req["claims"][about] - except KeyError: - pass +class ClaimsInterface: + init_args = { + "add_claims_by_scope": False, + "enable_claims_per_client": False + } + + def __init__(self, endpoint_context, usage, **kwargs): + self.usage = usage # for instance introspection, id_token, userinfo + self.endpoint_context = endpoint_context + self.add_claims_by_scope = kwargs.get("add_claims_by_scope", + self.init_args["add_claims_by_scope"]) + self.enable_claims_per_client = kwargs.get("enable_claims_per_client", + self.init_args["enable_claims_per_client"]) + + def request_claims(self, user_id, client_id): + if self.usage in ["id_token", "userinfo"]: + _csi = self.endpoint_context.session_manager.get([user_id, client_id]) + if "claims" in _csi["authorization_request"]: + return _csi["authorization_request"]["claims"].get(self.usage) + + return {} + + def _get_client_claims(self, client_id): + client_info = self.endpoint_context.cdb.get(client_id, {}) + return client_info.get("{}_claims".format(self.usage), {}) + + def get_claims(self, client_id, user_id, scopes): + """ + + :param client_id: + :param user_id: + :param scopes: + :return: + """ + claims = self._get_client_claims(client_id) + if self.add_claims_by_scope: + _supported = self.endpoint_context.provider_info.get("scopes_supported", []) + if _supported: + _scopes = set(_supported).intersection(set(scopes)) + else: + _scopes = scopes + + _claims = convert_scopes2claims(_scopes, map=self.endpoint_context.scope2claims) + claims.update(_claims) + request_claims = self.request_claims(user_id=user_id, client_id=client_id) + claims.update(request_claims) + return claims + + def get_user_claims(self, user_id, claims_restriction): + """ + + :param user_id: User identifier + :param claims_restriction: Specifies the upper limit of what claims can be returned + :return: + """ + if claims_restriction: + user_info = self.endpoint_context.userinfo(user_id, client_id=None) + return {k: user_info.get(k) for k, v in claims_restriction.items() if + claims_match(user_info.get(k), v)} else: - if _claims: - # Deal only with supported claims - _unsup = [ - c - for c in _claims.keys() - if c not in provider_info["claims_supported"] - ] - for _c in _unsup: - del _claims[_c] - - # update with old claims, do not overwrite - for key, val in old_claims.items(): - if key not in _claims: - _claims[key] = val - return _claims - - return old_claims + return {} def claims_match(value, claimspec): @@ -76,6 +80,9 @@ def claims_match(value, claimspec): as key :return: Boolean """ + if value is None: + return False + if claimspec is None: # match anything return True @@ -110,103 +117,3 @@ def by_schema(cls, **kwa): :return: A dictionary with claims (keys) that meets the filter criteria """ return dict([(key, val) for key, val in kwa.items() if key in cls.c_param]) - - -def collect_user_info( - endpoint_context, session, userinfo_claims=None, scope_to_claims=None -): - """ - Collect information about a user. - This can happen in two cases, either when constructing an IdToken or - when returning user info through the UserInfo endpoint - - :param session: Session information - :param userinfo_claims: user info claims - :return: User info - """ - authn_req = session["authorization_request"] - if scope_to_claims is None: - scope_to_claims = endpoint_context.scope2claims - - _allowed = endpoint_context.scopes_handler.allowed_scopes( - authn_req["client_id"], endpoint_context - ) - supported_scopes = [s for s in authn_req["scope"] if s in _allowed] - if userinfo_claims is None: - _allowed_claims = endpoint_context.claims_handler.allowed_claims( - authn_req["client_id"], endpoint_context - ) - uic = convert_scopes2claims( - supported_scopes, _allowed_claims, map=scope_to_claims - ) - - # Get only keys allowed by user and update the dict if such info - # is stored in session - perm_set = session.get("permission") - if perm_set: - uic = {key: uic[key] for key in uic if key in perm_set} - - uic = update_claims( - session, - "userinfo", - provider_info=endpoint_context.provider_info, - old_claims=uic, - ) - - if uic: - userinfo_claims = Claims(**uic) - logger.debug("userinfo_claim: %s" % sanitize(userinfo_claims.to_dict())) - else: - userinfo_claims = None - logger.warning(("Client {} doesn't have any claims " - "belonging to one or more scopes.").format(authn_req["client_id"])) - raise ImproperlyConfigured("Some additional scopes doesn't have any claims.") - - logger.debug("Session info: %s" % sanitize(session)) - - authn_event = session["authn_event"] - if authn_event: - uid = authn_event["uid"] - else: - uid = session["uid"] - - info = endpoint_context.userinfo(uid, authn_req["client_id"], userinfo_claims) - - if "sub" in userinfo_claims: - if not claims_match(session["sub"], userinfo_claims["sub"]): - raise FailedAuthentication("Unmatched sub claim") - - info["sub"] = session["sub"] - try: - logger.debug("user_info_response: {}".format(info)) - except UnicodeEncodeError: - logger.debug("user_info_response: {}".format(info.encode("utf-8"))) - - return info - - -def userinfo_in_id_token_claims(endpoint_context, session, def_itc=None): - """ - Collect user info claims that are to be placed in the id token. - - :param endpoint_context: Endpoint context - :param session: Session information - :param def_itc: Default ID Token claims - :return: User information or None - """ - if def_itc: - itc = def_itc - else: - itc = {} - - itc.update(id_token_claims(session, provider_info=endpoint_context.provider_info)) - - if not itc: - return None - - _claims = by_schema(endpoint_context.id_token_schema, **itc) - - if _claims: - return collect_user_info(endpoint_context, session, _claims) - else: - return None diff --git a/src/oidcendpoint/util.py b/src/oidcendpoint/util.py index 89ac46c..64372e6 100755 --- a/src/oidcendpoint/util.py +++ b/src/oidcendpoint/util.py @@ -174,7 +174,8 @@ def split_uri(uri): def allow_refresh_token(endpoint_context): # Are there a refresh_token handler - refresh_token_handler = endpoint_context.sdb.handler.handler.get("refresh_token") + refresh_token_handler = endpoint_context.session_manager.token_handler.handler[ + "refresh_token"] # Is refresh_token grant type supported _token_supported = False diff --git a/tests/test_70_grant.py b/tests/test_01_grant.py similarity index 100% rename from tests/test_70_grant.py rename to tests/test_01_grant.py diff --git a/tests/test_71_sess_mngm_db.py b/tests/test_01_sess_mngm_db.py similarity index 100% rename from tests/test_71_sess_mngm_db.py rename to tests/test_01_sess_mngm_db.py diff --git a/tests/test_72_session_life.py b/tests/test_01_session_life.py similarity index 89% rename from tests/test_72_session_life.py rename to tests/test_01_session_life.py index d93cf4e..6a6771b 100644 --- a/tests/test_72_session_life.py +++ b/tests/test_01_session_life.py @@ -1,11 +1,11 @@ import os -import pytest from cryptojwt.key_jar import init_key_jar from oidcmsg.oidc import AccessTokenRequest from oidcmsg.oidc import AuthorizationRequest from oidcmsg.oidc import RefreshAccessTokenRequest from oidcmsg.time_util import time_sans_frac +import pytest from oidcendpoint import user_info from oidcendpoint.authn_event import create_authn_event @@ -19,11 +19,11 @@ from oidcendpoint.oidc.session import Session from oidcendpoint.oidc.token import AccessToken from oidcendpoint.session_management import ClientSessionInfo +from oidcendpoint.session_management import SessionManager +from oidcendpoint.session_management import UserSessionInfo from oidcendpoint.session_management import db_key from oidcendpoint.session_management import public_id -from oidcendpoint.session_management import SessionManager from oidcendpoint.session_management import unpack_db_key -from oidcendpoint.session_management import UserSessionInfo from oidcendpoint.token_handler import DefaultToken from oidcendpoint.token_handler import TokenHandler from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD @@ -40,16 +40,16 @@ def setup_token_handler(self): code_handler = DefaultToken(password, typ="A", lifetime=grant_expires_in) access_token_handler = DefaultToken( password, typ="T", lifetime=token_expires_in - ) + ) refresh_token_handler = DefaultToken( password, typ="R", lifetime=refresh_token_expires_in - ) + ) handler = TokenHandler( code_handler=code_handler, access_token_handler=access_token_handler, refresh_token_handler=refresh_token_handler, - ) + ) self.session_manager = SessionManager({}, handler=handler) @@ -62,7 +62,7 @@ def auth(self): scope=["openid", "mail", "address", "offline_access"], state="STATE", response_type="code", - ) + ) # The authentication returns a user ID user_id = "diana" @@ -73,9 +73,9 @@ def auth(self): self.session_manager.salt, authn_info=INTERNETPROTOCOLPASSWORD, authn_time=time_sans_frac(), - ) + ) - user_info = UserSessionInfo(authenticationEvent=authn_event) + user_info = UserSessionInfo(authentication_event=authn_event) self.session_manager.set([user_id], user_info) # Now for client session information @@ -83,7 +83,7 @@ def auth(self): client_info = ClientSessionInfo( authorization_request=AUTH_REQ, sub=public_id(user_id, self.session_manager.salt) - ) + ) self.session_manager.set([user_id, AUTH_REQ['client_id']], client_info) # The user consent module produces a Grant instance @@ -100,13 +100,13 @@ def auth(self): 'authorization_code', value=self.session_manager.token_handler["code"](user_id), expires_at=time_sans_frac() + 300 # 5 minutes from now - ) + ) - return code + return grant.id, code def test_code_flow(self): # code is a Token instance - code = self.auth() + _grant_id, code = self.auth() # next step is access token request @@ -117,7 +117,7 @@ def test_code_flow(self): grant_type="authorization_code", client_secret="hemligt", code=code.value - ) + ) # parse the token user_id = self.session_manager.token_handler.sid(TOKEN_REQ['code']) @@ -126,9 +126,9 @@ def test_code_flow(self): # token I can easily find the grant # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']]) - session_id = db_key(user_id, TOKEN_REQ['client_id']) - grant, tok = self.session_manager.find_grant(session_id, - TOKEN_REQ['code']) + session_id = db_key(user_id, TOKEN_REQ['client_id'], _grant_id) + tok = self.session_manager.find_token(session_id, + TOKEN_REQ['code']) # Verify that it's of the correct type and can be used assert tok.type == "authorization_code" @@ -138,12 +138,14 @@ def test_code_flow(self): assert tok.supports_minting("access_token") + grant = self.session_manager[session_id] + access_token = grant.mint_token( 'access_token', value=self.session_manager.token_handler["access_token"](user_id), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=tok # Means the token (tok) was used to mint this token - ) + ) assert tok.supports_minting("refresh_token") @@ -151,7 +153,7 @@ def test_code_flow(self): 'refresh_token', value=self.session_manager.token_handler["refresh_token"](user_id), based_on=tok - ) + ) tok.register_usage() @@ -165,11 +167,11 @@ def test_code_flow(self): client_secret="hemligt", refresh_token=refresh_token.value, scope=["openid", "mail", "offline_access"] - ) + ) - session_id = db_key(user_id,REFRESH_TOKEN_REQ['client_id']) - grant, reftok = self.session_manager.find_grant(session_id, - REFRESH_TOKEN_REQ['refresh_token']) + session_id = db_key(user_id, REFRESH_TOKEN_REQ['client_id'], _grant_id) + reftok = self.session_manager.find_token(session_id, + REFRESH_TOKEN_REQ['refresh_token']) assert reftok.supports_minting("access_token") @@ -178,7 +180,7 @@ def test_code_flow(self): value=self.session_manager.token_handler["access_token"](user_id), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=reftok # Means the token (tok) was used to mint this token - ) + ) assert access_token_2.is_active() @@ -186,7 +188,7 @@ def test_code_flow(self): KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, - ] +] ISSUER = "https://example.com/" @@ -201,7 +203,7 @@ def test_code_flow(self): ["id_token", "token"], ["code", "token", "id_token"], ["none"], - ] +] CAPABILITIES = { "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], "token_endpoint_auth_methods_supported": [ @@ -209,19 +211,19 @@ def test_code_flow(self): "client_secret_basic", "client_secret_jwt", "private_key_jwt", - ], + ], "response_modes_supported": ["query", "fragment", "form_post"], "subject_types_supported": ["public", "pairwise"], "grant_types_supported": [ "authorization_code", "implicit", "urn:ietf:params:oauth:grant-type:jwt-bearer", - ], + ], "claim_types_supported": ["normal", "aggregated", "distributed"], "claims_parameter_supported": True, "request_parameter_supported": True, "request_uri_parameter_supported": True, - } +} BASEDIR = os.path.abspath(os.path.dirname(__file__)) @@ -248,8 +250,8 @@ def setup_session_manager(self): "key_defs": [ {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}, {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "refresh"} - ], - }, + ], + }, "code": {"lifetime": 600}, "token": { "class": "oidcendpoint.jwt_token.JWTToken", @@ -260,50 +262,52 @@ def setup_session_manager(self): "email_verified", "phone_number", "phone_number_verified", - ], + ], "add_claim_by_scope": True, "aud": ["https://example.org/appl"], - }, }, - "refresh": {}, }, + "refresh": {}, + }, "endpoint": { "provider_config": { "path": "{}/.well-known/openid-configuration", "class": ProviderConfiguration, "kwargs": {}, - }, + }, "registration": { "path": "{}/registration", "class": Registration, "kwargs": {}, - }, + }, "authorization": { "path": "{}/authorization", "class": Authorization, "kwargs": {}, - }, + }, "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}}, "session": {"path": "{}/end_session", "class": Session}, - }, + }, "client_authn": verify_client, "authentication": { "anon": { "acr": INTERNETPROTOCOLPASSWORD, "class": "oidcendpoint.user_authn.user.NoAuthn", "kwargs": {"user": "diana"}, - } - }, + } + }, "template_dir": "template", "userinfo": { "class": user_info.UserInfo, "kwargs": {"db_file": full_path("users.json")}, - }, + }, "id_token": {"class": IDToken}, - } + } self.endpoint_context = EndpointContext(conf, keyjar=KEYJAR) - self.session_manager = SessionManager({}, handler=self.endpoint_context.sdb.handler) + self.session_manager = self.endpoint_context.session_manager + # self.session_manager = SessionManager({}, handler=self.endpoint_context.sdb.handler) + # self.endpoint_context.session_manager = self.session_manager def auth(self): # Start with an authentication request @@ -314,7 +318,7 @@ def auth(self): scope=["openid", "mail", "address", "offline_access"], state="STATE", response_type="code", - ) + ) # The authentication returns a user ID user_id = "diana" @@ -331,13 +335,13 @@ def auth(self): salt, authn_info=INTERNETPROTOCOLPASSWORD, authn_time=time_sans_frac(), - ) + ) client_info = ClientSessionInfo( authorization_request=AUTH_REQ, - authenticationEvent=authn_event, + authentication_event=authn_event, sub=public_id(user_id, salt) - ) + ) self.session_manager.set([user_id, AUTH_REQ['client_id']], client_info) # The user consent module produces a Grant instance @@ -353,9 +357,9 @@ def auth(self): code = grant.mint_token( 'authorization_code', value=self.session_manager.token_handler["code"]( - db_key(user_id, AUTH_REQ['client_id'])), + db_key(user_id, AUTH_REQ['client_id'], grant.id)), expires_at=time_sans_frac() + 300 # 5 minutes from now - ) + ) return code @@ -372,19 +376,17 @@ def test_code_flow(self): grant_type="authorization_code", client_secret="hemligt", code=code.value - ) + ) # parse the token session_id = self.session_manager.token_handler.sid(TOKEN_REQ['code']) - user_id, client_id = unpack_db_key(session_id) + user_id, client_id, grant_id = unpack_db_key(session_id) # Now given I have the client_id from the request and the user_id from the # token I can easily find the grant # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']]) - session_id = db_key(user_id, TOKEN_REQ['client_id']) - grant, tok = self.session_manager.find_grant(session_id, - TOKEN_REQ['code']) + tok = self.session_manager.find_token(session_id, TOKEN_REQ['code']) # Verify that it's of the correct type and can be used assert tok.type == "authorization_code" @@ -398,31 +400,33 @@ def test_code_flow(self): assert tok.supports_minting("access_token") + grant = self.session_manager[session_id] + user_claims = self.endpoint_context.userinfo(user_id, client_id=TOKEN_REQ["client_id"], user_info_claims=grant.claims) access_token = grant.mint_token( 'access_token', value=self.session_manager.token_handler["access_token"]( - db_key(user_id, client_id), + session_id, client_id=TOKEN_REQ['client_id'], aud=grant.resources, user_claims=user_claims, scope=grant.scope, sub=client_info['sub'] - ), + ), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=tok # Means the token (tok) was used to mint this token - ) + ) # this test is include in the mint_token methods # assert tok.supports_minting("refresh_token") refresh_token = grant.mint_token( 'refresh_token', - value=self.session_manager.token_handler["refresh_token"](db_key(user_id, client_id)), + value=self.session_manager.token_handler["refresh_token"](session_id), based_on=tok - ) + ) tok.register_usage() @@ -436,10 +440,10 @@ def test_code_flow(self): client_secret="hemligt", refresh_token=refresh_token.value, scope=["openid", "mail", "offline_access"] - ) + ) - session_id = db_key(user_id, REFRESH_TOKEN_REQ['client_id']) - grant, reftok = self.session_manager.find_grant(session_id, + session_id = db_key(user_id, REFRESH_TOKEN_REQ['client_id'], grant_id) + reftok = self.session_manager.find_token(session_id, REFRESH_TOKEN_REQ['refresh_token']) # Can I use this token to mint another token ? @@ -451,14 +455,17 @@ def test_code_flow(self): access_token_2 = grant.mint_token( 'access_token', value=self.session_manager.token_handler["access_token"]( - db_key(user_id, client_id), + session_id, sub=client_info['sub'], client_id=TOKEN_REQ['client_id'], aud=grant.resources, user_claims=user_claims - ), + ), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=reftok # Means the refresh token (reftok) was used to mint this token - ) + ) assert access_token_2.is_active() + + token_info = self.session_manager.token_handler.info(access_token_2.value) + assert token_info diff --git a/tests/test_03_id_token.py b/tests/test_03_id_token.py index 7ab1da0..37bfcf0 100644 --- a/tests/test_03_id_token.py +++ b/tests/test_03_id_token.py @@ -2,20 +2,25 @@ import os import time -import pytest from cryptojwt.jws import jws from cryptojwt.jwt import JWT from cryptojwt.key_jar import KeyJar from oidcmsg.oidc import AuthorizationRequest from oidcmsg.oidc import RegistrationResponse +from oidcmsg.time_util import time_sans_frac +import pytest +from oidcendpoint.authn_event import create_authn_event from oidcendpoint.client_authn import verify_client from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.grant import Grant from oidcendpoint.id_token import IDToken from oidcendpoint.id_token import get_sign_and_encrypt_algorithms from oidcendpoint.oidc import userinfo from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.session_management import SessionManager +from oidcendpoint.session_management import db_key from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo @@ -34,15 +39,34 @@ def full_path(local_file): USERS = json.loads(open(full_path("users.json")).read()) USERINFO = UserInfo(USERS) -AREQN = AuthorizationRequest( +AREQ = AuthorizationRequest( response_type="code", - client_id="client1", + client_id="client_1", redirect_uri="http://example.com/authz", scope=["openid"], state="state000", nonce="nonce", ) +AREQS = AuthorizationRequest( + response_type="code", + client_id="client_1", + redirect_uri="http://example.com/authz", + scope=["openid", "address", "email"], + state="state000", + nonce="nonce", +) + +AREQRC = AuthorizationRequest( + response_type="code", + client_id="client_1", + redirect_uri="http://example.com/authz", + scope=["openid", "address", "email"], + state="state000", + nonce="nonce", + claims={"id_token": {"nickname": None}} +) + conf = { "issuer": "https://example.com/", "password": "mycket hemligt", @@ -72,12 +96,14 @@ def full_path(local_file): "kwargs": {"user": "diana"}, } }, - "userinfo": {"class": "oidcendpoint.user_info.UserInfo", "kwargs": {"db": USERS},}, + "userinfo": {"class": "oidcendpoint.user_info.UserInfo", "kwargs": {"db": USERS}, }, "client_authn": verify_client, "template_dir": "template", "id_token": {"class": IDToken, "kwargs": {"foo": "bar"}}, } +USER_ID = "diana" + class TestEndpoint(object): @pytest.fixture(autouse=True) @@ -93,112 +119,128 @@ def create_idtoken(self): self.endpoint_context.keyjar.add_symmetric( "client_1", "hemligtochintekort", ["sig", "enc"] ) + self.session_manager = self.endpoint_context.session_manager + self.user_id = USER_ID + + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + client_id = auth_req['client_id'] + ae = create_authn_event(self.user_id, self.session_manager.salt) + self.session_manager.create_session(ae, auth_req, self.user_id, client_id=client_id, + sub_type=sub_type, sector_identifier=sector_identifier) + return db_key(self.user_id, client_id) + + def _do_grant(self, auth_req): + client_id = auth_req['client_id'] + # The user consent module produces a Grant instance + grant = Grant(scope=AREQ['scope'], resources=[client_id]) + + # the grant is assigned to a session (user_id, client_id) + self.session_manager.set([self.user_id, client_id, grant.id], grant) + return db_key(self.user_id, client_id, grant.id) + + def _mint_code(self, grant): + # Constructing an authorization code is now done + return grant.mint_token( + 'authorization_code', + value=self.session_manager.token_handler["code"](self.user_id), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) - def test_id_token_payload_0(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} - info = self.endpoint_context.idtoken.payload(session_info) - assert info["payload"] == {"sub": "1234567890", "nonce": "nonce"} - assert info["lifetime"] == 300 + def _mint_access_token(self, grant, client_id, token_ref): + _csi = self.session_manager.get([self.user_id, client_id]) + return grant.mint_token( + 'access_token', + value=self.session_manager.token_handler["access_token"]( + db_key(self.user_id, client_id, grant.id), + client_id=client_id, + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_csi['sub'] + ), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=token_ref # Means the token (tok) was used to mint this token + ) - def test_id_token_payload_1(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} + def test_id_token_payload_0(self): + self._create_session(AREQ) + session_id = self._do_grant(AREQ) - info = self.endpoint_context.idtoken.payload(session_info) - assert info["payload"] == {"nonce": "nonce", "sub": "1234567890"} - assert info["lifetime"] == 300 + payload = self.endpoint_context.idtoken.payload(session_id) + assert set(payload.keys()) == {"sub", "nonce", "auth_time"} def test_id_token_payload_with_code(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] - info = self.endpoint_context.idtoken.payload( - session_info, code="ABCDEFGHIJKLMNOP" + code = self._mint_code(grant) + payload = self.endpoint_context.idtoken.payload( + session_id, AREQ["client_id"], code=code.value ) - assert info["payload"] == { - "nonce": "nonce", - "c_hash": "5-i4nCch0pDMX1VCVJHs1g", - "sub": "1234567890", - } - assert info["lifetime"] == 300 + assert set(payload.keys()) == {"nonce", "c_hash", "sub", "auth_time"} def test_id_token_payload_with_access_token(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] - info = self.endpoint_context.idtoken.payload( - session_info, access_token="012ABCDEFGHIJKLMNOP" + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AREQ['client_id'], code) + + payload = self.endpoint_context.idtoken.payload( + session_id, AREQ["client_id"], access_token=access_token.value ) - assert info["payload"] == { - "nonce": "nonce", - "at_hash": "bKkyhbn1CC8IMdavzOV-Qg", - "sub": "1234567890", - } - assert info["lifetime"] == 300 + assert set(payload.keys()) == {"nonce", "at_hash", "sub", "auth_time"} def test_id_token_payload_with_code_and_access_token(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] + + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AREQ['client_id'], code) - info = self.endpoint_context.idtoken.payload( - session_info, access_token="012ABCDEFGHIJKLMNOP", code="ABCDEFGHIJKLMNOP" + payload = self.endpoint_context.idtoken.payload( + session_id, AREQ["client_id"], access_token=access_token.value, code=code.value ) - assert info["payload"] == { - "nonce": "nonce", - "at_hash": "bKkyhbn1CC8IMdavzOV-Qg", - "c_hash": "5-i4nCch0pDMX1VCVJHs1g", - "sub": "1234567890", - } - assert info["lifetime"] == 300 + assert set(payload.keys()) == {"nonce", "c_hash", "at_hash", "sub", "auth_time"} def test_id_token_payload_with_userinfo(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] + grant.claims = {"id_token": {"given_name": None}} - info = self.endpoint_context.idtoken.payload( - session_info, user_info={"given_name": "Diana"} - ) - assert info["payload"] == { - "nonce": "nonce", - "given_name": "Diana", - "sub": "1234567890", - } - assert info["lifetime"] == 300 + payload = self.endpoint_context.idtoken.payload(session_id=session_id) + assert set(payload.keys()) == {"nonce", "given_name", "sub", "auth_time"} def test_id_token_payload_many_0(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} - - info = self.endpoint_context.idtoken.payload( - session_info, - user_info={"given_name": "Diana"}, - access_token="012ABCDEFGHIJKLMNOP", - code="ABCDEFGHIJKLMNOP", + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] + grant.claims = {"id_token": {"given_name": None}} + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AREQ['client_id'], code) + + payload = self.endpoint_context.idtoken.payload( + session_id, AREQ["client_id"], + access_token=access_token.value, + code=code.value ) - assert info["payload"] == { - "nonce": "nonce", - "given_name": "Diana", - "at_hash": "bKkyhbn1CC8IMdavzOV-Qg", - "c_hash": "5-i4nCch0pDMX1VCVJHs1g", - "sub": "1234567890", - } - assert info["lifetime"] == 300 + assert set(payload.keys()) == {"nonce", "c_hash", "at_hash", "sub", "auth_time", + "given_name"} def test_sign_encrypt_id_token(self): - client_info = RegistrationResponse( - id_token_signed_response_alg="RS512", client_id="client_1" - ) - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": {"authn_info": "loa2", "authn_time": time.time()}, - } + self._create_session(AREQ) + session_id = self._do_grant(AREQ) - self.endpoint_context.jwx_def["signing_alg"] = {"id_token": "RS384"} - self.endpoint_context.cdb["client_1"] = client_info.to_dict() - - _token = self.endpoint_context.idtoken.sign_encrypt( - session_info, "client_1", sign=True - ) + _token = self.endpoint_context.idtoken.sign_encrypt(session_id, AREQ['client_id'], sign=True) assert _token _jws = jws.factory(_token) - assert _jws.jwt.headers["alg"] == "RS512" + assert _jws.jwt.headers["alg"] == "RS256" client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() @@ -210,10 +252,9 @@ def test_sign_encrypt_id_token(self): assert res["aud"] == ["client_1"] def test_get_sign_algorithm(self): - client_info = RegistrationResponse() - endpoint_context = EndpointContext(conf) + client_info = self.endpoint_context.cdb[AREQ['client_id']] algs = get_sign_and_encrypt_algorithms( - endpoint_context, client_info, "id_token", sign=True + self.endpoint_context, client_info, "id_token", sign=True ) # default signing alg assert algs == {"sign": True, "encrypt": False, "sign_alg": "RS256"} @@ -260,20 +301,12 @@ def test_get_sign_algorithm_4(self): assert algs == {"sign": True, "encrypt": False, "sign_alg": "RS512"} def test_available_claims(self): - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": { - "authn_info": "loa2", - "authn_time": time.time(), - "uid": "diana", - }, - } - self.endpoint_context.idtoken.kwargs["available_claims"] = { - "nickname": {"essential": True} - } - req = {"client_id": "client_1"} - _token = self.endpoint_context.idtoken.make(req, session_info) + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] + grant.claims = {"id_token": {"nickname": {"essential": True}}} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) assert _token client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() @@ -283,39 +316,34 @@ def test_available_claims(self): assert "nickname" in res def test_no_available_claims(self): - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": { - "authn_info": "loa2", - "authn_time": time.time(), - "uid": "diana", - }, - } + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] + grant.claims = {"id_token":{"foobar": None}} + req = {"client_id": "client_1"} - _token = self.endpoint_context.idtoken.make(req, session_info) + _token = self.endpoint_context.idtoken.make(session_id=session_id) assert _token client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(_token) - assert "nickname" not in res + assert "foobar" not in res def test_client_claims(self): - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": { - "authn_info": "loa2", - "authn_time": time.time(), - "uid": "diana", - }, - } - self.endpoint_context.idtoken.enable_claims_per_client = True + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] + + self.endpoint_context.idtoken.claims_interface.enable_claims_per_client = True self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None} - req = {"client_id": "client_1"} - _token = self.endpoint_context.idtoken.make(req, session_info) + + _claims = self.endpoint_context.idtoken.claims_interface.get_claims( + client_id=AREQ["client_id"], user_id=USER_ID, scopes=AREQ["scope"]) + grant.claims = {'id_token': _claims} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) assert _token client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() @@ -326,50 +354,68 @@ def test_client_claims(self): assert "nickname" not in res def test_client_claims_with_default(self): - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": { - "authn_info": "loa2", - "authn_time": time.time(), - "uid": "diana", - }, - } - self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None} - self.endpoint_context.idtoken.kwargs["available_claims"] = { - "nickname": {"essential": True} - } - self.endpoint_context.idtoken.enable_claims_per_client = True - req = {"client_id": "client_1"} - _token = self.endpoint_context.idtoken.make(req, session_info) + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] + + # self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None} + # self.endpoint_context.idtoken.enable_claims_per_client = True + + _claims = self.endpoint_context.idtoken.claims_interface.get_claims( + client_id=AREQ["client_id"], user_id=USER_ID, scopes=AREQ["scope"]) + grant.claims = {"id_token": _claims} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) assert _token client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(_token) - assert "address" in res - assert "nickname" in res - def test_client_claims_disabled(self): - # enable_claims_per_client defaults to False - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": { - "authn_info": "loa2", - "authn_time": time.time(), - "uid": "diana", - }, - } - self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None} - req = {"client_id": "client_1"} - _token = self.endpoint_context.idtoken.make(req, session_info) + # No user info claims should be there + assert "address" not in res + assert "nickname" not in res + + def test_client_claims_scopes(self): + self._create_session(AREQS) + session_id = self._do_grant(AREQS) + grant = self.session_manager[session_id] + + self.endpoint_context.idtoken.claims_interface.add_claims_by_scope = True + _claims = self.endpoint_context.idtoken.claims_interface.get_claims( + client_id=AREQS["client_id"], user_id=USER_ID, scopes=AREQS["scope"]) + grant.claims = {"id_token": _claims} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) assert _token client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(_token) - assert "address" not in res + assert "address" in res + assert "email" in res assert "nickname" not in res + + def test_client_claims_scopes_and_request_claims(self): + self._create_session(AREQRC) + session_id = self._do_grant(AREQRC) + grant = self.session_manager[session_id] + + self.endpoint_context.idtoken.claims_interface.add_claims_by_scope = True + _claims = self.endpoint_context.idtoken.claims_interface.get_claims( + client_id=AREQRC["client_id"], user_id=USER_ID, scopes=AREQRC["scope"]) + grant.claims = {"id_token": _claims} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) + assert _token + client_keyjar = KeyJar() + _jwks = self.endpoint_context.keyjar.export_jwks() + client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) + _jwt = JWT(key_jar=client_keyjar, iss="client_1") + res = _jwt.unpack(_token) + assert "address" in res + assert "email" in res + assert "nickname" in res + diff --git a/tests/test_05_sso_db.py b/tests/test_05_sso_db.py deleted file mode 100644 index 5c8479f..0000000 --- a/tests/test_05_sso_db.py +++ /dev/null @@ -1,135 +0,0 @@ -import shutil - -import pytest - -from oidcendpoint.sso_db import SSODb - -DB_CONF = { - "handler": "oidcmsg.storage.abfile.AbstractFileSystem", - "fdir": "db/sso", - "key_conv": "oidcmsg.storage.converter.QPKey", - "value_conv": "oidcmsg.storage.converter.JSON", -} - - -def rmtree(item): - try: - shutil.rmtree(item) - except FileNotFoundError: - pass - - -class TestSSODB(object): - @pytest.fixture(autouse=True) - def create_sdb(self): - rmtree("db/sso") - self.sso_db = SSODb(DB_CONF) - - def test_map_sid2uid(self): - self.sso_db.map_sid2uid("session id 1", "Lizz") - assert self.sso_db.get_sids_by_uid("Lizz") == ["session id 1"] - - def test_missing_map(self): - assert self.sso_db.get_sids_by_uid("Lizz") == [] - - def test_multiple_map_sid2uid(self): - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Lizz") - assert set(self.sso_db.get_sids_by_uid("Lizz")) == { - "session id 1", - "session id 2", - } - - def test_map_unmap_sid2uid(self): - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Lizz") - assert set(self.sso_db.get_sids_by_uid("Lizz")) == { - "session id 1", - "session id 2", - } - - self.sso_db.remove_sid2uid("session id 1", "Lizz") - assert self.sso_db.get_sids_by_uid("Lizz") == ["session id 2"] - - def test_get_uid_by_sid(self): - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Lizz") - - assert self.sso_db.get_uid_by_sid("session id 1") == "Lizz" - assert self.sso_db.get_uid_by_sid("session id 2") == "Lizz" - - def test_remove_uid(self): - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Diana") - - self.sso_db.remove_uid("Lizz") - assert self.sso_db.get_uid_by_sid("session id 1") is None - assert self.sso_db.get_sids_by_uid("Lizz") == [] - - def test_map_sid2sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - assert self.sso_db.get_sids_by_sub("abcdefgh") == ["session id 1"] - - def test_missing_sid2sub_map(self): - assert self.sso_db.get_sids_by_sub("abcdefgh") == [] - - def test_multiple_map_sid2sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "abcdefgh") - assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == { - "session id 1", - "session id 2", - } - - def test_map_unmap_sid2sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "abcdefgh") - assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == { - "session id 1", - "session id 2", - } - - self.sso_db.remove_sid2sub("session id 1", "abcdefgh") - assert self.sso_db.get_sids_by_sub("abcdefgh") == ["session id 2"] - - def test_get_sub_by_sid(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "abcdefgh") - - assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == { - "session id 1", - "session id 2", - } - - def test_remove_sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "012346789") - - self.sso_db.remove_sub("abcdefgh") - assert self.sso_db.get_sub_by_sid("session id 1") is None - assert self.sso_db.get_sids_by_sub("abcdefgh") == [] - # have not touched the others - assert self.sso_db.get_sub_by_sid("session id 2") == "012346789" - assert self.sso_db.get_sids_by_sub("012346789") == ["session id 2"] - - def test_get_sub_by_uid_same_sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "abcdefgh") - - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Lizz") - - res = self.sso_db.get_subs_by_uid("Lizz") - - assert set(res) == {"abcdefgh"} - - def test_get_sub_by_uid_different_sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "012346789") - - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Lizz") - - res = self.sso_db.get_subs_by_uid("Lizz") - - assert set(res) == {"abcdefgh", "012346789"} diff --git a/tests/test_07_userinfo.py b/tests/test_07_userinfo.py index d3a0573..f94e859 100644 --- a/tests/test_07_userinfo.py +++ b/tests/test_07_userinfo.py @@ -1,25 +1,23 @@ import json import os -import pytest -from oidcmsg.message import Message from oidcmsg.oidc import OpenIDRequest -from oidcmsg.oidc import OpenIDSchema +import pytest from oidcendpoint.authn_event import create_authn_event from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.grant import Grant from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration from oidcendpoint.scopes import SCOPE2CLAIMS from oidcendpoint.scopes import STANDARD_CLAIMS from oidcendpoint.scopes import convert_scopes2claims +from oidcendpoint.session_management import db_key +from oidcendpoint.session_management import unpack_db_key from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo -from oidcendpoint.userinfo import by_schema -from oidcendpoint.userinfo import claims_match -from oidcendpoint.userinfo import collect_user_info -from oidcendpoint.userinfo import update_claims +from oidcendpoint.userinfo import ClaimsInterface CLAIMS = { "userinfo": { @@ -136,27 +134,27 @@ def test_custom_scopes(): assert set( convert_scopes2claims(["email"], _available_claims, map=_scopes).keys() - ) == {"email", "email_verified",} + ) == {"email", "email_verified", } assert set( convert_scopes2claims(["address"], _available_claims, map=_scopes).keys() ) == {"address"} assert set( convert_scopes2claims(["phone"], _available_claims, map=_scopes).keys() - ) == {"phone_number", "phone_number_verified",} + ) == {"phone_number", "phone_number_verified", } assert set( convert_scopes2claims( ["research_and_scholarship"], _available_claims, map=_scopes ).keys() ) == { - "name", - "given_name", - "family_name", - "email", - "email_verified", - "sub", - "eduperson_scoped_affiliation", - } + "name", + "given_name", + "family_name", + "email", + "email_verified", + "sub", + "eduperson_scoped_affiliation", + } PROVIDER_INFO = { @@ -172,71 +170,6 @@ def test_custom_scopes(): ] } - -def test_update_claims_authn_req_id_token(): - _session_info = {"authn_req": OIDR} - claims = update_claims(_session_info, "id_token", PROVIDER_INFO) - assert set(claims.keys()) == {"auth_time", "acr"} - - -def test_update_claims_authn_req_userinfo(): - _session_info = {"authn_req": OIDR} - claims = update_claims(_session_info, "userinfo", PROVIDER_INFO) - assert set(claims.keys()) == { - "given_name", - "nickname", - "email", - "email_verified", - "picture", - "http://example.info/claims/groups", - } - - -def test_update_claims_authzreq_id_token(): - _session_info = {"authn_req": OIDR} - claims = update_claims(_session_info, "id_token", PROVIDER_INFO) - assert set(claims.keys()) == {"auth_time", "acr"} - - -def test_update_claims_authzreq_userinfo(): - _session_info = {"authn_req": OIDR} - claims = update_claims(_session_info, "userinfo", PROVIDER_INFO) - assert set(claims.keys()) == { - "given_name", - "nickname", - "email", - "email_verified", - "picture", - "http://example.info/claims/groups", - } - - -def test_clams_value(): - assert claims_match("red", CLAIMS["userinfo"]["http://example.info/claims/groups"]) - - -def test_clams_values(): - assert claims_match("urn:mace:incommon:iap:silver", CLAIMS["id_token"]["acr"]) - - -def test_clams_essential(): - assert claims_match(["foobar@example"], CLAIMS["userinfo"]["email"]) - - -def test_clams_none(): - assert claims_match(["angle"], CLAIMS["userinfo"]["nickname"]) - - -def test_by_schema(): - # There are no requested or optional claims defined for Message - assert by_schema(Message, sub="John") == {} - - assert by_schema(OpenIDSchema, sub="John", given_name="John", age=34) == { - "sub": "John", - "given_name": "John", - } - - KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, @@ -300,22 +233,42 @@ def create_endpoint_context(self): ) # Just has to be there self.endpoint_context.cdb["client1"] = {} + self.session_manager = self.endpoint_context.session_manager + self.claims_interface = ClaimsInterface(self.endpoint_context, "userinfo") + self.user_id = "diana" + + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + client_id = auth_req['client_id'] + ae = create_authn_event(self.user_id, self.session_manager.salt) + self.session_manager.create_session(ae, auth_req, self.user_id, client_id=client_id, + sub_type=sub_type, sector_identifier=sector_identifier) + return db_key(self.user_id, client_id) + + def _do_grant(self, auth_req): + client_id = auth_req['client_id'] + # The user consent module produces a Grant instance + grant = Grant(scope=auth_req['scope'], resources=[client_id]) + + # the grant is assigned to a session (user_id, client_id) + self.session_manager.set([self.user_id, client_id, grant.id], grant) + return db_key(self.user_id, client_id, grant.id) def test_collect_user_info(self): _req = OIDR.copy() _req["claims"] = CLAIMS_2 - _session_info = {"authn_req": _req} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") + self._create_session(_req) + session_id = self._do_grant(_req) + _uid, _cid, _gid = unpack_db_key(session_id) + + _restriction = self.claims_interface.get_claims(client_id=_cid, user_id=_uid, + scopes=OIDR["scope"]) - res = collect_user_info(self.endpoint_context, session) + res = self.claims_interface.get_user_claims("diana", _restriction) assert res == { + 'eduperson_scoped_affiliation': ['staff@example.org'], "nickname": "Dina", - "sub": "doe", "email": "diana@example.org", "email_verified": False, } @@ -325,21 +278,17 @@ def test_collect_user_info_2(self): _req["scope"] = "openid email" del _req["claims"] - _session_info = {"authn_req": _req} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") + self._create_session(_req) + session_id = self._do_grant(_req) + _uid, _cid, _gid = unpack_db_key(session_id) - self.endpoint_context.provider_info["scopes_supported"] = [ - "openid", - "email", - "offline_access", - ] - res = collect_user_info(self.endpoint_context, session) + self.claims_interface.add_claims_by_scope = True + _restriction = self.claims_interface.get_claims(client_id=_cid, user_id=_uid, + scopes=_req["scope"]) + + res = self.claims_interface.get_user_claims("diana", _restriction) assert res == { - "sub": "doe", "email": "diana@example.org", "email_verified": False, } @@ -349,25 +298,17 @@ def test_collect_user_info_scope_not_supported(self): _req["scope"] = "openid email address" del _req["claims"] - _session_info = {"authn_req": _req} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") + self._create_session(_req) + session_id = self._do_grant(_req) + _uid, _cid, _gid = unpack_db_key(session_id) - # Scope address not supported - self.endpoint_context.provider_info["scopes_supported"] = [ - "openid", - "email", - "offline_access", - ] - res = collect_user_info(self.endpoint_context, session) + self.claims_interface.add_claims_by_scope = False + _restriction = self.claims_interface.get_claims(client_id=_cid, user_id=_uid, + scopes=_req["scope"]) - assert res == { - "sub": "doe", - "email": "diana@example.org", - "email_verified": False, - } + res = self.claims_interface.get_user_claims("diana", _restriction) + + assert res == {} class TestCollectUserInfoCustomScopes: @@ -443,22 +384,43 @@ def create_endpoint_context(self): } ) self.endpoint_context.cdb["client1"] = {} + self.endpoint_context.cdb["client1"] = {} + self.session_manager = self.endpoint_context.session_manager + self.claims_interface = ClaimsInterface(self.endpoint_context, "userinfo") + self.user_id = "diana" + + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + client_id = auth_req['client_id'] + ae = create_authn_event(self.user_id, self.session_manager.salt) + self.session_manager.create_session(ae, auth_req, self.user_id, client_id=client_id, + sub_type=sub_type, sector_identifier=sector_identifier) + return db_key(self.user_id, client_id) + + def _do_grant(self, auth_req): + client_id = auth_req['client_id'] + # The user consent module produces a Grant instance + grant = Grant(scope=auth_req['scope'], resources=[client_id]) + + # the grant is assigned to a session (user_id, client_id) + self.session_manager.set([self.user_id, client_id, grant.id], grant) + return db_key(self.user_id, client_id, grant.id) def test_collect_user_info(self): - _session_info = {"authn_req": OIDR} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") + self._create_session(OIDR) + session_id = self._do_grant(OIDR) + _uid, _cid, _gid = unpack_db_key(session_id) + + self.claims_interface.add_claims_by_scope = False + _restriction = self.claims_interface.get_claims(client_id=_cid, user_id=_uid, + scopes=OIDR["scope"]) - res = collect_user_info(self.endpoint_context, session) + res = self.claims_interface.get_user_claims("diana", _restriction) assert res == { "email": "diana@example.org", "email_verified": False, "nickname": "Dina", "given_name": "Diana", - "sub": "doe", } def test_collect_user_info_2(self): @@ -466,40 +428,41 @@ def test_collect_user_info_2(self): _req["scope"] = "openid email" del _req["claims"] - _session_info = {"authn_req": _req} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") + self._create_session(_req) + session_id = self._do_grant(_req) + _uid, _cid, _gid = unpack_db_key(session_id) - self.endpoint_context.provider_info["claims_supported"].remove("email") - self.endpoint_context.provider_info["claims_supported"].remove("email_verified") + self.claims_interface.add_claims_by_scope = True + _restriction = self.claims_interface.get_claims(client_id=_cid, user_id=_uid, + scopes=_req["scope"]) - res = collect_user_info(self.endpoint_context, session) + res = self.claims_interface.get_user_claims("diana", _restriction) - assert res == {"sub": "doe"} + assert res == {'email': 'diana@example.org', 'email_verified': False} def test_collect_user_info_scope_not_supported(self): _req = OIDR.copy() _req["scope"] = "openid email address" del _req["claims"] - _session_info = {"authn_req": _req} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") + self._create_session(_req) + session_id = self._do_grant(_req) + _uid, _cid, _gid = unpack_db_key(session_id) - # Scope address not supported + # Address asked for but not supported self.endpoint_context.provider_info["scopes_supported"] = [ "openid", "email", "offline_access", ] - res = collect_user_info(self.endpoint_context, session) + + self.claims_interface.add_claims_by_scope = True + _restriction = self.claims_interface.get_claims(client_id=_cid, user_id=_uid, + scopes=_req["scope"]) + + res = self.claims_interface.get_user_claims("diana", _restriction) assert res == { - "sub": "doe", - "email": "diana@example.org", - "email_verified": False, + 'email': 'diana@example.org', + 'email_verified': False } diff --git a/tests/test_08_session.py b/tests/test_08_session.py deleted file mode 100644 index 15ca52c..0000000 --- a/tests/test_08_session.py +++ /dev/null @@ -1,520 +0,0 @@ -import os -import shutil -import time - -import pytest -from oidcmsg.oidc import AuthorizationRequest -from oidcmsg.oidc import OpenIDRequest -from oidcmsg.storage.init import storage_factory - -from oidcendpoint import rndstr -from oidcendpoint import token_handler -from oidcendpoint.authn_event import create_authn_event -from oidcendpoint.endpoint_context import EndpointContext -from oidcendpoint.oidc.authorization import Authorization -from oidcendpoint.oidc.provider_config import ProviderConfiguration -from oidcendpoint.session import SessionDB -from oidcendpoint.session import setup_session -from oidcendpoint.sso_db import SSODb -from oidcendpoint.token_handler import UnknownToken -from oidcendpoint.token_handler import WrongTokenType -from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD -from oidcendpoint.user_info import UserInfo - -__author__ = "rohe0002" - -AREQ = AuthorizationRequest( - response_type="code", - client_id="client1", - redirect_uri="http://example.com/authz", - scope=["openid"], - state="state000", -) - -AREQN = AuthorizationRequest( - response_type="code", - client_id="client1", - redirect_uri="http://example.com/authz", - scope=["openid"], - state="state000", - nonce="something", -) - -AREQO = AuthorizationRequest( - response_type="code", - client_id="client1", - redirect_uri="http://example.com/authz", - scope=["openid", "offline_access"], - prompt="consent", - state="state000", -) - -OIDR = OpenIDRequest( - response_type="code", - client_id="client1", - redirect_uri="http://example.com/authz", - scope=["openid"], - state="state000", -) - -BASEDIR = os.path.abspath(os.path.dirname(__file__)) - - -def full_path(local_file): - return os.path.join(BASEDIR, local_file) - - -SSO_DB_CONF = { - "handler": "oidcmsg.storage.abfile.AbstractFileSystem", - "fdir": "db/sso", - "key_conv": "oidcmsg.storage.converter.QPKey", - "value_conv": "oidcmsg.storage.converter.JSON", -} - -SESSION_DB_CONF = { - "handler": "oidcmsg.storage.abfile.AbstractFileSystem", - "fdir": "db/session", - "key_conv": "oidcmsg.storage.converter.QPKey", - "value_conv": "oidcmsg.storage.converter.JSON", -} - - -def rmtree(item): - try: - shutil.rmtree(item) - except FileNotFoundError: - pass - - -class TestSessionDB(object): - @pytest.fixture(autouse=True) - def create_sdb(self): - rmtree("db/sso") - rmtree("db/session") - - passwd = rndstr(24) - _th_args = { - "code": {"lifetime": 600, "password": passwd}, - "token": {"lifetime": 3600, "password": passwd}, - "refresh": {"lifetime": 86400, "password": passwd}, - } - - _token_handler = token_handler.factory(None, **_th_args) - userinfo = UserInfo(db_file=full_path("users.json")) - self.sdb = SessionDB( - storage_factory(SESSION_DB_CONF), - _token_handler, - SSODb(SSO_DB_CONF), - userinfo, - ) - - def test_create_authz_session(self): - ae = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - self.sdb.do_sub(sid, uid="user", client_salt="client_salt") - - info = self.sdb[sid] - assert info["client_id"] == "client_id" - assert set(info.keys()) == { - "sid", - "client_id", - "authn_req", - "authn_event", - "sub", - "oauth_state", - "code", - } - - def test_create_authz_session_without_nonce(self): - ae = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - info = self.sdb[sid] - assert info["oauth_state"] == "authz" - - def test_create_authz_session_with_nonce(self): - ae = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session(ae, AREQN, client_id="client_id") - info = self.sdb[sid] - authz_request = info["authn_req"] - assert authz_request["nonce"] == "something" - - def test_create_authz_session_with_id_token(self): - ae = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session( - ae, AREQ, client_id="client_id", id_token="id_token" - ) - - info = self.sdb[sid] - assert info["id_token"] == "id_token" - - def test_create_authz_session_with_oidreq(self): - ae = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session( - ae, AREQ, client_id="client_id", oidreq=OIDR - ) - info = self.sdb[sid] - assert "id_token" not in info - assert "oidreq" in info - - def test_create_authz_session_with_sector_id(self): - ae = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session( - ae, AREQ, client_id="client_id", oidreq=OIDR - ) - self.sdb.do_sub( - sid, "user1", "client_salt", "http://example.com/si.jwt", "pairwise" - ) - - info_1 = self.sdb[sid].copy() - assert "id_token" not in info_1 - assert "oidreq" in info_1 - assert info_1["sub"] != "sub" - - self.sdb.do_sub( - sid, "user2", "client_salt", "http://example.net/si.jwt", "pairwise" - ) - - info_2 = self.sdb[sid] - assert info_2["sub"] != "sub" - assert info_2["sub"] != info_1["sub"] - - def test_upgrade_to_token(self): - ae1 = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - grant = self.sdb[sid]["code"] - _dict = self.sdb.upgrade_to_token(grant) - - print(_dict.keys()) - assert set(_dict.keys()) == { - "sid", - "authn_event", - "authn_req", - "access_token", - "token_type", - "client_id", - "oauth_state", - "expires_in", - "expires_at", - "code_is_used" - } - - # can't update again - # with pytest.raises(AccessCodeUsed): - print(self.sdb.upgrade_to_token(grant)) - - def test_upgrade_to_token_refresh(self): - ae1 = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session(ae1, AREQO, client_id="client_id") - self.sdb.do_sub(sid, "user", ae1["salt"]) - grant = self.sdb[sid]["code"] - # Issue an access token trading in the access grant code - _dict = self.sdb.upgrade_to_token(grant, issue_refresh=True) - - print(_dict.keys()) - assert set(_dict.keys()) == { - "sid", - "authn_event", - "authn_req", - "access_token", - "sub", - "token_type", - "client_id", - "oauth_state", - "refresh_token", - "expires_in", - "expires_at", - "code_is_used" - } - - # You can't refresh a token using the token itself - with pytest.raises(WrongTokenType): - self.sdb.refresh_token(_dict["access_token"]) - - def test_upgrade_to_token_with_id_token_and_oidreq(self): - ae2 = create_authn_event("another_user_id", "salt") - sid = self.sdb.create_authz_session(ae2, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - grant = self.sdb[sid]["code"] - - _dict = self.sdb.upgrade_to_token(grant, id_token="id_token", oidreq=OIDR) - print(_dict.keys()) - assert set(_dict.keys()) == { - "sid", - "authn_event", - "authn_req", - "oidreq", - "access_token", - "id_token", - "token_type", - "client_id", - "oauth_state", - "expires_in", - "expires_at", - "code_is_used" - } - - assert _dict["id_token"] == "id_token" - assert isinstance(_dict["oidreq"], OpenIDRequest) - - def test_refresh_token(self): - ae = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - grant = self.sdb[sid]["code"] - - dict1 = self.sdb.upgrade_to_token(grant, issue_refresh=True).copy() - rtoken = dict1["refresh_token"] - dict2 = self.sdb.refresh_token(rtoken, AREQ["client_id"]) - - assert dict1["access_token"] != dict2["access_token"] - - with pytest.raises(WrongTokenType): - self.sdb.refresh_token(dict2["access_token"], AREQ["client_id"]) - - def test_refresh_token_cleared_session(self): - ae = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - grant = self.sdb[sid]["code"] - dict1 = self.sdb.upgrade_to_token(grant, issue_refresh=True) - ac1 = dict1["access_token"] - - # Purge the SessionDB - self.sdb._db = {} - - rtoken = dict1["refresh_token"] - with pytest.raises(UnknownToken): - self.sdb.refresh_token(rtoken, AREQ["client_id"]) - - def test_is_valid(self): - ae1 = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - grant = self.sdb[sid]["code"] - - assert self.sdb.is_valid("code", grant) - - sinfo = self.sdb.upgrade_to_token(grant, issue_refresh=True) - assert not self.sdb.is_valid("code", grant) - access_token = sinfo["access_token"] - assert self.sdb.is_valid("access_token", access_token) - - refresh_token = sinfo["refresh_token"] - sinfo = self.sdb.refresh_token(refresh_token, AREQ["client_id"]) - access_token2 = sinfo["access_token"] - assert self.sdb.is_valid("access_token", access_token2) - - # The old access code should be invalid - try: - self.sdb.is_valid("access_token", access_token) - except KeyError: - pass - - def test_valid_grant(self): - ae = create_authn_event("another:user", "salt") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - grant = self.sdb[sid]["code"] - - assert self.sdb.is_valid("code", grant) - - def test_revoke_token(self): - ae1 = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - - grant = self.sdb[sid]["code"] - tokens = self.sdb.upgrade_to_token(grant, issue_refresh=True) - access_token = tokens["access_token"] - refresh_token = tokens["refresh_token"] - - assert self.sdb.is_valid("access_token", access_token) - - self.sdb.revoke_token(sid, "access_token") - assert not self.sdb.is_valid("access_token", access_token) - - sinfo = self.sdb.refresh_token(refresh_token, AREQ["client_id"]) - access_token = sinfo["access_token"] - assert self.sdb.is_valid("access_token", access_token) - - self.sdb.revoke_token(sid, "refresh_token") - assert not self.sdb.is_valid("refresh_token", refresh_token) - - ae2 = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session(ae2, AREQ, client_id="client_2") - - grant = self.sdb[sid]["code"] - self.sdb.revoke_token(sid, "code") - assert not self.sdb.is_valid("code", grant) - - def test_sub_to_authn_event(self): - ae = create_authn_event("sub", "salt", time_stamp=time.time()) - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - sub = self.sdb.do_sub(sid, "user", "client_salt") - - # given the sub find out whether the authn event is still valid - sids = self.sdb.get_sids_by_sub(sub) - ae = self.sdb[sids[0]]["authn_event"] - assert ae.valid() - - def test_do_sub_deterministic(self): - ae = create_authn_event("tester", "random_value") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - self.sdb.do_sub(sid, "user", "other_random_value") - - info = self.sdb[sid] - assert ( - info["sub"] - == "d657bddf3d30970aa681663978ea84e26553ead03cb6fe8fcfa6523f2bcd0ad2" - ) - - self.sdb.do_sub( - sid, - "user", - "other_random_value", - sector_id="http://example.com", - subject_type="pairwise", - ) - info2 = self.sdb[sid] - assert ( - info2["sub"] - == "1442ceb13a822e802f85832ce93a8fda011e32a3363834dd1db3f9aa211065bd" - ) - - self.sdb.do_sub( - sid, - "user", - "another_random_value", - sector_id="http://other.example.com", - subject_type="pairwise", - ) - - info2 = self.sdb[sid] - assert ( - info2["sub"] - == "56e0a53d41086e7b22d78d52ee461655e9b090d50a0663d16136ea49a56c9bec" - ) - - def test_match_session(self): - ae1 = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - self.sdb.sso_db.map_sid2uid(sid, "uid") - - res = self.sdb.match_session("uid", client_id="client_id") - assert res == sid - - def test_get_token(self): - ae1 = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - self.sdb.sso_db.map_sid2uid(sid, "uid") - - grant = self.sdb.get_token(sid) - assert self.sdb.is_valid("code", grant) - assert self.sdb.handler.type(grant) == "A" - - -KEYDEFS = [ - {"type": "RSA", "key": "", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["sig"]}, -] - -conf = { - "issuer": "https://example.com/", - "password": "mycket hemligt", - "token_expires_in": 600, - "grant_expires_in": 300, - "refresh_token_expires_in": 86400, - "verify_ssl": False, - "capabilities": {}, - "keys": { - "uri_path": "static/jwks.json", - "key_defs": KEYDEFS, - "private_path": "own/jwks.json", - }, - "endpoint": { - "provider_config": { - "path": ".well-known/openid-configuration", - "class": ProviderConfiguration, - "kwargs": {}, - }, - "authorization_endpoint": { - "path": "authorization", - "class": Authorization, - "kwargs": {}, - }, - }, - "authentication": { - "anon": { - "acr": INTERNETPROTOCOLPASSWORD, - "class": "oidcendpoint.user_authn.user.NoAuthn", - "kwargs": {"user": "diana"}, - } - }, - "userinfo": {"class": UserInfo, "kwargs": {"db_file": full_path("users.json")}}, - "template_dir": "template", - "sso_db": SSO_DB_CONF, - "session_db": SESSION_DB_CONF, -} - - -def test_setup_session(): - endpoint_context = EndpointContext(conf) - uid = "_user_" - client_id = "EXTERNAL" - areq = None - acr = None - sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt") - assert sid - - -def test_setup_session_upgrade_to_token(): - endpoint_context = EndpointContext(conf) - uid = "_user_" - client_id = "EXTERNAL" - areq = None - acr = None - sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt") - assert sid - code = endpoint_context.sdb[sid]["code"] - assert code - - res = endpoint_context.sdb.upgrade_to_token(code) - assert "access_token" in res - - endpoint_context.sdb.revoke_uid("_user_") - assert endpoint_context.sdb.is_session_revoked(sid) - - -def make_sub_uid(uid, **kwargs): - return uid - - -def test_sub_minting_function(): - conf["sub_func"] = {"public": {"function": make_sub_uid}} - - endpoint_context = EndpointContext(conf) - uid = "_user_" - client_id = "EXTERNAL" - areq = None - acr = None - sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt") - assert endpoint_context.sdb[sid]["sub"] == uid - - -class SubMinter(object): - def __call__(self, *args, **kwargs): - return args[0] - - -def test_sub_minting_class(): - conf["sub_func"] = {"public": {"class": SubMinter}} - - endpoint_context = EndpointContext(conf) - uid = "_user_" - client_id = "EXTERNAL" - areq = None - acr = None - sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt") - assert endpoint_context.sdb[sid]["sub"] == uid diff --git a/tests/test_10_oidc_authz.py b/tests/test_10_oidc_authz.py.no similarity index 100% rename from tests/test_10_oidc_authz.py rename to tests/test_10_oidc_authz.py.no diff --git a/tests/test_24_oauth2_authorization_endpoint.py b/tests/test_24_oauth2_authorization_endpoint.py index 8d4cbcd..cbda2fc 100755 --- a/tests/test_24_oauth2_authorization_endpoint.py +++ b/tests/test_24_oauth2_authorization_endpoint.py @@ -1,12 +1,10 @@ +from http.cookies import SimpleCookie import io import json import os -from http.cookies import SimpleCookie from urllib.parse import parse_qs from urllib.parse import urlparse -import pytest -import yaml from cryptojwt import KeyJar from cryptojwt.jwt import utc_time_sans_frac from cryptojwt.utils import as_bytes @@ -17,7 +15,10 @@ from oidcmsg.oauth2 import AuthorizationRequest from oidcmsg.oauth2 import AuthorizationResponse from oidcmsg.time_util import in_a_while +import pytest +import yaml +from oidcendpoint.authn_event import create_authn_event from oidcendpoint.common.authorization import FORM_POST from oidcendpoint.common.authorization import get_uri from oidcendpoint.common.authorization import inputs @@ -29,12 +30,12 @@ from oidcendpoint.exception import NoSuchAuthentication from oidcendpoint.exception import RedirectURIError from oidcendpoint.exception import ToOld -from oidcendpoint.exception import UnAuthorizedClient from oidcendpoint.exception import UnAuthorizedClientScope from oidcendpoint.exception import UnknownClient +from oidcendpoint.grant import Grant from oidcendpoint.id_token import IDToken from oidcendpoint.oauth2.authorization import Authorization -from oidcendpoint.session import SessionInfo +from oidcendpoint.session_management import db_key from oidcendpoint.user_info import UserInfo KEYDEFS = [ @@ -199,6 +200,8 @@ def create_endpoint(self): endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint = endpoint_context.endpoint["authorization"] + self.session_manager = endpoint_context.session_manager + self.user_id = "diana" self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") @@ -206,12 +209,52 @@ def create_endpoint(self): "client_1", "hemligtkodord1234567890" ) + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + client_id = auth_req['client_id'] + ae = create_authn_event(self.user_id, self.session_manager.salt) + self.session_manager.create_session(ae, auth_req, self.user_id, client_id=client_id, + sub_type=sub_type, sector_identifier=sector_identifier) + return db_key(self.user_id, client_id) + + def _do_grant(self, auth_req): + client_id = auth_req['client_id'] + # The user consent module produces a Grant instance + grant = Grant(scope=auth_req['scope'], resources=[client_id]) + + # the grant is assigned to a session (user_id, client_id) + self.session_manager.set([self.user_id, client_id, grant.id], grant) + return db_key(self.user_id, client_id, grant.id) + + # def _mint_code(self, grant, client_id): + # sid = db_key(self.user_id, client_id) + # # Constructing an authorization code is now done + # return grant.mint_token( + # 'authorization_code', + # value=self.session_manager.token_handler["code"](sid), + # expires_at=time_sans_frac() + 300 # 5 minutes from now + # ) + # + # def _mint_access_token(self, grant, client_id, token_ref=None): + # _csi = self.session_manager.get([self.user_id, client_id]) + # return grant.mint_token( + # 'access_token', + # value=self.session_manager.token_handler["access_token"]( + # db_key(self.user_id, client_id), + # client_id=client_id, + # aud=grant.resources, + # user_claims=None, + # scope=grant.scope, + # sub=_csi['sub'] + # ), + # expires_at=time_sans_frac() + 900, # 15 minutes from now + # based_on=token_ref # Means the token (tok) was used to mint this token + # ) + def test_init(self): assert self.endpoint def test_parse(self): _req = self.endpoint.parse_request(AUTH_REQ_DICT) - assert isinstance(_req, AuthorizationRequest) assert set(_req.keys()) == set(AUTH_REQ.keys()) @@ -393,24 +436,16 @@ def test_create_authn_response(self): scope="openid", ) - _ec = self.endpoint.endpoint_context - _ec.sdb["session_id"] = SessionInfo( - authn_req=request, - uid="diana", - sub="abcdefghijkl", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - }, - ) - _ec.cdb["client_id"] = { + self.endpoint.endpoint_context.cdb["client_id"] = { "client_id": "client_id", "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "ES256", } - resp = self.endpoint.create_authn_response(request, "session_id") + self._create_session(request) + session_id = self._do_grant(request) + + resp = self.endpoint.create_authn_response(request, session_id) assert isinstance(resp["response_args"], AuthorizationErrorResponse) def test_setup_auth(self): @@ -512,21 +547,13 @@ def test_setup_auth_user(self): "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "RS256", } - _ec = self.endpoint.endpoint_context - _ec.sdb["session_id"] = SessionInfo( - authn_req=request, - uid="diana", - sub="abcdefghijkl", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - }, - ) - item = _ec.authn_broker.db["anon"] + pre_sid = self._create_session(request) + session_id = self._do_grant(request) + + item = self.endpoint.endpoint_context.authn_broker.db["anon"] item["method"].user = b64e( - as_bytes(json.dumps({"uid": "krall", "sid": "session_id"})) + as_bytes(json.dumps({"uid": "krall", "sid": session_id})) ) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) @@ -548,22 +575,18 @@ def test_setup_auth_session_revoked(self): "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "RS256", } - _ec = self.endpoint.endpoint_context - _ec.sdb["session_id"] = SessionInfo( - authn_req=request, - uid="diana", - sub="abcdefghijkl", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - }, - revoked=True, - ) + pre_sid = self._create_session(request) + session_id = self._do_grant(request) + + _mngr = self.endpoint.endpoint_context.session_manager + _csi = _mngr[session_id] + _csi.revoked = True + + _ec = self.endpoint.endpoint_context item = _ec.authn_broker.db["anon"] item["method"].user = b64e( - as_bytes(json.dumps({"uid": "krall", "sid": "session_id"})) + as_bytes(json.dumps({"uid": "krall", "sid": session_id})) ) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) diff --git a/tests/test_24_oauth2_authorization_endpoint_jar.py b/tests/test_24_oauth2_authorization_endpoint_jar.py index efc1854..24b7644 100755 --- a/tests/test_24_oauth2_authorization_endpoint_jar.py +++ b/tests/test_24_oauth2_authorization_endpoint_jar.py @@ -182,6 +182,8 @@ def create_endpoint(self): endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint = endpoint_context.endpoint["authorization"] + self.session_manager = endpoint_context.session_manager + self.user_id = "diana" self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") diff --git a/tests/test_24_oidc_authorization_endpoint.py b/tests/test_24_oidc_authorization_endpoint.py index 2130b5d..214b74a 100755 --- a/tests/test_24_oidc_authorization_endpoint.py +++ b/tests/test_24_oidc_authorization_endpoint.py @@ -4,14 +4,17 @@ from urllib.parse import parse_qs from urllib.parse import urlparse -import pytest -import responses -import yaml from cryptojwt import JWT from cryptojwt import KeyJar from cryptojwt.jwt import utc_time_sans_frac from cryptojwt.utils import as_bytes from cryptojwt.utils import b64e + +from oidcendpoint.grant import Grant +from oidcendpoint.session_management import db_key + +from oidcendpoint.authn_event import create_authn_event +from oidcendpoint.oidc.authorization import Authorization from oidcmsg.exception import ParameterError from oidcmsg.exception import URIError from oidcmsg.oauth2 import AuthorizationErrorResponse @@ -20,7 +23,11 @@ from oidcmsg.oidc import AuthorizationResponse from oidcmsg.oidc import verified_claim_name from oidcmsg.oidc import verify_id_token +import pytest +import responses +import yaml +from oidcendpoint.authz import AuthzHandling from oidcendpoint.common.authorization import FORM_POST from oidcendpoint.common.authorization import join_query from oidcendpoint.common.authorization import verify_uri @@ -32,12 +39,10 @@ from oidcendpoint.exception import NoSuchAuthentication from oidcendpoint.exception import RedirectURIError from oidcendpoint.exception import ToOld -from oidcendpoint.exception import UnAuthorizedClient from oidcendpoint.exception import UnknownClient from oidcendpoint.id_token import IDToken from oidcendpoint.login_hint import LoginHint2Acrs from oidcendpoint.oidc import userinfo -from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.authorization import acr_claims from oidcendpoint.oidc.authorization import get_uri from oidcendpoint.oidc.authorization import inputs @@ -45,7 +50,6 @@ from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration from oidcendpoint.oidc.token import AccessToken -from oidcendpoint.session import SessionInfo from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_authn.authn_context import UNSPECIFIED from oidcendpoint.user_authn.authn_context import init_method @@ -110,7 +114,6 @@ def full_path(local_file): USERINFO_db = json.loads(open(full_path("users.json")).read()) - client_yaml = """ oidc_clients: client_1: @@ -222,6 +225,7 @@ def create_endpoint(self): }, "userinfo": {"class": UserInfo, "kwargs": {"db": USERINFO_db}}, "template_dir": "template", + "authz": {"class": AuthzHandling, "kwargs": {}}, "cookie_dealer": { "class": CookieDealer, "kwargs": { @@ -240,12 +244,15 @@ def create_endpoint(self): }, } endpoint_context = EndpointContext(conf) + _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["oidc_clients"] endpoint_context.keyjar.import_jwks( endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint = endpoint_context.endpoint["authorization"] + self.session_manager = endpoint_context.session_manager + self.user_id = "diana" self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") @@ -256,6 +263,22 @@ def create_endpoint(self): def test_init(self): assert self.endpoint + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + client_id = auth_req['client_id'] + ae = create_authn_event(self.user_id, self.session_manager.salt) + self.session_manager.create_session(ae, auth_req, self.user_id, client_id=client_id, + sub_type=sub_type, sector_identifier=sector_identifier) + return db_key(self.user_id, client_id) + + def _do_grant(self, auth_req): + client_id = auth_req['client_id'] + # The user consent module produces a Grant instance + grant = Grant(scope=auth_req['scope'], resources=[client_id]) + + # the grant is assigned to a session (user_id, client_id) + self.session_manager.set([self.user_id, client_id, grant.id], grant) + return db_key(self.user_id, client_id, grant.id) + def test_parse(self): _req = self.endpoint.parse_request(AUTH_REQ_DICT) diff --git a/tests/test_24_oidc_authorization_endpoint.py.no b/tests/test_24_oidc_authorization_endpoint.py.no new file mode 100755 index 0000000..2130b5d --- /dev/null +++ b/tests/test_24_oidc_authorization_endpoint.py.no @@ -0,0 +1,933 @@ +import io +import json +import os +from urllib.parse import parse_qs +from urllib.parse import urlparse + +import pytest +import responses +import yaml +from cryptojwt import JWT +from cryptojwt import KeyJar +from cryptojwt.jwt import utc_time_sans_frac +from cryptojwt.utils import as_bytes +from cryptojwt.utils import b64e +from oidcmsg.exception import ParameterError +from oidcmsg.exception import URIError +from oidcmsg.oauth2 import AuthorizationErrorResponse +from oidcmsg.oauth2 import ResponseMessage +from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.oidc import AuthorizationResponse +from oidcmsg.oidc import verified_claim_name +from oidcmsg.oidc import verify_id_token + +from oidcendpoint.common.authorization import FORM_POST +from oidcendpoint.common.authorization import join_query +from oidcendpoint.common.authorization import verify_uri +from oidcendpoint.cookie import CookieDealer +from oidcendpoint.cookie import cookie_value +from oidcendpoint.cookie import new_cookie +from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.exception import InvalidRequest +from oidcendpoint.exception import NoSuchAuthentication +from oidcendpoint.exception import RedirectURIError +from oidcendpoint.exception import ToOld +from oidcendpoint.exception import UnAuthorizedClient +from oidcendpoint.exception import UnknownClient +from oidcendpoint.id_token import IDToken +from oidcendpoint.login_hint import LoginHint2Acrs +from oidcendpoint.oidc import userinfo +from oidcendpoint.oidc.authorization import Authorization +from oidcendpoint.oidc.authorization import acr_claims +from oidcendpoint.oidc.authorization import get_uri +from oidcendpoint.oidc.authorization import inputs +from oidcendpoint.oidc.authorization import re_authenticate +from oidcendpoint.oidc.provider_config import ProviderConfiguration +from oidcendpoint.oidc.registration import Registration +from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.session import SessionInfo +from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from oidcendpoint.user_authn.authn_context import UNSPECIFIED +from oidcendpoint.user_authn.authn_context import init_method +from oidcendpoint.user_authn.user import NoAuthn +from oidcendpoint.user_authn.user import UserAuthnMethod +from oidcendpoint.user_authn.user import UserPassJinja2 +from oidcendpoint.user_info import UserInfo +from oidcendpoint.util import JSONDictDB + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]} + # {"type": "EC", "crv": "P-256", "use": ["sig"]} +] + +RESPONSE_TYPES_SUPPORTED = [ + ["code"], + ["token"], + ["id_token"], + ["code", "token"], + ["code", "id_token"], + ["id_token", "token"], + ["code", "token", "id_token"], + ["none"], +] + +CAPABILITIES = { + "subject_types_supported": ["public", "pairwise"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], +} + +CLAIMS = {"id_token": {"given_name": {"essential": True}, "nickname": None}} + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid"], + state="STATE", + response_type="code", +) + +AUTH_REQ_DICT = AUTH_REQ.to_dict() + +AUTH_REQ_2 = AuthorizationRequest( + client_id="client3", + redirect_uri="https://127.0.0.1:8090/authz_cb/bobcat", + scope=["openid"], + state="STATE2", + response_type="code", +) + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +USERINFO_db = json.loads(open(full_path("users.json")).read()) + + +client_yaml = """ +oidc_clients: + client_1: + "client_secret": 'hemligtkodord' + "redirect_uris": + - ['https://example.com/cb', ''] + "client_salt": "salted" + 'token_endpoint_auth_method': 'client_secret_post' + 'response_types': + - 'code' + - 'token' + - 'code id_token' + - 'id_token' + - 'code id_token token' + client2: + client_secret: "spraket_sr.se" + redirect_uris: + - ['https://app1.example.net/foo', ''] + - ['https://app2.example.net/bar', ''] + response_types: + - code + client3: + client_secret: '2222222222222222222222222222222222222222' + redirect_uris: + - ['https://127.0.0.1:8090/authz_cb/bobcat', ''] + post_logout_redirect_uris: + - ['https://openidconnect.net/', ''] + response_types: + - code +""" + + +class TestEndpoint(object): + @pytest.fixture(autouse=True) + def create_endpoint(self): + conf = { + "issuer": "https://example.com/", + "password": "mycket hemligt zebra", + "token_expires_in": 600, + "grant_expires_in": 300, + "refresh_token_expires_in": 86400, + "verify_ssl": False, + "capabilities": CAPABILITIES, + "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS}, + "id_token": { + "class": IDToken, + "kwargs": { + "available_claims": { + "email": {"essential": True}, + "email_verified": {"essential": True}, + } + }, + }, + "endpoint": { + "provider_config": { + "path": "{}/.well-known/openid-configuration", + "class": ProviderConfiguration, + "kwargs": {}, + }, + "registration": { + "path": "{}/registration", + "class": Registration, + "kwargs": {}, + }, + "authorization": { + "path": "{}/authorization", + "class": Authorization, + "kwargs": { + "response_types_supported": [ + " ".join(x) for x in RESPONSE_TYPES_SUPPORTED + ], + "response_modes_supported": ["query", "fragment", "form_post"], + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, + }, + }, + "token": { + "path": "token", + "class": AccessToken, + "kwargs": { + "client_authn_method": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ] + }, + }, + "userinfo": { + "path": "userinfo", + "class": userinfo.UserInfo, + "kwargs": { + "db_file": "users.json", + "claim_types_supported": [ + "normal", + "aggregated", + "distributed", + ], + }, + }, + }, + "authentication": { + "anon": { + "acr": "http://www.swamid.se/policy/assurance/al1", + "class": "oidcendpoint.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "userinfo": {"class": UserInfo, "kwargs": {"db": USERINFO_db}}, + "template_dir": "template", + "cookie_dealer": { + "class": CookieDealer, + "kwargs": { + "sign_key": "ghsNKDDLshZTPn974nOsIGhedULrsqnsGoBFBLwUKuJhE2ch", + "default_values": { + "name": "oidcop", + "domain": "127.0.0.1", + "path": "/", + "max_age": 3600, + }, + }, + }, + "login_hint2acrs": { + "class": LoginHint2Acrs, + "kwargs": {"scheme_map": {"email": [INTERNETPROTOCOLPASSWORD]}}, + }, + } + endpoint_context = EndpointContext(conf) + _clients = yaml.safe_load(io.StringIO(client_yaml)) + endpoint_context.cdb = _clients["oidc_clients"] + endpoint_context.keyjar.import_jwks( + endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] + ) + self.endpoint = endpoint_context.endpoint["authorization"] + + self.rp_keyjar = KeyJar() + self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") + self.endpoint.endpoint_context.keyjar.add_symmetric( + "client_1", "hemligtkodord1234567890" + ) + + def test_init(self): + assert self.endpoint + + def test_parse(self): + _req = self.endpoint.parse_request(AUTH_REQ_DICT) + + assert isinstance(_req, AuthorizationRequest) + assert set(_req.keys()) == set(AUTH_REQ.keys()) + + def test_process_request(self): + _pr_resp = self.endpoint.parse_request(AUTH_REQ_DICT) + _resp = self.endpoint.process_request(_pr_resp) + assert set(_resp.keys()) == { + "response_args", + "fragment_enc", + "return_uri", + "cookie", + } + + def test_do_response_code(self): + _pr_resp = self.endpoint.parse_request(AUTH_REQ_DICT) + _resp = self.endpoint.process_request(_pr_resp) + msg = self.endpoint.do_response(**_resp) + assert isinstance(msg, dict) + _msg = parse_qs(msg["response"]) + assert _msg + part = urlparse(msg["response"]) + assert part.fragment == "" + assert part.query + _query = parse_qs(part.query) + assert _query + assert "code" in _query + + def test_do_response_id_token_no_nonce(self): + _orig_req = AUTH_REQ_DICT.copy() + _orig_req["response_type"] = "id_token" + _pr_resp = self.endpoint.parse_request(_orig_req) + # Missing nonce + assert isinstance(_pr_resp, ResponseMessage) + + def test_do_response_id_token(self): + _orig_req = AUTH_REQ_DICT.copy() + _orig_req["response_type"] = "id_token" + _orig_req["nonce"] = "rnd_nonce" + _pr_resp = self.endpoint.parse_request(_orig_req) + _resp = self.endpoint.process_request(_pr_resp) + msg = self.endpoint.do_response(**_resp) + assert isinstance(msg, dict) + part = urlparse(msg["response"]) + assert part.query == "" + assert part.fragment + _frag_msg = parse_qs(part.fragment) + assert _frag_msg + assert "id_token" in _frag_msg + assert "code" not in _frag_msg + assert "token" not in _frag_msg + + def test_do_response_id_token_token(self): + _orig_req = AUTH_REQ_DICT.copy() + _orig_req["response_type"] = "id_token token" + _orig_req["nonce"] = "rnd_nonce" + _pr_resp = self.endpoint.parse_request(_orig_req) + assert isinstance(_pr_resp, AuthorizationErrorResponse) + assert _pr_resp["error"] == "invalid_request" + + def test_do_response_code_token(self): + _orig_req = AUTH_REQ_DICT.copy() + _orig_req["response_type"] = "code token" + _pr_resp = self.endpoint.parse_request(_orig_req) + assert isinstance(_pr_resp, AuthorizationErrorResponse) + assert _pr_resp["error"] == "invalid_request" + + def test_do_response_code_id_token(self): + _orig_req = AUTH_REQ_DICT.copy() + _orig_req["response_type"] = "code id_token" + _orig_req["nonce"] = "rnd_nonce" + _pr_resp = self.endpoint.parse_request(_orig_req) + _resp = self.endpoint.process_request(_pr_resp) + msg = self.endpoint.do_response(**_resp) + assert isinstance(msg, dict) + part = urlparse(msg["response"]) + assert part.query == "" + assert part.fragment + _frag_msg = parse_qs(part.fragment) + assert _frag_msg + assert "id_token" in _frag_msg + assert "code" in _frag_msg + assert "access_token" not in _frag_msg + + def test_do_response_code_id_token_token(self): + _orig_req = AUTH_REQ_DICT.copy() + _orig_req["response_type"] = "code id_token token" + _orig_req["nonce"] = "rnd_nonce" + _pr_resp = self.endpoint.parse_request(_orig_req) + _resp = self.endpoint.process_request(_pr_resp) + msg = self.endpoint.do_response(**_resp) + assert isinstance(msg, dict) + part = urlparse(msg["response"]) + assert part.query == "" + assert part.fragment + _frag_msg = parse_qs(part.fragment) + assert _frag_msg + assert "id_token" in _frag_msg + assert "code" in _frag_msg + assert "access_token" in _frag_msg + + def test_id_token_claims(self): + _req = AUTH_REQ_DICT.copy() + _req["claims"] = CLAIMS + _req["response_type"] = "code id_token token" + _req["nonce"] = "rnd_nonce" + _pr_resp = self.endpoint.parse_request(_req) + _resp = self.endpoint.process_request(_pr_resp) + idt = verify_id_token( + _resp["response_args"], keyjar=self.endpoint.endpoint_context.keyjar + ) + assert idt + # from claims + assert "given_name" in _resp["response_args"]["__verified_id_token"] + # from config + assert "email" in _resp["response_args"]["__verified_id_token"] + + def test_re_authenticate(self): + request = {"prompt": "login"} + authn = UserAuthnMethod(self.endpoint.endpoint_context) + assert re_authenticate(request, authn) + + def test_id_token_acr(self): + _req = AUTH_REQ_DICT.copy() + _req["claims"] = { + "id_token": {"acr": {"value": "http://www.swamid.se/policy/assurance/al1"}} + } + _req["response_type"] = "code id_token token" + _req["nonce"] = "rnd_nonce" + _pr_resp = self.endpoint.parse_request(_req) + _resp = self.endpoint.process_request(_pr_resp) + res = verify_id_token( + _resp["response_args"], keyjar=self.endpoint.endpoint_context.keyjar + ) + assert res + res = _resp["response_args"][verified_claim_name("id_token")] + assert res["acr"] == "http://www.swamid.se/policy/assurance/al1" + + def test_verify_uri_unknown_client(self): + request = {"redirect_uri": "https://rp.example.com/cb"} + with pytest.raises(UnknownClient): + verify_uri(self.endpoint.endpoint_context, request, "redirect_uri") + + def test_verify_uri_fragment(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = {"redirect_uri": ["https://rp.example.com/auth_cb"]} + request = {"redirect_uri": "https://rp.example.com/cb#foobar"} + with pytest.raises(URIError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_verify_uri_noregistered(self): + _ec = self.endpoint.endpoint_context + request = {"redirect_uri": "https://rp.example.com/cb"} + + with pytest.raises(KeyError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_verify_uri_unregistered(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = { + "redirect_uris": [("https://rp.example.com/auth_cb", {})] + } + + request = {"redirect_uri": "https://rp.example.com/cb"} + + with pytest.raises(RedirectURIError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_verify_uri_qp_match(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = { + "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})] + } + + request = {"redirect_uri": "https://rp.example.com/cb?foo=bar"} + + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_verify_uri_qp_mismatch(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = { + "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})] + } + + request = {"redirect_uri": "https://rp.example.com/cb?foo=bob"} + with pytest.raises(ValueError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + request = {"redirect_uri": "https://rp.example.com/cb?foo=bar&foo=kex"} + with pytest.raises(ValueError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + request = {"redirect_uri": "https://rp.example.com/cb"} + with pytest.raises(ValueError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + request = {"redirect_uri": "https://rp.example.com/cb?foo=bar&level=low"} + with pytest.raises(ValueError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_verify_uri_qp_missing(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = { + "redirect_uris": [ + ("https://rp.example.com/cb", {"foo": ["bar"], "state": ["low"]}) + ] + } + + request = {"redirect_uri": "https://rp.example.com/cb?foo=bar"} + with pytest.raises(ValueError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_verify_uri_qp_missing_val(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = { + "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar", "low"]})] + } + + request = {"redirect_uri": "https://rp.example.com/cb?foo=bar"} + with pytest.raises(ValueError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_verify_uri_no_registered_qp(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} + + request = {"redirect_uri": "https://rp.example.com/cb?foo=bob"} + with pytest.raises(ValueError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_get_uri(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} + + request = { + "redirect_uri": "https://rp.example.com/cb", + "client_id": "client_id", + } + + assert get_uri(_ec, request, "redirect_uri") == "https://rp.example.com/cb" + + def test_get_uri_no_redirect_uri(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} + + request = {"client_id": "client_id"} + + assert get_uri(_ec, request, "redirect_uri") == "https://rp.example.com/cb" + + def test_get_uri_no_registered(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} + + request = {"client_id": "client_id"} + + with pytest.raises(ParameterError): + get_uri(_ec, request, "post_logout_redirect_uri") + + def test_get_uri_more_then_one_registered(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = { + "redirect_uris": [ + ("https://rp.example.com/cb", {}), + ("https://rp.example.org/authz_cb", {"foo": "bar"}), + ] + } + + request = {"client_id": "client_id"} + + with pytest.raises(ParameterError): + get_uri(_ec, request, "redirect_uri") + + def test_create_authn_response(self): + request = AuthorizationRequest( + client_id="client_id", + redirect_uri="https://rp.example.com/cb", + response_type=["id_token"], + state="state", + nonce="nonce", + scope="openid", + ) + + _ec = self.endpoint.endpoint_context + _ec.sdb["session_id"] = SessionInfo( + authn_req=request, + uid="diana", + sub="abcdefghijkl", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + }, + ) + _ec.cdb["client_id"] = { + "client_id": "client_id", + "redirect_uris": [("https://rp.example.com/cb", {})], + "id_token_signed_response_alg": "ES256", + } + + resp = self.endpoint.create_authn_response(request, "session_id") + assert isinstance(resp["response_args"], AuthorizationErrorResponse) + + def test_setup_auth(self): + request = AuthorizationRequest( + client_id="client_id", + redirect_uri="https://rp.example.com/cb", + response_type=["id_token"], + state="state", + nonce="nonce", + scope="openid", + ) + redirect_uri = request["redirect_uri"] + cinfo = { + "client_id": "client_id", + "redirect_uris": [("https://rp.example.com/cb", {})], + "id_token_signed_response_alg": "RS256", + } + + kaka = self.endpoint.endpoint_context.cookie_dealer.create_cookie( + "value", "sso" + ) + + res = self.endpoint.setup_auth(request, redirect_uri, cinfo, kaka) + assert set(res.keys()) == {"authn_event", "identity", "user"} + + def test_setup_auth_error(self): + request = AuthorizationRequest( + client_id="client_id", + redirect_uri="https://rp.example.com/cb", + response_type=["id_token"], + state="state", + nonce="nonce", + scope="openid", + ) + redirect_uri = request["redirect_uri"] + cinfo = { + "client_id": "client_id", + "redirect_uris": [("https://rp.example.com/cb", {})], + "id_token_signed_response_alg": "RS256", + } + + item = self.endpoint.endpoint_context.authn_broker.db["anon"] + item["method"].fail = NoSuchAuthentication + + res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) + assert set(res.keys()) == {"function", "args"} + + item["method"].fail = ToOld + + res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) + assert set(res.keys()) == {"function", "args"} + + item["method"].file = "" + + def test_setup_auth_user(self): + request = AuthorizationRequest( + client_id="client_id", + redirect_uri="https://rp.example.com/cb", + response_type=["id_token"], + state="state", + nonce="nonce", + scope="openid", + ) + redirect_uri = request["redirect_uri"] + cinfo = { + "client_id": "client_id", + "redirect_uris": [("https://rp.example.com/cb", {})], + "id_token_signed_response_alg": "RS256", + } + _ec = self.endpoint.endpoint_context + _ec.sdb["session_id"] = SessionInfo( + authn_req=request, + uid="diana", + sub="abcdefghijkl", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + }, + ) + + item = _ec.authn_broker.db["anon"] + item["method"].user = b64e( + as_bytes(json.dumps({"uid": "krall", "sid": "session_id"})) + ) + + res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) + assert set(res.keys()) == {"authn_event", "identity", "user"} + assert res["identity"]["uid"] == "krall" + + def test_setup_auth_session_revoked(self): + request = AuthorizationRequest( + client_id="client_id", + redirect_uri="https://rp.example.com/cb", + response_type=["id_token"], + state="state", + nonce="nonce", + scope="openid", + ) + redirect_uri = request["redirect_uri"] + cinfo = { + "client_id": "client_id", + "redirect_uris": [("https://rp.example.com/cb", {})], + "id_token_signed_response_alg": "RS256", + } + _ec = self.endpoint.endpoint_context + _ec.sdb["session_id"] = SessionInfo( + authn_req=request, + uid="diana", + sub="abcdefghijkl", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + }, + revoked=True, + ) + + item = _ec.authn_broker.db["anon"] + item["method"].user = b64e( + as_bytes(json.dumps({"uid": "krall", "sid": "session_id"})) + ) + + res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) + assert set(res.keys()) == {"args", "function"} + + def test_response_mode_form_post(self): + request = {"response_mode": "form_post"} + info = { + "response_args": AuthorizationResponse(foo="bar"), + "return_uri": "https://example.com/cb", + } + info = self.endpoint.response_mode(request, **info) + assert set(info.keys()) == { + "response_args", + "return_uri", + "response_msg", + "content_type", + "response_placement", + } + assert info["response_msg"] == FORM_POST.format( + action="https://example.com/cb", + inputs='', + ) + + def test_do_response_code_form_post(self): + _req = AUTH_REQ_DICT.copy() + _req["response_mode"] = "form_post" + _pr_resp = self.endpoint.parse_request(_req) + _resp = self.endpoint.process_request(_pr_resp) + msg = self.endpoint.do_response(**_resp) + assert ("Content-type", "text/html") in msg["http_headers"] + assert "response_placement" in msg + + def test_response_mode_fragment(self): + request = {"response_mode": "fragment"} + self.endpoint.response_mode(request, fragment_enc=True) + + with pytest.raises(InvalidRequest): + self.endpoint.response_mode(request, fragment_enc=False) + + info = self.endpoint.response_mode(request) + assert set(info.keys()) == {"fragment_enc"} + + def test_check_session_iframe(self): + self.endpoint.endpoint_context.provider_info[ + "check_session_iframe" + ] = "https://example.com/csi" + _pr_resp = self.endpoint.parse_request(AUTH_REQ_DICT) + _resp = self.endpoint.process_request(_pr_resp) + assert "session_state" in _resp["response_args"] + + def test_setup_auth_login_hint(self): + request = AuthorizationRequest( + client_id="client_id", + redirect_uri="https://rp.example.com/cb", + response_type=["id_token"], + state="state", + nonce="nonce", + scope="openid", + login_hint="tel:0907865204", + ) + redirect_uri = request["redirect_uri"] + cinfo = { + "client_id": "client_id", + "redirect_uris": [("https://rp.example.com/cb", {})], + "id_token_signed_response_alg": "RS256", + } + + item = self.endpoint.endpoint_context.authn_broker.db["anon"] + item["method"].fail = NoSuchAuthentication + + res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) + assert set(res.keys()) == {"function", "args"} + assert "login_hint" in res["args"] + + def test_setup_auth_login_hint2acrs(self): + request = AuthorizationRequest( + client_id="client_id", + redirect_uri="https://rp.example.com/cb", + response_type=["id_token"], + state="state", + nonce="nonce", + scope="openid", + login_hint="email:foo@bar", + ) + redirect_uri = request["redirect_uri"] + + method_spec = { + "acr": INTERNETPROTOCOLPASSWORD, + "kwargs": {"user": "knoll"}, + "class": NoAuthn, + } + self.endpoint.endpoint_context.authn_broker["foo"] = init_method( + method_spec, None + ) + + item = self.endpoint.endpoint_context.authn_broker.db["anon"] + item["method"].fail = NoSuchAuthentication + item = self.endpoint.endpoint_context.authn_broker.db["foo"] + item["method"].fail = NoSuchAuthentication + + res = self.endpoint.pick_authn_method(request, redirect_uri) + assert set(res.keys()) == {"method", "acr"} + assert res["acr"] == INTERNETPROTOCOLPASSWORD + assert isinstance(res["method"], NoAuthn) + assert res["method"].user == "knoll" + + def test_post_logout_uri(self): + pass + + def test_parse_request(self): + _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") + _jws = _jwt.pack( + AUTH_REQ_DICT, aud=self.endpoint.endpoint_context.provider_info["issuer"] + ) + # ----------------- + _req = self.endpoint.parse_request( + { + "request": _jws, + "redirect_uri": AUTH_REQ.get("redirect_uri"), + "response_type": AUTH_REQ.get("response_type"), + "client_id": AUTH_REQ.get("client_id"), + "scope": AUTH_REQ.get("scope"), + } + ) + assert "__verified_request" in _req + + def test_parse_request_uri(self): + _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") + _jws = _jwt.pack( + AUTH_REQ_DICT, aud=self.endpoint.endpoint_context.provider_info["issuer"] + ) + + request_uri = "https://client.example.com/req" + # ----------------- + with responses.RequestsMock() as rsps: + rsps.add("GET", request_uri, body=_jws, status=200) + _req = self.endpoint.parse_request( + { + "request_uri": request_uri, + "redirect_uri": AUTH_REQ.get("redirect_uri"), + "response_type": AUTH_REQ.get("response_type"), + "client_id": AUTH_REQ.get("client_id"), + "scope": AUTH_REQ.get("scope"), + } + ) + + assert "__verified_request" in _req + + +def test_inputs(): + elems = inputs(dict(foo="bar", home="stead")) + test_elems = ( + '', + '', + ) + assert test_elems[0] in elems and test_elems[1] in elems + + +def test_acr_claims(): + assert acr_claims({"claims": {"id_token": {"acr": {"value": "foo"}}}}) == ["foo"] + assert acr_claims( + {"claims": {"id_token": {"acr": {"values": ["foo", "bar"]}}}} + ) == ["foo", "bar"] + assert acr_claims({"claims": {"id_token": {"acr": {"values": ["foo"]}}}}) == ["foo"] + assert acr_claims({"claims": {"id_token": {"acr": {"essential": True}}}}) is None + + +def test_join_query(): + redirect_uris = [("https://rp.example.com/cb", {"foo": ["bar"], "state": ["low"]})] + uri = join_query(*redirect_uris[0]) + test_uri = ("https://rp.example.com/cb?", "foo=bar", "state=low") + for i in test_uri: + assert i in uri + + +class TestUserAuthn(object): + @pytest.fixture(autouse=True) + def create_endpoint_context(self): + conf = { + "issuer": "https://example.com/", + "password": "mycket hemligt", + "token_expires_in": 600, + "grant_expires_in": 300, + "refresh_token_expires_in": 86400, + "verify_ssl": False, + "endpoint": {}, + "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS}, + "authentication": { + "user": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": UserPassJinja2, + "verify_endpoint": "verify/user", + "kwargs": { + "template": "user_pass.jinja2", + "sym_key": "24AA/LR6HighEnergy", + "db": { + "class": JSONDictDB, + "kwargs": {"json_path": full_path("passwd.json")}, + }, + "page_header": "Testing log in", + "submit_btn": "Get me in!", + "user_label": "Nickname", + "passwd_label": "Secret sauce", + }, + }, + "anon": { + "acr": UNSPECIFIED, + "class": NoAuthn, + "kwargs": {"user": "diana"}, + }, + }, + "cookie_dealer": { + "class": "oidcendpoint.cookie.CookieDealer", + "kwargs": { + "sign_key": "ghsNKDDLshZTPn974nOsIGhedULrsqnsGoBFBLwUKuJhE2ch", + "default_values": { + "name": "oidc_xx", + "domain": "example.com", + "path": "/", + "max_age": 3600, + }, + }, + }, + "template_dir": "template", + } + self.endpoint_context = EndpointContext(conf) + + def test_authenticated_as_without_cookie(self): + authn_item = self.endpoint_context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) + method = authn_item[0]["method"] + + _info, _time_stamp = method.authenticated_as(None) + assert _info is None + + def test_authenticated_as_with_cookie(self): + authn_item = self.endpoint_context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) + method = authn_item[0]["method"] + + authn_req = {"state": "state_identifier", "client_id": "client 12345"} + _cookie = new_cookie( + self.endpoint_context, + sub="diana", + sid="session_identifier", + state=authn_req["state"], + client_id=authn_req["client_id"], + cookie_name=self.endpoint_context.cookie_name["session"], + ) + + _info, _time_stamp = method.authenticated_as(_cookie) + _info = cookie_value(_info["uid"]) + assert _info["sub"] == "diana" diff --git a/tests/test_25_oidc_token_endpoint.py b/tests/test_25_oidc_token_endpoint.py index dd130c7..9d77be3 100755 --- a/tests/test_25_oidc_token_endpoint.py +++ b/tests/test_25_oidc_token_endpoint.py @@ -1,25 +1,28 @@ import json import os -import pytest from cryptojwt import JWT from cryptojwt.key_jar import build_keyjar from oidcmsg.oidc import AccessTokenRequest from oidcmsg.oidc import AuthorizationRequest from oidcmsg.oidc import RefreshAccessTokenRequest +from oidcmsg.time_util import time_sans_frac +import pytest from oidcendpoint import JWT_BEARER +from oidcendpoint.authn_event import create_authn_event from oidcendpoint.client_authn import verify_client from oidcendpoint.endpoint_context import EndpointContext -from oidcendpoint.exception import MultipleUsage from oidcendpoint.exception import UnAuthorizedClient +from oidcendpoint.grant import Grant +from oidcendpoint.id_token import IDToken from oidcendpoint.oidc import userinfo from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.refresh_token import RefreshAccessToken from oidcendpoint.oidc.registration import Registration from oidcendpoint.oidc.token import AccessToken -from oidcendpoint.session import setup_session +from oidcendpoint.session_management import db_key from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo @@ -30,7 +33,6 @@ CLIENT_KEYJAR = build_keyjar(KEYDEFS) - RESPONSE_TYPES_SUPPORTED = [ ["code"], ["token"], @@ -145,6 +147,7 @@ def create_endpoint(self): "userinfo": {"class": UserInfo, "kwargs": {"db": {}}}, "client_authn": verify_client, "template_dir": "template", + "id_token": {"class": IDToken, "kwargs": {}}, } endpoint_context = EndpointContext(conf) endpoint_context.cdb["client_1"] = { @@ -156,48 +159,87 @@ def create_endpoint(self): } endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") self.endpoint = endpoint_context.endpoint["token"] + self.session_manager = endpoint_context.session_manager + self.user_id = "diana" + + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + client_id = auth_req['client_id'] + ae = create_authn_event(self.user_id, self.session_manager.salt) + self.session_manager.create_session(ae, auth_req, self.user_id, client_id=client_id, + sub_type=sub_type, sector_identifier=sector_identifier) + return db_key(self.user_id, client_id) + + def _do_grant(self, auth_req): + client_id = auth_req['client_id'] + # The user consent module produces a Grant instance + grant = Grant(scope=auth_req['scope'], resources=[client_id]) + + # the grant is assigned to a session (user_id, client_id) + self.session_manager.set([self.user_id, client_id, grant.id], grant) + return grant + + def _mint_code(self, grant, client_id): + sid = db_key(self.user_id, client_id) + # Constructing an authorization code is now done + return grant.mint_token( + 'authorization_code', + value=self.session_manager.token_handler["code"](sid), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + + def _mint_access_token(self, grant, client_id, token_ref=None): + _csi = self.session_manager.get([self.user_id, client_id]) + return grant.mint_token( + 'access_token', + value=self.session_manager.token_handler["access_token"]( + db_key(self.user_id, client_id), + client_id=client_id, + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_csi['sub'] + ), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=token_ref # Means the token (tok) was used to mint this token + ) def test_init(self): assert self.endpoint def test_parse(self): - session_id = setup_session(self.endpoint.endpoint_context, AUTH_REQ, uid="user") + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant, AUTH_REQ['client_id']) + _token_request = TOKEN_REQ_DICT.copy() - _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"] + _token_request["code"] = code.value _req = self.endpoint.parse_request(_token_request) assert isinstance(_req, AccessTokenRequest) assert set(_req.keys()) == set(_token_request.keys()) def test_process_request(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="user", - acr=INTERNETPROTOCOLPASSWORD, - ) + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant, AUTH_REQ['client_id']) + _token_request = TOKEN_REQ_DICT.copy() - _context = self.endpoint.endpoint_context - _token_request["code"] = _context.sdb[session_id]["code"] - _context.sdb.update(session_id, user="diana") - _req = self.endpoint.parse_request(_token_request) + _token_request["code"] = code.value + _req = self.endpoint.parse_request(_token_request) _resp = self.endpoint.process_request(request=_req) assert _resp assert set(_resp.keys()) == {"http_headers", "response_args"} def test_process_request_using_code_twice(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="user", - acr=INTERNETPROTOCOLPASSWORD, - ) + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant, AUTH_REQ['client_id']) + _token_request = TOKEN_REQ_DICT.copy() - _context = self.endpoint.endpoint_context - _token_request["code"] = _context.sdb[session_id]["code"] - _context.sdb.update(session_id, user="diana") + _token_request["code"] = code.value + _req = self.endpoint.parse_request(_token_request) _resp = self.endpoint.process_request(request=_req) @@ -211,15 +253,12 @@ def test_process_request_using_code_twice(self): assert set(_resp.keys()) == {"error"} def test_do_response(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="user", - acr=INTERNETPROTOCOLPASSWORD, - ) - self.endpoint.endpoint_context.sdb.update(session_id, user="diana") + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant, AUTH_REQ['client_id']) + _token_request = TOKEN_REQ_DICT.copy() - _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"] + _token_request["code"] = code.value _req = self.endpoint.parse_request(_token_request) _resp = self.endpoint.process_request(request=_req) @@ -227,12 +266,10 @@ def test_do_response(self): assert isinstance(msg, dict) def test_process_request_using_private_key_jwt(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="user", - acr=INTERNETPROTOCOLPASSWORD, - ) + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant, AUTH_REQ['client_id']) + _token_request = TOKEN_REQ_DICT.copy() del _token_request["client_id"] del _token_request["client_secret"] @@ -244,9 +281,8 @@ def test_process_request_using_private_key_jwt(self): _token_request.update( {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} ) - _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"] + _token_request["code"] = code.value - _context.sdb.update(session_id, user="diana") _req = self.endpoint.parse_request(_token_request) _resp = self.endpoint.process_request(request=_req) diff --git a/tests/test_25_oidc_token_endpoint.py.no b/tests/test_25_oidc_token_endpoint.py.no new file mode 100755 index 0000000..dd130c7 --- /dev/null +++ b/tests/test_25_oidc_token_endpoint.py.no @@ -0,0 +1,255 @@ +import json +import os + +import pytest +from cryptojwt import JWT +from cryptojwt.key_jar import build_keyjar +from oidcmsg.oidc import AccessTokenRequest +from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.oidc import RefreshAccessTokenRequest + +from oidcendpoint import JWT_BEARER +from oidcendpoint.client_authn import verify_client +from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.exception import MultipleUsage +from oidcendpoint.exception import UnAuthorizedClient +from oidcendpoint.oidc import userinfo +from oidcendpoint.oidc.authorization import Authorization +from oidcendpoint.oidc.provider_config import ProviderConfiguration +from oidcendpoint.oidc.refresh_token import RefreshAccessToken +from oidcendpoint.oidc.registration import Registration +from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.session import setup_session +from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from oidcendpoint.user_info import UserInfo + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +CLIENT_KEYJAR = build_keyjar(KEYDEFS) + + +RESPONSE_TYPES_SUPPORTED = [ + ["code"], + ["token"], + ["id_token"], + ["code", "token"], + ["code", "id_token"], + ["id_token", "token"], + ["code", "token", "id_token"], + ["none"], +] + +CAPABILITIES = { + "subject_types_supported": ["public", "pairwise"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], +} + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid"], + state="STATE", + response_type="code", +) + +TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", +) + +REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( + grant_type="refresh_token", client_id="client_1", client_secret="hemligt" +) + +TOKEN_REQ_DICT = TOKEN_REQ.to_dict() + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +USERINFO = UserInfo(json.loads(open(full_path("users.json")).read())) + + +class TestEndpoint(object): + @pytest.fixture(autouse=True) + def create_endpoint(self): + conf = { + "issuer": "https://example.com/", + "password": "mycket hemligt", + "token_expires_in": 600, + "grant_expires_in": 300, + "refresh_token_expires_in": 86400, + "verify_ssl": False, + "capabilities": CAPABILITIES, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "endpoint": { + "provider_config": { + "path": ".well-known/openid-configuration", + "class": ProviderConfiguration, + "kwargs": {}, + }, + "registration": { + "path": "registration", + "class": Registration, + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": Authorization, + "kwargs": {}, + }, + "token": { + "path": "token", + "class": AccessToken, + "kwargs": { + "client_authn_method": [ + "client_secret_basic", + "client_secret_post", + "client_secret_jwt", + "private_key_jwt", + ] + }, + }, + "refresh_token": { + "path": "token", + "class": RefreshAccessToken, + "kwargs": {}, + }, + "userinfo": { + "path": "userinfo", + "class": userinfo.UserInfo, + "kwargs": {"db_file": "users.json"}, + }, + }, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "oidcendpoint.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "userinfo": {"class": UserInfo, "kwargs": {"db": {}}}, + "client_authn": verify_client, + "template_dir": "template", + } + endpoint_context = EndpointContext(conf) + endpoint_context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "token_endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + } + endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") + self.endpoint = endpoint_context.endpoint["token"] + + def test_init(self): + assert self.endpoint + + def test_parse(self): + session_id = setup_session(self.endpoint.endpoint_context, AUTH_REQ, uid="user") + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"] + _req = self.endpoint.parse_request(_token_request) + + assert isinstance(_req, AccessTokenRequest) + assert set(_req.keys()) == set(_token_request.keys()) + + def test_process_request(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="user", + acr=INTERNETPROTOCOLPASSWORD, + ) + _token_request = TOKEN_REQ_DICT.copy() + _context = self.endpoint.endpoint_context + _token_request["code"] = _context.sdb[session_id]["code"] + _context.sdb.update(session_id, user="diana") + _req = self.endpoint.parse_request(_token_request) + + _resp = self.endpoint.process_request(request=_req) + + assert _resp + assert set(_resp.keys()) == {"http_headers", "response_args"} + + def test_process_request_using_code_twice(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="user", + acr=INTERNETPROTOCOLPASSWORD, + ) + _token_request = TOKEN_REQ_DICT.copy() + _context = self.endpoint.endpoint_context + _token_request["code"] = _context.sdb[session_id]["code"] + _context.sdb.update(session_id, user="diana") + _req = self.endpoint.parse_request(_token_request) + _resp = self.endpoint.process_request(request=_req) + + # 2nd time used + # TODO: There is a bug in _post_parse_request, the returned error + # should be invalid_grant, not invalid_client + _req = self.endpoint.parse_request(_token_request) + _resp = self.endpoint.process_request(request=_req) + + assert _resp + assert set(_resp.keys()) == {"error"} + + def test_do_response(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="user", + acr=INTERNETPROTOCOLPASSWORD, + ) + self.endpoint.endpoint_context.sdb.update(session_id, user="diana") + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"] + _req = self.endpoint.parse_request(_token_request) + + _resp = self.endpoint.process_request(request=_req) + msg = self.endpoint.do_response(request=_req, **_resp) + assert isinstance(msg, dict) + + def test_process_request_using_private_key_jwt(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="user", + acr=INTERNETPROTOCOLPASSWORD, + ) + _token_request = TOKEN_REQ_DICT.copy() + del _token_request["client_id"] + del _token_request["client_secret"] + _context = self.endpoint.endpoint_context + + _jwt = JWT(CLIENT_KEYJAR, iss=AUTH_REQ["client_id"], sign_alg="RS256") + _jwt.with_jti = True + _assertion = _jwt.pack({"aud": [_context.endpoint["token"].full_path]}) + _token_request.update( + {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} + ) + _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"] + + _context.sdb.update(session_id, user="diana") + _req = self.endpoint.parse_request(_token_request) + _resp = self.endpoint.process_request(request=_req) + + # 2nd time used + with pytest.raises(UnAuthorizedClient): + self.endpoint.parse_request(_token_request) diff --git a/tests/test_26_oidc_userinfo_endpoint.py b/tests/test_26_oidc_userinfo_endpoint.py index 41f2bb9..88fc2d2 100755 --- a/tests/test_26_oidc_userinfo_endpoint.py +++ b/tests/test_26_oidc_userinfo_endpoint.py @@ -1,20 +1,24 @@ import json import os -import pytest -from cryptojwt.jwt import utc_time_sans_frac from oidcmsg.oidc import AccessTokenRequest from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.time_util import time_sans_frac +import pytest from oidcendpoint import user_info +from oidcendpoint.authn_event import create_authn_event from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.grant import Grant from oidcendpoint.id_token import IDToken from oidcendpoint.oidc import userinfo from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration from oidcendpoint.oidc.token import AccessToken -from oidcendpoint.session import setup_session +from oidcendpoint.session_management import SessionManager +from oidcendpoint.session_management import db_key +from oidcendpoint.session_management import unpack_db_key from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo @@ -165,133 +169,147 @@ def create_endpoint(self): "response_types": ["code", "token", "code id_token", "id_token"], } self.endpoint = endpoint_context.endpoint["userinfo"] + self.session_manager = SessionManager({}, endpoint_context.sdb.handler) + endpoint_context.session_manager = self.session_manager + self.user_id = "diana" + + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + client_id = auth_req['client_id'] + ae = create_authn_event(self.user_id, self.session_manager.salt) + self.session_manager.create_session(ae, auth_req, self.user_id, client_id=client_id, + sub_type=sub_type, sector_identifier=sector_identifier) + return db_key(self.user_id, client_id) + + def _do_grant(self, auth_req): + client_id = auth_req['client_id'] + # The user consent module produces a Grant instance + grant = Grant(scope=auth_req['scope'], resources=[client_id]) + + # the grant is assigned to a session (user_id, client_id) + self.session_manager.set([self.user_id, client_id, grant.id], grant) + return grant + + def _mint_code(self, grant): + # Constructing an authorization code is now done + return grant.mint_token( + 'authorization_code', + value=self.session_manager.token_handler["code"](self.user_id), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + + def _mint_access_token(self, grant, client_id, token_ref=None): + _csi = self.session_manager.get([self.user_id, client_id]) + return grant.mint_token( + 'access_token', + value=self.session_manager.token_handler["access_token"]( + db_key(self.user_id, client_id), + client_id=client_id, + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_csi['sub'] + ), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=token_ref # Means the token (tok) was used to mint this token + ) def test_init(self): assert self.endpoint assert set( self.endpoint.endpoint_context.provider_info["claims_supported"] ) == { - "address", - "birthdate", - "email", - "email_verified", - "eduperson_scoped_affiliation", - "family_name", - "gender", - "given_name", - "locale", - "middle_name", - "name", - "nickname", - "phone_number", - "phone_number_verified", - "picture", - "preferred_username", - "profile", - "sub", - "updated_at", - "website", - "zoneinfo", - } + "address", + "birthdate", + "email", + "email_verified", + "eduperson_scoped_affiliation", + "family_name", + "gender", + "given_name", + "locale", + "middle_name", + "name", + "nickname", + "phone_number", + "phone_number_verified", + "picture", + "preferred_username", + "profile", + "sub", + "updated_at", + "website", + "zoneinfo", + } def test_parse(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) - _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) - ) - + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + # Free standing access token, not based on an authorization code + access_token = self._mint_access_token(grant, AUTH_REQ['client_id'], None) + _req = self.endpoint.parse_request({}, auth="Bearer {}".format(access_token.value)) assert set(_req.keys()) == {"client_id", "access_token"} + assert _req["client_id"] == AUTH_REQ['client_id'] + assert _req["access_token"] == access_token.value def test_parse_invalid_token(self): _req = self.endpoint.parse_request({}, auth="Bearer invalid") - assert _req['error'] == "invalid_token" def test_process_request(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AUTH_REQ['client_id'], code) + _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) + {}, auth="Bearer {}".format(access_token.value) ) args = self.endpoint.process_request(_req) assert args def test_process_request_not_allowed(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac() - 7200, - "valid_until": utc_time_sans_frac() - 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AUTH_REQ['client_id'], code) + + _us_info = self.session_manager.get([self.user_id]) + # 2 things can make the request invalid. + # 1) The token is not valid anymore or 2) The event is not valid. + _event = _us_info["authentication_event"] + _event['authn_time'] -= 9000 + _event['valid_until'] -= 9000 + _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) + {}, auth="Bearer {}".format(access_token.value) ) args = self.endpoint.process_request(_req) assert set(args["response_args"].keys()) == {"error", "error_description"} - def test_process_request_offline_access(self): - auth_req = AUTH_REQ.copy() - auth_req["scope"] = ["openid", "offline_access"] - session_id = setup_session( - self.endpoint.endpoint_context, - auth_req, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac() , - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) - _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) - ) - args = self.endpoint.process_request(_req) - assert set(args["response_args"].keys()) == {"sub"} + # Offline access is presently not checked. + # + # def test_process_request_offline_access(self): + # auth_req = AUTH_REQ.copy() + # auth_req["scope"] = ["openid", "offline_access"] + # self._create_session(auth_req) + # grant = self._do_grant(auth_req) + # code = self._mint_code(grant) + # access_token = self._mint_access_token(grant, auth_req['client_id'], code) + # + # _req = self.endpoint.parse_request( + # {}, auth="Bearer {}".format(access_token.value) + # ) + # args = self.endpoint.process_request(_req) + # assert set(args["response_args"].keys()) =={'response_args', 'client_id'} def test_do_response(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AUTH_REQ['client_id'], code) + _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) + {}, auth="Bearer {}".format(access_token.value) ) args = self.endpoint.process_request(_req) assert args @@ -299,24 +317,15 @@ def test_do_response(self): assert res def test_do_signed_response(self): - self.endpoint.endpoint_context.cdb["client_1"][ - "userinfo_signed_response_alg" - ] = "ES256" - - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + self.endpoint.endpoint_context.cdb["client_1"]["userinfo_signed_response_alg"] = "ES256" + + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AUTH_REQ['client_id'], code) + _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) + {}, auth="Bearer {}".format(access_token.value) ) args = self.endpoint.process_request(_req) assert args @@ -326,28 +335,24 @@ def test_do_signed_response(self): def test_custom_scope(self): _auth_req = AUTH_REQ.copy() _auth_req["scope"] = ["openid", "research_and_scholarship"] - session_id = setup_session( - self.endpoint.endpoint_context, - _auth_req, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + + _sid = self._create_session(_auth_req) + grant = self._do_grant(_auth_req) + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AUTH_REQ['client_id'], code) + + user_id, client_id = unpack_db_key(_sid) + self.endpoint.claims_interface.add_claims_by_scope = True + grant.claims = { + "userinfo": self.endpoint.claims_interface.get_claims(client_id=client_id, + user_id=user_id, + scopes=_auth_req["scope"]) + } + _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) + {}, auth="Bearer {}".format(access_token.value) ) args = self.endpoint.process_request(_req) - assert set(args["response_args"].keys()) == { - "sub", - "name", - "given_name", - "family_name", - "email", - "email_verified", - "eduperson_scoped_affiliation", - } + assert set(args["response_args"].keys()) == {'eduperson_scoped_affiliation', 'given_name', + 'email_verified', 'email', 'family_name', + 'name', 'sub'} diff --git a/tests/test_26_oidc_userinfo_endpoint.py.no b/tests/test_26_oidc_userinfo_endpoint.py.no new file mode 100755 index 0000000..41f2bb9 --- /dev/null +++ b/tests/test_26_oidc_userinfo_endpoint.py.no @@ -0,0 +1,353 @@ +import json +import os + +import pytest +from cryptojwt.jwt import utc_time_sans_frac +from oidcmsg.oidc import AccessTokenRequest +from oidcmsg.oidc import AuthorizationRequest + +from oidcendpoint import user_info +from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.id_token import IDToken +from oidcendpoint.oidc import userinfo +from oidcendpoint.oidc.authorization import Authorization +from oidcendpoint.oidc.provider_config import ProviderConfiguration +from oidcendpoint.oidc.registration import Registration +from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.session import setup_session +from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from oidcendpoint.user_info import UserInfo + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +RESPONSE_TYPES_SUPPORTED = [ + ["code"], + ["token"], + ["id_token"], + ["code", "token"], + ["code", "id_token"], + ["id_token", "token"], + ["code", "token", "id_token"], + ["none"], +] + +CAPABILITIES = { + "subject_types_supported": ["public", "pairwise"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], +} + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid"], + state="STATE", + response_type="code", +) + +TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", +) + +TOKEN_REQ_DICT = TOKEN_REQ.to_dict() + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +USERINFO = UserInfo(json.loads(open(full_path("users.json")).read())) + + +class TestEndpoint(object): + @pytest.fixture(autouse=True) + def create_endpoint(self): + conf = { + "issuer": "https://example.com/", + "password": "mycket hemligt", + "token_expires_in": 600, + "grant_expires_in": 300, + "refresh_token_expires_in": 86400, + "verify_ssl": False, + "capabilities": CAPABILITIES, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "id_token": {"class": IDToken, "kwargs": {}}, + "endpoint": { + "provider_config": { + "path": ".well-known/openid-configuration", + "class": ProviderConfiguration, + "kwargs": {}, + }, + "registration": { + "path": "registration", + "class": Registration, + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": Authorization, + "kwargs": {}, + }, + "token": { + "path": "token", + "class": AccessToken, + "kwargs": { + "client_authn_methods": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ] + }, + }, + "userinfo": { + "path": "userinfo", + "class": userinfo.UserInfo, + "kwargs": { + "claim_types_supported": [ + "normal", + "aggregated", + "distributed", + ], + "client_authn_method": ["bearer_header"], + }, + }, + }, + "userinfo": { + "class": user_info.UserInfo, + "kwargs": {"db_file": full_path("users.json")}, + }, + # "client_authn": verify_client, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "oidcendpoint.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "template_dir": "template", + "add_on": { + "custom_scopes": { + "function": "oidcendpoint.oidc.add_on.custom_scopes.add_custom_scopes", + "kwargs": { + "research_and_scholarship": [ + "name", + "given_name", + "family_name", + "email", + "email_verified", + "sub", + "eduperson_scoped_affiliation", + ] + }, + } + }, + } + endpoint_context = EndpointContext(conf) + endpoint_context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "token_endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + } + self.endpoint = endpoint_context.endpoint["userinfo"] + + def test_init(self): + assert self.endpoint + assert set( + self.endpoint.endpoint_context.provider_info["claims_supported"] + ) == { + "address", + "birthdate", + "email", + "email_verified", + "eduperson_scoped_affiliation", + "family_name", + "gender", + "given_name", + "locale", + "middle_name", + "name", + "nickname", + "phone_number", + "phone_number_verified", + "picture", + "preferred_username", + "profile", + "sub", + "updated_at", + "website", + "zoneinfo", + } + + def test_parse(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="userID", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + "valid_until": utc_time_sans_frac() + 3600, + }, + ) + _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(_dic["access_token"]) + ) + + assert set(_req.keys()) == {"client_id", "access_token"} + + def test_parse_invalid_token(self): + _req = self.endpoint.parse_request({}, auth="Bearer invalid") + + assert _req['error'] == "invalid_token" + + def test_process_request(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="userID", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + "valid_until": utc_time_sans_frac() + 3600, + }, + ) + _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(_dic["access_token"]) + ) + args = self.endpoint.process_request(_req) + assert args + + def test_process_request_not_allowed(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="userID", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac() - 7200, + "valid_until": utc_time_sans_frac() - 3600, + }, + ) + _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(_dic["access_token"]) + ) + args = self.endpoint.process_request(_req) + assert set(args["response_args"].keys()) == {"error", "error_description"} + + def test_process_request_offline_access(self): + auth_req = AUTH_REQ.copy() + auth_req["scope"] = ["openid", "offline_access"] + session_id = setup_session( + self.endpoint.endpoint_context, + auth_req, + uid="userID", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac() , + "valid_until": utc_time_sans_frac() + 3600, + }, + ) + _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(_dic["access_token"]) + ) + args = self.endpoint.process_request(_req) + assert set(args["response_args"].keys()) == {"sub"} + + def test_do_response(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="userID", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + "valid_until": utc_time_sans_frac() + 3600, + }, + ) + _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(_dic["access_token"]) + ) + args = self.endpoint.process_request(_req) + assert args + res = self.endpoint.do_response(request=_req, **args) + assert res + + def test_do_signed_response(self): + self.endpoint.endpoint_context.cdb["client_1"][ + "userinfo_signed_response_alg" + ] = "ES256" + + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="userID", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + "valid_until": utc_time_sans_frac() + 3600, + }, + ) + _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(_dic["access_token"]) + ) + args = self.endpoint.process_request(_req) + assert args + res = self.endpoint.do_response(request=_req, **args) + assert res + + def test_custom_scope(self): + _auth_req = AUTH_REQ.copy() + _auth_req["scope"] = ["openid", "research_and_scholarship"] + session_id = setup_session( + self.endpoint.endpoint_context, + _auth_req, + uid="userID", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + "valid_until": utc_time_sans_frac() + 3600, + }, + ) + _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(_dic["access_token"]) + ) + args = self.endpoint.process_request(_req) + assert set(args["response_args"].keys()) == { + "sub", + "name", + "given_name", + "family_name", + "email", + "email_verified", + "eduperson_scoped_affiliation", + } diff --git a/tests/test_30_oidc_end_session.py b/tests/test_30_oidc_end_session.py index 088e591..6e73bf7 100644 --- a/tests/test_30_oidc_end_session.py +++ b/tests/test_30_oidc_end_session.py @@ -4,14 +4,14 @@ from urllib.parse import parse_qs from urllib.parse import urlparse -import pytest -import responses from cryptojwt.key_jar import build_keyjar from oidcmsg.exception import InvalidRequest from oidcmsg.message import Message from oidcmsg.oidc import AuthorizationRequest from oidcmsg.oidc import verified_claim_name from oidcmsg.oidc import verify_id_token +import pytest +import responses from oidcendpoint.common.authorization import join_query from oidcendpoint.cookie import CookieDealer @@ -25,6 +25,7 @@ from oidcendpoint.oidc.session import Session from oidcendpoint.oidc.session import do_front_channel_logout_iframe from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.session_management import db_key from oidcendpoint.token_handler import UnknownToken from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo @@ -192,15 +193,14 @@ def create_endpoint(self): def test_end_session_endpoint(self): # End session not allowed if no cookie and no id_token_hint is sent # (can't determine session) - with pytest.raises(UnknownToken): + with pytest.raises(ValueError): _ = self.session_endpoint.process_request("", cookie="FAIL") - def _create_cookie(self, user, sid, state, client_id): + def _create_cookie(self, session_id, state, client_id): ec = self.session_endpoint.endpoint_context return new_cookie( ec, - sub=user, - sid=sid, + sid=session_id, state=state, client_id=client_id, cookie_name=ec.cookie_name["session"], @@ -228,14 +228,14 @@ def _code_auth2(self, state): _pr_resp = self.authn_endpoint.parse_request(req.to_dict()) _resp = self.authn_endpoint.process_request(_pr_resp) - def _get_sid(self): - _sdb = self.session_endpoint.endpoint_context.sdb - - for _sid in _sdb.keys(): - if _sid.startswith("__state__"): - continue - else: - return _sid + # def _get_sid(self): + # _mngr = self.session_endpoint.endpoint_context.session_manager + # + # for _sid in _sdb.keys(): + # if _sid.startswith("__state__"): + # continue + # else: + # return _sid def _auth_with_id_token(self, state): req = AuthorizationRequest( @@ -253,8 +253,7 @@ def _auth_with_id_token(self, state): def test_end_session_endpoint_with_cookie(self): self._code_auth("1234567") - _sid = self._get_sid() - cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + cookie = self._create_cookie(db_key("diana", "client_1"), "1234567", "client_1") _req_args = self.session_endpoint.parse_request({"state": "1234567"}) resp = self.session_endpoint.process_request(_req_args, cookie=cookie) @@ -272,7 +271,7 @@ def test_end_session_endpoint_with_cookie(self): def test_end_session_endpoint_with_wrong_cookie(self): self._code_auth("1234567") - cookie = self._create_cookie("diana", "client_2", "abcdefg", "client_1") + cookie = self._create_cookie(db_key("diana", "client_1"), "abcdefg", "client_2") with pytest.raises(UnknownToken): self.session_endpoint.process_request({"state": "abcde"}, cookie=cookie) @@ -281,7 +280,7 @@ def test_end_session_endpoint_with_cookie_wrong_user(self): # Need cookie and ID Token to figure this out id_token = self._auth_with_id_token("1234567") - cookie = self._create_cookie("diggins", "_sid_", "1234567", "client_1") + cookie = self._create_cookie(db_key("diggins", "client_1"), "1234567", "client_1") msg = Message(id_token=id_token) verify_id_token(msg, keyjar=self.session_endpoint.endpoint_context.keyjar) @@ -298,7 +297,7 @@ def test_end_session_endpoint_with_cookie_unknown_sid(self): id_token = self._auth_with_id_token("1234567") # Wrong client_id - cookie = self._create_cookie("diana", "_sid_", "state", "client_1") + cookie = self._create_cookie(db_key("diana", "client_1"), "state", "client_1") msg = Message(id_token=id_token) verify_id_token(msg, keyjar=self.session_endpoint.endpoint_context.keyjar) @@ -314,8 +313,7 @@ def test_end_session_endpoint_with_cookie_dual_login(self): self._code_auth("1234567") self._code_auth2("abcdefg") _sdb = self.session_endpoint.endpoint_context.sdb - _sid = self._get_sid() - cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + cookie = self._create_cookie(db_key("diana", "client_1"), "1234567", "client_1") resp = self.session_endpoint.process_request({"state": "abcde"}, cookie=cookie) @@ -334,8 +332,7 @@ def test_end_session_endpoint_with_post_logout_redirect_uri(self): self._code_auth("1234567") self._code_auth2("abcdefg") _sdb = self.session_endpoint.endpoint_context.sdb - _sid = self._get_sid() - cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + cookie = self._create_cookie(db_key("diana", "client_1"), "1234567", "client_1") post_logout_redirect_uri = join_query( *self.session_endpoint.endpoint_context.cdb["client_1"][ @@ -359,8 +356,7 @@ def test_end_session_endpoint_with_wrong_post_logout_redirect_uri(self): id_token = self._auth_with_id_token("1234567") _sdb = self.session_endpoint.endpoint_context.sdb - _sid = self._get_sid() - cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + cookie = self._create_cookie(db_key("diana", "client_1"), "1234567", "client_1") post_logout_redirect_uri = "https://demo.example.com/log_out" @@ -453,8 +449,8 @@ def test_logout_from_client_bc(self): "backchannel_logout_uri" ] = "https://example.com/bc_logout" self.session_endpoint.endpoint_context.cdb["client_1"]["client_id"] = "client_1" - _sid = self._get_sid() - res = self.session_endpoint.logout_from_client(_sid, "client_1") + _sid = db_key() + res = self.session_endpoint.logout_from_client(db_key(), "client_1") assert set(res.keys()) == {"blu"} assert set(res["blu"].keys()) == {"client_1"} _spec = res["blu"]["client_1"] @@ -475,7 +471,7 @@ def test_logout_from_client_fc(self): "frontchannel_logout_uri" ] = "https://example.com/fc_logout" self.session_endpoint.endpoint_context.cdb["client_1"]["client_id"] = "client_1" - _sid = self._get_sid() + _sid = db_key("diana", "client_1") res = self.session_endpoint.logout_from_client(_sid, "client_1") assert set(res.keys()) == {"flu"} assert set(res["flu"].keys()) == {"client_1"} @@ -499,7 +495,7 @@ def test_logout_from_client(self): ] = "https://example.com/fc_logout" self.session_endpoint.endpoint_context.cdb["client_2"]["client_id"] = "client_2" - _sid = self._get_sid() + _sid = db_key("diana", "client_1") res = self.session_endpoint.logout_all_clients(_sid, "client_1") assert res @@ -527,7 +523,7 @@ def test_do_verified_logout(self): _cdb["client_1"]["backchannel_logout_uri"] = "https://example.com/bc_logout" _cdb["client_1"]["client_id"] = "client_1" - _sid = self._get_sid() + _sid = db_key("diana", "client_1") res = self.session_endpoint.do_verified_logout(_sid, "client_1") assert res == [] @@ -565,9 +561,9 @@ def test_logout_from_client_no_session(self): ] = "https://example.com/fc_logout" self.session_endpoint.endpoint_context.cdb["client_2"]["client_id"] = "client_2" - _sid = self._get_sid() + _sid = db_key("diana", "client_1") - self.session_endpoint.endpoint_context.sdb.sso_db.delete("diana", "sid") + self.session_endpoint.endpoint_context.session_manager.delete("diana", "client_1") res = self.session_endpoint.logout_all_clients(_sid, "client_1") assert res == {} diff --git a/tests/test_30_oidc_end_session.py.no b/tests/test_30_oidc_end_session.py.no new file mode 100644 index 0000000..088e591 --- /dev/null +++ b/tests/test_30_oidc_end_session.py.no @@ -0,0 +1,573 @@ +import copy +import json +import os +from urllib.parse import parse_qs +from urllib.parse import urlparse + +import pytest +import responses +from cryptojwt.key_jar import build_keyjar +from oidcmsg.exception import InvalidRequest +from oidcmsg.message import Message +from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.oidc import verified_claim_name +from oidcmsg.oidc import verify_id_token + +from oidcendpoint.common.authorization import join_query +from oidcendpoint.cookie import CookieDealer +from oidcendpoint.cookie import new_cookie +from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.exception import RedirectURIError +from oidcendpoint.oidc import userinfo +from oidcendpoint.oidc.authorization import Authorization +from oidcendpoint.oidc.provider_config import ProviderConfiguration +from oidcendpoint.oidc.registration import Registration +from oidcendpoint.oidc.session import Session +from oidcendpoint.oidc.session import do_front_channel_logout_iframe +from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.token_handler import UnknownToken +from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from oidcendpoint.user_info import UserInfo + +ISS = "https://example.com/" + +CLI1 = "https://client1.example.com/" +CLI2 = "https://client2.example.com/" + +KEYDEFS = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +KEYJAR = build_keyjar(KEYDEFS) +KEYJAR.import_jwks(KEYJAR.export_jwks(private=True), ISS) + +RESPONSE_TYPES_SUPPORTED = [ + ["code"], + ["token"], + ["id_token"], + ["code", "token"], + ["code", "id_token"], + ["id_token", "token"], + ["code", "token", "id_token"], + ["none"], +] + +CAPABILITIES = { + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ], + "response_modes_supported": ["query", "fragment", "form_post"], + "subject_types_supported": ["public", "pairwise"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], + "claim_types_supported": ["normal", "aggregated", "distributed"], + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, +} + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="{}cb".format(ISS), + scope=["openid"], + state="STATE", + response_type="code", + client_secret="hemligt", +) + +AUTH_REQ_DICT = AUTH_REQ.to_dict() + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +USERINFO_db = json.loads(open(full_path("users.json")).read()) + + +class TestEndpoint(object): + @pytest.fixture(autouse=True) + def create_endpoint(self): + conf = { + "issuer": ISS, + "password": "mycket hemlig zebra", + "token_expires_in": 600, + "grant_expires_in": 300, + "refresh_token_expires_in": 86400, + "verify_ssl": False, + "capabilities": CAPABILITIES, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "endpoint": { + "provider_config": { + "path": "{}/.well-known/openid-configuration", + "class": ProviderConfiguration, + "kwargs": {"client_authn_method": None}, + }, + "registration": { + "path": "{}/registration", + "class": Registration, + "kwargs": {"client_authn_method": None}, + }, + "authorization": { + "path": "{}/authorization", + "class": Authorization, + "kwargs": {"client_authn_method": None}, + }, + "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}}, + "userinfo": { + "path": "{}/userinfo", + "class": userinfo.UserInfo, + "kwargs": {"db_file": "users.json"}, + }, + "session": { + "path": "{}/end_session", + "class": Session, + "kwargs": { + "post_logout_uri_path": "post_logout", + "signing_alg": "ES256", + "logout_verify_url": "{}/verify_logout".format(ISS), + "client_authn_method": None, + }, + }, + }, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "oidcendpoint.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "userinfo": {"class": UserInfo, "kwargs": {"db": USERINFO_db}}, + "template_dir": "template", + # 'cookie_name':{ + # 'session': 'oidcop', + # 'register': 'oidcreg' + # } + } + cookie_conf = { + "sign_key": "ghsNKDDLshZTPn974nOsIGhedULrsqnsGoBFBLwUKuJhE2ch", + "default_values": { + "name": "oidcop", + "domain": "127.0.0.1", + "path": "/", + "max_age": 3600, + }, + } + + self.cd = CookieDealer(**cookie_conf) + endpoint_context = EndpointContext(conf, cookie_dealer=self.cd, keyjar=KEYJAR) + endpoint_context.cdb = { + "client_1": { + "client_secret": "hemligt", + "redirect_uris": [("{}cb".format(CLI1), None)], + "client_salt": "salted", + "token_endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "post_logout_redirect_uris": [("{}logout_cb".format(CLI1), "")], + }, + "client_2": { + "client_secret": "hemligare", + "redirect_uris": [("{}cb".format(CLI2), None)], + "client_salt": "saltare", + "token_endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "post_logout_redirect_uris": [("{}logout_cb".format(CLI2), "")], + }, + } + self.authn_endpoint = endpoint_context.endpoint["authorization"] + self.session_endpoint = endpoint_context.endpoint["session"] + self.token_endpoint = endpoint_context.endpoint["token"] + + def test_end_session_endpoint(self): + # End session not allowed if no cookie and no id_token_hint is sent + # (can't determine session) + with pytest.raises(UnknownToken): + _ = self.session_endpoint.process_request("", cookie="FAIL") + + def _create_cookie(self, user, sid, state, client_id): + ec = self.session_endpoint.endpoint_context + return new_cookie( + ec, + sub=user, + sid=sid, + state=state, + client_id=client_id, + cookie_name=ec.cookie_name["session"], + ) + + def _code_auth(self, state): + req = AuthorizationRequest( + state=state, + response_type="code", + redirect_uri="{}cb".format(CLI1), + scope=["openid"], + client_id="client_1", + ) + _pr_resp = self.authn_endpoint.parse_request(req.to_dict()) + _resp = self.authn_endpoint.process_request(_pr_resp) + + def _code_auth2(self, state): + req = AuthorizationRequest( + state=state, + response_type="code", + redirect_uri="{}cb".format(CLI2), + scope=["openid"], + client_id="client_2", + ) + _pr_resp = self.authn_endpoint.parse_request(req.to_dict()) + _resp = self.authn_endpoint.process_request(_pr_resp) + + def _get_sid(self): + _sdb = self.session_endpoint.endpoint_context.sdb + + for _sid in _sdb.keys(): + if _sid.startswith("__state__"): + continue + else: + return _sid + + def _auth_with_id_token(self, state): + req = AuthorizationRequest( + state=state, + response_type="id_token", + redirect_uri="{}cb".format(CLI1), + scope=["openid"], + client_id="client_1", + nonce="_nonce_", + ) + _pr_resp = self.authn_endpoint.parse_request(req.to_dict()) + _resp = self.authn_endpoint.process_request(_pr_resp) + + return _resp["response_args"]["id_token"] + + def test_end_session_endpoint_with_cookie(self): + self._code_auth("1234567") + _sid = self._get_sid() + cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + + _req_args = self.session_endpoint.parse_request({"state": "1234567"}) + resp = self.session_endpoint.process_request(_req_args, cookie=cookie) + + # returns a signed JWT to be put in a verification web page shown to + # the user + + p = urlparse(resp["redirect_location"]) + qs = parse_qs(p.query) + jwt_info = self.session_endpoint.unpack_signed_jwt(qs["sjwt"][0]) + + assert jwt_info["user"] == "diana" + assert jwt_info["client_id"] == "client_1" + assert jwt_info["redirect_uri"] == "https://example.com/post_logout" + + def test_end_session_endpoint_with_wrong_cookie(self): + self._code_auth("1234567") + cookie = self._create_cookie("diana", "client_2", "abcdefg", "client_1") + + with pytest.raises(UnknownToken): + self.session_endpoint.process_request({"state": "abcde"}, cookie=cookie) + + def test_end_session_endpoint_with_cookie_wrong_user(self): + # Need cookie and ID Token to figure this out + id_token = self._auth_with_id_token("1234567") + + cookie = self._create_cookie("diggins", "_sid_", "1234567", "client_1") + + msg = Message(id_token=id_token) + verify_id_token(msg, keyjar=self.session_endpoint.endpoint_context.keyjar) + + msg2 = Message(id_token_hint=id_token) + msg2[verified_claim_name("id_token_hint")] = msg[ + verified_claim_name("id_token") + ] + with pytest.raises(ValueError): + self.session_endpoint.process_request(msg2, cookie=cookie) + + def test_end_session_endpoint_with_cookie_unknown_sid(self): + # Need cookie and ID Token to figure this out + id_token = self._auth_with_id_token("1234567") + + # Wrong client_id + cookie = self._create_cookie("diana", "_sid_", "state", "client_1") + + msg = Message(id_token=id_token) + verify_id_token(msg, keyjar=self.session_endpoint.endpoint_context.keyjar) + + msg2 = Message(id_token_hint=id_token) + msg2[verified_claim_name("id_token_hint")] = msg[ + verified_claim_name("id_token") + ] + with pytest.raises(ValueError): + self.session_endpoint.process_request(msg2, cookie=cookie) + + def test_end_session_endpoint_with_cookie_dual_login(self): + self._code_auth("1234567") + self._code_auth2("abcdefg") + _sdb = self.session_endpoint.endpoint_context.sdb + _sid = self._get_sid() + cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + + resp = self.session_endpoint.process_request({"state": "abcde"}, cookie=cookie) + + # returns a signed JWT to be put in a verification web page shown to + # the user + + p = urlparse(resp["redirect_location"]) + qs = parse_qs(p.query) + jwt_info = self.session_endpoint.unpack_signed_jwt(qs["sjwt"][0]) + + assert jwt_info["user"] == "diana" + assert jwt_info["client_id"] == "client_1" + assert jwt_info["redirect_uri"] == "https://example.com/post_logout" + + def test_end_session_endpoint_with_post_logout_redirect_uri(self): + self._code_auth("1234567") + self._code_auth2("abcdefg") + _sdb = self.session_endpoint.endpoint_context.sdb + _sid = self._get_sid() + cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + + post_logout_redirect_uri = join_query( + *self.session_endpoint.endpoint_context.cdb["client_1"][ + "post_logout_redirect_uris" + ][0] + ) + + with pytest.raises(InvalidRequest): + self.session_endpoint.process_request( + { + "post_logout_redirect_uri": post_logout_redirect_uri, + "state": "abcde", + }, + cookie=cookie, + ) + + def test_end_session_endpoint_with_wrong_post_logout_redirect_uri(self): + self._code_auth("1234567") + self._code_auth2("abcdefg") + + id_token = self._auth_with_id_token("1234567") + + _sdb = self.session_endpoint.endpoint_context.sdb + _sid = self._get_sid() + cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + + post_logout_redirect_uri = "https://demo.example.com/log_out" + + msg = Message(id_token=id_token) + verify_id_token(msg, keyjar=self.session_endpoint.endpoint_context.keyjar) + + with pytest.raises(RedirectURIError): + self.session_endpoint.process_request( + { + "post_logout_redirect_uri": post_logout_redirect_uri, + "state": "abcde", + "id_token_hint": id_token, + verified_claim_name("id_token_hint"): msg[ + verified_claim_name("id_token") + ], + }, + cookie=cookie, + ) + + def test_back_channel_logout_no_uri(self): + self._code_auth("1234567") + + res = self.session_endpoint.do_back_channel_logout( + self.session_endpoint.endpoint_context.cdb["client_1"], "username", 0 + ) + assert res is None + + def test_back_channel_logout(self): + self._code_auth("1234567") + + _cdb = copy.copy(self.session_endpoint.endpoint_context.cdb["client_1"]) + _cdb["backchannel_logout_uri"] = "https://example.com/bc_logout" + _cdb["client_id"] = "client_1" + res = self.session_endpoint.do_back_channel_logout(_cdb, "username", "_sid_") + assert isinstance(res, tuple) + assert res[0] == "https://example.com/bc_logout" + _jwt = self.session_endpoint.unpack_signed_jwt(res[1], "RS256") + assert _jwt + assert _jwt["iss"] == ISS + assert _jwt["aud"] == ["client_1"] + assert _jwt["sub"] == "username" + assert _jwt["sid"] == "_sid_" + + def test_front_channel_logout(self): + self._code_auth("1234567") + + _cdb = copy.copy(self.session_endpoint.endpoint_context.cdb["client_1"]) + _cdb["frontchannel_logout_uri"] = "https://example.com/fc_logout" + _cdb["client_id"] = "client_1" + res = do_front_channel_logout_iframe(_cdb, ISS, "_sid_") + assert res == '