Skip to content

providers

mlflow_assistant.providers

Provider module for MLflow Assistant.

AIProvider

Bases: ABC

Abstract base class for AI providers.

langchain_model abstractmethod property

Get the underlying LangChain model.

__init_subclass__(**kwargs)

Auto-register provider subclasses.

Source code in src/mlflow_assistant/providers/base.py
def __init_subclass__(cls, **kwargs):
    """Auto-register provider subclasses."""
    super().__init_subclass__(**kwargs)
    # Register the provider using the class name
    provider_type = cls.__name__.lower().replace(CONFIG_KEY_PROVIDER, "")
    AIProvider._providers[provider_type] = cls
    logger.debug(f"Registered provider: {provider_type}")

create(config) classmethod

Create an AI provider based on configuration.

Source code in src/mlflow_assistant/providers/base.py
@classmethod
def create(cls, config: dict[str, Any]) -> "AIProvider":
    """Create an AI provider based on configuration."""
    provider_type = config.get(CONFIG_KEY_TYPE)

    if not provider_type:
        error_msg = "Provider type not specified in configuration"
        raise ValueError(error_msg)

    provider_type = provider_type.lower()

    # Extract common parameters
    kwargs = {}
    for param in ParameterKeys.PARAMETERS_ALL:
        if param in config:
            kwargs[param] = config[param]

    # Import providers dynamically to avoid circular imports
    if provider_type == Provider.OPENAI.value:
        from .openai_provider import OpenAIProvider

        logger.debug(
            f"Creating OpenAI provider with model {config.get(CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OPENAI))}",
        )
        return OpenAIProvider(
            api_key=config.get(CONFIG_KEY_API_KEY),
            model=config.get(
                CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OPENAI),
            ),
            temperature=config.get(
                ParameterKeys.TEMPERATURE,
                Provider.get_default_temperature(Provider.OPENAI),
            ),
            **kwargs,
        )
    if provider_type == Provider.OLLAMA.value:
        from .ollama_provider import OllamaProvider

        logger.debug(
            f"Creating Ollama provider with model {config.get(CONFIG_KEY_MODEL)}",
        )
        return OllamaProvider(
            uri=config.get(CONFIG_KEY_URI),
            model=config.get(CONFIG_KEY_MODEL),
            temperature=config.get(
                ParameterKeys.TEMPERATURE,
                Provider.get_default_temperature(Provider.OLLAMA),
            ),
            **kwargs,
        )
    if provider_type == Provider.DATABRICKS.value:
        from .databricks_provider import DatabricksProvider

        logger.debug(
            f"Creating Databricks provider with model {config.get(CONFIG_KEY_MODEL)}",
        )
        return DatabricksProvider(
            model=config.get(CONFIG_KEY_MODEL),
            temperature=config.get(
                ParameterKeys.TEMPERATURE,
                Provider.get_default_temperature(Provider.DATABRICKS),
            ),
            **kwargs,
        )
    if provider_type not in cls._providers:
        error_msg = f"Unknown provider type: {provider_type}. Available types: {', '.join(cls._providers.keys())}"
        raise ValueError(error_msg)
    # Generic initialization for future providers
    provider_class = cls._providers[provider_type]
    return provider_class(config)

get_ollama_models(uri=DEFAULT_OLLAMA_URI)

Fetch the list of available Ollama models.

Source code in src/mlflow_assistant/providers/utilities.py
def get_ollama_models(uri: str = DEFAULT_OLLAMA_URI) -> list:
    """Fetch the list of available Ollama models."""
    # Try using direct API call first
    try:
        response = requests.get(f"{uri}/api/tags", timeout=10)
        if response.status_code == 200:
            data = response.json()
            models = [m.get("name") for m in data.get("models", [])]
            if models:
                return models
    except Exception as e:
        logger.debug(f"Failed to get Ollama models from API: {e}")

    try:
        # Execute the Ollama list command
        ollama_path = shutil.which("ollama")

        result = subprocess.run(  # noqa: S603
            [ollama_path, "list"], capture_output=True, text=True, check=False,
        )

        # Check if command executed successfully
        if result.returncode != 0:
            logger.warning(f"ollama list failed: {result.stderr}")
            return FALLBACK_MODELS

        # Parse the output to extract model names
        lines = result.stdout.strip().split("\n")
        if len(lines) <= 1:  # Only header line or empty
            return FALLBACK_MODELS

        # Skip header line and extract the first column (model name)
        models = [line.split()[0] for line in lines[1:]]
        return models or FALLBACK_MODELS

    except (subprocess.SubprocessError, FileNotFoundError, IndexError) as e:
        logger.warning(f"Error fetching Ollama models: {e!s}")
        return FALLBACK_MODELS

verify_ollama_running(uri=DEFAULT_OLLAMA_URI)

Verify if Ollama is running at the given URI.

Source code in src/mlflow_assistant/providers/utilities.py
def verify_ollama_running(uri: str = DEFAULT_OLLAMA_URI) -> bool:
    """Verify if Ollama is running at the given URI."""
    try:
        response = requests.get(f"{uri}/api/tags", timeout=2)
        return response.status_code == 200
    except Exception:
        return False

base

Base class for AI providers.

AIProvider

Bases: ABC

Abstract base class for AI providers.

langchain_model abstractmethod property

Get the underlying LangChain model.

__init_subclass__(**kwargs)

Auto-register provider subclasses.

Source code in src/mlflow_assistant/providers/base.py
def __init_subclass__(cls, **kwargs):
    """Auto-register provider subclasses."""
    super().__init_subclass__(**kwargs)
    # Register the provider using the class name
    provider_type = cls.__name__.lower().replace(CONFIG_KEY_PROVIDER, "")
    AIProvider._providers[provider_type] = cls
    logger.debug(f"Registered provider: {provider_type}")
create(config) classmethod

Create an AI provider based on configuration.

Source code in src/mlflow_assistant/providers/base.py
@classmethod
def create(cls, config: dict[str, Any]) -> "AIProvider":
    """Create an AI provider based on configuration."""
    provider_type = config.get(CONFIG_KEY_TYPE)

    if not provider_type:
        error_msg = "Provider type not specified in configuration"
        raise ValueError(error_msg)

    provider_type = provider_type.lower()

    # Extract common parameters
    kwargs = {}
    for param in ParameterKeys.PARAMETERS_ALL:
        if param in config:
            kwargs[param] = config[param]

    # Import providers dynamically to avoid circular imports
    if provider_type == Provider.OPENAI.value:
        from .openai_provider import OpenAIProvider

        logger.debug(
            f"Creating OpenAI provider with model {config.get(CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OPENAI))}",
        )
        return OpenAIProvider(
            api_key=config.get(CONFIG_KEY_API_KEY),
            model=config.get(
                CONFIG_KEY_MODEL, Provider.get_default_model(Provider.OPENAI),
            ),
            temperature=config.get(
                ParameterKeys.TEMPERATURE,
                Provider.get_default_temperature(Provider.OPENAI),
            ),
            **kwargs,
        )
    if provider_type == Provider.OLLAMA.value:
        from .ollama_provider import OllamaProvider

        logger.debug(
            f"Creating Ollama provider with model {config.get(CONFIG_KEY_MODEL)}",
        )
        return OllamaProvider(
            uri=config.get(CONFIG_KEY_URI),
            model=config.get(CONFIG_KEY_MODEL),
            temperature=config.get(
                ParameterKeys.TEMPERATURE,
                Provider.get_default_temperature(Provider.OLLAMA),
            ),
            **kwargs,
        )
    if provider_type == Provider.DATABRICKS.value:
        from .databricks_provider import DatabricksProvider

        logger.debug(
            f"Creating Databricks provider with model {config.get(CONFIG_KEY_MODEL)}",
        )
        return DatabricksProvider(
            model=config.get(CONFIG_KEY_MODEL),
            temperature=config.get(
                ParameterKeys.TEMPERATURE,
                Provider.get_default_temperature(Provider.DATABRICKS),
            ),
            **kwargs,
        )
    if provider_type not in cls._providers:
        error_msg = f"Unknown provider type: {provider_type}. Available types: {', '.join(cls._providers.keys())}"
        raise ValueError(error_msg)
    # Generic initialization for future providers
    provider_class = cls._providers[provider_type]
    return provider_class(config)

databricks_provider

Databricks provider for MLflow Assistant.

DatabricksProvider(model=None, temperature=None, **kwargs)

Bases: AIProvider

Databricks provider implementation.

Initialize the Databricks provider with model.

Source code in src/mlflow_assistant/providers/databricks_provider.py
def __init__(
    self,
    model: str | None = None,
    temperature: float | None = None,
    **kwargs,
):
    """Initialize the Databricks provider with model."""
    self.model_name = (
        model or Provider.get_default_model(Provider.DATABRICKS.value)
    )
    self.temperature = (
        temperature or Provider.get_default_temperature(Provider.DATABRICKS.value)
    )
    self.kwargs = kwargs

    for var in DATABRICKS_CREDENTIALS:
        if var not in os.environ:
            logger.warning(
                f"Missing environment variable: {var}. "
                "Responses may fail if you are running outside Databricks.",
            )

    # Build parameters dict with only non-None values
    model_params = {"endpoint": self.model_name, "temperature": temperature}

    # Only add optional parameters if they're not None
    for param in ParameterKeys.get_parameters(Provider.DATABRICKS.value):
        if param in kwargs and kwargs[param] is not None:
            model_params[param] = kwargs[param]

    # Initialize with parameters matching the documentation
    self.model = ChatDatabricks(**model_params)

    logger.debug(f"Databricks provider initialized with model {self.model_name}")
langchain_model()

Get the underlying LangChain model.

Source code in src/mlflow_assistant/providers/databricks_provider.py
def langchain_model(self):
    """Get the underlying LangChain model."""
    return self.model

definitions

Constants for the MLflow Assistant providers.

ParameterKeys

Keys and default parameter groupings for supported providers.

get_parameters(provider) classmethod

Return the list of parameters for the given provider name.

Source code in src/mlflow_assistant/providers/definitions.py
@classmethod
def get_parameters(cls, provider: str) -> list[str]:
    """Return the list of parameters for the given provider name."""
    provider_map = {
        "openai": cls.PARAMETERS_OPENAI,
        "ollama": cls.PARAMETERS_OLLAMA,
        "databricks": cls.PARAMETERS_DATABRICKS,
    }
    return provider_map.get(provider.lower(), [])

ollama_provider

Ollama provider for MLflow Assistant.

OllamaProvider(uri=None, model=None, temperature=None, **kwargs)

Bases: AIProvider

Ollama provider implementation.

Initialize the Ollama provider with URI and model.

Source code in src/mlflow_assistant/providers/ollama_provider.py
def __init__(self, uri=None, model=None, temperature=None, **kwargs):
    """Initialize the Ollama provider with URI and model."""
    # Handle None URI case to prevent attribute errors
    if uri is None:
        logger.warning(
            f"Ollama URI is None. Using default URI: {DEFAULT_OLLAMA_URI}",
        )
        self.uri = DEFAULT_OLLAMA_URI
    else:
        self.uri = uri.rstrip("/")

    self.model_name = model or OllamaModel.LLAMA32.value
    self.temperature = (
        temperature or Provider.get_default_temperature(Provider.OLLAMA.value)
    )

    # Store kwargs for later use when creating specialized models
    self.kwargs = kwargs

    # Build parameters dict with only non-None values
    model_params = {
        "base_url": self.uri,
        "model": self.model_name,
        "temperature": temperature,
    }

    # Only add optional parameters if they're not None
    for param in ParameterKeys.get_parameters(Provider.OLLAMA.value):
        if param in kwargs and kwargs[param] is not None:
            model_params[param] = kwargs[param]

    # Use langchain-ollama's dedicated ChatOllama class
    self.model = ChatOllama(**model_params)

    logger.debug(
        f"Ollama provider initialized with model {self.model_name} at {self.uri}",
    )
langchain_model()

Get the underlying LangChain model.

Source code in src/mlflow_assistant/providers/ollama_provider.py
def langchain_model(self):
    """Get the underlying LangChain model."""
    return self.model

openai_provider

OpenAI provider for MLflow Assistant.

OpenAIProvider(api_key=None, model=OpenAIModel.GPT35.value, temperature=None, **kwargs)

Bases: AIProvider

OpenAI provider implementation.

Initialize the OpenAI provider with API key and model.

Source code in src/mlflow_assistant/providers/openai_provider.py
def __init__(
    self,
    api_key=None,
    model=OpenAIModel.GPT35.value,
    temperature: float | None = None,
    **kwargs,
):
    """Initialize the OpenAI provider with API key and model."""
    self.api_key = api_key
    self.model_name = model or OpenAIModel.GPT35.value
    self.temperature = (
        temperature or Provider.get_default_temperature(Provider.OPENAI.value)
    )
    self.kwargs = kwargs

    if not self.api_key:
        logger.warning("No OpenAI API key provided. Responses may fail.")

    # Build parameters dict with only non-None values
    model_params = {
        "api_key": api_key,
        "model": self.model_name,
        "temperature": temperature,
    }

    # Only add optional parameters if they're not None
    for param in ParameterKeys.get_parameters(Provider.OLLAMA.value):
        if param in kwargs and kwargs[param] is not None:
            model_params[param] = kwargs[param]

    # Initialize with parameters matching the documentation
    self.model = ChatOpenAI(**model_params)

    logger.debug(f"OpenAI provider initialized with model {self.model_name}")
langchain_model()

Get the underlying LangChain model.

Source code in src/mlflow_assistant/providers/openai_provider.py
def langchain_model(self):
    """Get the underlying LangChain model."""
    return self.model

utilities

Providers utilities.

get_ollama_models(uri=DEFAULT_OLLAMA_URI)

Fetch the list of available Ollama models.

Source code in src/mlflow_assistant/providers/utilities.py
def get_ollama_models(uri: str = DEFAULT_OLLAMA_URI) -> list:
    """Fetch the list of available Ollama models."""
    # Try using direct API call first
    try:
        response = requests.get(f"{uri}/api/tags", timeout=10)
        if response.status_code == 200:
            data = response.json()
            models = [m.get("name") for m in data.get("models", [])]
            if models:
                return models
    except Exception as e:
        logger.debug(f"Failed to get Ollama models from API: {e}")

    try:
        # Execute the Ollama list command
        ollama_path = shutil.which("ollama")

        result = subprocess.run(  # noqa: S603
            [ollama_path, "list"], capture_output=True, text=True, check=False,
        )

        # Check if command executed successfully
        if result.returncode != 0:
            logger.warning(f"ollama list failed: {result.stderr}")
            return FALLBACK_MODELS

        # Parse the output to extract model names
        lines = result.stdout.strip().split("\n")
        if len(lines) <= 1:  # Only header line or empty
            return FALLBACK_MODELS

        # Skip header line and extract the first column (model name)
        models = [line.split()[0] for line in lines[1:]]
        return models or FALLBACK_MODELS

    except (subprocess.SubprocessError, FileNotFoundError, IndexError) as e:
        logger.warning(f"Error fetching Ollama models: {e!s}")
        return FALLBACK_MODELS

verify_ollama_running(uri=DEFAULT_OLLAMA_URI)

Verify if Ollama is running at the given URI.

Source code in src/mlflow_assistant/providers/utilities.py
def verify_ollama_running(uri: str = DEFAULT_OLLAMA_URI) -> bool:
    """Verify if Ollama is running at the given URI."""
    try:
        response = requests.get(f"{uri}/api/tags", timeout=2)
        return response.status_code == 200
    except Exception:
        return False