diff --git a/msal/managed_identity.py b/msal/managed_identity.py index 230ff2f9..866d379a 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -298,7 +298,7 @@ def acquire_token_for_client( now = time.time() if True: # Attempt cache search even if receiving claims_challenge, # because we want to locate the existing token (if any) and refresh it - matches = self._token_cache.find( + matches = self._token_cache.search( self._token_cache.CredentialType.ACCESS_TOKEN, target=[resource], query=dict( diff --git a/tests/test_application.py b/tests/test_application.py index 0c7f2d29..16e512c4 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -153,8 +153,8 @@ def tester(url, data=None, **kwargs): return MinimalResponse(status_code=400, text=error_response) app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( app.authority, self.scopes, self.account, post=tester) - self.assertNotEqual([], app.token_cache.find( - msal.TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": self.frt}), + self.assertIsNotNone(next(app.token_cache.search( + msal.TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": self.frt}), None), "The FRT should not be removed from the cache") def test_known_orphan_app_will_skip_frt_and_only_use_its_own_rt(self): @@ -187,11 +187,11 @@ def tester(url, data=None, **kwargs): app.authority, self.scopes, self.account, post=tester) logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) self.assertEqual("at", at.get("access_token"), "New app should get a new AT") - app_metadata = app.token_cache.find( + app_metadata = next(app.token_cache.search( msal.TokenCache.CredentialType.APP_METADATA, - query={"client_id": app.client_id}) - self.assertNotEqual([], app_metadata, "Should record new app's metadata") - self.assertEqual("1", app_metadata[0].get("family_id"), + query={"client_id": app.client_id}), None) + self.assertIsNotNone(app_metadata, "Should record new app's metadata") + self.assertEqual("1", app_metadata.get("family_id"), "The new family app should be recorded as in the same family") # Known family app will simply use FRT, which is largely the same as this one @@ -218,25 +218,25 @@ def test_family_app_remove_account(self): account = app.get_accounts()[0] mine = {"home_account_id": account["home_account_id"]} - self.assertNotEqual([], self.cache.find( - self.cache.CredentialType.ACCESS_TOKEN, query=mine)) - self.assertNotEqual([], self.cache.find( - self.cache.CredentialType.REFRESH_TOKEN, query=mine)) - self.assertNotEqual([], self.cache.find( - self.cache.CredentialType.ID_TOKEN, query=mine)) - self.assertNotEqual([], self.cache.find( - self.cache.CredentialType.ACCOUNT, query=mine)) + self.assertIsNotNone(next(self.cache.search( + self.cache.CredentialType.ACCESS_TOKEN, query=mine), None)) + self.assertIsNotNone(next(self.cache.search( + self.cache.CredentialType.REFRESH_TOKEN, query=mine), None)) + self.assertIsNotNone(next(self.cache.search( + self.cache.CredentialType.ID_TOKEN, query=mine), None)) + self.assertIsNotNone(next(self.cache.search( + self.cache.CredentialType.ACCOUNT, query=mine), None)) app.remove_account(account) - self.assertEqual([], self.cache.find( - self.cache.CredentialType.ACCESS_TOKEN, query=mine)) - self.assertEqual([], self.cache.find( - self.cache.CredentialType.REFRESH_TOKEN, query=mine)) - self.assertEqual([], self.cache.find( - self.cache.CredentialType.ID_TOKEN, query=mine)) - self.assertEqual([], self.cache.find( - self.cache.CredentialType.ACCOUNT, query=mine)) + self.assertIsNone(next(self.cache.search( + self.cache.CredentialType.ACCESS_TOKEN, query=mine), None)) + self.assertIsNone(next(self.cache.search( + self.cache.CredentialType.REFRESH_TOKEN, query=mine), None)) + self.assertIsNone(next(self.cache.search( + self.cache.CredentialType.ID_TOKEN, query=mine), None)) + self.assertIsNone(next(self.cache.search( + self.cache.CredentialType.ACCOUNT, query=mine), None)) class TestClientApplicationForAuthorityMigration(unittest.TestCase): @@ -711,14 +711,14 @@ def test_remove_tokens_for_client_should_remove_client_tokens_only(self): cca = msal.ConfidentialClientApplication( "client_id", client_credential="secret", authority="https://login.microsoftonline.com/microsoft.onmicrosoft.com") - self.assertEqual( - 0, len(cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN))) + self.assertIsNone(next(cca.token_cache.search( + msal.TokenCache.CredentialType.ACCESS_TOKEN), None)) cca.acquire_token_for_client( ["scope"], post=lambda url, **kwargs: MinimalResponse( status_code=200, text=json.dumps({"access_token": "AT for client"}))) - self.assertEqual( - 1, len(cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN))) + self.assertEqual(1, len(list(cca.token_cache.search( + msal.TokenCache.CredentialType.ACCESS_TOKEN)))) cca.acquire_token_by_username_password( "johndoe", "password", ["scope"], post=lambda url, **kwargs: MinimalResponse( @@ -726,10 +726,11 @@ def test_remove_tokens_for_client_should_remove_client_tokens_only(self): access_token=at_for_user, expires_in=3600, uid="uid", utid="utid", # This populates home_account_id )))) - self.assertEqual( - 2, len(cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN))) + self.assertEqual(2, len(list(cca.token_cache.search( + msal.TokenCache.CredentialType.ACCESS_TOKEN)))) cca.remove_tokens_for_client() - remaining_tokens = cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN) + remaining_tokens = list(cca.token_cache.search( + msal.TokenCache.CredentialType.ACCESS_TOKEN)) self.assertEqual(1, len(remaining_tokens)) self.assertEqual(at_for_user, remaining_tokens[0].get("secret")) diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 494d6daf..5310b789 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -2,6 +2,7 @@ import base64 import json import time +import warnings from msal.token_cache import TokenCache, SerializableTokenCache from tests import unittest @@ -83,10 +84,11 @@ def testAddByAad(self): } self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get( self.at_key_maker(**access_token_entry))) - self.assertIn( - access_token_entry, - self.cache.find(self.cache.CredentialType.ACCESS_TOKEN, now=now), - "find(..., query=None) should not crash, even though MSAL does not use it") + with warnings.catch_warnings(): + self.assertIn( + access_token_entry, + self.cache.find(self.cache.CredentialType.ACCESS_TOKEN, now=now), + "find(..., query=None) should not crash, even though MSAL does not use it") self.assertEqual( { 'client_id': 'my_client_id',