Caching and Retries¶
MLflow Secrets Auth implements sophisticated caching and retry mechanisms to optimize performance, ensure reliability, and handle transient failures gracefully.
Caching Architecture¶
Overview¶
The caching system provides in-memory storage of retrieved credentials with configurable TTL (Time To Live) expiration. Caching reduces latency, minimizes load on secret management systems, and improves user experience.
Cache Implementation¶
Cache Structure¶
class CacheEntry:
"""Represents a cached secret with expiration."""
def __init__(self, data: dict[str, str], ttl: int):
self.data = data
self.created_at = time.time()
self.expiry = self.created_at + ttl
def is_expired(self) -> bool:
"""Check if cache entry has expired."""
return time.time() > self.expiry
def remaining_ttl(self) -> int:
"""Calculate remaining TTL in seconds."""
return max(0, int(self.expiry - time.time()))
class SecretCache:
"""Thread-safe in-memory cache for secrets."""
def __init__(self):
self._cache: dict[str, CacheEntry] = {}
self._lock = threading.RLock()
def get(self, key: str) -> dict[str, str] | None:
"""Retrieve cached secret if not expired."""
with self._lock:
entry = self._cache.get(key)
if entry and not entry.is_expired():
return entry.data
elif entry:
# Clean up expired entry
del self._cache[key]
return None
def set(self, key: str, data: dict[str, str], ttl: int) -> None:
"""Store secret in cache with TTL."""
with self._lock:
self._cache[key] = CacheEntry(data, ttl)
def delete(self, key: str) -> None:
"""Remove specific cache entry."""
with self._lock:
self._cache.pop(key, None)
def clear(self) -> None:
"""Clear all cache entries."""
with self._lock:
self._cache.clear()
Cache Key Generation¶
Cache keys are generated based on provider-specific configuration to ensure proper isolation:
# Vault provider cache key
def _get_cache_key(self) -> str:
vault_addr = get_env_var("VAULT_ADDR", "")
secret_path = get_env_var("MLFLOW_VAULT_SECRET_PATH", "")
return f"vault:{vault_addr}:{secret_path}"
# AWS provider cache key
def _get_cache_key(self) -> str:
region = get_env_var("AWS_REGION", "")
secret_name = get_env_var("MLFLOW_AWS_SECRET_NAME", "")
return f"aws:{region}:{secret_name}"
# Azure provider cache key
def _get_cache_key(self) -> str:
vault_url = get_env_var("MLFLOW_AZURE_KEY_VAULT_URL", "")
secret_name = get_env_var("MLFLOW_AZURE_SECRET_NAME", "")
return f"azure:{vault_url}:{secret_name}"
Cache Lifecycle¶
Cache Hit Flow¶
graph TB
A[Request Auth] --> B[Generate Cache Key]
B --> C{Cache Hit?}
C -->|Yes| D{Expired?}
C -->|No| E[Cache Miss]
D -->|No| F[Return Cached Auth]
D -->|Yes| G[Remove Expired Entry]
G --> E
E --> H[Fetch from Provider]
H --> I{Success?}
I -->|Yes| J[Parse Secret]
J --> K[Create Auth Object]
K --> L[Store in Cache]
L --> M[Return Auth]
I -->|No| N[Return Error]
Cache Operations¶
def _fetch_secret_cached(self) -> dict[str, str] | None:
"""Fetch secret with caching support."""
cache_key = self._get_cache_key()
# Check cache first
cached_secret = self._cache.get(cache_key)
if cached_secret is not None:
self.logger.debug("Cache hit for key: %s", cache_key)
return cached_secret
# Cache miss - fetch from provider
self.logger.debug("Cache miss for key: %s", cache_key)
secret_json = self._fetch_secret()
if secret_json:
try:
secret_data = json.loads(secret_json)
# Store in cache with configured TTL
self._cache.set(cache_key, secret_data, self._get_ttl())
return secret_data
except json.JSONDecodeError as e:
self.logger.error("Failed to parse secret JSON: %s", e)
return None
TTL Configuration¶
Per-Provider TTL Settings¶
Each provider supports independent TTL configuration:
# Vault TTL (default: 300 seconds)
MLFLOW_VAULT_TTL_SEC=300
# AWS TTL (default: 300 seconds)
MLFLOW_AWS_TTL_SEC=600
# Azure TTL (default: 300 seconds)
MLFLOW_AZURE_TTL_SEC=900
TTL Validation¶
TTL values are validated and clamped to reasonable ranges:
def validate_ttl(ttl: int, min_ttl: int = 60, max_ttl: int = 3600) -> int:
"""Validate and clamp TTL to reasonable bounds."""
if ttl < min_ttl:
logger.warning(f"TTL {ttl} too low, using minimum {min_ttl}")
return min_ttl
if ttl > max_ttl:
logger.warning(f"TTL {ttl} too high, using maximum {max_ttl}")
return max_ttl
return ttl
Environment-Specific TTL Recommendations¶
# Development: Short TTL for rapid iteration
MLFLOW_VAULT_TTL_SEC=60 # 1 minute
# Staging: Medium TTL for testing
MLFLOW_VAULT_TTL_SEC=300 # 5 minutes
# Production: Longer TTL for performance
MLFLOW_VAULT_TTL_SEC=900 # 15 minutes
# High-security: Very short TTL
MLFLOW_VAULT_TTL_SEC=30 # 30 seconds
Cache Busting¶
Automatic Cache Busting¶
The cache automatically invalidates entries on authentication failures:
def _handle_auth_failure(self, response: requests.Response, url: str) -> None:
"""Handle authentication failure by clearing cache."""
if response.status_code in (401, 403):
cache_key = self._get_cache_key()
self._cache.delete(cache_key)
self.logger.info(
"Cleared cache due to auth failure",
extra={
"status_code": response.status_code,
"url": redact_sensitive_data(url),
"cache_key": cache_key
}
)
Manual Cache Operations¶
Future CLI support for manual cache management:
# Clear all cached secrets
mlflow-secrets-auth cache clear
# Clear specific provider cache
mlflow-secrets-auth cache clear --provider vault
# Show cache status
mlflow-secrets-auth cache status
Cache Security¶
Memory-Only Storage¶
- No Disk Persistence: Cache exists only in memory
- Process Isolation: Cache cleared on process termination
- Secure Cleanup: Automatic cleanup of expired entries
Cache Isolation¶
# Different configurations use different cache keys
# Production Vault
cache_key_prod = "vault:https://vault.prod.com:secret/mlflow/prod"
# Staging Vault
cache_key_staging = "vault:https://vault.staging.com:secret/mlflow/staging"
# These are isolated in separate cache entries
Thread Safety¶
class ThreadSafeCache:
"""Thread-safe cache implementation."""
def __init__(self):
self._cache = {}
self._lock = threading.RLock() # Reentrant lock
def get(self, key: str):
with self._lock:
return self._cache.get(key)
def set(self, key: str, value, ttl: int):
with self._lock:
self._cache[key] = CacheEntry(value, ttl)
Retry Mechanisms¶
Retry Strategy¶
The plugin implements exponential backoff with jitter to handle transient failures gracefully.
Retry Implementation¶
def retry_with_jitter(
func: Callable[[], T],
max_retries: int = 3,
base_delay: float = 1.0,
max_delay: float = 30.0,
backoff_factor: float = 2.0,
jitter_factor: float = 0.1
) -> T:
"""Execute function with exponential backoff and jitter."""
for attempt in range(max_retries + 1):
try:
return func()
except Exception as e:
if attempt == max_retries:
# Final attempt failed
raise e
# Calculate delay with exponential backoff
delay = min(base_delay * (backoff_factor ** attempt), max_delay)
# Add jitter to prevent thundering herd
jitter = random.uniform(0, delay * jitter_factor)
total_delay = delay + jitter
logger.debug(
f"Attempt {attempt + 1} failed, retrying in {total_delay:.2f}s",
extra={"error": str(e), "attempt": attempt + 1}
)
time.sleep(total_delay)
raise RuntimeError("Retry logic error") # Should never reach here
Retry Configuration¶
# Default retry settings
DEFAULT_RETRY_CONFIG = {
"max_retries": 3,
"base_delay": 1.0, # 1 second
"max_delay": 30.0, # 30 seconds maximum
"backoff_factor": 2.0, # Double delay each attempt
"jitter_factor": 0.1 # 10% jitter
}
# Environment variable overrides
def get_retry_config() -> dict:
return {
"max_retries": get_env_int("MLFLOW_SECRETS_MAX_RETRIES", 3),
"base_delay": get_env_float("MLFLOW_SECRETS_BASE_DELAY", 1.0),
"max_delay": get_env_float("MLFLOW_SECRETS_MAX_DELAY", 30.0),
}
Retry Examples¶
Timing Examples:
Attempt 1: Immediate
Attempt 2: 1.0s + jitter (0.0-0.1s) = ~1.05s
Attempt 3: 2.0s + jitter (0.0-0.2s) = ~2.15s
Attempt 4: 4.0s + jitter (0.0-0.4s) = ~4.25s
Retryable Conditions¶
Network Errors¶
RETRYABLE_EXCEPTIONS = (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.HTTPError, # 5xx only
socket.timeout,
socket.gaierror,
)
def is_retryable_error(exception: Exception) -> bool:
"""Determine if error is worth retrying."""
if isinstance(exception, requests.exceptions.HTTPError):
# Only retry server errors (5xx), not client errors (4xx)
return 500 <= exception.response.status_code < 600
return isinstance(exception, RETRYABLE_EXCEPTIONS)
Provider-Specific Retry Logic¶
Vault Retries:
def _fetch_secret(self) -> str | None:
"""Fetch secret with Vault-specific retry logic."""
def _vault_fetch():
client = self._get_vault_client()
if not client or not client.is_authenticated():
raise VaultAuthenticationError("Vault not authenticated")
# Try KV v2 first, then v1
try:
response = client.secrets.kv.v2.read_secret_version(path=path)
return json.dumps(response["data"]["data"])
except hvac.exceptions.InvalidPath:
# Fallback to KV v1
response = client.secrets.kv.v1.read_secret(path=path)
return json.dumps(response["data"])
return retry_with_jitter(_vault_fetch)
AWS Retries:
def _fetch_secret(self) -> str | None:
"""Fetch secret with AWS-specific retry logic."""
def _aws_fetch():
client = self._get_secrets_client()
response = client.get_secret_value(SecretId=self.secret_name)
# Handle both SecretString and SecretBinary
if 'SecretString' in response:
return response['SecretString']
else:
# Convert binary to base64 string
import base64
binary_data = response['SecretBinary']
return base64.b64encode(binary_data).decode('utf-8')
return retry_with_jitter(_aws_fetch)
Circuit Breaker Pattern¶
For production resilience, implement circuit breaker pattern:
class CircuitBreaker:
"""Circuit breaker for provider calls."""
def __init__(self, failure_threshold: int = 5, timeout: int = 60):
self.failure_threshold = failure_threshold
self.timeout = timeout
self.failure_count = 0
self.last_failure_time = None
self.state = "CLOSED" # CLOSED, OPEN, HALF_OPEN
def call(self, func: Callable) -> Any:
"""Execute function with circuit breaker protection."""
if self.state == "OPEN":
if time.time() - self.last_failure_time < self.timeout:
raise CircuitBreakerOpenError("Circuit breaker is open")
else:
self.state = "HALF_OPEN"
try:
result = func()
# Success resets the circuit breaker
if self.state == "HALF_OPEN":
self.state = "CLOSED"
self.failure_count = 0
return result
except Exception as e:
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= self.failure_threshold:
self.state = "OPEN"
raise e
Performance Optimization¶
Cache Hit Optimization¶
Metrics Collection¶
class CacheMetrics:
"""Collect cache performance metrics."""
def __init__(self):
self.hits = 0
self.misses = 0
self.total_requests = 0
def record_hit(self):
self.hits += 1
self.total_requests += 1
def record_miss(self):
self.misses += 1
self.total_requests += 1
def hit_rate(self) -> float:
if self.total_requests == 0:
return 0.0
return self.hits / self.total_requests
def stats(self) -> dict:
return {
"hits": self.hits,
"misses": self.misses,
"total_requests": self.total_requests,
"hit_rate": self.hit_rate()
}
Cache Warming¶
For high-performance scenarios, implement cache warming:
def warm_cache(self) -> None:
"""Pre-populate cache with secrets."""
try:
# Fetch secret to populate cache
self._fetch_secret_cached()
logger.info("Cache warmed successfully")
except Exception as e:
logger.warning(f"Cache warming failed: {e}")
Connection Pooling¶
Provider Connection Reuse¶
class VaultAuthProvider:
"""Vault provider with connection pooling."""
def __init__(self):
super().__init__("vault", default_ttl=300)
self._vault_client = None
self._client_lock = threading.Lock()
def _get_vault_client(self):
"""Get or create Vault client with connection reuse."""
if self._vault_client is not None:
# Verify client is still authenticated
try:
if self._vault_client.is_authenticated():
return self._vault_client
except Exception:
pass # Client no longer valid
with self._client_lock:
# Double-check pattern
if self._vault_client is None or not self._vault_client.is_authenticated():
self._vault_client = self._create_vault_client()
return self._vault_client
Async Considerations¶
While the current implementation is synchronous, the architecture supports future async operations:
# Future async implementation
class AsyncSecretsBackedAuthProvider:
"""Async version of secrets auth provider."""
async def get_request_auth_async(self, url: str) -> requests.auth.AuthBase | None:
"""Async version of get_request_auth."""
if not self._is_enabled():
return None
if not self._is_host_allowed(url):
return None
secret_data = await self._fetch_secret_cached_async()
if not secret_data:
return None
return self._create_auth(secret_data)
async def _fetch_secret_cached_async(self) -> dict[str, str] | None:
"""Async version with cache support."""
# Implementation would use async cache and HTTP clients
pass
Configuration Best Practices¶
Cache Configuration¶
# High-performance environment
MLFLOW_VAULT_TTL_SEC=1800 # 30 minutes
MLFLOW_SECRETS_MAX_RETRIES=5 # More retries
# High-security environment
MLFLOW_VAULT_TTL_SEC=60 # 1 minute
MLFLOW_SECRETS_MAX_RETRIES=2 # Fewer retries
# Development environment
MLFLOW_VAULT_TTL_SEC=300 # 5 minutes
MLFLOW_SECRETS_MAX_RETRIES=3 # Standard retries
Monitoring Configuration¶
# Enable detailed logging for performance monitoring
MLFLOW_SECRETS_LOG_LEVEL=DEBUG
# Monitor cache performance
# (Future feature: metrics endpoint)
Troubleshooting Performance¶
Cache Issues¶
# Check cache status
mlflow-secrets-auth info # Shows cache hit rate
# Clear cache if stale
mlflow-secrets-auth cache clear
# Test with cache disabled (future feature)
MLFLOW_SECRETS_CACHE_DISABLED=true mlflow-secrets-auth doctor
Retry Issues¶
# Reduce retries for faster failures
MLFLOW_SECRETS_MAX_RETRIES=1
# Increase retries for resilience
MLFLOW_SECRETS_MAX_RETRIES=5
# Debug retry behavior
MLFLOW_SECRETS_LOG_LEVEL=DEBUG mlflow-secrets-auth doctor
Network Performance¶
# Test network connectivity
mlflow-secrets-auth doctor --dry-run https://your-server.com
# Monitor secret fetch latency
# (Logged at DEBUG level)
Next Steps¶
- Architecture - Overall system design and flow
- Security - Security model and threat mitigation
- Provider Documentation - Provider-specific performance considerations
- Troubleshooting - Performance troubleshooting guide