Skip to content

connection

mlflow_assistant.core.connection

MLflow connection module for handling connections to MLflow Tracking Server.

This module provides functionality to connect to both local and remote MLflow Tracking Servers using environment variables or direct configuration.

MLflowConnection(tracking_uri=None, client_factory=None)

MLflow connection class to handle connections to MLflow Tracking Server.

This class provides functionality to connect to both local and remote MLflow Tracking Servers.

Initialize MLflow connection.

Parameters:

Name Type Description Default
tracking_uri str | None

URI of the MLflow Tracking Server. If None, will try to get from environment.

None
client_factory Any

A callable to create the MlflowClient instance. Defaults to MlflowClient.

None
Source code in src/mlflow_assistant/core/connection.py
def __init__(self, tracking_uri: str | None = None, client_factory: Any = None):
    """Initialize MLflow connection.

    Args:
        tracking_uri: URI of the MLflow Tracking Server. If None, will try to get from environment.
        client_factory: A callable to create the MlflowClient instance. Defaults to MlflowClient.

    """
    self.config = self._load_config(tracking_uri=tracking_uri)
    self.client = None
    self.is_connected_flag = False
    self.client_factory = client_factory or MlflowClient

connect()

Connect to MLflow Tracking Server.

Returns
Tuple[bool, str]: (success, message)
Source code in src/mlflow_assistant/core/connection.py
def connect(self) -> tuple[bool, str]:
    """Connect to MLflow Tracking Server.

    Returns
    -------
        Tuple[bool, str]: (success, message)

    """
    try:
        logger.debug(f"Connecting to MLflow Tracking Server at {self.config.tracking_uri}")
        mlflow.set_tracking_uri(self.config.tracking_uri)
        self.client = self.client_factory(tracking_uri=self.config.tracking_uri)
        self.client.search_experiments()  # Trigger connection attempt
        self.is_connected_flag = True
        logger.debug(f"Successfully connected to MLflow Tracking Server at {self.config.tracking_uri}")
        return True, f"Successfully connected to MLflow Tracking Server at {self.config.tracking_uri}"
    except Exception as e:
        self.is_connected_flag = False
        logger.exception(f"Failed to connect to MLflow Tracking Server: {e}")
        return False, f"Failed to connect to MLflow Tracking Server: {e!s}"

get_client()

Get MLflow client instance.

Returns
MlflowClient: MLflow client instance.
Raises
MLflowConnectionError: If not connected to MLflow Tracking Server.
Source code in src/mlflow_assistant/core/connection.py
def get_client(self) -> MlflowClient:
    """Get MLflow client instance.

    Returns
    -------
        MlflowClient: MLflow client instance.

    Raises
    ------
        MLflowConnectionError: If not connected to MLflow Tracking Server.

    """
    if self.client is None:
        msg = "Not connected to MLflow Tracking Server. Call connect() first."
        raise MLflowConnectionError(msg)
    return self.client

get_connection_info()

Get connection information.

Returns
Dict[str, Any]: Connection information.
Source code in src/mlflow_assistant/core/connection.py
def get_connection_info(self) -> dict[str, Any]:
    """Get connection information.

    Returns
    -------
        Dict[str, Any]: Connection information.

    """
    return {
        "tracking_uri": self.config.tracking_uri,
        "connection_type": self.config.connection_type,
        "is_connected": self.is_connected_flag,
    }

is_connected()

Check if connected to MLflow Tracking Server.

Returns
bool: True if connected, False otherwise.
Source code in src/mlflow_assistant/core/connection.py
def is_connected(self) -> bool:
    """Check if connected to MLflow Tracking Server.

    Returns
    -------
        bool: True if connected, False otherwise.

    """
    return self.is_connected_flag