From 8b283af9f0b2ec076ba469431847e21d1db9048a Mon Sep 17 00:00:00 2001 From: jiasli <4003950+jiasli@users.noreply.github.com> Date: Thu, 26 Jun 2025 19:17:07 +0800 Subject: [PATCH] get_msal_token --- src/azure-cli-core/azure/cli/core/_profile.py | 27 ++++++++++++------- .../azure/cli/core/tests/test_profile.py | 27 +++++++++++++++++++ 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/src/azure-cli-core/azure/cli/core/_profile.py b/src/azure-cli-core/azure/cli/core/_profile.py index 0c4583c078b..e2c4e4e905f 100644 --- a/src/azure-cli-core/azure/cli/core/_profile.py +++ b/src/azure-cli-core/azure/cli/core/_profile.py @@ -299,7 +299,8 @@ def logout_all(self): identity.logout_all_users() identity.logout_all_service_principal() - def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, aux_tenants=None): + def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, aux_tenants=None, + sdk_credential=True): """Get a credential compatible with Track 2 SDK.""" if aux_tenants and aux_subscriptions: raise CLIError("Please specify only one of aux_subscriptions and aux_tenants, not both") @@ -307,18 +308,15 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au account = self.get_subscription(subscription_id) managed_identity_type, managed_identity_id = Profile._parse_managed_identity_account(account) - + external_credentials = None if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): # Cloud Shell from .auth.msal_credentials import CloudShellCredential - # The credential must be wrapped by CredentialAdaptor so that it can work with SDK. - sdk_cred = CredentialAdaptor(CloudShellCredential()) + cred = CloudShellCredential() elif managed_identity_type: # managed identity - # The credential must be wrapped by CredentialAdaptor so that it can work with SDK. cred = ManagedIdentityAuth.credential_factory(managed_identity_type, managed_identity_id) - sdk_cred = CredentialAdaptor(cred) else: # user and service principal @@ -332,13 +330,15 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au if sub[_TENANT_ID] != account[_TENANT_ID]: external_tenants.append(sub[_TENANT_ID]) - credential = self._create_credential(account) + cred = self._create_credential(account) external_credentials = [] for external_tenant in external_tenants: external_credentials.append(self._create_credential(account, tenant_id=external_tenant)) - sdk_cred = CredentialAdaptor(credential, auxiliary_credentials=external_credentials) - return (sdk_cred, + # Wrapping the credential with CredentialAdaptor makes it compatible with SDK. + cred_result = CredentialAdaptor(cred, auxiliary_credentials=external_credentials) if sdk_credential else cred + + return (cred_result, str(account[_SUBSCRIPTION_ID]), str(account[_TENANT_ID])) @@ -401,6 +401,15 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No None if tenant else str(account[_SUBSCRIPTION_ID]), str(tenant if tenant else account[_TENANT_ID])) + def get_msal_token(self, scopes, data): + """Get VM SSH certificate. DO NOT use it for other purposes. To get an access token, use get_raw_token instead. + """ + credential, _, _ = self.get_login_credentials(sdk_credential=False) + from .auth.constants import ACCESS_TOKEN + certificate_string = credential.acquire_token(scopes, data=data)[ACCESS_TOKEN] + # The first value used to be username, but it is no longer used. + return None, certificate_string + def _normalize_properties(self, user, subscriptions, is_service_principal, cert_sn_issuer_auth=None, assigned_identity_info=None): consolidated = [] diff --git a/src/azure-cli-core/azure/cli/core/tests/test_profile.py b/src/azure-cli-core/azure/cli/core/tests/test_profile.py index c5e7d70a815..061954f1766 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_profile.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_profile.py @@ -50,10 +50,12 @@ def __init__(self, *args, **kwargs): # If acquire_token_scopes is checked, make sure to create a new instance of MsalCredentialStub # to avoid interference from other tests. self.acquire_token_scopes = None + self.acquire_token_data=None super().__init__() def acquire_token(self, scopes, **kwargs): self.acquire_token_scopes = scopes + self.acquire_token_data = kwargs.get('data') return { 'access_token': MOCK_ACCESS_TOKEN, 'token_type': 'Bearer', @@ -1287,6 +1289,31 @@ def cloud_shell_credential_factory(): with self.assertRaisesRegex(CLIError, 'Cloud Shell'): profile.get_raw_token(resource='http://test_resource', tenant=self.tenant_id) + @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential') + def test_get_msal_token(self, get_user_credential_mock): + credential_mock_temp = MsalCredentialStub() + get_user_credential_mock.return_value = credential_mock_temp + cli = DummyCli() + + storage_mock = {'subscriptions': None} + profile = Profile(cli_ctx=cli, storage=storage_mock) + consolidated = profile._normalize_properties(self.user1, + [self.subscription1], + False, None, None) + profile._set_subscriptions(consolidated) + + MOCK_DATA = { + 'key_id': 'test', + 'req_cnf': 'test', + 'token_type': 'ssh-cert' + } + result = profile.get_msal_token(['https://pas.windows.net/CheckMyAccess/Linux/.default'], + MOCK_DATA) + + assert result == (None, MOCK_ACCESS_TOKEN) + assert credential_mock_temp.acquire_token_scopes == ['https://pas.windows.net/CheckMyAccess/Linux/.default'] + assert credential_mock_temp.acquire_token_data == MOCK_DATA + @mock.patch('azure.cli.core.auth.identity.Identity.logout_service_principal') @mock.patch('azure.cli.core.auth.identity.Identity.logout_user') def test_logout(self, logout_user_mock, logout_service_principal_mock):