1515#include < chrono>
1616#include < cstring>
1717#include < mutex>
18+ #include < shared_mutex>
1819#include < algorithm>
1920#include < random>
2021#include < thread>
22+ #include < atomic>
2123#include < openssl/evp.h>
2224#include < openssl/rsa.h>
2325#include < openssl/pem.h>
@@ -991,10 +993,13 @@ struct mcp_auth_client {
991993 std::string last_error_context;
992994 mcp_auth_error_t last_error_code = MCP_AUTH_SUCCESS;
993995
994- // JWKS cache
996+ // JWKS cache with read-write lock for better concurrency
995997 std::vector<jwks_key> cached_keys;
996998 std::chrono::steady_clock::time_point cache_timestamp;
997- std::mutex cache_mutex;
999+ mutable std::shared_mutex cache_mutex; // mutable for const methods
1000+
1001+ // Auto-refresh state
1002+ std::atomic<bool > refresh_in_progress{false };
9981003
9991004 mcp_auth_client (const char * uri, const char * iss)
10001005 : jwks_uri(uri ? normalize_url(uri) : " " )
@@ -1159,8 +1164,9 @@ struct mcp_auth_metadata {
11591164// ========================================================================
11601165
11611166// Check if JWKS cache is still valid
1162- static bool is_cache_valid (mcp_auth_client_t client) {
1163- std::lock_guard<std::mutex> lock (client->cache_mutex );
1167+ static bool is_cache_valid (const mcp_auth_client_t client) {
1168+ // Use shared lock for read-only access
1169+ std::shared_lock<std::shared_mutex> lock (client->cache_mutex );
11641170
11651171 // Check if we have cached keys
11661172 if (client->cached_keys .empty ()) {
@@ -1178,23 +1184,42 @@ static bool is_cache_valid(mcp_auth_client_t client) {
11781184static bool get_jwks_keys (mcp_auth_client_t client, std::vector<jwks_key>& keys) {
11791185 // Check if cache is valid
11801186 if (is_cache_valid (client)) {
1181- std::lock_guard<std::mutex> lock (client->cache_mutex );
1187+ // Use shared lock for reading cached data
1188+ std::shared_lock<std::shared_mutex> lock (client->cache_mutex );
11821189 keys = client->cached_keys ;
11831190 return true ;
11841191 }
11851192
1186- // Cache is invalid or expired, fetch new keys
1187- return fetch_and_cache_jwks (client) && get_jwks_keys (client, keys);
1193+ // Prevent multiple simultaneous refreshes using atomic flag
1194+ bool expected = false ;
1195+ if (client->refresh_in_progress .compare_exchange_strong (expected, true )) {
1196+ // This thread won the race to refresh
1197+ bool success = fetch_and_cache_jwks (client);
1198+ client->refresh_in_progress = false ;
1199+
1200+ if (success) {
1201+ // Read the newly cached keys
1202+ std::shared_lock<std::shared_mutex> lock (client->cache_mutex );
1203+ keys = client->cached_keys ;
1204+ return true ;
1205+ }
1206+ return false ;
1207+ } else {
1208+ // Another thread is refreshing, wait a bit and retry
1209+ std::this_thread::sleep_for (std::chrono::milliseconds (100 ));
1210+ return get_jwks_keys (client, keys);
1211+ }
11881212}
11891213
11901214// Invalidate cache (for when validation fails with unknown kid)
11911215static void invalidate_cache (mcp_auth_client_t client) {
1192- std::lock_guard<std::mutex> lock (client->cache_mutex );
1216+ // Use unique lock for write access
1217+ std::unique_lock<std::shared_mutex> lock (client->cache_mutex );
11931218 client->cached_keys .clear ();
11941219 client->cache_timestamp = std::chrono::steady_clock::time_point ();
11951220}
11961221
1197- // Fetch and cache JWKS keys
1222+ // Fetch and cache JWKS keys (thread-safe)
11981223static bool fetch_and_cache_jwks (mcp_auth_client_t client) {
11991224 std::string jwks_json;
12001225 if (!fetch_jwks_json (client->jwks_uri , jwks_json, client->request_timeout )) {
@@ -1212,10 +1237,13 @@ static bool fetch_and_cache_jwks(mcp_auth_client_t client) {
12121237 return false ;
12131238 }
12141239
1215- // Update cache
1216- std::lock_guard<std::mutex> lock (client->cache_mutex );
1217- client->cached_keys = std::move (keys);
1218- client->cache_timestamp = std::chrono::steady_clock::now ();
1240+ // Update cache atomically with exclusive lock
1241+ {
1242+ std::unique_lock<std::shared_mutex> lock (client->cache_mutex );
1243+ // Use swap for atomic update
1244+ client->cached_keys .swap (keys);
1245+ client->cache_timestamp = std::chrono::steady_clock::now ();
1246+ }
12191247
12201248 fprintf (stderr, " JWKS cache updated with %zu keys\n " , client->cached_keys .size ());
12211249 for (const auto & key : client->cached_keys ) {
@@ -1435,7 +1463,7 @@ mcp_auth_error_t mcp_auth_client_destroy(mcp_auth_client_t client) {
14351463
14361464 // Clean up cached JWKS keys
14371465 {
1438- std::lock_guard <std::mutex > lock (client->cache_mutex );
1466+ std::unique_lock <std::shared_mutex > lock (client->cache_mutex );
14391467
14401468 // Clear PEM strings in cached keys
14411469 for (auto & key : client->cached_keys ) {
0 commit comments