Skip to content

Commit 78907c6

Browse files
RahulHereRahulHere
authored andcommitted
Implement Thread-Safe Cache (#130)
- Replaced std::mutex with std::shared_mutex for read-write locking - Added shared locks for read operations on cache - Added unique locks for write operations on cache - Implemented atomic refresh flag to prevent concurrent refreshes - Added race condition protection in get_jwks_keys - Used swap for atomic cache updates - Made cache mutex mutable for const methods - Added sleep and retry mechanism for concurrent refresh attempts - Protected auto-refresh from race conditions - Ensured RAII lock guards for mutex release
1 parent cab4c0b commit 78907c6

File tree

1 file changed

+42
-14
lines changed

1 file changed

+42
-14
lines changed

src/c_api/mcp_c_auth_api.cc

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
#include <chrono>
1616
#include <cstring>
1717
#include <mutex>
18+
#include <shared_mutex>
1819
#include <algorithm>
1920
#include <random>
2021
#include <thread>
22+
#include <atomic>
2123
#include <openssl/evp.h>
2224
#include <openssl/rsa.h>
2325
#include <openssl/pem.h>
@@ -991,10 +993,13 @@ struct mcp_auth_client {
991993
std::string last_error_context;
992994
mcp_auth_error_t last_error_code = MCP_AUTH_SUCCESS;
993995

994-
// JWKS cache
996+
// JWKS cache with read-write lock for better concurrency
995997
std::vector<jwks_key> cached_keys;
996998
std::chrono::steady_clock::time_point cache_timestamp;
997-
std::mutex cache_mutex;
999+
mutable std::shared_mutex cache_mutex; // mutable for const methods
1000+
1001+
// Auto-refresh state
1002+
std::atomic<bool> refresh_in_progress{false};
9981003

9991004
mcp_auth_client(const char* uri, const char* iss)
10001005
: jwks_uri(uri ? normalize_url(uri) : "")
@@ -1159,8 +1164,9 @@ struct mcp_auth_metadata {
11591164
// ========================================================================
11601165

11611166
// Check if JWKS cache is still valid
1162-
static bool is_cache_valid(mcp_auth_client_t client) {
1163-
std::lock_guard<std::mutex> lock(client->cache_mutex);
1167+
static bool is_cache_valid(const mcp_auth_client_t client) {
1168+
// Use shared lock for read-only access
1169+
std::shared_lock<std::shared_mutex> lock(client->cache_mutex);
11641170

11651171
// Check if we have cached keys
11661172
if (client->cached_keys.empty()) {
@@ -1178,23 +1184,42 @@ static bool is_cache_valid(mcp_auth_client_t client) {
11781184
static bool get_jwks_keys(mcp_auth_client_t client, std::vector<jwks_key>& keys) {
11791185
// Check if cache is valid
11801186
if (is_cache_valid(client)) {
1181-
std::lock_guard<std::mutex> lock(client->cache_mutex);
1187+
// Use shared lock for reading cached data
1188+
std::shared_lock<std::shared_mutex> lock(client->cache_mutex);
11821189
keys = client->cached_keys;
11831190
return true;
11841191
}
11851192

1186-
// Cache is invalid or expired, fetch new keys
1187-
return fetch_and_cache_jwks(client) && get_jwks_keys(client, keys);
1193+
// Prevent multiple simultaneous refreshes using atomic flag
1194+
bool expected = false;
1195+
if (client->refresh_in_progress.compare_exchange_strong(expected, true)) {
1196+
// This thread won the race to refresh
1197+
bool success = fetch_and_cache_jwks(client);
1198+
client->refresh_in_progress = false;
1199+
1200+
if (success) {
1201+
// Read the newly cached keys
1202+
std::shared_lock<std::shared_mutex> lock(client->cache_mutex);
1203+
keys = client->cached_keys;
1204+
return true;
1205+
}
1206+
return false;
1207+
} else {
1208+
// Another thread is refreshing, wait a bit and retry
1209+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
1210+
return get_jwks_keys(client, keys);
1211+
}
11881212
}
11891213

11901214
// Invalidate cache (for when validation fails with unknown kid)
11911215
static void invalidate_cache(mcp_auth_client_t client) {
1192-
std::lock_guard<std::mutex> lock(client->cache_mutex);
1216+
// Use unique lock for write access
1217+
std::unique_lock<std::shared_mutex> lock(client->cache_mutex);
11931218
client->cached_keys.clear();
11941219
client->cache_timestamp = std::chrono::steady_clock::time_point();
11951220
}
11961221

1197-
// Fetch and cache JWKS keys
1222+
// Fetch and cache JWKS keys (thread-safe)
11981223
static bool fetch_and_cache_jwks(mcp_auth_client_t client) {
11991224
std::string jwks_json;
12001225
if (!fetch_jwks_json(client->jwks_uri, jwks_json, client->request_timeout)) {
@@ -1212,10 +1237,13 @@ static bool fetch_and_cache_jwks(mcp_auth_client_t client) {
12121237
return false;
12131238
}
12141239

1215-
// Update cache
1216-
std::lock_guard<std::mutex> lock(client->cache_mutex);
1217-
client->cached_keys = std::move(keys);
1218-
client->cache_timestamp = std::chrono::steady_clock::now();
1240+
// Update cache atomically with exclusive lock
1241+
{
1242+
std::unique_lock<std::shared_mutex> lock(client->cache_mutex);
1243+
// Use swap for atomic update
1244+
client->cached_keys.swap(keys);
1245+
client->cache_timestamp = std::chrono::steady_clock::now();
1246+
}
12191247

12201248
fprintf(stderr, "JWKS cache updated with %zu keys\n", client->cached_keys.size());
12211249
for (const auto& key : client->cached_keys) {
@@ -1435,7 +1463,7 @@ mcp_auth_error_t mcp_auth_client_destroy(mcp_auth_client_t client) {
14351463

14361464
// Clean up cached JWKS keys
14371465
{
1438-
std::lock_guard<std::mutex> lock(client->cache_mutex);
1466+
std::unique_lock<std::shared_mutex> lock(client->cache_mutex);
14391467

14401468
// Clear PEM strings in cached keys
14411469
for (auto& key : client->cached_keys) {

0 commit comments

Comments
 (0)