Skip to content

mlflow_assistant

mlflow_assistant

MLflow Assistant: Interact with MLflow using LLMs.

cli

CLI modules for MLflow Assistant.

commands

CLI commands for MLflow Assistant.

This module contains the main CLI commands for interacting with MLflow using natural language queries through various AI providers.

cli(verbose)

MLflow Assistant: Interact with MLflow using LLMs.

This CLI tool helps you to interact with MLflow using natural language.

Source code in src/mlflow_assistant/cli/commands.py
@click.group()
@click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging")
def cli(verbose):
    """MLflow Assistant: Interact with MLflow using LLMs.

    This CLI tool helps you to interact with MLflow using natural language.
    """
    # Configure logging
    log_level = logging.DEBUG if verbose else logging.INFO
    logging.basicConfig(level=log_level, format=LOG_FORMAT)
mock_process_query(query, provider_config, verbose=False)

Mock function that simulates the query processing workflow.

This will be replaced with the actual implementation later.

Parameters:

Name Type Description Default
query str

The user's query

required
provider_config dict[str, Any]

The AI provider configuration

required
verbose bool

Whether to show verbose output

False

Returns:

Type Description
dict[str, Any]

Dictionary with mock response information

Source code in src/mlflow_assistant/cli/commands.py
def mock_process_query(
    query: str, provider_config: dict[str, Any], verbose: bool = False,
) -> dict[str, Any]:
    """Mock function that simulates the query processing workflow.

    This will be replaced with the actual implementation later.

    Args:
        query: The user's query
        provider_config: The AI provider configuration
        verbose: Whether to show verbose output

    Returns:
        Dictionary with mock response information

    """
    # Create a mock response
    provider_type = provider_config.get(
        CONFIG_KEY_TYPE, DEFAULT_STATUS_NOT_CONFIGURED,
    )
    model = provider_config.get(
        CONFIG_KEY_MODEL, DEFAULT_STATUS_NOT_CONFIGURED,
    )

    response_text = (
        f"This is a mock response to: '{query}'\n\n"
        f"The MLflow integration will be implemented soon!"
    )

    if verbose:
        response_text += f"\n\nDebug: Using {provider_type} with {model}"

    return {
        "original_query": query,
        "provider_config": {
            CONFIG_KEY_TYPE: provider_type,
            CONFIG_KEY_MODEL: model,
        },
        "enhanced": False,
        "response": response_text,
    }
setup()

Run the interactive setup wizard.

This wizard helps you configure MLflow Assistant.

Source code in src/mlflow_assistant/cli/commands.py
@cli.command()
def setup():
    """Run the interactive setup wizard.

    This wizard helps you configure MLflow Assistant.
    """
    setup_wizard()
start(verbose)

Start an interactive chat session with MLflow Assistant.

This opens an interactive chat session where you can ask questions about your MLflow experiments, models, and data. Type /bye to exit the session.

Examples of questions you can ask: - What are my best performing models for classification? - Show me details of experiment 'customer_churn' - Compare runs abc123 and def456 - Which hyperparameters should I try next for my regression model?

Commands: - /bye: Exit the chat session - /help: Show help about available commands - /clear: Clear the screen

Source code in src/mlflow_assistant/cli/commands.py
@cli.command()
@click.option("--verbose", "-v", is_flag=True, help="Show verbose output")
def start(verbose):
    """Start an interactive chat session with MLflow Assistant.

    This opens an interactive chat session where you can ask questions about
    your MLflow experiments, models, and data. Type /bye to exit the session.

    Examples of questions you can ask:
    - What are my best performing models for classification?
    - Show me details of experiment 'customer_churn'
    - Compare runs abc123 and def456
    - Which hyperparameters should I try next for my regression model?

    Commands:
    - /bye: Exit the chat session
    - /help: Show help about available commands
    - /clear: Clear the screen
    """
    # Use validation function to check setup
    is_valid, error_message = validate_setup()
    if not is_valid:
        click.echo(f"❌ Error: {error_message}")
        return

    # Get provider config
    provider_config = get_provider_config()

    # Print welcome message and instructions
    provider_type = provider_config.get(
        CONFIG_KEY_TYPE, DEFAULT_STATUS_NOT_CONFIGURED,
        )
    model = provider_config.get(
        CONFIG_KEY_MODEL, DEFAULT_STATUS_NOT_CONFIGURED,
        )

    click.echo("\n🤖 MLflow Assistant Chat Session")
    click.echo(f"Connected to MLflow at: {get_mlflow_uri()}")
    click.echo(f"Using {provider_type.upper()} with model: {model}")
    click.echo("\nType your questions and press Enter.")
    click.echo(f"Type {Command.EXIT.value} to exit.")
    click.echo("=" * 70)

    # Start interactive loop
    while True:
        # Get user input with a prompt
        try:
            query = click.prompt("\n🧑", prompt_suffix="").strip()
        except (KeyboardInterrupt, EOFError):
            click.echo("\nExiting chat session...")
            break

        # Handle special commands
        action = _handle_special_commands(query)
        if action == "exit":
            break
        if action == "continue":
            continue

        # Process the query
        asyncio.run(_process_user_query(query, provider_config, verbose))
version()

Show MLflow Assistant version information.

Source code in src/mlflow_assistant/cli/commands.py
@cli.command()
def version():
    """Show MLflow Assistant version information."""
    from mlflow_assistant import __version__

    click.echo(f"MLflow Assistant version: {__version__}")

    # Show configuration
    config = load_config()
    mlflow_uri = config.get(
        CONFIG_KEY_MLFLOW_URI, DEFAULT_STATUS_NOT_CONFIGURED,
        )
    provider = config.get(CONFIG_KEY_PROVIDER, {}).get(
        CONFIG_KEY_TYPE, DEFAULT_STATUS_NOT_CONFIGURED,
    )
    model = config.get(CONFIG_KEY_PROVIDER, {}).get(
        CONFIG_KEY_MODEL, DEFAULT_STATUS_NOT_CONFIGURED,
    )

    click.echo(f"MLflow URI: {mlflow_uri}")
    click.echo(f"Provider: {provider}")
    click.echo(f"Model: {model}")

setup

Setup wizard for MLflow Assistant configuration.

This module provides an interactive setup wizard that guides users through configuring MLflow Assistant, including MLflow connection settings and AI provider configuration (OpenAI or Ollama).

setup_wizard()

Interactive setup wizard for mlflow-assistant.

Source code in src/mlflow_assistant/cli/setup.py
def setup_wizard():
    """Interactive setup wizard for mlflow-assistant."""
    click.echo("┌──────────────────────────────────────────────────────┐")
    click.echo("│             MLflow Assistant Setup Wizard            │")
    click.echo("└──────────────────────────────────────────────────────┘")

    click.echo("\nThis wizard will help you configure MLflow Assistant.")

    # Initialize config
    config = load_config()
    previous_provider = config.get(
        CONFIG_KEY_PROVIDER, {}).get(CONFIG_KEY_TYPE)

    # MLflow URI
    mlflow_uri = click.prompt(
        "Enter your MLflow URI",
        default=config.get(CONFIG_KEY_MLFLOW_URI, DEFAULT_MLFLOW_URI),
    )

    if not validate_mlflow_uri(mlflow_uri):
        click.echo("\n⚠️  Warning: Could not connect to MLflow URI.")
        click.echo(
            "    Please ensure MLflow is running.",
        )
        click.echo(
            "    Common MLflow URLs: http://localhost:5000, "
            "http://localhost:8080",
        )
        if not click.confirm(
            "Continue anyway? (Choose Yes if you're sure MLflow is running)",
        ):
            click.echo(
                "Setup aborted. "
                "Please ensure MLflow is running and try again.")
            return
        click.echo("Continuing with setup using the provided MLflow URI.")
    else:
        click.echo("✅ Successfully connected to MLflow!")

    config[CONFIG_KEY_MLFLOW_URI] = mlflow_uri

    # AI Provider
    provider_options = [p.value.capitalize() for p in Provider]
    provider_choice = click.prompt(
        "\nWhich AI provider would you like to use?",
        type=click.Choice(provider_options, case_sensitive=False),
        default=config.get(CONFIG_KEY_PROVIDER, {})
        .get(CONFIG_KEY_TYPE, Provider.OPENAI.value)
        .capitalize(),
    )

    current_provider_type = provider_choice.lower()
    provider_config = {}

    # Check if provider is changing and handle default models
    provider_changed = (previous_provider and
                        previous_provider != current_provider_type)

    if current_provider_type == Provider.OPENAI.value:
        # If switching from another provider, show a message
        if provider_changed:
            click.echo("\n✅ Switching to OpenAI provider")

        # Initialize provider config
        provider_config = {
            CONFIG_KEY_TYPE: Provider.OPENAI.value,
            CONFIG_KEY_MODEL: Provider.get_default_model(
                Provider.OPENAI,
            ),  # Will be updated after user selection
        }

        # Check for OpenAI API key
        api_key = os.environ.get(OPENAI_API_KEY_ENV)
        if not api_key:
            click.echo(
                "\n⚠️  OpenAI API key not found in environment variables.",
            )
            click.echo(
                f"Please export your OpenAI API key as {OPENAI_API_KEY_ENV}.",
            )
            click.echo(f"Example: export {OPENAI_API_KEY_ENV}='your-key-here'")
            if not click.confirm("Continue without API key?"):
                click.echo(
                    "Setup aborted. Please set the API key and try again.",
                )
                return
        else:
            click.echo("✅ Found OpenAI API key in environment!")

        # Always ask for model choice
        model_choices = OpenAIModel.choices()

        # If changing providers, suggest the default,
        # otherwise use previous config
        if provider_changed:
            suggested_model = Provider.get_default_model(Provider.OPENAI)
        else:
            current_model = config.get(CONFIG_KEY_PROVIDER, {}).get(
                CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OPENAI),
            )
            suggested_model = (
                current_model
                if current_model in model_choices
                else Provider.get_default_model(Provider.OPENAI)
            )

        model = click.prompt(
            "Choose an OpenAI model",
            type=click.Choice(model_choices, case_sensitive=False),
            default=suggested_model,
        )
        provider_config[CONFIG_KEY_MODEL] = model

    elif current_provider_type == Provider.OLLAMA.value:
        # If switching from another provider, automatically set defaults
        if provider_changed:
            click.echo(
                "\n✅ Switching to Ollama provider with default URI and model",
            )

        # Ollama configuration - always ask for URI
        ollama_uri = click.prompt(
            "\nEnter your Ollama server URI",
            default=config.get(CONFIG_KEY_PROVIDER, {}).get(
                CONFIG_KEY_URI, DEFAULT_OLLAMA_URI,
            ),
        )

        # Initialize provider config with default model and user-specified URI
        provider_config = {
            CONFIG_KEY_TYPE: Provider.OLLAMA.value,
            CONFIG_KEY_URI: ollama_uri,
            CONFIG_KEY_MODEL: Provider.get_default_model(
                Provider.OLLAMA,
            ),  # Will be updated if user selects a different model
        }

        # Check if Ollama is running
        is_connected, ollama_data = validate_ollama_connection(ollama_uri)
        if is_connected:
            click.echo("✅ Ollama server is running!")

            # Get available models
            available_models = ollama_data.get("models", [])

            if available_models:
                click.echo(
                    f"\nAvailable Ollama models: {', '.join(available_models)}",
                )

                # If changing providers, suggest the default,
                # otherwise use previous config
                default_model = Provider.get_default_model(Provider.OLLAMA)
                if provider_changed:
                    suggested_model = (
                        default_model
                        if default_model in available_models
                        else available_models[0]
                    )
                else:
                    current_model = config.get(CONFIG_KEY_PROVIDER, {}).get(
                        CONFIG_KEY_MODEL,
                    )
                    suggested_model = (
                        current_model
                        if current_model in available_models
                        else default_model
                    )

                ollama_model = click.prompt(
                    "Choose an Ollama model",
                    type=click.Choice(available_models, case_sensitive=True),
                    default=suggested_model,
                )
                provider_config[CONFIG_KEY_MODEL] = ollama_model
            else:
                click.echo("\nNo models found. Using default model.")
                ollama_model = click.prompt(
                    "Enter the Ollama model to use",
                    default=config.get(CONFIG_KEY_PROVIDER, {}).get(
                        CONFIG_KEY_MODEL, Provider.get_default_model(
                            Provider.OLLAMA,
                        ),
                    ),
                )
                provider_config[CONFIG_KEY_MODEL] = ollama_model
        else:
            click.echo(
                "\n⚠️  Warning: Ollama server not running or"
                " not accessible at this URI.",
            )
            if not click.confirm("Continue anyway?"):
                click.echo(
                    "Setup aborted. Please start Ollama server and try again.",
                )
                return

            # Still prompt for model name
            ollama_model = click.prompt(
                "Enter the Ollama model to use",
                default=config.get(CONFIG_KEY_PROVIDER, {}).get(
                    CONFIG_KEY_MODEL, Provider.get_default_model(
                        Provider.OLLAMA,
                    ),
                ),
            )
            provider_config[CONFIG_KEY_MODEL] = ollama_model

    elif current_provider_type == Provider.DATABRICKS.value:
        config_path = Path(DEFAULT_DATABRICKS_CONFIG_FILE).expanduser()
        # Verify Databricks configuration file path
        click.echo(f"Checking Databricks configuration file at: {config_path}")
        if not os.path.isfile(config_path):
            # File does not exist, prompt user to create it
            click.echo(
                    "Setup aborted. Please setup Databricks config file and try again.",
                )
            return

        # Get Databricks configuration file
        config_string = Path(config_path).read_text()

        # Get profiles from the Databricks configuration file
        # Parse the config string
        databricks_config = configparser.ConfigParser()
        databricks_config.read_string(config_string)

        # Manually include DEFAULT section
        all_sections = ['DEFAULT', *databricks_config.sections()]

        profile_options = [section for section in all_sections if 'token' in databricks_config[section]]

        if not profile_options:
            click.echo(
                "\n⚠️  No valid profiles found in Databricks configuration file.",
            )
            click.echo(
                "Please ensure your Databricks config file contains a profile with a 'token'.",
            )
            click.echo(
                "Setup aborted. Please fix the configuration and try again.",
            )
            return

        profile = click.prompt(
            "\nWhich databricks profile would you like to use?",
            type=click.Choice(profile_options, case_sensitive=False),
            default=profile_options[0],
        )

        # Peompt for model name
        databricks_model = click.prompt(
            "Enter the Databricks model to use",
        )

        provider_config = {
            CONFIG_KEY_TYPE: Provider.DATABRICKS.value,
            CONFIG_KEY_PROFILE: profile,
            CONFIG_KEY_MODEL: databricks_model,
        }

    config[CONFIG_KEY_PROVIDER] = provider_config

    # Save the configuration
    save_config(config)

    click.echo("\n✅ Configuration saved successfully!")
    click.echo("\n┌──────────────────────────────────────────────────┐")
    click.echo("│               Getting Started                    │")
    click.echo("└──────────────────────────────────────────────────┘")
    click.echo(
        "\nYou can now use MLflow Assistant with the following commands:")
    click.echo(
        "  mlflow-assistant start     - Start an interactive chat "
        "session.",
    )
    click.echo(
        "  mlflow-assistant version   - Show version "
        "information.",
    )

    click.echo("\nFor more information, use 'mlflow-assistant --help'")

validation

Validation utilities for MLflow Assistant configuration.

This module provides validation functions to check MLflow connections, AI provider configurations, and overall system setup to ensure proper operation of MLflow Assistant.

validate_mlflow_uri(uri)

Validate MLflow URI by attempting to connect.

Parameters:

Name Type Description Default
uri str

MLflow server URI

required

Returns:

Name Type Description
bool bool

True if connection successful, False otherwise

Source code in src/mlflow_assistant/cli/validation.py
def validate_mlflow_uri(uri: str) -> bool:
    """Validate MLflow URI by attempting to connect.

    Args:
        uri: MLflow server URI

    Returns:
        bool: True if connection successful, False otherwise

    """
    for endpoint in MLFLOW_VALIDATION_ENDPOINTS:
        try:
            # Try with trailing slash trimmed
            clean_uri = uri.rstrip("/")
            url = f"{clean_uri}{endpoint}"
            logger.debug(f"Trying to connect to MLflow at: {url}")

            response = requests.get(url, timeout=MLFLOW_CONNECTION_TIMEOUT)
            if response.status_code == 200:
                logger.info(f"Successfully connected to MLflow at {url}")
                return True
            logger.debug(f"Response from {url}: {response.status_code}")
        except Exception as e:
            logger.debug(f"Failed to connect to {endpoint}: {e!s}")

    # If we get here, none of the endpoints worked
    logger.warning(
        f"Could not validate MLflow at {uri} on any standard endpoint",
    )
    return False
validate_ollama_connection(uri)

Validate Ollama connection and get available models.

Parameters:

Name Type Description Default
uri str

Ollama server URI

required

Returns:

Type Description
tuple[bool, dict[str, Any]]

Tuple[bool, Dict[str, Any]]: (is_valid, response_data)

Source code in src/mlflow_assistant/cli/validation.py
def validate_ollama_connection(uri: str) -> tuple[bool, dict[str, Any]]:
    """Validate Ollama connection and get available models.

    Args:
        uri: Ollama server URI

    Returns:
        Tuple[bool, Dict[str, Any]]: (is_valid, response_data)

    """
    try:
        response = requests.get(
            f"{uri}{OLLAMA_TAGS_ENDPOINT}", timeout=OLLAMA_CONNECTION_TIMEOUT,
        )
        if response.status_code == 200:
            try:
                models_data = response.json()
                available_models = [
                    m.get("name") for m in models_data.get("models", [])
                ]
                return True, {"models": available_models}
            except Exception as e:
                logger.debug(f"Error parsing Ollama models: {e}")
                return True, {"models": []}
        else:
            return False, {}
    except Exception as e:
        logger.debug(f"Error connecting to Ollama: {e}")
        return False, {}
validate_setup(check_api_key=True)

Validate that MLflow Assistant is properly configured.

Parameters:

Name Type Description Default
check_api_key bool

Whether to check for API key if using OpenAI

True

Returns:

Type Description
tuple[bool, str]

Tuple[bool, str]: (is_valid, error_message)

Source code in src/mlflow_assistant/cli/validation.py
def validate_setup(check_api_key: bool = True) -> tuple[bool, str]:
    """Validate that MLflow Assistant is properly configured.

    Args:
        check_api_key: Whether to check for API key if using OpenAI

    Returns:
        Tuple[bool, str]: (is_valid, error_message)

    """
    # Check MLflow URI
    mlflow_uri = get_mlflow_uri()
    if not mlflow_uri:
        return (
            False,
            "MLflow URI not configured. "
            "Run 'mlflow-assistant setup' first.",
        )

    # Get provider config
    provider_config = get_provider_config()
    if not provider_config or not provider_config.get(CONFIG_KEY_TYPE):
        return (
            False,
            "AI provider not configured. "
            "Run 'mlflow-assistant setup' first.",
        )

    # Ensure OpenAI has an API key if that's the configured provider
    if (
        check_api_key
        and provider_config.get(CONFIG_KEY_TYPE) == Provider.OPENAI.value
        and not provider_config.get(CONFIG_KEY_API_KEY)
    ):
        return (
            False,
            f"OpenAI API key not found in environment. "
            f"Set {OPENAI_API_KEY_ENV}.",
        )

    return True, ""

core

Core functionality for MLflow Assistant.

This subpackage contains the core modules for managing connections, workflows, and interactions with the MLflow Tracking Server.

cli

Command-line interface (CLI) for MLflow Assistant.

This module provides the CLI entry points for interacting with the MLflow Assistant, allowing users to manage connections, workflows, and other operations via the command line.

connection

MLflow connection module for handling connections to MLflow Tracking Server.

This module provides functionality to connect to both local and remote MLflow Tracking Servers using environment variables or direct configuration.

MLflowConnection(tracking_uri=None, client_factory=None)

MLflow connection class to handle connections to MLflow Tracking Server.

This class provides functionality to connect to both local and remote MLflow Tracking Servers.

Initialize MLflow connection.

Parameters:

Name Type Description Default
tracking_uri str | None

URI of the MLflow Tracking Server. If None, will try to get from environment.

None
client_factory Any

A callable to create the MlflowClient instance. Defaults to MlflowClient.

None
Source code in src/mlflow_assistant/core/connection.py
def __init__(self, tracking_uri: str | None = None, client_factory: Any = None):
    """Initialize MLflow connection.

    Args:
        tracking_uri: URI of the MLflow Tracking Server. If None, will try to get from environment.
        client_factory: A callable to create the MlflowClient instance. Defaults to MlflowClient.

    """
    self.config = self._load_config(tracking_uri=tracking_uri)
    self.client = None
    self.is_connected_flag = False
    self.client_factory = client_factory or MlflowClient
connect()

Connect to MLflow Tracking Server.

Returns
Tuple[bool, str]: (success, message)
Source code in src/mlflow_assistant/core/connection.py
def connect(self) -> tuple[bool, str]:
    """Connect to MLflow Tracking Server.

    Returns
    -------
        Tuple[bool, str]: (success, message)

    """
    try:
        logger.debug(f"Connecting to MLflow Tracking Server at {self.config.tracking_uri}")
        mlflow.set_tracking_uri(self.config.tracking_uri)
        self.client = self.client_factory(tracking_uri=self.config.tracking_uri)
        self.client.search_experiments()  # Trigger connection attempt
        self.is_connected_flag = True
        logger.debug(f"Successfully connected to MLflow Tracking Server at {self.config.tracking_uri}")
        return True, f"Successfully connected to MLflow Tracking Server at {self.config.tracking_uri}"
    except Exception as e:
        self.is_connected_flag = False
        logger.exception(f"Failed to connect to MLflow Tracking Server: {e}")
        return False, f"Failed to connect to MLflow Tracking Server: {e!s}"
get_client()

Get MLflow client instance.

Returns
MlflowClient: MLflow client instance.
Raises
MLflowConnectionError: If not connected to MLflow Tracking Server.
Source code in src/mlflow_assistant/core/connection.py
def get_client(self) -> MlflowClient:
    """Get MLflow client instance.

    Returns
    -------
        MlflowClient: MLflow client instance.

    Raises
    ------
        MLflowConnectionError: If not connected to MLflow Tracking Server.

    """
    if self.client is None:
        msg = "Not connected to MLflow Tracking Server. Call connect() first."
        raise MLflowConnectionError(msg)
    return self.client
get_connection_info()

Get connection information.

Returns
Dict[str, Any]: Connection information.
Source code in src/mlflow_assistant/core/connection.py
def get_connection_info(self) -> dict[str, Any]:
    """Get connection information.

    Returns
    -------
        Dict[str, Any]: Connection information.

    """
    return {
        "tracking_uri": self.config.tracking_uri,
        "connection_type": self.config.connection_type,
        "is_connected": self.is_connected_flag,
    }
is_connected()

Check if connected to MLflow Tracking Server.

Returns
bool: True if connected, False otherwise.
Source code in src/mlflow_assistant/core/connection.py
def is_connected(self) -> bool:
    """Check if connected to MLflow Tracking Server.

    Returns
    -------
        bool: True if connected, False otherwise.

    """
    return self.is_connected_flag

core

Core utilities and functionality for MLflow Assistant.

This module provides foundational classes, functions, and utilities used across the MLflow Assistant project, including shared logic for managing workflows and interactions with the MLflow Tracking Server.

get_mlflow_client()

Initialize and return an MLflow client instance.

Returns
MlflowClient: An instance of the MLflow client.
Source code in src/mlflow_assistant/core/core.py
def get_mlflow_client():
    """Initialize and return an MLflow client instance.

    Returns
    -------
        MlflowClient: An instance of the MLflow client.

    """
    return MlflowClient()

provider

Provider integrations for MLflow Assistant.

This module defines the interfaces and implementations for integrating with various large language model (LLM) providers, such as OpenAI and Ollama.

workflow

Workflow management for LangGraph in MLflow Assistant.

This module provides functionality for defining, managing, and executing workflows using LangGraph, enabling seamless integration with MLflow for tracking and managing machine learning workflows.

engine

MLflow Assistant Engine - Provides workflow functionality to process user query.

definitions

Constants for the MLflow Assistant engine.

processor

Query processor that leverages the workflow engine for processing user queries and generating responses using an AI provider.

process_query(query, provider_config, verbose=False) async

Process a query through the MLflow Assistant workflow.

Parameters:

Name Type Description Default
query str

The query to process

required
provider_config dict[str, Any]

AI provider configuration

required
verbose bool

Whether to show verbose output

False

Returns:

Type Description
dict[str, Any]

Dict containing the response

Source code in src/mlflow_assistant/engine/processor.py
async def process_query(
    query: str, provider_config: dict[str, Any], verbose: bool = False,
) -> dict[str, Any]:
    """Process a query through the MLflow Assistant workflow.

    Args:
        query: The query to process
        provider_config: AI provider configuration
        verbose: Whether to show verbose output

    Returns:
        Dict containing the response

    """
    import time

    from .workflow import create_workflow

    # Track start time for duration calculation
    start_time = time.time()

    try:
        # Create workflow
        workflow = create_workflow()

        # Run workflow with provider config
        initial_state = {
            STATE_KEY_MESSAGES: [HumanMessage(content=query)],
            STATE_KEY_PROVIDER_CONFIG: provider_config,
        }

        if verbose:
            logger.info(f"Running workflow with query: {query}")
            logger.info(f"Using provider: {provider_config.get(CONFIG_KEY_TYPE)}")
            logger.info(
                f"Using model: {provider_config.get(CONFIG_KEY_MODEL, 'default')}",
            )

        result = await workflow.ainvoke(initial_state)

        # Calculate duration
        duration = time.time() - start_time

        return {
            "original_query": query,
            "response": result.get(STATE_KEY_MESSAGES)[-1],
            "duration": duration,  # Add duration to response
        }

    except Exception as e:
        # Calculate duration even for errors
        duration = time.time() - start_time

        logger.error(f"Error processing query: {e}")

        return {
            "error": str(e),
            "original_query": query,
            "response": f"Error processing query: {e!s}",
        }

tools

LangGraph tools for MLflow interactions.

MLflowTools

Collection of helper utilities for MLflow interactions.

format_timestamp(timestamp_ms) staticmethod

Convert a millisecond timestamp to a human-readable string.

Source code in src/mlflow_assistant/engine/tools.py
@staticmethod
def format_timestamp(timestamp_ms: int) -> str:
    """Convert a millisecond timestamp to a human-readable string."""
    if not timestamp_ms:
        return NA
    dt = datetime.fromtimestamp(timestamp_ms / 1000.0)
    return dt.strftime(TIME_FORMAT)
get_model_details(model_name)

Get detailed information about a specific registered model.

Parameters:

Name Type Description Default
model_name str

The name of the registered model

required

Returns:

Type Description
str

A JSON string containing detailed information about the model.

Source code in src/mlflow_assistant/engine/tools.py
@tool
def get_model_details(model_name: str) -> str:
    """Get detailed information about a specific registered model.

    Args:
        model_name: The name of the registered model

    Returns:
        A JSON string containing detailed information about the model.

    """
    logger.debug(f"Fetching details for model: {model_name}")

    try:
        # Get the registered model
        model = client.get_registered_model(model_name)

        model_info = {
            "name": model.name,
            "creation_timestamp": MLflowTools.format_timestamp(
                model.creation_timestamp,
            ),
            "last_updated_timestamp": MLflowTools.format_timestamp(
                model.last_updated_timestamp,
            ),
            "description": model.description or "",
            "tags": {tag.key: tag.value for tag in model.tags}
            if hasattr(model, "tags")
            else {},
            "versions": [],
        }

        # Get all versions for this model
        versions = client.search_model_versions(f"name='{model_name}'")

        for version in versions:
            version_info = {
                "version": version.version,
                "status": version.status,
                "stage": version.current_stage,
                "creation_timestamp": MLflowTools.format_timestamp(
                    version.creation_timestamp,
                ),
                "source": version.source,
                "run_id": version.run_id,
            }

            # Get additional information about the run if available
            if version.run_id:
                try:
                    run = client.get_run(version.run_id)
                    # Extract only essential run information to avoid serialization issues
                    run_metrics = {}
                    for k, v in run.data.metrics.items():
                        try:
                            run_metrics[k] = float(v)
                        except ValueError:
                            run_metrics[k] = str(v)

                    version_info["run"] = {
                        "status": run.info.status,
                        "start_time": MLflowTools.format_timestamp(
                            run.info.start_time,
                        ),
                        "end_time": MLflowTools.format_timestamp(run.info.end_time)
                        if run.info.end_time
                        else None,
                        "metrics": run_metrics,
                    }
                except Exception as e:
                    logger.warning(
                        f"Error getting run details for {version.run_id}: {e!s}",
                    )
                    version_info["run"] = "Error retrieving run details"

            model_info["versions"].append(version_info)

        return json.dumps(model_info, indent=2)

    except Exception as e:
        error_msg = f"Error getting model details: {e!s}"
        logger.error(error_msg, exc_info=True)
        return json.dumps({"error": error_msg})
get_system_info()

Get information about the MLflow tracking server and system.

Returns:

Type Description
str

A JSON string containing system information.

Source code in src/mlflow_assistant/engine/tools.py
@tool
def get_system_info() -> str:
    """Get information about the MLflow tracking server and system.

    Returns:
        A JSON string containing system information.

    """
    logger.debug("Getting MLflow system information")

    try:
        info = {
            "mlflow_version": mlflow.__version__,
            "tracking_uri": mlflow.get_tracking_uri(),
            "registry_uri": mlflow.get_registry_uri(),
            "artifact_uri": mlflow.get_artifact_uri(),
            "python_version": sys.version,
            "server_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        }

        # Get experiment count
        try:
            experiments = client.search_experiments()
            info["experiment_count"] = len(experiments)
        except Exception as e:
            logger.warning(f"Error getting experiment count: {e!s}")
            info["experiment_count"] = "Error retrieving count"

        # Get model count
        try:
            models = client.search_registered_models()
            info["model_count"] = len(models)
        except Exception as e:
            logger.warning(f"Error getting model count: {e!s}")
            info["model_count"] = "Error retrieving count"

        # Get active run count
        try:
            active_runs = 0
            for exp in experiments:
                runs = client.search_runs(
                    experiment_ids=[exp.experiment_id],
                    filter_string="attributes.status = 'RUNNING'",
                    max_results=1000,
                )
                active_runs += len(runs)

            info["active_runs"] = active_runs
        except Exception as e:
            logger.warning(f"Error getting active run count: {e!s}")
            info["active_runs"] = "Error retrieving count"

        return json.dumps(info, indent=2)

    except Exception as e:
        error_msg = f"Error getting system info: {e!s}"
        logger.error(error_msg, exc_info=True)
        return json.dumps({"error": error_msg})
list_experiments(name_contains='', max_results=MLFLOW_MAX_RESULTS)

List all experiments in the MLflow tracking server, with optional filtering.

Parameters:

Name Type Description Default
name_contains str

Optional filter to only include experiments whose names contain this string

''
max_results int

Maximum number of results to return (default: 100)

MLFLOW_MAX_RESULTS

Returns:

Type Description
str

A JSON string containing all experiments matching the criteria.

Source code in src/mlflow_assistant/engine/tools.py
@tool
def list_experiments(
    name_contains: str = "", max_results: int = MLFLOW_MAX_RESULTS,
) -> str:
    """List all experiments in the MLflow tracking server, with optional filtering.

    Args:
        name_contains: Optional filter to only include experiments whose names contain this string
        max_results: Maximum number of results to return (default: 100)

    Returns:
        A JSON string containing all experiments matching the criteria.

    """
    logger.debug(f"Fetching experiments (filter: '{name_contains}', max: {max_results})")

    try:
        # Get all experiments
        experiments = client.search_experiments()

        # Filter by name if specified
        if name_contains:
            experiments = [
                exp for exp in experiments if name_contains.lower() in exp.name.lower()
            ]

        # Limit to max_results
        experiments = experiments[:max_results]

        # Create a list to hold experiment information
        experiments_info = []

        # Extract relevant information for each experiment
        for exp in experiments:
            exp_info = {
                "experiment_id": exp.experiment_id,
                "name": exp.name,
                "artifact_location": exp.artifact_location,
                "lifecycle_stage": exp.lifecycle_stage,
                "creation_time": MLflowTools.format_timestamp(exp.creation_time)
                if hasattr(exp, "creation_time")
                else None,
                "tags": {tag.key: tag.value for tag in exp.tags}
                if hasattr(exp, "tags")
                else {},
            }

            # Get the run count for this experiment
            try:
                runs = client.search_runs(
                    experiment_ids=[exp.experiment_id], max_results=1,
                )
                if runs:
                    # Just get the count of runs, not the actual runs
                    run_count = client.search_runs(
                        experiment_ids=[exp.experiment_id], max_results=1000,
                    )
                    exp_info["run_count"] = len(run_count)
                else:
                    exp_info["run_count"] = 0
            except Exception as e:
                logger.warning(
                    f"Error getting run count for experiment {exp.experiment_id}: {e!s}",
                )
                exp_info["run_count"] = "Error getting count"

            experiments_info.append(exp_info)

        result = {
            "total_experiments": len(experiments_info),
            "experiments": experiments_info,
        }

        return json.dumps(result, indent=2)

    except Exception as e:
        error_msg = f"Error listing experiments: {e!s}"
        logger.error(error_msg, exc_info=True)
        return json.dumps({"error": error_msg})
list_models(name_contains='', max_results=MLFLOW_MAX_RESULTS)

List all registered models in the MLflow model registry, with optional filtering.

Parameters:

Name Type Description Default
name_contains str

Optional filter to only include models whose names contain this string

''
max_results int

Maximum number of results to return (default: 100)

MLFLOW_MAX_RESULTS

Returns:

Type Description
str

A JSON string containing all registered models matching the criteria.

Source code in src/mlflow_assistant/engine/tools.py
@tool
def list_models(name_contains: str = "", max_results: int = MLFLOW_MAX_RESULTS) -> str:
    """List all registered models in the MLflow model registry, with optional filtering.

    Args:
        name_contains: Optional filter to only include models whose names contain this string
        max_results: Maximum number of results to return (default: 100)

    Returns:
        A JSON string containing all registered models matching the criteria.

    """
    logger.debug(
        f"Fetching registered models (filter: '{name_contains}', max: {max_results})",
    )

    try:
        # Get all registered models
        registered_models = client.search_registered_models(max_results=max_results)

        # Filter by name if specified
        if name_contains:
            registered_models = [
                model
                for model in registered_models
                if name_contains.lower() in model.name.lower()
            ]

        # Create a list to hold model information
        models_info = []

        # Extract relevant information for each model
        for model in registered_models:
            model_info = {
                "name": model.name,
                "creation_timestamp": MLflowTools.format_timestamp(
                    model.creation_timestamp,
                ),
                "last_updated_timestamp": MLflowTools.format_timestamp(
                    model.last_updated_timestamp,
                ),
                "description": model.description or "",
                "tags": {tag.key: tag.value for tag in model.tags}
                if hasattr(model, "tags")
                else {},
                "latest_versions": [],
            }

            # Add the latest versions if available
            if model.latest_versions and len(model.latest_versions) > 0:
                for version in model.latest_versions:
                    version_info = {
                        "version": version.version,
                        "status": version.status,
                        "stage": version.current_stage,
                        "creation_timestamp": MLflowTools.format_timestamp(
                            version.creation_timestamp,
                        ),
                        "run_id": version.run_id,
                    }
                    model_info["latest_versions"].append(version_info)

            models_info.append(model_info)

        result = {"total_models": len(models_info), "models": models_info}

        return json.dumps(result, indent=2)

    except Exception as e:
        error_msg = f"Error listing models: {e!s}"
        logger.error(error_msg, exc_info=True)
        return json.dumps({"error": error_msg})

workflow

Core LangGraph-based workflow engine for processing user queries and generating responses using an AI provider.

This workflow supports tool-augmented generation: tool calls are detected and executed in a loop until a final AI response is produced.

State

Bases: TypedDict

State schema for the workflow engine.

create_workflow()

Create and return a compiled LangGraph workflow.

Source code in src/mlflow_assistant/engine/workflow.py
def create_workflow():
    """Create and return a compiled LangGraph workflow."""
    graph_builder = StateGraph(State)

    def call_model(state: State) -> State:
        """Call the AI model and return updated state with response."""
        messages = state[STATE_KEY_MESSAGES]
        provider_config = state.get(STATE_KEY_PROVIDER_CONFIG, {})
        try:
            provider = AIProvider.create(provider_config)
            model = provider.langchain_model().bind_tools(tools)
            response = model.invoke(messages)
            return {**state, STATE_KEY_MESSAGES: [response]}
        except Exception as e:
            logger.error(f"Error generating response: {e}", exc_info=True)
            return {**state, STATE_KEY_MESSAGES: messages}

    # Add nodes
    graph_builder.add_node("tools", ToolNode(tools))
    graph_builder.add_node("model", call_model)

    # Define graph transitions
    graph_builder.add_edge("tools", "model")
    graph_builder.add_conditional_edges("model", tools_condition)
    graph_builder.set_entry_point("model")

    return graph_builder.compile()

main

Main entry point for executing the MLflow Assistant package directly.

providers

Provider module for MLflow Assistant.

AIProvider

Bases: ABC

Abstract base class for AI providers.

langchain_model abstractmethod property

Get the underlying LangChain model.

__init_subclass__(**kwargs)

Auto-register provider subclasses.

Source code in src/mlflow_assistant/providers/base.py
def __init_subclass__(cls, **kwargs):
    """Auto-register provider subclasses."""
    super().__init_subclass__(**kwargs)
    # Register the provider using the class name
    provider_type = cls.__name__.lower().replace(CONFIG_KEY_PROVIDER, "")
    AIProvider._providers[provider_type] = cls
    logger.debug(f"Registered provider: {provider_type}")
create(config) classmethod

Create an AI provider based on configuration.

Source code in src/mlflow_assistant/providers/base.py
@classmethod
def create(cls, config: dict[str, Any]) -> "AIProvider":
    """Create an AI provider based on configuration."""
    provider_type = config.get(CONFIG_KEY_TYPE)

    if not provider_type:
        error_msg = "Provider type not specified in configuration"
        raise ValueError(error_msg)

    provider_type = provider_type.lower()

    # Extract common parameters
    kwargs = {}
    for param in ParameterKeys.PARAMETERS_ALL:
        if param in config:
            kwargs[param] = config[param]

    # Import providers dynamically to avoid circular imports
    if provider_type == Provider.OPENAI.value:
        from .openai_provider import OpenAIProvider

        logger.debug(
            f"Creating OpenAI provider with model {config.get(CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OPENAI))}",
        )
        return OpenAIProvider(
            api_key=config.get(CONFIG_KEY_API_KEY),
            model=config.get(
                CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OPENAI),
            ),
            temperature=config.get(
                ParameterKeys.TEMPERATURE,
                Provider.get_default_temperature(Provider.OPENAI),
            ),
            **kwargs,
        )
    if provider_type == Provider.OLLAMA.value:
        from .ollama_provider import OllamaProvider

        logger.debug(
            f"Creating Ollama provider with model {config.get(CONFIG_KEY_MODEL)}",
        )
        return OllamaProvider(
            uri=config.get(CONFIG_KEY_URI),
            model=config.get(CONFIG_KEY_MODEL),
            temperature=config.get(
                ParameterKeys.TEMPERATURE,
                Provider.get_default_temperature(Provider.OLLAMA),
            ),
            **kwargs,
        )
    if provider_type == Provider.DATABRICKS.value:
        from .databricks_provider import DatabricksProvider

        logger.debug(
            f"Creating Databricks provider with model {config.get(CONFIG_KEY_MODEL)}",
        )
        return DatabricksProvider(
            model=config.get(CONFIG_KEY_MODEL),
            temperature=config.get(
                ParameterKeys.TEMPERATURE,
                Provider.get_default_temperature(Provider.DATABRICKS),
            ),
            **kwargs,
        )
    if provider_type not in cls._providers:
        error_msg = f"Unknown provider type: {provider_type}. Available types: {', '.join(cls._providers.keys())}"
        raise ValueError(error_msg)
    # Generic initialization for future providers
    provider_class = cls._providers[provider_type]
    return provider_class(config)

get_ollama_models(uri=DEFAULT_OLLAMA_URI)

Fetch the list of available Ollama models.

Source code in src/mlflow_assistant/providers/utilities.py
def get_ollama_models(uri: str = DEFAULT_OLLAMA_URI) -> list:
    """Fetch the list of available Ollama models."""
    # Try using direct API call first
    try:
        response = requests.get(f"{uri}/api/tags", timeout=10)
        if response.status_code == 200:
            data = response.json()
            models = [m.get("name") for m in data.get("models", [])]
            if models:
                return models
    except Exception as e:
        logger.debug(f"Failed to get Ollama models from API: {e}")

    try:
        # Execute the Ollama list command
        ollama_path = shutil.which("ollama")

        result = subprocess.run(  # noqa: S603
            [ollama_path, "list"], capture_output=True, text=True, check=False,
        )

        # Check if command executed successfully
        if result.returncode != 0:
            logger.warning(f"ollama list failed: {result.stderr}")
            return FALLBACK_MODELS

        # Parse the output to extract model names
        lines = result.stdout.strip().split("\n")
        if len(lines) <= 1:  # Only header line or empty
            return FALLBACK_MODELS

        # Skip header line and extract the first column (model name)
        models = [line.split()[0] for line in lines[1:]]
        return models or FALLBACK_MODELS

    except (subprocess.SubprocessError, FileNotFoundError, IndexError) as e:
        logger.warning(f"Error fetching Ollama models: {e!s}")
        return FALLBACK_MODELS

verify_ollama_running(uri=DEFAULT_OLLAMA_URI)

Verify if Ollama is running at the given URI.

Source code in src/mlflow_assistant/providers/utilities.py
def verify_ollama_running(uri: str = DEFAULT_OLLAMA_URI) -> bool:
    """Verify if Ollama is running at the given URI."""
    try:
        response = requests.get(f"{uri}/api/tags", timeout=2)
        return response.status_code == 200
    except Exception:
        return False

base

Base class for AI providers.

AIProvider

Bases: ABC

Abstract base class for AI providers.

langchain_model abstractmethod property

Get the underlying LangChain model.

__init_subclass__(**kwargs)

Auto-register provider subclasses.

Source code in src/mlflow_assistant/providers/base.py
def __init_subclass__(cls, **kwargs):
    """Auto-register provider subclasses."""
    super().__init_subclass__(**kwargs)
    # Register the provider using the class name
    provider_type = cls.__name__.lower().replace(CONFIG_KEY_PROVIDER, "")
    AIProvider._providers[provider_type] = cls
    logger.debug(f"Registered provider: {provider_type}")
create(config) classmethod

Create an AI provider based on configuration.

Source code in src/mlflow_assistant/providers/base.py
@classmethod
def create(cls, config: dict[str, Any]) -> "AIProvider":
    """Create an AI provider based on configuration."""
    provider_type = config.get(CONFIG_KEY_TYPE)

    if not provider_type:
        error_msg = "Provider type not specified in configuration"
        raise ValueError(error_msg)

    provider_type = provider_type.lower()

    # Extract common parameters
    kwargs = {}
    for param in ParameterKeys.PARAMETERS_ALL:
        if param in config:
            kwargs[param] = config[param]

    # Import providers dynamically to avoid circular imports
    if provider_type == Provider.OPENAI.value:
        from .openai_provider import OpenAIProvider

        logger.debug(
            f"Creating OpenAI provider with model {config.get(CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OPENAI))}",
        )
        return OpenAIProvider(
            api_key=config.get(CONFIG_KEY_API_KEY),
            model=config.get(
                CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OPENAI),
            ),
            temperature=config.get(
                ParameterKeys.TEMPERATURE,
                Provider.get_default_temperature(Provider.OPENAI),
            ),
            **kwargs,
        )
    if provider_type == Provider.OLLAMA.value:
        from .ollama_provider import OllamaProvider

        logger.debug(
            f"Creating Ollama provider with model {config.get(CONFIG_KEY_MODEL)}",
        )
        return OllamaProvider(
            uri=config.get(CONFIG_KEY_URI),
            model=config.get(CONFIG_KEY_MODEL),
            temperature=config.get(
                ParameterKeys.TEMPERATURE,
                Provider.get_default_temperature(Provider.OLLAMA),
            ),
            **kwargs,
        )
    if provider_type == Provider.DATABRICKS.value:
        from .databricks_provider import DatabricksProvider

        logger.debug(
            f"Creating Databricks provider with model {config.get(CONFIG_KEY_MODEL)}",
        )
        return DatabricksProvider(
            model=config.get(CONFIG_KEY_MODEL),
            temperature=config.get(
                ParameterKeys.TEMPERATURE,
                Provider.get_default_temperature(Provider.DATABRICKS),
            ),
            **kwargs,
        )
    if provider_type not in cls._providers:
        error_msg = f"Unknown provider type: {provider_type}. Available types: {', '.join(cls._providers.keys())}"
        raise ValueError(error_msg)
    # Generic initialization for future providers
    provider_class = cls._providers[provider_type]
    return provider_class(config)

databricks_provider

Databricks provider for MLflow Assistant.

DatabricksProvider(model=None, temperature=None, **kwargs)

Bases: AIProvider

Databricks provider implementation.

Initialize the Databricks provider with model.

Source code in src/mlflow_assistant/providers/databricks_provider.py
def __init__(
    self,
    model: str | None = None,
    temperature: float | None = None,
    **kwargs,
):
    """Initialize the Databricks provider with model."""
    self.model_name = (
        model or Provider.get_default_model(Provider.DATABRICKS.value)
    )
    self.temperature = (
        temperature or Provider.get_default_temperature(Provider.DATABRICKS.value)
    )
    self.kwargs = kwargs

    for var in DATABRICKS_CREDENTIALS:
        if var not in os.environ:
            logger.warning(
                f"Missing environment variable: {var}. "
                "Responses may fail if you are running outside Databricks.",
            )

    # Build parameters dict with only non-None values
    model_params = {"endpoint": self.model_name, "temperature": temperature}

    # Only add optional parameters if they're not None
    for param in ParameterKeys.get_parameters(Provider.DATABRICKS.value):
        if param in kwargs and kwargs[param] is not None:
            model_params[param] = kwargs[param]

    # Initialize with parameters matching the documentation
    self.model = ChatDatabricks(**model_params)

    logger.debug(f"Databricks provider initialized with model {self.model_name}")
langchain_model()

Get the underlying LangChain model.

Source code in src/mlflow_assistant/providers/databricks_provider.py
def langchain_model(self):
    """Get the underlying LangChain model."""
    return self.model

definitions

Constants for the MLflow Assistant providers.

ParameterKeys

Keys and default parameter groupings for supported providers.

get_parameters(provider) classmethod

Return the list of parameters for the given provider name.

Source code in src/mlflow_assistant/providers/definitions.py
@classmethod
def get_parameters(cls, provider: str) -> list[str]:
    """Return the list of parameters for the given provider name."""
    provider_map = {
        "openai": cls.PARAMETERS_OPENAI,
        "ollama": cls.PARAMETERS_OLLAMA,
        "databricks": cls.PARAMETERS_DATABRICKS,
    }
    return provider_map.get(provider.lower(), [])

ollama_provider

Ollama provider for MLflow Assistant.

OllamaProvider(uri=None, model=None, temperature=None, **kwargs)

Bases: AIProvider

Ollama provider implementation.

Initialize the Ollama provider with URI and model.

Source code in src/mlflow_assistant/providers/ollama_provider.py
def __init__(self, uri=None, model=None, temperature=None, **kwargs):
    """Initialize the Ollama provider with URI and model."""
    # Handle None URI case to prevent attribute errors
    if uri is None:
        logger.warning(
            f"Ollama URI is None. Using default URI: {DEFAULT_OLLAMA_URI}",
        )
        self.uri = DEFAULT_OLLAMA_URI
    else:
        self.uri = uri.rstrip("/")

    self.model_name = model or OllamaModel.LLAMA32.value
    self.temperature = (
        temperature or Provider.get_default_temperature(Provider.OLLAMA.value)
    )

    # Store kwargs for later use when creating specialized models
    self.kwargs = kwargs

    # Build parameters dict with only non-None values
    model_params = {
        "base_url": self.uri,
        "model": self.model_name,
        "temperature": temperature,
    }

    # Only add optional parameters if they're not None
    for param in ParameterKeys.get_parameters(Provider.OLLAMA.value):
        if param in kwargs and kwargs[param] is not None:
            model_params[param] = kwargs[param]

    # Use langchain-ollama's dedicated ChatOllama class
    self.model = ChatOllama(**model_params)

    logger.debug(
        f"Ollama provider initialized with model {self.model_name} at {self.uri}",
    )
langchain_model()

Get the underlying LangChain model.

Source code in src/mlflow_assistant/providers/ollama_provider.py
def langchain_model(self):
    """Get the underlying LangChain model."""
    return self.model

openai_provider

OpenAI provider for MLflow Assistant.

OpenAIProvider(api_key=None, model=OpenAIModel.GPT35.value, temperature=None, **kwargs)

Bases: AIProvider

OpenAI provider implementation.

Initialize the OpenAI provider with API key and model.

Source code in src/mlflow_assistant/providers/openai_provider.py
def __init__(
    self,
    api_key=None,
    model=OpenAIModel.GPT35.value,
    temperature: float | None = None,
    **kwargs,
):
    """Initialize the OpenAI provider with API key and model."""
    self.api_key = api_key
    self.model_name = model or OpenAIModel.GPT35.value
    self.temperature = (
        temperature or Provider.get_default_temperature(Provider.OPENAI.value)
    )
    self.kwargs = kwargs

    if not self.api_key:
        logger.warning("No OpenAI API key provided. Responses may fail.")

    # Build parameters dict with only non-None values
    model_params = {
        "api_key": api_key,
        "model": self.model_name,
        "temperature": temperature,
    }

    # Only add optional parameters if they're not None
    for param in ParameterKeys.get_parameters(Provider.OLLAMA.value):
        if param in kwargs and kwargs[param] is not None:
            model_params[param] = kwargs[param]

    # Initialize with parameters matching the documentation
    self.model = ChatOpenAI(**model_params)

    logger.debug(f"OpenAI provider initialized with model {self.model_name}")
langchain_model()

Get the underlying LangChain model.

Source code in src/mlflow_assistant/providers/openai_provider.py
def langchain_model(self):
    """Get the underlying LangChain model."""
    return self.model

utilities

Providers utilities.

get_ollama_models(uri=DEFAULT_OLLAMA_URI)

Fetch the list of available Ollama models.

Source code in src/mlflow_assistant/providers/utilities.py
def get_ollama_models(uri: str = DEFAULT_OLLAMA_URI) -> list:
    """Fetch the list of available Ollama models."""
    # Try using direct API call first
    try:
        response = requests.get(f"{uri}/api/tags", timeout=10)
        if response.status_code == 200:
            data = response.json()
            models = [m.get("name") for m in data.get("models", [])]
            if models:
                return models
    except Exception as e:
        logger.debug(f"Failed to get Ollama models from API: {e}")

    try:
        # Execute the Ollama list command
        ollama_path = shutil.which("ollama")

        result = subprocess.run(  # noqa: S603
            [ollama_path, "list"], capture_output=True, text=True, check=False,
        )

        # Check if command executed successfully
        if result.returncode != 0:
            logger.warning(f"ollama list failed: {result.stderr}")
            return FALLBACK_MODELS

        # Parse the output to extract model names
        lines = result.stdout.strip().split("\n")
        if len(lines) <= 1:  # Only header line or empty
            return FALLBACK_MODELS

        # Skip header line and extract the first column (model name)
        models = [line.split()[0] for line in lines[1:]]
        return models or FALLBACK_MODELS

    except (subprocess.SubprocessError, FileNotFoundError, IndexError) as e:
        logger.warning(f"Error fetching Ollama models: {e!s}")
        return FALLBACK_MODELS
verify_ollama_running(uri=DEFAULT_OLLAMA_URI)

Verify if Ollama is running at the given URI.

Source code in src/mlflow_assistant/providers/utilities.py
def verify_ollama_running(uri: str = DEFAULT_OLLAMA_URI) -> bool:
    """Verify if Ollama is running at the given URI."""
    try:
        response = requests.get(f"{uri}/api/tags", timeout=2)
        return response.status_code == 200
    except Exception:
        return False

utils

Utility modules for MLflow Assistant.

Command

Bases: Enum

Special commands for interactive chat sessions.

description property

Get the description for a command.

OllamaModel

Bases: Enum

Default Ollama models supported by MLflow Assistant.

choices() classmethod

Get all available Ollama model choices.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def choices(cls):
    """Get all available Ollama model choices."""
    return [model.value for model in cls]

OpenAIModel

Bases: Enum

OpenAI models supported by MLflow Assistant.

choices() classmethod

Get all available OpenAI model choices.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def choices(cls):
    """Get all available OpenAI model choices."""
    return [model.value for model in cls]

Provider

Bases: Enum

AI provider types supported by MLflow Assistant.

get_default_model(provider) classmethod

Get the default model for a provider.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def get_default_model(cls, provider):
    """Get the default model for a provider."""
    defaults = {
        cls.OPENAI: OpenAIModel.GPT35.value,
        cls.OLLAMA: OllamaModel.LLAMA32.value,
        cls.DATABRICKS: DatabricksModel.DATABRICKS_META_LLAMA3.value,
    }
    return defaults.get(provider)
get_default_temperature(provider) classmethod

Get the default temperature for a provider.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def get_default_temperature(cls, provider):
    """Get the default temperature for a provider."""
    defaults = {
        cls.OPENAI: 0.7,
        cls.DATABRICKS: 0.7,
        cls.OLLAMA: 0.7,
    }
    return defaults.get(provider)

get_mlflow_uri()

Get the MLflow URI from config or environment.

Returns:

Type Description
str | None

Optional[str]: The MLflow URI or None if not configured

Source code in src/mlflow_assistant/utils/config.py
def get_mlflow_uri() -> str | None:
    """Get the MLflow URI from config or environment.

    Returns:
        Optional[str]: The MLflow URI or None if not configured

    """
    # Environment variable should take precedence
    if mlflow_uri_env := os.environ.get(MLFLOW_URI_ENV):
        return mlflow_uri_env

    # Fall back to config
    config = load_config()
    return config.get(CONFIG_KEY_MLFLOW_URI)

get_provider_config()

Get the AI provider configuration.

Returns:

Type Description
dict[str, Any]

Dict[str, Any]: The provider configuration

Source code in src/mlflow_assistant/utils/config.py
def get_provider_config() -> dict[str, Any]:
    """Get the AI provider configuration.

    Returns:
        Dict[str, Any]: The provider configuration

    """
    config = load_config()
    provider = config.get(CONFIG_KEY_PROVIDER, {})

    provider_type = provider.get(CONFIG_KEY_TYPE)

    if provider_type == Provider.OPENAI.value:
        # Environment variable should take precedence
        api_key = (os.environ.get(OPENAI_API_KEY_ENV) or
                   provider.get(CONFIG_KEY_API_KEY))
        return {
            CONFIG_KEY_TYPE: Provider.OPENAI.value,
            CONFIG_KEY_API_KEY: api_key,
            CONFIG_KEY_MODEL: provider.get(
                CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OPENAI),
            ),
        }

    if provider_type == Provider.OLLAMA.value:
        return {
            CONFIG_KEY_TYPE: Provider.OLLAMA.value,
            CONFIG_KEY_URI: provider.get(CONFIG_KEY_URI, DEFAULT_OLLAMA_URI),
            CONFIG_KEY_MODEL: provider.get(
                CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OLLAMA),
            ),
        }

    if provider_type == Provider.DATABRICKS.value:
        # Set environment variables for Databricks profile
        _set_environment_variables(provider.get(CONFIG_KEY_PROFILE))

        return {
            CONFIG_KEY_TYPE: Provider.DATABRICKS.value,
            CONFIG_KEY_PROFILE: provider.get(CONFIG_KEY_PROFILE),
            CONFIG_KEY_MODEL: provider.get(CONFIG_KEY_MODEL),
        }

    return {CONFIG_KEY_TYPE: None}

load_config()

Load configuration from file.

Source code in src/mlflow_assistant/utils/config.py
def load_config() -> dict[str, Any]:
    """Load configuration from file."""
    if not CONFIG_FILE.exists():
        logger.info(f"No configuration file found at {CONFIG_FILE}")
        return {}

    try:
        with open(CONFIG_FILE) as f:
            config = yaml.safe_load(f) or {}
            logger.debug(f"Loaded configuration: {config}")
            return config
    except Exception as e:
        logger.error(f"Error loading configuration: {e}")
        return {}

save_config(config)

Save configuration to file.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary to save

required

Returns:

Name Type Description
bool bool

True if successful, False otherwise

Source code in src/mlflow_assistant/utils/config.py
def save_config(config: dict[str, Any]) -> bool:
    """Save configuration to file.

    Args:
        config: Configuration dictionary to save

    Returns:
        bool: True if successful, False otherwise

    """
    ensure_config_dir()

    try:
        with open(CONFIG_FILE, "w") as f:
            yaml.dump(config, f)
        logger.info(f"Configuration saved to {CONFIG_FILE}")
        return True
    except Exception as e:
        logger.error(f"Error saving configuration: {e}")
        return False

config

Configuration management utilities for MLflow Assistant.

This module provides functions for loading, saving, and accessing configuration settings for MLflow Assistant, including MLflow URI and AI provider settings. Configuration is stored in YAML format in the user's home directory.

ensure_config_dir()

Ensure the configuration directory exists.

Source code in src/mlflow_assistant/utils/config.py
def ensure_config_dir():
    """Ensure the configuration directory exists."""
    if not CONFIG_DIR.exists():
        CONFIG_DIR.mkdir(parents=True)
        logger.info(f"Created configuration directory at {CONFIG_DIR}")
get_mlflow_uri()

Get the MLflow URI from config or environment.

Returns:

Type Description
str | None

Optional[str]: The MLflow URI or None if not configured

Source code in src/mlflow_assistant/utils/config.py
def get_mlflow_uri() -> str | None:
    """Get the MLflow URI from config or environment.

    Returns:
        Optional[str]: The MLflow URI or None if not configured

    """
    # Environment variable should take precedence
    if mlflow_uri_env := os.environ.get(MLFLOW_URI_ENV):
        return mlflow_uri_env

    # Fall back to config
    config = load_config()
    return config.get(CONFIG_KEY_MLFLOW_URI)
get_provider_config()

Get the AI provider configuration.

Returns:

Type Description
dict[str, Any]

Dict[str, Any]: The provider configuration

Source code in src/mlflow_assistant/utils/config.py
def get_provider_config() -> dict[str, Any]:
    """Get the AI provider configuration.

    Returns:
        Dict[str, Any]: The provider configuration

    """
    config = load_config()
    provider = config.get(CONFIG_KEY_PROVIDER, {})

    provider_type = provider.get(CONFIG_KEY_TYPE)

    if provider_type == Provider.OPENAI.value:
        # Environment variable should take precedence
        api_key = (os.environ.get(OPENAI_API_KEY_ENV) or
                   provider.get(CONFIG_KEY_API_KEY))
        return {
            CONFIG_KEY_TYPE: Provider.OPENAI.value,
            CONFIG_KEY_API_KEY: api_key,
            CONFIG_KEY_MODEL: provider.get(
                CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OPENAI),
            ),
        }

    if provider_type == Provider.OLLAMA.value:
        return {
            CONFIG_KEY_TYPE: Provider.OLLAMA.value,
            CONFIG_KEY_URI: provider.get(CONFIG_KEY_URI, DEFAULT_OLLAMA_URI),
            CONFIG_KEY_MODEL: provider.get(
                CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OLLAMA),
            ),
        }

    if provider_type == Provider.DATABRICKS.value:
        # Set environment variables for Databricks profile
        _set_environment_variables(provider.get(CONFIG_KEY_PROFILE))

        return {
            CONFIG_KEY_TYPE: Provider.DATABRICKS.value,
            CONFIG_KEY_PROFILE: provider.get(CONFIG_KEY_PROFILE),
            CONFIG_KEY_MODEL: provider.get(CONFIG_KEY_MODEL),
        }

    return {CONFIG_KEY_TYPE: None}
load_config()

Load configuration from file.

Source code in src/mlflow_assistant/utils/config.py
def load_config() -> dict[str, Any]:
    """Load configuration from file."""
    if not CONFIG_FILE.exists():
        logger.info(f"No configuration file found at {CONFIG_FILE}")
        return {}

    try:
        with open(CONFIG_FILE) as f:
            config = yaml.safe_load(f) or {}
            logger.debug(f"Loaded configuration: {config}")
            return config
    except Exception as e:
        logger.error(f"Error loading configuration: {e}")
        return {}
save_config(config)

Save configuration to file.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary to save

required

Returns:

Name Type Description
bool bool

True if successful, False otherwise

Source code in src/mlflow_assistant/utils/config.py
def save_config(config: dict[str, Any]) -> bool:
    """Save configuration to file.

    Args:
        config: Configuration dictionary to save

    Returns:
        bool: True if successful, False otherwise

    """
    ensure_config_dir()

    try:
        with open(CONFIG_FILE, "w") as f:
            yaml.dump(config, f)
        logger.info(f"Configuration saved to {CONFIG_FILE}")
        return True
    except Exception as e:
        logger.error(f"Error saving configuration: {e}")
        return False

constants

Constants and enumerations for MLflow Assistant.

This module defines configuration keys, default values, API endpoints, model definitions, and other constants used throughout MLflow Assistant. It includes enums for AI providers (OpenAI, Ollama) and their supported models.

Command

Bases: Enum

Special commands for interactive chat sessions.

description property

Get the description for a command.

DatabricksModel

Bases: Enum

Databricks models supported by MLflow Assistant.

choices() classmethod

Get all available Databricks model choices.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def choices(cls):
    """Get all available Databricks model choices."""
    return [model.value for model in cls]
OllamaModel

Bases: Enum

Default Ollama models supported by MLflow Assistant.

choices() classmethod

Get all available Ollama model choices.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def choices(cls):
    """Get all available Ollama model choices."""
    return [model.value for model in cls]
OpenAIModel

Bases: Enum

OpenAI models supported by MLflow Assistant.

choices() classmethod

Get all available OpenAI model choices.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def choices(cls):
    """Get all available OpenAI model choices."""
    return [model.value for model in cls]
Provider

Bases: Enum

AI provider types supported by MLflow Assistant.

get_default_model(provider) classmethod

Get the default model for a provider.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def get_default_model(cls, provider):
    """Get the default model for a provider."""
    defaults = {
        cls.OPENAI: OpenAIModel.GPT35.value,
        cls.OLLAMA: OllamaModel.LLAMA32.value,
        cls.DATABRICKS: DatabricksModel.DATABRICKS_META_LLAMA3.value,
    }
    return defaults.get(provider)
get_default_temperature(provider) classmethod

Get the default temperature for a provider.

Source code in src/mlflow_assistant/utils/constants.py
@classmethod
def get_default_temperature(cls, provider):
    """Get the default temperature for a provider."""
    defaults = {
        cls.OPENAI: 0.7,
        cls.DATABRICKS: 0.7,
        cls.OLLAMA: 0.7,
    }
    return defaults.get(provider)

definitions

Constants and definitions for the MLflow Assistant.

MLflowConnectionConfig(tracking_uri) dataclass

Configuration for MLflow connection.

connection_type property

Return the connection type (local or remote).

exceptions

Custom exceptions for the MLflow Assistant.

MLflowConnectionError

Bases: Exception

Exception raised when there's an issue connecting to MLflow Tracking Server.