Architecture¶
MLflow Secrets Auth implements a clean, extensible architecture that seamlessly integrates with MLflow's authentication system while maintaining security and performance best practices.
Overview¶
The plugin uses a factory pattern with provider-specific implementations to retrieve credentials from various secret management systems. The architecture emphasizes security, caching, and extensibility while maintaining compatibility with MLflow's existing authentication framework.
High-Level Architecture¶
graph TB
A[MLflow Client] -->|HTTP Request| B[MLflow Request Handler]
B -->|Auth Provider| C[SecretsAuthProviderFactory]
C -->|Provider Selection| D{First Available Provider}
D -->|Vault Enabled| E[VaultAuthProvider]
D -->|AWS Enabled| F[AWSSecretsManagerAuthProvider]
D -->|Azure Enabled| G[AzureKeyVaultAuthProvider]
E -->|Fetch Secret| H[Cache Layer]
F -->|Fetch Secret| H
G -->|Fetch Secret| H
H -->|Cache Hit| I[Return Cached Auth]
H -->|Cache Miss| J[Provider Secret Fetch]
J -->|Success| K[Parse & Validate Secret]
K -->|Valid| L[Create Auth Object]
L -->|Cache| H
L -->|Return| M[Inject Auth Header]
M -->|Authenticated Request| N[MLflow Server]
Core Components¶
SecretsAuthProviderFactory¶
The main factory class that implements MLflow's RequestAuthProvider
interface and manages provider selection.
Responsibilities: - Provider discovery and instantiation - Request routing to active provider - Fallback behavior when no provider is available - Integration with MLflow's authentication lifecycle
Key Methods:
- get_request_auth(url)
- Main entry point for authentication
- _get_actual_provider()
- Lazy provider instantiation
- _is_enabled()
- Provider availability check
SecretsBackedAuthProvider (Base Class)¶
Abstract base class that defines the common interface and shared functionality for all provider implementations.
Core Features: - Caching with configurable TTL - Retry logic with exponential backoff and jitter - Host allowlisting validation - Credential parsing and validation - Error handling and logging
Abstract Methods:
def _fetch_secret(self) -> str | None
def _get_cache_key(self) -> str
def _get_auth_mode(self) -> str
def _get_ttl(self) -> int
Provider Implementations¶
VaultAuthProvider¶
- Authentication: Token or AppRole
- Secret Formats: KV v1 and KV v2 support
- Features: Auto-detection of KV version, graceful fallback
AWSSecretsManagerAuthProvider¶
- Authentication: IAM credentials, roles, profiles
- Secret Formats: SecretString and SecretBinary
- Features: Multi-region support, version handling
AzureKeyVaultAuthProvider¶
- Authentication: DefaultAzureCredential chain
- Secret Formats: Key Vault secrets
- Features: Managed identity support, certificate auth
Secret Resolution Flow¶
1. Request Interception¶
# MLflow makes HTTP request
response = requests.get(
"https://mlflow.company.com/api/2.0/mlflow/experiments/list",
auth=auth_provider.get_request_auth(url)
)
2. Provider Selection¶
def _get_actual_provider(self) -> SecretsBackedAuthProvider | None:
"""Select first available provider from priority list."""
for name, provider_cls in self._PROVIDERS.items():
if is_provider_enabled(name):
try:
return provider_cls()
except Exception:
continue # Try next provider
return None
3. Host Validation¶
def _is_host_allowed(self, hostname: str) -> bool:
"""Validate hostname against allowlist patterns."""
allowed_hosts = get_allowed_hosts()
if not allowed_hosts:
return True
return any(
fnmatch.fnmatch(hostname, pattern)
for pattern in allowed_hosts
)
4. Cache Check¶
def _fetch_secret_cached(self) -> dict[str, str] | None:
"""Check cache first, fetch if needed."""
cache_key = self._get_cache_key()
cached = self._cache.get(cache_key)
if cached and not cached.is_expired():
return cached.data
# Cache miss - fetch from provider
secret = self._fetch_secret()
if secret:
self._cache.set(cache_key, secret, self._get_ttl())
return secret
5. Secret Fetching¶
Provider-specific implementation retrieves secret from the configured source:
# Vault example
def _fetch_secret(self) -> str | None:
client = self._get_vault_client()
response = client.secrets.kv.v2.read_secret_version(path=secret_path)
return json.dumps(response["data"]["data"])
6. Authentication Object Creation¶
def _create_auth(self, secret_data: dict[str, str]) -> requests.auth.AuthBase:
"""Create appropriate auth object based on mode."""
auth_mode = self._get_auth_mode()
if auth_mode == "bearer":
token = secret_data.get("token")
return BearerAuth(token)
elif auth_mode == "basic":
username = secret_data.get("username")
password = secret_data.get("password")
return HTTPBasicAuth(username, password)
Caching Architecture¶
Cache Implementation¶
The caching layer provides in-memory storage with TTL expiration and automatic cache busting.
Features: - Per-configuration caching (different configs = different cache entries) - Configurable TTL per provider - Automatic expiration based on timestamps - Cache busting on authentication failures (401/403) - Thread-safe operations
Cache Key Generation¶
Cache keys are generated based on provider-specific configuration to ensure proper isolation:
# Vault example
def _get_cache_key(self) -> str:
vault_addr = get_env_var("VAULT_ADDR", "")
secret_path = get_env_var("MLFLOW_VAULT_SECRET_PATH", "")
return f"{vault_addr}:{secret_path}"
Cache Lifecycle¶
graph LR
A[Request] --> B{Cache Hit?}
B -->|Yes| C{Expired?}
B -->|No| D[Fetch Secret]
C -->|No| E[Return Cached]
C -->|Yes| D
D --> F{Success?}
F -->|Yes| G[Update Cache]
F -->|No| H[Return Error]
G --> I[Return Secret]
Error Handling and Resilience¶
Retry Logic¶
The plugin implements sophisticated retry logic with exponential backoff and jitter:
def retry_with_jitter(func, max_retries=3, base_delay=1.0, max_delay=30.0):
"""Execute function with exponential backoff and jitter."""
for attempt in range(max_retries + 1):
try:
return func()
except Exception as e:
if attempt == max_retries:
raise e
# Exponential backoff with jitter
delay = min(base_delay * (2 ** attempt), max_delay)
jitter = random.uniform(0, delay * 0.1)
time.sleep(delay + jitter)
Failure Modes¶
The architecture handles various failure scenarios gracefully:
- Provider Unavailable: Factory selects next available provider
- Network Failures: Retry with exponential backoff
- Authentication Failures: Cache busting and re-authentication
- Configuration Errors: Detailed error messages and fallback behavior
- Secret Not Found: Graceful degradation with logging
Circuit Breaker Pattern¶
For production resilience, the plugin implements circuit breaker behavior:
class ProviderCircuitBreaker:
def __init__(self, failure_threshold=5, timeout=60):
self.failure_count = 0
self.failure_threshold = failure_threshold
self.timeout = timeout
self.last_failure_time = None
self.state = "CLOSED" # CLOSED, OPEN, HALF_OPEN
Security Architecture¶
Host Allowlisting¶
Host validation prevents credential leakage to unauthorized servers:
def validate_host(self, url: str) -> bool:
"""Validate URL host against allowlist."""
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
return False
allowed_hosts = get_allowed_hosts()
if not allowed_hosts:
return True # No restrictions if not configured
return any(
fnmatch.fnmatch(hostname, pattern)
for pattern in allowed_hosts
)
Credential Redaction¶
All logging output automatically redacts sensitive information:
def redact_sensitive_data(text: str) -> str:
"""Redact common credential patterns from text."""
patterns = [
r"(Bearer\s+)([A-Za-z0-9._\-]+)",
r"(Basic\s+)([A-Za-z0-9+/=]+)",
r'("(?:token|password|secret|key)"\s*:\s*")([^"]+)(")',
]
for pattern in patterns:
text = re.sub(pattern, lambda m: f"{m.group(1)}{mask_secret(m.group(2))}{m.group(3) if len(m.groups()) > 2 else ''}", text)
return text
Memory Security¶
- No persistence of credentials to disk
- Secure memory handling for sensitive data
- Automatic cleanup of expired cache entries
- Zero-copy operations where possible
Extensibility¶
Adding New Providers¶
The architecture supports easy addition of new secret management providers:
class CustomProvider(SecretsBackedAuthProvider):
def __init__(self):
super().__init__("custom", default_ttl=300)
def _fetch_secret(self) -> str | None:
# Implement custom secret fetching logic
pass
def _get_cache_key(self) -> str:
# Generate provider-specific cache key
pass
def _get_auth_mode(self) -> str:
# Return "bearer" or "basic"
pass
def _get_ttl(self) -> int:
# Return cache TTL in seconds
pass
Configuration System¶
The configuration system uses environment variables for consistency and container-friendliness:
# Provider enablement
def is_provider_enabled(provider_name: str) -> bool:
# Check global list
global_enable = get_env_var("MLFLOW_SECRETS_AUTH_ENABLE", "")
if provider_name.lower() in global_enable.lower().split(","):
return True
# Check provider-specific flag
env_key = f"MLFLOW_SECRETS_AUTH_ENABLE_{provider_name.upper().replace('-', '_')}"
return get_env_bool(env_key, False)
Performance Considerations¶
Lazy Initialization¶
Providers are instantiated only when needed:
def _get_actual_provider(self) -> SecretsBackedAuthProvider | None:
if self._actual_provider is not None:
return self._actual_provider
# Lazy instantiation
for name, provider_cls in self._PROVIDERS.items():
if is_provider_enabled(name):
try:
self._actual_provider = provider_cls()
return self._actual_provider
except Exception:
continue
return None
Connection Pooling¶
Provider implementations reuse connections when possible:
class VaultAuthProvider:
def __init__(self):
self._vault_client = None # Cached client instance
def _get_vault_client(self):
if self._vault_client is not None:
return self._vault_client
# Create and cache client
self._vault_client = hvac.Client(url=vault_addr)
return self._vault_client
Async Considerations¶
While the current implementation is synchronous, the architecture supports future async operations:
# Future async support
async def _fetch_secret_async(self) -> str | None:
async with self._get_async_client() as client:
response = await client.get_secret(self.secret_path)
return response.value
Integration Points¶
MLflow Integration¶
The plugin integrates seamlessly with MLflow's authentication system:
# Entry point registration in pyproject.toml
[tool.poetry.plugins."mlflow.request_auth_provider"]
mlflow_secrets_auth = "mlflow_secrets_auth:SecretsAuthProviderFactory"
Monitoring Integration¶
The architecture supports monitoring and observability:
# Metrics collection
class MetricsCollector:
def record_cache_hit(self, provider: str):
self.cache_hits[provider] += 1
def record_secret_fetch_duration(self, provider: str, duration: float):
self.fetch_durations[provider].append(duration)
Next Steps¶
- Security Concepts - Detailed security model and threat analysis
- Caching and Retries - Deep dive into caching and retry mechanisms
- Provider Documentation - Provider-specific implementation details
- Configuration Reference - Complete configuration options