Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion msal/managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def acquire_token_for_client(
now = time.time()
if True: # Attempt cache search even if receiving claims_challenge,
# because we want to locate the existing token (if any) and refresh it
matches = self._token_cache.find(
matches = self._token_cache.search(
self._token_cache.CredentialType.ACCESS_TOKEN,
target=[resource],
query=dict(
Expand Down
59 changes: 30 additions & 29 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def tester(url, data=None, **kwargs):
return MinimalResponse(status_code=400, text=error_response)
app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
app.authority, self.scopes, self.account, post=tester)
self.assertNotEqual([], app.token_cache.find(
msal.TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": self.frt}),
self.assertIsNotNone(next(app.token_cache.search(
msal.TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": self.frt}), None),
"The FRT should not be removed from the cache")

def test_known_orphan_app_will_skip_frt_and_only_use_its_own_rt(self):
Expand Down Expand Up @@ -187,11 +187,11 @@ def tester(url, data=None, **kwargs):
app.authority, self.scopes, self.account, post=tester)
logger.debug("%s.cache = %s", self.id(), self.cache.serialize())
self.assertEqual("at", at.get("access_token"), "New app should get a new AT")
app_metadata = app.token_cache.find(
app_metadata = next(app.token_cache.search(
msal.TokenCache.CredentialType.APP_METADATA,
query={"client_id": app.client_id})
self.assertNotEqual([], app_metadata, "Should record new app's metadata")
self.assertEqual("1", app_metadata[0].get("family_id"),
query={"client_id": app.client_id}), None)
self.assertIsNotNone(app_metadata, "Should record new app's metadata")
self.assertEqual("1", app_metadata.get("family_id"),
"The new family app should be recorded as in the same family")
# Known family app will simply use FRT, which is largely the same as this one

Expand All @@ -218,25 +218,25 @@ def test_family_app_remove_account(self):
account = app.get_accounts()[0]
mine = {"home_account_id": account["home_account_id"]}

self.assertNotEqual([], self.cache.find(
self.cache.CredentialType.ACCESS_TOKEN, query=mine))
self.assertNotEqual([], self.cache.find(
self.cache.CredentialType.REFRESH_TOKEN, query=mine))
self.assertNotEqual([], self.cache.find(
self.cache.CredentialType.ID_TOKEN, query=mine))
self.assertNotEqual([], self.cache.find(
self.cache.CredentialType.ACCOUNT, query=mine))
self.assertIsNotNone(next(self.cache.search(
self.cache.CredentialType.ACCESS_TOKEN, query=mine), None))
self.assertIsNotNone(next(self.cache.search(
self.cache.CredentialType.REFRESH_TOKEN, query=mine), None))
self.assertIsNotNone(next(self.cache.search(
self.cache.CredentialType.ID_TOKEN, query=mine), None))
self.assertIsNotNone(next(self.cache.search(
self.cache.CredentialType.ACCOUNT, query=mine), None))

app.remove_account(account)

self.assertEqual([], self.cache.find(
self.cache.CredentialType.ACCESS_TOKEN, query=mine))
self.assertEqual([], self.cache.find(
self.cache.CredentialType.REFRESH_TOKEN, query=mine))
self.assertEqual([], self.cache.find(
self.cache.CredentialType.ID_TOKEN, query=mine))
self.assertEqual([], self.cache.find(
self.cache.CredentialType.ACCOUNT, query=mine))
self.assertIsNone(next(self.cache.search(
self.cache.CredentialType.ACCESS_TOKEN, query=mine), None))
self.assertIsNone(next(self.cache.search(
self.cache.CredentialType.REFRESH_TOKEN, query=mine), None))
self.assertIsNone(next(self.cache.search(
self.cache.CredentialType.ID_TOKEN, query=mine), None))
self.assertIsNone(next(self.cache.search(
self.cache.CredentialType.ACCOUNT, query=mine), None))


class TestClientApplicationForAuthorityMigration(unittest.TestCase):
Expand Down Expand Up @@ -711,25 +711,26 @@ def test_remove_tokens_for_client_should_remove_client_tokens_only(self):
cca = msal.ConfidentialClientApplication(
"client_id", client_credential="secret",
authority="https://login.microsoftonline.com/microsoft.onmicrosoft.com")
self.assertEqual(
0, len(cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN)))
self.assertIsNone(next(cca.token_cache.search(
msal.TokenCache.CredentialType.ACCESS_TOKEN), None))
cca.acquire_token_for_client(
["scope"],
post=lambda url, **kwargs: MinimalResponse(
status_code=200, text=json.dumps({"access_token": "AT for client"})))
self.assertEqual(
1, len(cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN)))
self.assertEqual(1, len(list(cca.token_cache.search(
msal.TokenCache.CredentialType.ACCESS_TOKEN))))
cca.acquire_token_by_username_password(
"johndoe", "password", ["scope"],
post=lambda url, **kwargs: MinimalResponse(
status_code=200, text=json.dumps(build_response(
access_token=at_for_user, expires_in=3600,
uid="uid", utid="utid", # This populates home_account_id
))))
self.assertEqual(
2, len(cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN)))
self.assertEqual(2, len(list(cca.token_cache.search(
msal.TokenCache.CredentialType.ACCESS_TOKEN))))
cca.remove_tokens_for_client()
remaining_tokens = cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN)
remaining_tokens = list(cca.token_cache.search(
msal.TokenCache.CredentialType.ACCESS_TOKEN))
self.assertEqual(1, len(remaining_tokens))
self.assertEqual(at_for_user, remaining_tokens[0].get("secret"))

Expand Down
10 changes: 6 additions & 4 deletions tests/test_token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import base64
import json
import time
import warnings

from msal.token_cache import TokenCache, SerializableTokenCache
from tests import unittest
Expand Down Expand Up @@ -83,10 +84,11 @@ def testAddByAad(self):
}
self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get(
self.at_key_maker(**access_token_entry)))
self.assertIn(
access_token_entry,
self.cache.find(self.cache.CredentialType.ACCESS_TOKEN, now=now),
"find(..., query=None) should not crash, even though MSAL does not use it")
with warnings.catch_warnings():
self.assertIn(
access_token_entry,
self.cache.find(self.cache.CredentialType.ACCESS_TOKEN, now=now),
"find(..., query=None) should not crash, even though MSAL does not use it")
self.assertEqual(
{
'client_id': 'my_client_id',
Expand Down