Skip to content

databricks_provider

mlflow_assistant.providers.databricks_provider

Databricks provider for MLflow Assistant.

DatabricksProvider(model=None, temperature=None, **kwargs)

Bases: AIProvider

Databricks provider implementation.

Initialize the Databricks provider with model.

Source code in src/mlflow_assistant/providers/databricks_provider.py
def __init__(
    self,
    model: str | None = None,
    temperature: float | None = None,
    **kwargs,
):
    """Initialize the Databricks provider with model."""
    self.model_name = (
        model or Provider.get_default_model(Provider.DATABRICKS.value)
    )
    self.temperature = (
        temperature or Provider.get_default_temperature(Provider.DATABRICKS.value)
    )
    self.kwargs = kwargs

    for var in DATABRICKS_CREDENTIALS:
        if var not in os.environ:
            logger.warning(
                f"Missing environment variable: {var}. "
                "Responses may fail if you are running outside Databricks.",
            )

    # Build parameters dict with only non-None values
    model_params = {"endpoint": self.model_name, "temperature": temperature}

    # Only add optional parameters if they're not None
    for param in ParameterKeys.get_parameters(Provider.DATABRICKS.value):
        if param in kwargs and kwargs[param] is not None:
            model_params[param] = kwargs[param]

    # Initialize with parameters matching the documentation
    self.model = ChatDatabricks(**model_params)

    logger.debug(f"Databricks provider initialized with model {self.model_name}")

langchain_model()

Get the underlying LangChain model.

Source code in src/mlflow_assistant/providers/databricks_provider.py
def langchain_model(self):
    """Get the underlying LangChain model."""
    return self.model