Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,26 +299,24 @@ def logout_all(self):
identity.logout_all_users()
identity.logout_all_service_principal()

def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, aux_tenants=None):
def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, aux_tenants=None,
sdk_credential=True):
"""Get a credential compatible with Track 2 SDK."""
if aux_tenants and aux_subscriptions:
raise CLIError("Please specify only one of aux_subscriptions and aux_tenants, not both")

account = self.get_subscription(subscription_id)

managed_identity_type, managed_identity_id = Profile._parse_managed_identity_account(account)

external_credentials = None
if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID):
# Cloud Shell
from .auth.msal_credentials import CloudShellCredential
# The credential must be wrapped by CredentialAdaptor so that it can work with SDK.
sdk_cred = CredentialAdaptor(CloudShellCredential())
cred = CloudShellCredential()

elif managed_identity_type:
# managed identity
# The credential must be wrapped by CredentialAdaptor so that it can work with SDK.
cred = ManagedIdentityAuth.credential_factory(managed_identity_type, managed_identity_id)
sdk_cred = CredentialAdaptor(cred)

else:
# user and service principal
Expand All @@ -332,13 +330,15 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au
if sub[_TENANT_ID] != account[_TENANT_ID]:
external_tenants.append(sub[_TENANT_ID])

credential = self._create_credential(account)
cred = self._create_credential(account)
external_credentials = []
for external_tenant in external_tenants:
external_credentials.append(self._create_credential(account, tenant_id=external_tenant))
sdk_cred = CredentialAdaptor(credential, auxiliary_credentials=external_credentials)

return (sdk_cred,
# Wrapping the credential with CredentialAdaptor makes it compatible with SDK.
cred_result = CredentialAdaptor(cred, auxiliary_credentials=external_credentials) if sdk_credential else cred

return (cred_result,
str(account[_SUBSCRIPTION_ID]),
str(account[_TENANT_ID]))

Expand Down Expand Up @@ -401,6 +401,15 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No
None if tenant else str(account[_SUBSCRIPTION_ID]),
str(tenant if tenant else account[_TENANT_ID]))

def get_msal_token(self, scopes, data):
"""Get VM SSH certificate. DO NOT use it for other purposes. To get an access token, use get_raw_token instead.
"""
credential, _, _ = self.get_login_credentials(sdk_credential=False)
from .auth.constants import ACCESS_TOKEN
certificate_string = credential.acquire_token(scopes, data=data)[ACCESS_TOKEN]
# The first value used to be username, but it is no longer used.
return None, certificate_string

def _normalize_properties(self, user, subscriptions, is_service_principal, cert_sn_issuer_auth=None,
assigned_identity_info=None):
consolidated = []
Expand Down
27 changes: 27 additions & 0 deletions src/azure-cli-core/azure/cli/core/tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ def __init__(self, *args, **kwargs):
# If acquire_token_scopes is checked, make sure to create a new instance of MsalCredentialStub
# to avoid interference from other tests.
self.acquire_token_scopes = None
self.acquire_token_data=None
super().__init__()

def acquire_token(self, scopes, **kwargs):
self.acquire_token_scopes = scopes
self.acquire_token_data = kwargs.get('data')
return {
'access_token': MOCK_ACCESS_TOKEN,
'token_type': 'Bearer',
Expand Down Expand Up @@ -1287,6 +1289,31 @@ def cloud_shell_credential_factory():
with self.assertRaisesRegex(CLIError, 'Cloud Shell'):
profile.get_raw_token(resource='http://test_resource', tenant=self.tenant_id)

@mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential')
def test_get_msal_token(self, get_user_credential_mock):
credential_mock_temp = MsalCredentialStub()
get_user_credential_mock.return_value = credential_mock_temp
cli = DummyCli()

storage_mock = {'subscriptions': None}
profile = Profile(cli_ctx=cli, storage=storage_mock)
consolidated = profile._normalize_properties(self.user1,
[self.subscription1],
False, None, None)
profile._set_subscriptions(consolidated)

MOCK_DATA = {
'key_id': 'test',
'req_cnf': 'test',
'token_type': 'ssh-cert'
}
result = profile.get_msal_token(['https://pas.windows.net/CheckMyAccess/Linux/.default'],
MOCK_DATA)

assert result == (None, MOCK_ACCESS_TOKEN)
assert credential_mock_temp.acquire_token_scopes == ['https://pas.windows.net/CheckMyAccess/Linux/.default']
assert credential_mock_temp.acquire_token_data == MOCK_DATA

@mock.patch('azure.cli.core.auth.identity.Identity.logout_service_principal')
@mock.patch('azure.cli.core.auth.identity.Identity.logout_user')
def test_logout(self, logout_user_mock, logout_service_principal_mock):
Expand Down
Loading