From 65fb051bec4d03cc2ca73c59c66a49dfaa59aede Mon Sep 17 00:00:00 2001 From: jiasli <4003950+jiasli@users.noreply.github.com> Date: Fri, 25 Oct 2024 11:47:31 +0800 Subject: [PATCH] cloud-shell-auth --- src/azure-cli-core/azure/cli/core/_profile.py | 67 ++++---- .../cli/core/auth/adal_authentication.py | 2 +- .../azure/cli/core/auth/constants.py | 6 + .../azure/cli/core/auth/identity.py | 3 +- .../azure/cli/core/auth/msal_credentials.py | 23 +++ .../azure/cli/core/tests/test_profile.py | 154 +++++++++--------- 6 files changed, 144 insertions(+), 111 deletions(-) create mode 100644 src/azure-cli-core/azure/cli/core/auth/constants.py diff --git a/src/azure-cli-core/azure/cli/core/_profile.py b/src/azure-cli-core/azure/cli/core/_profile.py index 38829a46af8..a4e71416808 100644 --- a/src/azure-cli-core/azure/cli/core/_profile.py +++ b/src/azure-cli-core/azure/cli/core/_profile.py @@ -284,17 +284,16 @@ def login_with_managed_identity(self, identity_id=None, allow_no_subscriptions=N def login_in_cloud_shell(self): import jwt - from azure.cli.core.auth.adal_authentication import MSIAuthenticationWrapper + from .auth.msal_credentials import CloudShellCredential - msi_creds = MSIAuthenticationWrapper(resource=self.cli_ctx.cloud.endpoints.active_directory_resource_id) - token_entry = msi_creds.token - token = token_entry['access_token'] - logger.info('MSI: token was retrieved. Now trying to initialize local accounts...') + cred = CloudShellCredential() + token = cred.get_token(*self._arm_scope).token + logger.info('Cloud Shell token was retrieved. Now trying to initialize local accounts...') decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False}) tenant = decode['tid'] subscription_finder = SubscriptionFinder(self.cli_ctx) - subscriptions = subscription_finder.find_using_specific_tenant(tenant, msi_creds) + subscriptions = subscription_finder.find_using_specific_tenant(tenant, cred) if not subscriptions: raise CLIError('No subscriptions were found in the cloud shell') user = decode.get('unique_name', 'N/A') @@ -351,11 +350,19 @@ def get_login_credentials(self, resource=None, client_id=None, subscription_id=N managed_identity_type, managed_identity_id = Profile._try_parse_msi_account_name(account) - # Cloud Shell is just a system assignment managed identity if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): - managed_identity_type = MsiAccountTypes.system_assigned + # Cloud Shell + from .auth.msal_credentials import CloudShellCredential + from azure.cli.core.auth.credential_adaptor import CredentialAdaptor + cs_cred = CloudShellCredential() + # The cloud shell credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs. + cred = CredentialAdaptor(cs_cred, resource=resource) - if managed_identity_type is None: + elif managed_identity_type: + # managed identity + cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id, resource) + + else: # user and service principal external_tenants = [] if aux_tenants: @@ -375,9 +382,7 @@ def get_login_credentials(self, resource=None, client_id=None, subscription_id=N cred = CredentialAdaptor(credential, auxiliary_credentials=external_credentials, resource=resource) - else: - # managed identity - cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id, resource) + return (cred, str(account[_SUBSCRIPTION_ID]), str(account[_TENANT_ID])) @@ -397,27 +402,27 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No account = self.get_subscription(subscription) - identity_type, identity_id = Profile._try_parse_msi_account_name(account) - if identity_type: + managed_identity_type, managed_identity_id = Profile._try_parse_msi_account_name(account) + + if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): + # Cloud Shell + if tenant: + raise CLIError("Tenant shouldn't be specified for Cloud Shell account") + from .auth.msal_credentials import CloudShellCredential + cred = CloudShellCredential() + + elif managed_identity_type: # managed identity if tenant: raise CLIError("Tenant shouldn't be specified for managed identity account") from .auth.util import scopes_to_resource - msi_creds = MsiAccountTypes.msi_auth_factory(identity_type, identity_id, - scopes_to_resource(scopes)) - sdk_token = msi_creds.get_token(*scopes) - elif in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): - # Cloud Shell, which is just a system-assigned managed identity. - if tenant: - raise CLIError("Tenant shouldn't be specified for Cloud Shell account") - from .auth.util import scopes_to_resource - msi_creds = MsiAccountTypes.msi_auth_factory(MsiAccountTypes.system_assigned, identity_id, - scopes_to_resource(scopes)) - sdk_token = msi_creds.get_token(*scopes) + cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id, + scopes_to_resource(scopes)) + else: - credential = self._create_credential(account, tenant) - sdk_token = credential.get_token(*scopes) + cred = self._create_credential(account, tenant) + sdk_token = cred.get_token(*scopes) # Convert epoch int 'expires_on' to datetime string 'expiresOn' for backward compatibility # WARNING: expiresOn is deprecated and will be removed in future release. import datetime @@ -429,11 +434,11 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No 'expiresOn': expiresOn # datetime string, like "2020-11-12 13:50:47.114324" } - # (tokenType, accessToken, tokenEntry) - creds = 'Bearer', sdk_token.token, token_entry + # Build a tuple of (token_type, token, token_entry) + token_tuple = 'Bearer', sdk_token.token, token_entry - # (cred, subscription, tenant) - return (creds, + # Return a tuple of (token_tuple, subscription, tenant) + return (token_tuple, None if tenant else str(account[_SUBSCRIPTION_ID]), str(tenant if tenant else account[_TENANT_ID])) diff --git a/src/azure-cli-core/azure/cli/core/auth/adal_authentication.py b/src/azure-cli-core/azure/cli/core/auth/adal_authentication.py index 154db9e05d3..35174dafb0f 100644 --- a/src/azure-cli-core/azure/cli/core/auth/adal_authentication.py +++ b/src/azure-cli-core/azure/cli/core/auth/adal_authentication.py @@ -24,7 +24,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # Use MSAL to get VM SSH certificate import msal from .util import check_result, build_sdk_access_token - from .identity import AZURE_CLI_CLIENT_ID + from .constants import AZURE_CLI_CLIENT_ID app = msal.PublicClientApplication( AZURE_CLI_CLIENT_ID, # Use a real client_id, so that cache would work # TODO: This PoC does not currently maintain a token cache; diff --git a/src/azure-cli-core/azure/cli/core/auth/constants.py b/src/azure-cli-core/azure/cli/core/auth/constants.py new file mode 100644 index 00000000000..b011d7816ec --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/auth/constants.py @@ -0,0 +1,6 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +AZURE_CLI_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46' diff --git a/src/azure-cli-core/azure/cli/core/auth/identity.py b/src/azure-cli-core/azure/cli/core/auth/identity.py index 49894903fc0..595102fd915 100644 --- a/src/azure-cli-core/azure/cli/core/auth/identity.py +++ b/src/azure-cli-core/azure/cli/core/auth/identity.py @@ -13,12 +13,11 @@ from knack.util import CLIError from msal import PublicClientApplication, ConfidentialClientApplication +from .constants import AZURE_CLI_CLIENT_ID from .msal_credentials import UserCredential, ServicePrincipalCredential from .persistence import load_persisted_token_cache, file_extensions, load_secret_store from .util import check_result -AZURE_CLI_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46' - # Service principal entry properties. Names are taken from OAuth 2.0 client credentials flow parameters: # https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-client-creds-grant-flow _TENANT = 'tenant' diff --git a/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py b/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py index 8c6dfd0daf3..3b58ecdaa48 100644 --- a/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py +++ b/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py @@ -20,6 +20,7 @@ from knack.util import CLIError from msal import PublicClientApplication, ConfidentialClientApplication +from .constants import AZURE_CLI_CLIENT_ID from .util import check_result, build_sdk_access_token logger = get_logger(__name__) @@ -108,3 +109,25 @@ def get_token(self, *scopes, **kwargs): result = self._msal_app.acquire_token_for_client(list(scopes), **kwargs) check_result(result) return build_sdk_access_token(result) + + +class CloudShellCredential: # pylint: disable=too-few-public-methods + # Cloud Shell acts as a "broker" to obtain access token for the user account, so even though it uses + # managed identity protocol, it returns a user token. + # That's why MSAL uses acquire_token_interactive to retrieve an access token in Cloud Shell. + # See https://github.com/Azure/azure-cli/pull/29637 + + def __init__(self): + self._msal_app = PublicClientApplication( + AZURE_CLI_CLIENT_ID, # Use a real client_id, so that cache would work + # TODO: We currently don't maintain an MSAL token cache as Cloud Shell already has its own token cache. + # Ideally we should also use an MSAL token cache. + # token_cache=... + ) + + def get_token(self, *scopes, **kwargs): + logger.debug("CloudShellCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs) + # kwargs is already sanitized by CredentialAdaptor, so it can be safely passed to MSAL + result = self._msal_app.acquire_token_interactive(list(scopes), prompt="none", **kwargs) + check_result(result, scopes=scopes) + return build_sdk_access_token(result) diff --git a/src/azure-cli-core/azure/cli/core/tests/test_profile.py b/src/azure-cli-core/azure/cli/core/tests/test_profile.py index 2084148df9a..bc5da923953 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_profile.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_profile.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------------------------- # pylint: disable=protected-access +import base64 import json import datetime import unittest @@ -29,6 +30,15 @@ MOCK_TENANT_DEFAULT_DOMAIN = 'test.onmicrosoft.com' +def _build_test_jwt(claims): + parts = [ + '{"typ":"JWT","alg":"RS256"}', + json.dumps(claims, separators=(',', ':')), + 'test_sig' + ] + return '.'.join(base64.urlsafe_b64encode(p.encode('utf-8')).decode('utf-8').replace('=', '') for p in parts) + + class CredentialMock: def __init__(self, *args, **kwargs): @@ -81,11 +91,20 @@ def get_token(self, *args, **kwargs): return AccessToken(self.token['access_token'], int(self.token['expires_on'])) +class CloudShellCredentialStub: + def __init__(self): + self.get_token_scopes = None + + def get_token(self, *scopes, **kwargs): + self.get_token_scopes = scopes + return AccessToken(TestProfile.test_cloud_shell_access_token, MOCK_EXPIRES_ON_INT) + + class TestProfile(unittest.TestCase): @classmethod def setUpClass(cls): - cls.tenant_id = 'microsoft.com' + cls.tenant_id = 'test.onmicrosoft.com' cls.tenant_display_name = MOCK_TENANT_DISPLAY_NAME cls.tenant_default_domain = MOCK_TENANT_DEFAULT_DOMAIN @@ -112,14 +131,14 @@ def setUpClass(cls): managed_by_tenants=cls.managed_by_tenants) cls.subscription1_output = [{'environmentName': 'AzureCloud', - 'homeTenantId': 'microsoft.com', + 'homeTenantId': 'test.onmicrosoft.com', 'id': '1', 'isDefault': True, 'managedByTenants': [{'tenantId': '00000003-0000-0000-0000-000000000000'}, {'tenantId': '00000004-0000-0000-0000-000000000000'}], 'name': 'foo account', 'state': 'Enabled', - 'tenantId': 'microsoft.com', + 'tenantId': 'test.onmicrosoft.com', 'user': { 'name': 'foo@foo.com', 'type': 'user' @@ -127,14 +146,14 @@ def setUpClass(cls): cls.subscription1_with_tenant_info_output = [{ 'environmentName': 'AzureCloud', - 'homeTenantId': 'microsoft.com', + 'homeTenantId': 'test.onmicrosoft.com', 'id': '1', 'isDefault': True, 'managedByTenants': [{'tenantId': '00000003-0000-0000-0000-000000000000'}, {'tenantId': '00000004-0000-0000-0000-000000000000'}], 'name': 'foo account', 'state': 'Enabled', - 'tenantId': 'microsoft.com', + 'tenantId': 'test.onmicrosoft.com', 'tenantDisplayName': MOCK_TENANT_DISPLAY_NAME, 'tenantDefaultDomain': MOCK_TENANT_DEFAULT_DOMAIN, 'user': { @@ -225,48 +244,18 @@ def setUpClass(cls): 'homeTenantId': cls.tenant_id, 'managedByTenants': [], } - cls.test_msi_tenant = '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a' - cls.test_msi_access_token = ('eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsIng1dCI6IlZXVkljMVdEMVRrc2JiMzAxc2FzTTVrT3E1' - 'USIsImtpZCI6IlZXVkljMVdEMVRrc2JiMzAxc2FzTTVrT3E1USJ9.eyJhdWQiOiJodHRwczovL21hbmF' - 'nZW1lbnQuY29yZS53aW5kb3dzLm5ldC8iLCJpc3MiOiJodHRwczovL3N0cy53aW5kb3dzLm5ldC81NDg' - 'yNmIyMi0zOGQ2LTRmYjItYmFkOS1iN2I5M2EzZTljNWEvIiwiaWF0IjoxNTAzMzU0ODc2LCJuYmYiOjE' - '1MDMzNTQ4NzYsImV4cCI6MTUwMzM1ODc3NiwiYWNyIjoiMSIsImFpbyI6IkFTUUEyLzhFQUFBQTFGL1k' - '0VVR3bFI1Y091QXJxc1J0OU5UVVc2MGlsUHZna0daUC8xczVtdzg9IiwiYW1yIjpbInB3ZCJdLCJhcHB' - 'pZCI6IjA0YjA3Nzk1LThkZGItNDYxYS1iYmVlLTAyZjllMWJmN2I0NiIsImFwcGlkYWNyIjoiMCIsImV' - 'fZXhwIjoyNjI4MDAsImZhbWlseV9uYW1lIjoic2RrIiwiZ2l2ZW5fbmFtZSI6ImFkbWluMyIsImdyb3V' - 'wcyI6WyJlNGJiMGI1Ni0xMDE0LTQwZjgtODhhYi0zZDhhOGNiMGUwODYiLCI4YTliMTYxNy1mYzhkLTR' - 'hYTktYTQyZi05OTg2OGQzMTQ2OTkiLCI1NDgwMzkxNy00YzcxLTRkNmMtOGJkZi1iYmQ5MzEwMTBmOGM' - 'iXSwiaXBhZGRyIjoiMTY3LjIyMC4xLjIzNCIsIm5hbWUiOiJhZG1pbjMiLCJvaWQiOiJlN2UxNThkMy0' - '3Y2RjLTQ3Y2QtODgyNS01ODU5ZDdhYjJiNTUiLCJwdWlkIjoiMTAwMzNGRkY5NUQ0NEU4NCIsInNjcCI' - '6InVzZXJfaW1wZXJzb25hdGlvbiIsInN1YiI6ImhRenl3b3FTLUEtRzAySTl6ZE5TRmtGd3R2MGVwZ2l' - 'WY1Vsdm1PZEZHaFEiLCJ0aWQiOiI1NDgyNmIyMi0zOGQ2LTRmYjItYmFkOS1iN2I5M2EzZTljNWEiLCJ' - '1bmlxdWVfbmFtZSI6ImFkbWluM0BBenVyZVNES1RlYW0ub25taWNyb3NvZnQuY29tIiwidXBuIjoiYWR' - 'taW4zQEF6dXJlU0RLVGVhbS5vbm1pY3Jvc29mdC5jb20iLCJ1dGkiOiJuUEROYm04UFkwYUdELWhNeWx' - 'rVEFBIiwidmVyIjoiMS4wIiwid2lkcyI6WyI2MmU5MDM5NC02OWY1LTQyMzctOTE5MC0wMTIxNzcxNDV' - 'lMTAiXX0.Pg4cq0MuP1uGhY_h51ZZdyUYjGDUFgTW2EfIV4DaWT9RU7GIK_Fq9VGBTTbFZA0pZrrmP-z' - '7DlN9-U0A0nEYDoXzXvo-ACTkm9_TakfADd36YlYB5aLna-yO0B7rk5W9ANelkzUQgRfidSHtCmV6i4V' - 'e-lOym1sH5iOcxfIjXF0Tp2y0f3zM7qCq8Cp1ZxEwz6xYIgByoxjErNXrOME5Ld1WizcsaWxTXpwxJn_' - 'Q8U2g9kXHrbYFeY2gJxF_hnfLvNKxUKUBnftmyYxZwKi0GDS0BvdJnJnsqSRSpxUx__Ra9QJkG1IaDzj' - 'ZcSZPHK45T6ohK9Hk9ktZo0crVl7Tmw') - cls.test_user_msi_access_token = ('eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsIng1dCI6IlNzWnNCTmhaY0YzUTlTNHRycFFCVE' - 'J5TlJSSSIsImtpZCI6IlNzWnNCTmhaY0YzUTlTNHRycFFCVEJ5TlJSSSJ9.eyJhdWQiOiJodHR' - 'wczovL21hbmFnZW1lbnQuY29yZS53aW5kb3dzLm5ldCIsImlzcyI6Imh0dHBzOi8vc3RzLndpbm' - 'Rvd3MubmV0LzU0ODI2YjIyLTM4ZDYtNGZiMi1iYWQ5LWI3YjkzYTNlOWM1YS8iLCJpYXQiOjE1O' - 'TE3ODM5MDQsIm5iZiI6MTU5MTc4MzkwNCwiZXhwIjoxNTkxODcwNjA0LCJhaW8iOiI0MmRnWUZE' - 'd2JsZmR0WmYxck8zeGlMcVdtOU5MQVE9PSIsImFwcGlkIjoiNjJhYzQ5ZTYtMDQzOC00MTJjLWJ' - 'kZjUtNDg0ZTdkNDUyOTM2IiwiYXBwaWRhY3IiOiIyIiwiaWRwIjoiaHR0cHM6Ly9zdHMud2luZG' - '93cy5uZXQvNTQ4MjZiMjItMzhkNi00ZmIyLWJhZDktYjdiOTNhM2U5YzVhLyIsIm9pZCI6ImQ4M' - 'zRjNjZmLTNhZjgtNDBiNy1iNDYzLWViZGNlN2YzYTgyNyIsInN1YiI6ImQ4MzRjNjZmLTNhZjgt' - 'NDBiNy1iNDYzLWViZGNlN2YzYTgyNyIsInRpZCI6IjU0ODI2YjIyLTM4ZDYtNGZiMi1iYWQ5LWI' - '3YjkzYTNlOWM1YSIsInV0aSI6Ild2YjFyVlBQT1V5VjJDYmNyeHpBQUEiLCJ2ZXIiOiIxLjAiLC' - 'J4bXNfbWlyaWQiOiIvc3Vic2NyaXB0aW9ucy8wYjFmNjQ3MS0xYmYwLTRkZGEtYWVjMy1jYjkyNz' - 'JmMDk1OTAvcmVzb3VyY2Vncm91cHMvcWlhbndlbnMvcHJvdmlkZXJzL01pY3Jvc29mdC5NYW5hZ2' - 'VkSWRlbnRpdHkvdXNlckFzc2lnbmVkSWRlbnRpdGllcy9xaWFud2VuaWRlbnRpdHkifQ.nAxWA5_' - 'qTs_uwGoziKtDFAqxlmYSlyPGqAKZ8YFqFfm68r5Ouo2x2PztAv2D71L-j8B3GykNgW-2yhbB-z2' - 'h53dgjG2TVoeZjhV9DOpSJ06kLAeH-nskGxpBFf7se1qohlU7uyctsUMQWjXVUQbTEanJzj_IH-Y' - '47O3lvM4Yrliz5QUApm63VF4EhqNpNvb5w0HkuB72SJ0MKJt5VdQqNcG077NQNoiTJ34XVXkyNDp' - 'I15y0Cj504P_xw-Dpvg-hmEbykjFMIaB8RoSrp3BzYjNtJh2CHIuWhXF0ngza2SwN2CXK0Vpn5Za' - 'EvZdD57j3h8iGE0Tw5IzG86uNS2AQ0A') + + # A random GUID generated by uuid.uuid4() + cls.test_cloud_shell_tenant = 'ee59da2c-4d2c-4cfb-8753-ff9df4f31556' + # Cloud Shell returns a user token which contains the unique_name claim + cls.test_cloud_shell_access_token = _build_test_jwt({ + 'tid': cls.test_cloud_shell_tenant, + 'unique_name': 'foo@foo.com' + }) + + # A random GUID generated by uuid.uuid4() + cls.test_msi_tenant = 'b6f04d88-9bff-45da-a9b4-a0b6d3cb1b2a' + cls.test_msi_access_token = _build_test_jwt({'tid': cls.test_msi_tenant}) cls.msal_accounts = [ { @@ -456,14 +445,14 @@ def test_login_with_service_principal(self, login_with_service_principal_mock, subs = profile.login(False, 'my app', {'secret': 'very_secret'}, True, self.tenant_id, use_device_code=True, allow_no_subscriptions=False) output = [{'environmentName': 'AzureCloud', - 'homeTenantId': 'microsoft.com', + 'homeTenantId': 'test.onmicrosoft.com', 'id': '1', 'isDefault': True, 'managedByTenants': [{'tenantId': '00000003-0000-0000-0000-000000000000'}, {'tenantId': '00000004-0000-0000-0000-000000000000'}], 'name': 'foo account', 'state': 'Enabled', - 'tenantId': 'microsoft.com', + 'tenantId': 'test.onmicrosoft.com', 'user': { 'name': 'my app', 'type': 'servicePrincipal'}}] @@ -473,9 +462,9 @@ def test_login_with_service_principal(self, login_with_service_principal_mock, self.assertEqual(output, subs) @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) - @mock.patch('azure.cli.core.auth.adal_authentication.MSIAuthenticationWrapper', autospec=True) - def test_login_in_cloud_shell(self, msi_auth_mock, create_subscription_client_mock): - msi_auth_mock.return_value = MSRestAzureAuthStub() + @mock.patch('azure.cli.core.auth.msal_credentials.CloudShellCredential', autospec=True) + def test_login_in_cloud_shell(self, cloud_shell_credential_mock, create_subscription_client_mock): + cloud_shell_credential_mock.return_value = CloudShellCredentialStub() cli = DummyCli() mock_subscription_client = mock.MagicMock() @@ -487,13 +476,14 @@ def test_login_in_cloud_shell(self, msi_auth_mock, create_subscription_client_mo subscriptions = profile.login_in_cloud_shell() - # Check correct token is used - assert create_subscription_client_mock.call_args[0][1].token['access_token'] == TestProfile.test_msi_access_token + # Verify correct scopes are passed to get_token + credential_instance = create_subscription_client_mock.call_args.args[1] + assert credential_instance.get_token_scopes == ('https://management.core.windows.net//.default',) self.assertEqual(len(subscriptions), 1) s = subscriptions[0] - self.assertEqual(s['user']['name'], 'admin3@AzureSDKTeam.onmicrosoft.com') - self.assertEqual(s['tenantId'], '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a') + self.assertEqual(s['user']['name'], 'foo@foo.com') + self.assertEqual(s['tenantId'], self.test_cloud_shell_tenant) self.assertEqual(s['user']['cloudShellID'], True) self.assertEqual(s['user']['type'], 'user') self.assertEqual(s['name'], self.display_name1) @@ -530,7 +520,7 @@ def test_find_subscriptions_in_vm_with_msi_system_assigned(self, create_subscrip self.assertEqual(s['user']['assignedIdentityInfo'], 'MSI') self.assertEqual(s['name'], self.display_name1) self.assertEqual(s['id'], self.id1.split('/')[-1]) - self.assertEqual(s['tenantId'], '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a') + self.assertEqual(s['tenantId'], self.test_msi_tenant) @mock.patch('requests.get', autospec=True) @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) @@ -595,7 +585,7 @@ def test_find_subscriptions_in_vm_with_msi_user_assigned_with_client_id(self, cr s = subscriptions[0] self.assertEqual(s['name'], self.display_name1) self.assertEqual(s['id'], self.id1.split('/')[-1]) - self.assertEqual(s['tenantId'], '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a') + self.assertEqual(s['tenantId'], self.test_msi_tenant) self.assertEqual(s['user']['name'], 'userAssignedIdentity') self.assertEqual(s['user']['type'], 'servicePrincipal') @@ -1245,11 +1235,11 @@ def mi_auth_factory(*args, **kwargs): cred, subscription_id, _ = profile.get_raw_token(resource='http://test_resource', tenant=self.tenant_id) @mock.patch('azure.cli.core._profile.in_cloud_console', autospec=True) - @mock.patch('azure.cli.core.auth.adal_authentication.MSIAuthenticationWrapper', autospec=True) - def test_get_raw_token_in_cloud_console(self, mock_msi_auth, mock_in_cloud_console): + @mock.patch('azure.cli.core.auth.msal_credentials.CloudShellCredential', autospec=True) + def test_get_raw_token_in_cloud_shell(self, cloud_shell_credential_mock, mock_in_cloud_console): mock_in_cloud_console.return_value = True - # setup an existing msi subscription + # Set up an existing Cloud Shell account profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}) test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' test_tenant_id = '12345678-38d6-4fb2-bad9-b7b93a3e1234' @@ -1261,35 +1251,45 @@ def test_get_raw_token_in_cloud_console(self, mock_msi_auth, mock_in_cloud_conso consolidated[0]['user']['cloudShellID'] = True profile._set_subscriptions(consolidated) - mi_auth_instance = None + # The below code creates a credential instance and checks it. + # + # We can define a normal variable `credential_instance` here and use `nonlocal` to assign the credential + # instance to it, but using a mutable list also allows us to check how many instances are created. + # See https://stackoverflow.com/a/8448011/2199657 + # + # test_login_in_cloud_shell retrieves the credential instance from + # create_subscription_client_mock.call_args.args[1], so another possible way to retrieve the credential + # instance is to create a hook in get_raw_token and patch that hook during tests. + credential_instances = [] - def mi_auth_factory(*args, **kwargs): - nonlocal mi_auth_instance - mi_auth_instance = MSRestAzureAuthStub(*args, **kwargs) - return mi_auth_instance + def cloud_shell_credential_factory(): + credential = CloudShellCredentialStub() + credential_instances.append(credential) + return credential - mock_msi_auth.side_effect = mi_auth_factory + cloud_shell_credential_mock.side_effect = cloud_shell_credential_factory # action - cred, subscription_id, tenant_id = profile.get_raw_token(resource=self.adal_resource) + token_tuple, subscription_id, tenant_id = profile.get_raw_token(scopes=self.msal_scopes) - # Make sure resource/scopes are passed to MSIAuthenticationWrapper - assert mi_auth_instance.resource == self.adal_resource - assert list(mi_auth_instance.get_token_scopes) == self.msal_scopes + # Verify only one credential is created + assert len(credential_instances) == 1 + # Verify correct scopes are passed to get_token + assert list(credential_instances[0].get_token_scopes) == self.msal_scopes self.assertEqual(subscription_id, test_subscription_id) - self.assertEqual(cred[0], 'Bearer') - self.assertEqual(cred[1], TestProfile.test_msi_access_token) + self.assertEqual(token_tuple[0], 'Bearer') + self.assertEqual(token_tuple[1], TestProfile.test_cloud_shell_access_token) # Make sure expires_on and expiresOn are set - self.assertEqual(cred[2]['expires_on'], MOCK_EXPIRES_ON_INT) - self.assertEqual(cred[2]['expiresOn'], MOCK_EXPIRES_ON_DATETIME) + self.assertEqual(token_tuple[2]['expires_on'], MOCK_EXPIRES_ON_INT) + self.assertEqual(token_tuple[2]['expiresOn'], MOCK_EXPIRES_ON_DATETIME) self.assertEqual(subscription_id, test_subscription_id) self.assertEqual(tenant_id, test_tenant_id) - # verify tenant shouldn't be specified for Cloud Shell account + # Verify tenant shouldn't be specified for Cloud Shell account with self.assertRaisesRegex(CLIError, 'Cloud Shell'): - cred, subscription_id, _ = profile.get_raw_token(resource='http://test_resource', tenant=self.tenant_id) + profile.get_raw_token(resource='http://test_resource', tenant=self.tenant_id) @mock.patch('azure.cli.core.auth.identity.Identity.logout_service_principal') @mock.patch('azure.cli.core.auth.identity.Identity.logout_user')