Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 49 additions & 9 deletions msal/managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# All rights reserved.
#
# This code is licensed under the MIT License.
import hashlib
import json
import logging
import os
Expand All @@ -10,7 +11,7 @@
import time
from urllib.parse import urlparse # Python 3+
from collections import UserDict # Python 3+
from typing import Optional, Union # Needed in Python 3.7 & 3.8
from typing import List, Optional, Union # Needed in Python 3.7 & 3.8
from .token_cache import TokenCache
from .individual_cache import _IndividualCache as IndividualCache
from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser
Expand Down Expand Up @@ -162,6 +163,7 @@ def __init__(
http_client,
token_cache=None,
http_cache=None,
client_capabilities: Optional[List[str]] = None,
):
"""Create a managed identity client.

Expand Down Expand Up @@ -192,6 +194,17 @@ def __init__(
Optional. It has the same characteristics as the
:paramref:`msal.ClientApplication.http_cache`.

:param list[str] client_capabilities: (optional)
Allows configuration of one or more client capabilities, e.g. ["CP1"].

Client capability is meant to inform the Microsoft identity platform
(STS) what this client is capable for,
so STS can decide to turn on certain features.

Implementation details:
Client capability in Managed Identity is relayed as-is
via ``xms_cc`` parameter on the wire.

Recipe 1: Hard code a managed identity for your app::

import msal, requests
Expand Down Expand Up @@ -238,6 +251,7 @@ def __init__(
http_cache=http_cache,
)
self._token_cache = token_cache or TokenCache()
self._client_capabilities = client_capabilities

def _get_instance(self):
if self.__instance is None:
Expand Down Expand Up @@ -266,8 +280,7 @@ def acquire_token_for_client(
and then a *claims challenge* will be returned by the target resource,
as a `claims_challenge` directive in the `www-authenticate` header,
even if the app developer did not opt in for the "CP1" client capability.
Upon receiving a `claims_challenge`, MSAL will skip a token cache read,
and will attempt to acquire a new token.
Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token.

.. note::

Expand All @@ -278,11 +291,13 @@ def acquire_token_for_client(
This is a service-side behavior that cannot be changed by this library.
`Azure VM docs <https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>`_
"""
access_token_to_refresh = None # This could become a public parameter in the future
access_token_from_cache = None
client_id_in_cache = self._managed_identity.get(
ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY")
now = time.time()
if not claims_challenge: # Then attempt token cache search
if True: # Attempt cache search even if receiving claims_challenge,
# because we want to locate the existing token (if any) and refresh it
matches = self._token_cache.find(
self._token_cache.CredentialType.ACCESS_TOKEN,
target=[resource],
Expand All @@ -297,6 +312,11 @@ def acquire_token_for_client(
expires_in = int(entry["expires_on"]) - now
if expires_in < 5*60: # Then consider it expired
continue # Removal is not necessary, it will be overwritten
if claims_challenge and not access_token_to_refresh:
# Since caller did not pinpoint the token causing claims challenge,
# we have to assume it is the first token we found in cache.
access_token_to_refresh = entry["secret"]
break
logger.debug("Cache hit an AT")
access_token_from_cache = { # Mimic a real response
"access_token": entry["secret"],
Expand All @@ -310,7 +330,13 @@ def acquire_token_for_client(
break # With a fallback in hand, we break here to go refresh
return access_token_from_cache # It is still good as new
try:
result = _obtain_token(self._http_client, self._managed_identity, resource)
result = _obtain_token(
self._http_client, self._managed_identity, resource,
access_token_sha256_to_refresh=hashlib.sha256(
access_token_to_refresh.encode("utf-8")).hexdigest()
if access_token_to_refresh else None,
client_capabilities=self._client_capabilities,
)
if "access_token" in result:
expires_in = result.get("expires_in", 3600)
if "refresh_in" not in result and expires_in >= 7200:
Expand Down Expand Up @@ -385,8 +411,12 @@ def get_managed_identity_source():
return DEFAULT_TO_VM


def _obtain_token(http_client, managed_identity, resource):
# A unified low-level API that talks to different Managed Identity
def _obtain_token(
http_client, managed_identity, resource,
*,
access_token_sha256_to_refresh: Optional[str] = None,
client_capabilities: Optional[List[str]] = None,
):
if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ
and "IDENTITY_SERVER_THUMBPRINT" in os.environ
):
Expand All @@ -402,6 +432,8 @@ def _obtain_token(http_client, managed_identity, resource):
os.environ["IDENTITY_HEADER"],
os.environ["IDENTITY_SERVER_THUMBPRINT"],
resource,
access_token_sha256_to_refresh=access_token_sha256_to_refresh,
client_capabilities=client_capabilities,
)
if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ:
return _obtain_token_on_app_service(
Expand Down Expand Up @@ -553,6 +585,9 @@ def _obtain_token_on_machine_learning(

def _obtain_token_on_service_fabric(
http_client, endpoint, identity_header, server_thumbprint, resource,
*,
access_token_sha256_to_refresh: str = None,
client_capabilities: Optional[List[str]] = None,
):
"""Obtains token for
`Service Fabric <https://learn.microsoft.com/en-us/azure/service-fabric/>`_
Expand All @@ -563,7 +598,12 @@ def _obtain_token_on_service_fabric(
logger.debug("Obtaining token via managed identity on Azure Service Fabric")
resp = http_client.get(
endpoint,
params={"api-version": "2019-07-01-preview", "resource": resource},
params={k: v for k, v in {
"api-version": "2019-07-01-preview",
"resource": resource,
"token_sha256_to_refresh": access_token_sha256_to_refresh,
"xms_cc": ",".join(client_capabilities) if client_capabilities else None,
}.items() if v is not None},
headers={"Secret": identity_header},
)
try:
Expand All @@ -584,7 +624,7 @@ def _obtain_token_on_service_fabric(
"ArgumentNullOrEmpty": "invalid_scope",
}
return {
"error": error_mapping.get(payload["error"]["code"], "invalid_request"),
"error": error_mapping.get(error.get("code"), "invalid_request"),
"error_description": resp.text,
}
except json.decoder.JSONDecodeError:
Expand Down
67 changes: 57 additions & 10 deletions tests/test_mi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import hashlib
import json
import os
import sys
import time
from typing import List, Optional
import unittest
try:
from unittest.mock import patch, ANY, mock_open, Mock
Expand Down Expand Up @@ -52,15 +54,23 @@ def test_helper_class_should_be_interchangable_with_dict_which_could_be_loaded_f
class ClientTestCase(unittest.TestCase):
maxDiff = None

def setUp(self):
self.app = ManagedIdentityClient(
def _build_app(
self,
*,
client_capabilities: Optional[List[str]] = None,
):
return ManagedIdentityClient(
{ # Here we test it with the raw dict form, to test that
# the client has no hard dependency on ManagedIdentity object
"ManagedIdentityIdType": "SystemAssigned", "Id": None,
},
http_client=requests.Session(),
client_capabilities=client_capabilities,
)

def setUp(self):
self.app = self._build_app()

def test_error_out_on_invalid_input(self):
with self.assertRaises(ManagedIdentityError):
ManagedIdentityClient({"foo": "bar"}, http_client=requests.Session())
Expand All @@ -79,7 +89,13 @@ def assertCacheStatus(self, app):
"Should have expected client_id")
self.assertEqual("managed_identity", at["realm"], "Should have expected realm")

def _test_happy_path(self, app, mocked_http, expires_in, resource="R"):
def _test_happy_path(
self, app, mocked_http, expires_in, *, resource="R", claims_challenge=None,
):
"""It tests a normal token request that is expected to hit IdP,
a subsequent same token request that is expected to hit cache,
and then a request with claims_challenge that shall hit IdP again.
"""
result = app.acquire_token_for_client(resource=resource)
mocked_http.assert_called()
call_count = mocked_http.call_count
Expand Down Expand Up @@ -115,7 +131,8 @@ def _test_happy_path(self, app, mocked_http, expires_in, resource="R"):
expected_refresh_on - 5 < result["refresh_on"] <= expected_refresh_on,
"Should have a refresh_on time around the middle of the token's life")

result = app.acquire_token_for_client(resource=resource, claims_challenge="foo")
result = app.acquire_token_for_client(
resource=resource, claims_challenge=claims_challenge or "placeholder")
self.assertEqual("identity_provider", result["token_source"], "Should miss cache")


Expand All @@ -132,6 +149,9 @@ def _test_happy_path(self) -> callable:

def test_happy_path_of_vm(self):
self._test_happy_path().assert_called_with(
# The last call contained claims_challenge
# but since IMDS doesn't support token_sha256_to_refresh,
# the request shall remain the same as before
'http://169.254.169.254/metadata/identity/oauth2/token',
params={'api-version': '2018-02-01', 'resource': 'R'},
headers={'Metadata': 'true'},
Expand Down Expand Up @@ -244,19 +264,46 @@ def test_machine_learning_error_should_be_normalized(self):
"IDENTITY_SERVER_THUMBPRINT": "bar",
})
class ServiceFabricTestCase(ClientTestCase):
access_token = "AT"
access_token_sha256 = hashlib.sha256(access_token.encode()).hexdigest()

def _test_happy_path(self, app):
def _test_happy_path(self, app, *, claims_challenge=None) -> callable:
expires_in = 1234
with patch.object(app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % (
int(time.time()) + expires_in),
text='{"access_token": "%s", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % (
self.access_token, int(time.time()) + expires_in),
)) as mocked_method:
super(ServiceFabricTestCase, self)._test_happy_path(
app, mocked_method, expires_in)
app, mocked_method, expires_in, claims_challenge=claims_challenge)
return mocked_method

def test_happy_path(self):
self._test_happy_path(self.app)
def test_happy_path_with_client_capabilities_should_relay_capabilities(self):
self._test_happy_path(self._build_app(client_capabilities=["foo", "bar"])).assert_called_with(
'http://localhost',
params={
'api-version': '2019-07-01-preview',
'resource': 'R',
'token_sha256_to_refresh': self.access_token_sha256,
"xms_cc": "foo,bar",
},
headers={'Secret': 'foo'},
)

def test_happy_path_with_claim_challenge_should_send_sha256_to_provider(self):
self._test_happy_path(
self._build_app(client_capabilities=[]), # Test empty client_capabilities
claims_challenge='{"access_token": {"nbf": {"essential": true, "value": "1563308371"}}}',
).assert_called_with(
'http://localhost',
params={
'api-version': '2019-07-01-preview',
'resource': 'R',
'token_sha256_to_refresh': self.access_token_sha256,
# There is no xms_cc in this case
},
headers={'Secret': 'foo'},
)

def test_unified_api_service_should_ignore_unnecessary_client_id(self):
self._test_happy_path(ManagedIdentityClient(
Expand Down