Skip to content

utils

mlflow_assistant.utils

Utility modules for MLflow Assistant.

Command

Bases: Enum

Special commands for interactive chat sessions.

description property

Get the description for a command.

OllamaModel

Bases: Enum

Default Ollama models supported by MLflow Assistant.

choices() classmethod

Get all available Ollama model choices.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def choices(cls):
    """Get all available Ollama model choices."""
    return [model.value for model in cls]

OpenAIModel

Bases: Enum

OpenAI models supported by MLflow Assistant.

choices() classmethod

Get all available OpenAI model choices.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def choices(cls):
    """Get all available OpenAI model choices."""
    return [model.value for model in cls]

Provider

Bases: Enum

AI provider types supported by MLflow Assistant.

get_default_model(provider) classmethod

Get the default model for a provider.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def get_default_model(cls, provider):
    """Get the default model for a provider."""
    defaults = {
        cls.OPENAI: OpenAIModel.GPT35.value,
        cls.OLLAMA: OllamaModel.LLAMA32.value,
        cls.DATABRICKS: DatabricksModel.DATABRICKS_META_LLAMA3.value,
    }
    return defaults.get(provider)

get_default_temperature(provider) classmethod

Get the default temperature for a provider.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def get_default_temperature(cls, provider):
    """Get the default temperature for a provider."""
    defaults = {
        cls.OPENAI: 0.7,
        cls.DATABRICKS: 0.7,
        cls.OLLAMA: 0.7,
    }
    return defaults.get(provider)

get_mlflow_uri()

Get the MLflow URI from config or environment.

Returns:

Type Description
str | None

Optional[str]: The MLflow URI or None if not configured

Source code in src/mlflow_assistant/utils/config.py
def get_mlflow_uri() -> str | None:
    """Get the MLflow URI from config or environment.

    Returns:
        Optional[str]: The MLflow URI or None if not configured

    """
    # Environment variable should take precedence
    if mlflow_uri_env := os.environ.get(MLFLOW_URI_ENV):
        return mlflow_uri_env

    # Fall back to config
    config = load_config()
    return config.get(CONFIG_KEY_MLFLOW_URI)

get_provider_config()

Get the AI provider configuration.

Returns:

Type Description
dict[str, Any]

Dict[str, Any]: The provider configuration

Source code in src/mlflow_assistant/utils/config.py
def get_provider_config() -> dict[str, Any]:
    """Get the AI provider configuration.

    Returns:
        Dict[str, Any]: The provider configuration

    """
    config = load_config()
    provider = config.get(CONFIG_KEY_PROVIDER, {})

    provider_type = provider.get(CONFIG_KEY_TYPE)

    if provider_type == Provider.OPENAI.value:
        # Environment variable should take precedence
        api_key = (os.environ.get(OPENAI_API_KEY_ENV) or
                   provider.get(CONFIG_KEY_API_KEY))
        return {
            CONFIG_KEY_TYPE: Provider.OPENAI.value,
            CONFIG_KEY_API_KEY: api_key,
            CONFIG_KEY_MODEL: provider.get(
                CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OPENAI),
            ),
        }

    if provider_type == Provider.OLLAMA.value:
        return {
            CONFIG_KEY_TYPE: Provider.OLLAMA.value,
            CONFIG_KEY_URI: provider.get(CONFIG_KEY_URI, DEFAULT_OLLAMA_URI),
            CONFIG_KEY_MODEL: provider.get(
                CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OLLAMA),
            ),
        }

    if provider_type == Provider.DATABRICKS.value:
        # Set environment variables for Databricks profile
        _set_environment_variables(provider.get(CONFIG_KEY_PROFILE))

        return {
            CONFIG_KEY_TYPE: Provider.DATABRICKS.value,
            CONFIG_KEY_PROFILE: provider.get(CONFIG_KEY_PROFILE),
            CONFIG_KEY_MODEL: provider.get(CONFIG_KEY_MODEL),
        }

    return {CONFIG_KEY_TYPE: None}

load_config()

Load configuration from file.

Source code in src/mlflow_assistant/utils/config.py
def load_config() -> dict[str, Any]:
    """Load configuration from file."""
    if not CONFIG_FILE.exists():
        logger.info(f"No configuration file found at {CONFIG_FILE}")
        return {}

    try:
        with open(CONFIG_FILE) as f:
            config = yaml.safe_load(f) or {}
            logger.debug(f"Loaded configuration: {config}")
            return config
    except Exception as e:
        logger.error(f"Error loading configuration: {e}")
        return {}

save_config(config)

Save configuration to file.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary to save

required

Returns:

Name Type Description
bool bool

True if successful, False otherwise

Source code in src/mlflow_assistant/utils/config.py
def save_config(config: dict[str, Any]) -> bool:
    """Save configuration to file.

    Args:
        config: Configuration dictionary to save

    Returns:
        bool: True if successful, False otherwise

    """
    ensure_config_dir()

    try:
        with open(CONFIG_FILE, "w") as f:
            yaml.dump(config, f)
        logger.info(f"Configuration saved to {CONFIG_FILE}")
        return True
    except Exception as e:
        logger.error(f"Error saving configuration: {e}")
        return False

config

Configuration management utilities for MLflow Assistant.

This module provides functions for loading, saving, and accessing configuration settings for MLflow Assistant, including MLflow URI and AI provider settings. Configuration is stored in YAML format in the user's home directory.

ensure_config_dir()

Ensure the configuration directory exists.

Source code in src/mlflow_assistant/utils/config.py
def ensure_config_dir():
    """Ensure the configuration directory exists."""
    if not CONFIG_DIR.exists():
        CONFIG_DIR.mkdir(parents=True)
        logger.info(f"Created configuration directory at {CONFIG_DIR}")

get_mlflow_uri()

Get the MLflow URI from config or environment.

Returns:

Type Description
str | None

Optional[str]: The MLflow URI or None if not configured

Source code in src/mlflow_assistant/utils/config.py
def get_mlflow_uri() -> str | None:
    """Get the MLflow URI from config or environment.

    Returns:
        Optional[str]: The MLflow URI or None if not configured

    """
    # Environment variable should take precedence
    if mlflow_uri_env := os.environ.get(MLFLOW_URI_ENV):
        return mlflow_uri_env

    # Fall back to config
    config = load_config()
    return config.get(CONFIG_KEY_MLFLOW_URI)

get_provider_config()

Get the AI provider configuration.

Returns:

Type Description
dict[str, Any]

Dict[str, Any]: The provider configuration

Source code in src/mlflow_assistant/utils/config.py
def get_provider_config() -> dict[str, Any]:
    """Get the AI provider configuration.

    Returns:
        Dict[str, Any]: The provider configuration

    """
    config = load_config()
    provider = config.get(CONFIG_KEY_PROVIDER, {})

    provider_type = provider.get(CONFIG_KEY_TYPE)

    if provider_type == Provider.OPENAI.value:
        # Environment variable should take precedence
        api_key = (os.environ.get(OPENAI_API_KEY_ENV) or
                   provider.get(CONFIG_KEY_API_KEY))
        return {
            CONFIG_KEY_TYPE: Provider.OPENAI.value,
            CONFIG_KEY_API_KEY: api_key,
            CONFIG_KEY_MODEL: provider.get(
                CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OPENAI),
            ),
        }

    if provider_type == Provider.OLLAMA.value:
        return {
            CONFIG_KEY_TYPE: Provider.OLLAMA.value,
            CONFIG_KEY_URI: provider.get(CONFIG_KEY_URI, DEFAULT_OLLAMA_URI),
            CONFIG_KEY_MODEL: provider.get(
                CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OLLAMA),
            ),
        }

    if provider_type == Provider.DATABRICKS.value:
        # Set environment variables for Databricks profile
        _set_environment_variables(provider.get(CONFIG_KEY_PROFILE))

        return {
            CONFIG_KEY_TYPE: Provider.DATABRICKS.value,
            CONFIG_KEY_PROFILE: provider.get(CONFIG_KEY_PROFILE),
            CONFIG_KEY_MODEL: provider.get(CONFIG_KEY_MODEL),
        }

    return {CONFIG_KEY_TYPE: None}

load_config()

Load configuration from file.

Source code in src/mlflow_assistant/utils/config.py
def load_config() -> dict[str, Any]:
    """Load configuration from file."""
    if not CONFIG_FILE.exists():
        logger.info(f"No configuration file found at {CONFIG_FILE}")
        return {}

    try:
        with open(CONFIG_FILE) as f:
            config = yaml.safe_load(f) or {}
            logger.debug(f"Loaded configuration: {config}")
            return config
    except Exception as e:
        logger.error(f"Error loading configuration: {e}")
        return {}

save_config(config)

Save configuration to file.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary to save

required

Returns:

Name Type Description
bool bool

True if successful, False otherwise

Source code in src/mlflow_assistant/utils/config.py
def save_config(config: dict[str, Any]) -> bool:
    """Save configuration to file.

    Args:
        config: Configuration dictionary to save

    Returns:
        bool: True if successful, False otherwise

    """
    ensure_config_dir()

    try:
        with open(CONFIG_FILE, "w") as f:
            yaml.dump(config, f)
        logger.info(f"Configuration saved to {CONFIG_FILE}")
        return True
    except Exception as e:
        logger.error(f"Error saving configuration: {e}")
        return False

constants

Constants and enumerations for MLflow Assistant.

This module defines configuration keys, default values, API endpoints, model definitions, and other constants used throughout MLflow Assistant. It includes enums for AI providers (OpenAI, Ollama) and their supported models.

Command

Bases: Enum

Special commands for interactive chat sessions.

description property

Get the description for a command.

DatabricksModel

Bases: Enum

Databricks models supported by MLflow Assistant.

choices() classmethod

Get all available Databricks model choices.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def choices(cls):
    """Get all available Databricks model choices."""
    return [model.value for model in cls]

OllamaModel

Bases: Enum

Default Ollama models supported by MLflow Assistant.

choices() classmethod

Get all available Ollama model choices.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def choices(cls):
    """Get all available Ollama model choices."""
    return [model.value for model in cls]

OpenAIModel

Bases: Enum

OpenAI models supported by MLflow Assistant.

choices() classmethod

Get all available OpenAI model choices.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def choices(cls):
    """Get all available OpenAI model choices."""
    return [model.value for model in cls]

Provider

Bases: Enum

AI provider types supported by MLflow Assistant.

get_default_model(provider) classmethod

Get the default model for a provider.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def get_default_model(cls, provider):
    """Get the default model for a provider."""
    defaults = {
        cls.OPENAI: OpenAIModel.GPT35.value,
        cls.OLLAMA: OllamaModel.LLAMA32.value,
        cls.DATABRICKS: DatabricksModel.DATABRICKS_META_LLAMA3.value,
    }
    return defaults.get(provider)
get_default_temperature(provider) classmethod

Get the default temperature for a provider.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def get_default_temperature(cls, provider):
    """Get the default temperature for a provider."""
    defaults = {
        cls.OPENAI: 0.7,
        cls.DATABRICKS: 0.7,
        cls.OLLAMA: 0.7,
    }
    return defaults.get(provider)

definitions

Constants and definitions for the MLflow Assistant.

MLflowConnectionConfig(tracking_uri) dataclass

Configuration for MLflow connection.

connection_type property

Return the connection type (local or remote).

exceptions

Custom exceptions for the MLflow Assistant.

MLflowConnectionError

Bases: Exception

Exception raised when there's an issue connecting to MLflow Tracking Server.