From 2f9747fa995e55fe4034308d9d23133523827e09 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Fri, 21 Feb 2025 12:34:56 -0800 Subject: [PATCH] ManagedIdentityClient sends xms_cc and token_sha256_to_refresh to SF --- msal/managed_identity.py | 58 ++++++++++++++++++++++++++++------ tests/test_mi.py | 67 ++++++++++++++++++++++++++++++++++------ 2 files changed, 106 insertions(+), 19 deletions(-) diff --git a/msal/managed_identity.py b/msal/managed_identity.py index 6f85571d..230ff2f9 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -2,6 +2,7 @@ # All rights reserved. # # This code is licensed under the MIT License. +import hashlib import json import logging import os @@ -10,7 +11,7 @@ import time from urllib.parse import urlparse # Python 3+ from collections import UserDict # Python 3+ -from typing import Optional, Union # Needed in Python 3.7 & 3.8 +from typing import List, Optional, Union # Needed in Python 3.7 & 3.8 from .token_cache import TokenCache from .individual_cache import _IndividualCache as IndividualCache from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser @@ -162,6 +163,7 @@ def __init__( http_client, token_cache=None, http_cache=None, + client_capabilities: Optional[List[str]] = None, ): """Create a managed identity client. @@ -192,6 +194,17 @@ def __init__( Optional. It has the same characteristics as the :paramref:`msal.ClientApplication.http_cache`. + :param list[str] client_capabilities: (optional) + Allows configuration of one or more client capabilities, e.g. ["CP1"]. + + Client capability is meant to inform the Microsoft identity platform + (STS) what this client is capable for, + so STS can decide to turn on certain features. + + Implementation details: + Client capability in Managed Identity is relayed as-is + via ``xms_cc`` parameter on the wire. + Recipe 1: Hard code a managed identity for your app:: import msal, requests @@ -238,6 +251,7 @@ def __init__( http_cache=http_cache, ) self._token_cache = token_cache or TokenCache() + self._client_capabilities = client_capabilities def _get_instance(self): if self.__instance is None: @@ -266,8 +280,7 @@ def acquire_token_for_client( and then a *claims challenge* will be returned by the target resource, as a `claims_challenge` directive in the `www-authenticate` header, even if the app developer did not opt in for the "CP1" client capability. - Upon receiving a `claims_challenge`, MSAL will skip a token cache read, - and will attempt to acquire a new token. + Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token. .. note:: @@ -278,11 +291,13 @@ def acquire_token_for_client( This is a service-side behavior that cannot be changed by this library. `Azure VM docs `_ """ + access_token_to_refresh = None # This could become a public parameter in the future access_token_from_cache = None client_id_in_cache = self._managed_identity.get( ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY") now = time.time() - if not claims_challenge: # Then attempt token cache search + if True: # Attempt cache search even if receiving claims_challenge, + # because we want to locate the existing token (if any) and refresh it matches = self._token_cache.find( self._token_cache.CredentialType.ACCESS_TOKEN, target=[resource], @@ -297,6 +312,11 @@ def acquire_token_for_client( expires_in = int(entry["expires_on"]) - now if expires_in < 5*60: # Then consider it expired continue # Removal is not necessary, it will be overwritten + if claims_challenge and not access_token_to_refresh: + # Since caller did not pinpoint the token causing claims challenge, + # we have to assume it is the first token we found in cache. + access_token_to_refresh = entry["secret"] + break logger.debug("Cache hit an AT") access_token_from_cache = { # Mimic a real response "access_token": entry["secret"], @@ -310,7 +330,13 @@ def acquire_token_for_client( break # With a fallback in hand, we break here to go refresh return access_token_from_cache # It is still good as new try: - result = _obtain_token(self._http_client, self._managed_identity, resource) + result = _obtain_token( + self._http_client, self._managed_identity, resource, + access_token_sha256_to_refresh=hashlib.sha256( + access_token_to_refresh.encode("utf-8")).hexdigest() + if access_token_to_refresh else None, + client_capabilities=self._client_capabilities, + ) if "access_token" in result: expires_in = result.get("expires_in", 3600) if "refresh_in" not in result and expires_in >= 7200: @@ -385,8 +411,12 @@ def get_managed_identity_source(): return DEFAULT_TO_VM -def _obtain_token(http_client, managed_identity, resource): - # A unified low-level API that talks to different Managed Identity +def _obtain_token( + http_client, managed_identity, resource, + *, + access_token_sha256_to_refresh: Optional[str] = None, + client_capabilities: Optional[List[str]] = None, +): if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ and "IDENTITY_SERVER_THUMBPRINT" in os.environ ): @@ -402,6 +432,8 @@ def _obtain_token(http_client, managed_identity, resource): os.environ["IDENTITY_HEADER"], os.environ["IDENTITY_SERVER_THUMBPRINT"], resource, + access_token_sha256_to_refresh=access_token_sha256_to_refresh, + client_capabilities=client_capabilities, ) if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ: return _obtain_token_on_app_service( @@ -553,6 +585,9 @@ def _obtain_token_on_machine_learning( def _obtain_token_on_service_fabric( http_client, endpoint, identity_header, server_thumbprint, resource, + *, + access_token_sha256_to_refresh: str = None, + client_capabilities: Optional[List[str]] = None, ): """Obtains token for `Service Fabric `_ @@ -563,7 +598,12 @@ def _obtain_token_on_service_fabric( logger.debug("Obtaining token via managed identity on Azure Service Fabric") resp = http_client.get( endpoint, - params={"api-version": "2019-07-01-preview", "resource": resource}, + params={k: v for k, v in { + "api-version": "2019-07-01-preview", + "resource": resource, + "token_sha256_to_refresh": access_token_sha256_to_refresh, + "xms_cc": ",".join(client_capabilities) if client_capabilities else None, + }.items() if v is not None}, headers={"Secret": identity_header}, ) try: @@ -584,7 +624,7 @@ def _obtain_token_on_service_fabric( "ArgumentNullOrEmpty": "invalid_scope", } return { - "error": error_mapping.get(payload["error"]["code"], "invalid_request"), + "error": error_mapping.get(error.get("code"), "invalid_request"), "error_description": resp.text, } except json.decoder.JSONDecodeError: diff --git a/tests/test_mi.py b/tests/test_mi.py index a7c2cb6c..0fd432a5 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -1,7 +1,9 @@ +import hashlib import json import os import sys import time +from typing import List, Optional import unittest try: from unittest.mock import patch, ANY, mock_open, Mock @@ -52,15 +54,23 @@ def test_helper_class_should_be_interchangable_with_dict_which_could_be_loaded_f class ClientTestCase(unittest.TestCase): maxDiff = None - def setUp(self): - self.app = ManagedIdentityClient( + def _build_app( + self, + *, + client_capabilities: Optional[List[str]] = None, + ): + return ManagedIdentityClient( { # Here we test it with the raw dict form, to test that # the client has no hard dependency on ManagedIdentity object "ManagedIdentityIdType": "SystemAssigned", "Id": None, }, http_client=requests.Session(), + client_capabilities=client_capabilities, ) + def setUp(self): + self.app = self._build_app() + def test_error_out_on_invalid_input(self): with self.assertRaises(ManagedIdentityError): ManagedIdentityClient({"foo": "bar"}, http_client=requests.Session()) @@ -79,7 +89,13 @@ def assertCacheStatus(self, app): "Should have expected client_id") self.assertEqual("managed_identity", at["realm"], "Should have expected realm") - def _test_happy_path(self, app, mocked_http, expires_in, resource="R"): + def _test_happy_path( + self, app, mocked_http, expires_in, *, resource="R", claims_challenge=None, + ): + """It tests a normal token request that is expected to hit IdP, + a subsequent same token request that is expected to hit cache, + and then a request with claims_challenge that shall hit IdP again. + """ result = app.acquire_token_for_client(resource=resource) mocked_http.assert_called() call_count = mocked_http.call_count @@ -115,7 +131,8 @@ def _test_happy_path(self, app, mocked_http, expires_in, resource="R"): expected_refresh_on - 5 < result["refresh_on"] <= expected_refresh_on, "Should have a refresh_on time around the middle of the token's life") - result = app.acquire_token_for_client(resource=resource, claims_challenge="foo") + result = app.acquire_token_for_client( + resource=resource, claims_challenge=claims_challenge or "placeholder") self.assertEqual("identity_provider", result["token_source"], "Should miss cache") @@ -132,6 +149,9 @@ def _test_happy_path(self) -> callable: def test_happy_path_of_vm(self): self._test_happy_path().assert_called_with( + # The last call contained claims_challenge + # but since IMDS doesn't support token_sha256_to_refresh, + # the request shall remain the same as before 'http://169.254.169.254/metadata/identity/oauth2/token', params={'api-version': '2018-02-01', 'resource': 'R'}, headers={'Metadata': 'true'}, @@ -244,19 +264,46 @@ def test_machine_learning_error_should_be_normalized(self): "IDENTITY_SERVER_THUMBPRINT": "bar", }) class ServiceFabricTestCase(ClientTestCase): + access_token = "AT" + access_token_sha256 = hashlib.sha256(access_token.encode()).hexdigest() - def _test_happy_path(self, app): + def _test_happy_path(self, app, *, claims_challenge=None) -> callable: expires_in = 1234 with patch.object(app._http_client, "get", return_value=MinimalResponse( status_code=200, - text='{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % ( - int(time.time()) + expires_in), + text='{"access_token": "%s", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % ( + self.access_token, int(time.time()) + expires_in), )) as mocked_method: super(ServiceFabricTestCase, self)._test_happy_path( - app, mocked_method, expires_in) + app, mocked_method, expires_in, claims_challenge=claims_challenge) + return mocked_method - def test_happy_path(self): - self._test_happy_path(self.app) + def test_happy_path_with_client_capabilities_should_relay_capabilities(self): + self._test_happy_path(self._build_app(client_capabilities=["foo", "bar"])).assert_called_with( + 'http://localhost', + params={ + 'api-version': '2019-07-01-preview', + 'resource': 'R', + 'token_sha256_to_refresh': self.access_token_sha256, + "xms_cc": "foo,bar", + }, + headers={'Secret': 'foo'}, + ) + + def test_happy_path_with_claim_challenge_should_send_sha256_to_provider(self): + self._test_happy_path( + self._build_app(client_capabilities=[]), # Test empty client_capabilities + claims_challenge='{"access_token": {"nbf": {"essential": true, "value": "1563308371"}}}', + ).assert_called_with( + 'http://localhost', + params={ + 'api-version': '2019-07-01-preview', + 'resource': 'R', + 'token_sha256_to_refresh': self.access_token_sha256, + # There is no xms_cc in this case + }, + headers={'Secret': 'foo'}, + ) def test_unified_api_service_should_ignore_unnecessary_client_id(self): self._test_happy_path(ManagedIdentityClient(