Skip to content

engine

mlflow_assistant.engine

MLflow Assistant Engine - Provides workflow functionality to process user query.

definitions

Constants for the MLflow Assistant engine.

processor

Query processor that leverages the workflow engine for processing user queries and generating responses using an AI provider.

process_query(query, provider_config, verbose=False) async

Process a query through the MLflow Assistant workflow.

Parameters:

Name Type Description Default
query str

The query to process

required
provider_config dict[str, Any]

AI provider configuration

required
verbose bool

Whether to show verbose output

False

Returns:

Type Description
dict[str, Any]

Dict containing the response

Source code in src/mlflow_assistant/engine/processor.py
async def process_query(
    query: str, provider_config: dict[str, Any], verbose: bool = False,
) -> dict[str, Any]:
    """Process a query through the MLflow Assistant workflow.

    Args:
        query: The query to process
        provider_config: AI provider configuration
        verbose: Whether to show verbose output

    Returns:
        Dict containing the response

    """
    import time

    from .workflow import create_workflow

    # Track start time for duration calculation
    start_time = time.time()

    try:
        # Create workflow
        workflow = create_workflow()

        # Run workflow with provider config
        initial_state = {
            STATE_KEY_MESSAGES: [HumanMessage(content=query)],
            STATE_KEY_PROVIDER_CONFIG: provider_config,
        }

        if verbose:
            logger.info(f"Running workflow with query: {query}")
            logger.info(f"Using provider: {provider_config.get(CONFIG_KEY_TYPE)}")
            logger.info(
                f"Using model: {provider_config.get(CONFIG_KEY_MODEL, 'default')}",
            )

        result = await workflow.ainvoke(initial_state)

        # Calculate duration
        duration = time.time() - start_time

        return {
            "original_query": query,
            "response": result.get(STATE_KEY_MESSAGES)[-1],
            "duration": duration,  # Add duration to response
        }

    except Exception as e:
        # Calculate duration even for errors
        duration = time.time() - start_time

        logger.error(f"Error processing query: {e}")

        return {
            "error": str(e),
            "original_query": query,
            "response": f"Error processing query: {e!s}",
        }

tools

LangGraph tools for MLflow interactions.

MLflowTools

Collection of helper utilities for MLflow interactions.

format_timestamp(timestamp_ms) staticmethod

Convert a millisecond timestamp to a human-readable string.

Source code in src/mlflow_assistant/engine/tools.py
@staticmethod
def format_timestamp(timestamp_ms: int) -> str:
    """Convert a millisecond timestamp to a human-readable string."""
    if not timestamp_ms:
        return NA
    dt = datetime.fromtimestamp(timestamp_ms / 1000.0)
    return dt.strftime(TIME_FORMAT)

get_model_details(model_name)

Get detailed information about a specific registered model.

Parameters:

Name Type Description Default
model_name str

The name of the registered model

required

Returns:

Type Description
str

A JSON string containing detailed information about the model.

Source code in src/mlflow_assistant/engine/tools.py
@tool
def get_model_details(model_name: str) -> str:
    """Get detailed information about a specific registered model.

    Args:
        model_name: The name of the registered model

    Returns:
        A JSON string containing detailed information about the model.

    """
    logger.debug(f"Fetching details for model: {model_name}")

    try:
        # Get the registered model
        model = client.get_registered_model(model_name)

        model_info = {
            "name": model.name,
            "creation_timestamp": MLflowTools.format_timestamp(
                model.creation_timestamp,
            ),
            "last_updated_timestamp": MLflowTools.format_timestamp(
                model.last_updated_timestamp,
            ),
            "description": model.description or "",
            "tags": {tag.key: tag.value for tag in model.tags}
            if hasattr(model, "tags")
            else {},
            "versions": [],
        }

        # Get all versions for this model
        versions = client.search_model_versions(f"name='{model_name}'")

        for version in versions:
            version_info = {
                "version": version.version,
                "status": version.status,
                "stage": version.current_stage,
                "creation_timestamp": MLflowTools.format_timestamp(
                    version.creation_timestamp,
                ),
                "source": version.source,
                "run_id": version.run_id,
            }

            # Get additional information about the run if available
            if version.run_id:
                try:
                    run = client.get_run(version.run_id)
                    # Extract only essential run information to avoid serialization issues
                    run_metrics = {}
                    for k, v in run.data.metrics.items():
                        try:
                            run_metrics[k] = float(v)
                        except ValueError:
                            run_metrics[k] = str(v)

                    version_info["run"] = {
                        "status": run.info.status,
                        "start_time": MLflowTools.format_timestamp(
                            run.info.start_time,
                        ),
                        "end_time": MLflowTools.format_timestamp(run.info.end_time)
                        if run.info.end_time
                        else None,
                        "metrics": run_metrics,
                    }
                except Exception as e:
                    logger.warning(
                        f"Error getting run details for {version.run_id}: {e!s}",
                    )
                    version_info["run"] = "Error retrieving run details"

            model_info["versions"].append(version_info)

        return json.dumps(model_info, indent=2)

    except Exception as e:
        error_msg = f"Error getting model details: {e!s}"
        logger.error(error_msg, exc_info=True)
        return json.dumps({"error": error_msg})

get_system_info()

Get information about the MLflow tracking server and system.

Returns:

Type Description
str

A JSON string containing system information.

Source code in src/mlflow_assistant/engine/tools.py
@tool
def get_system_info() -> str:
    """Get information about the MLflow tracking server and system.

    Returns:
        A JSON string containing system information.

    """
    logger.debug("Getting MLflow system information")

    try:
        info = {
            "mlflow_version": mlflow.__version__,
            "tracking_uri": mlflow.get_tracking_uri(),
            "registry_uri": mlflow.get_registry_uri(),
            "artifact_uri": mlflow.get_artifact_uri(),
            "python_version": sys.version,
            "server_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        }

        # Get experiment count
        try:
            experiments = client.search_experiments()
            info["experiment_count"] = len(experiments)
        except Exception as e:
            logger.warning(f"Error getting experiment count: {e!s}")
            info["experiment_count"] = "Error retrieving count"

        # Get model count
        try:
            models = client.search_registered_models()
            info["model_count"] = len(models)
        except Exception as e:
            logger.warning(f"Error getting model count: {e!s}")
            info["model_count"] = "Error retrieving count"

        # Get active run count
        try:
            active_runs = 0
            for exp in experiments:
                runs = client.search_runs(
                    experiment_ids=[exp.experiment_id],
                    filter_string="attributes.status = 'RUNNING'",
                    max_results=1000,
                )
                active_runs += len(runs)

            info["active_runs"] = active_runs
        except Exception as e:
            logger.warning(f"Error getting active run count: {e!s}")
            info["active_runs"] = "Error retrieving count"

        return json.dumps(info, indent=2)

    except Exception as e:
        error_msg = f"Error getting system info: {e!s}"
        logger.error(error_msg, exc_info=True)
        return json.dumps({"error": error_msg})

list_experiments(name_contains='', max_results=MLFLOW_MAX_RESULTS)

List all experiments in the MLflow tracking server, with optional filtering.

Parameters:

Name Type Description Default
name_contains str

Optional filter to only include experiments whose names contain this string

''
max_results int

Maximum number of results to return (default: 100)

MLFLOW_MAX_RESULTS

Returns:

Type Description
str

A JSON string containing all experiments matching the criteria.

Source code in src/mlflow_assistant/engine/tools.py
@tool
def list_experiments(
    name_contains: str = "", max_results: int = MLFLOW_MAX_RESULTS,
) -> str:
    """List all experiments in the MLflow tracking server, with optional filtering.

    Args:
        name_contains: Optional filter to only include experiments whose names contain this string
        max_results: Maximum number of results to return (default: 100)

    Returns:
        A JSON string containing all experiments matching the criteria.

    """
    logger.debug(f"Fetching experiments (filter: '{name_contains}', max: {max_results})")

    try:
        # Get all experiments
        experiments = client.search_experiments()

        # Filter by name if specified
        if name_contains:
            experiments = [
                exp for exp in experiments if name_contains.lower() in exp.name.lower()
            ]

        # Limit to max_results
        experiments = experiments[:max_results]

        # Create a list to hold experiment information
        experiments_info = []

        # Extract relevant information for each experiment
        for exp in experiments:
            exp_info = {
                "experiment_id": exp.experiment_id,
                "name": exp.name,
                "artifact_location": exp.artifact_location,
                "lifecycle_stage": exp.lifecycle_stage,
                "creation_time": MLflowTools.format_timestamp(exp.creation_time)
                if hasattr(exp, "creation_time")
                else None,
                "tags": {tag.key: tag.value for tag in exp.tags}
                if hasattr(exp, "tags")
                else {},
            }

            # Get the run count for this experiment
            try:
                runs = client.search_runs(
                    experiment_ids=[exp.experiment_id], max_results=1,
                )
                if runs:
                    # Just get the count of runs, not the actual runs
                    run_count = client.search_runs(
                        experiment_ids=[exp.experiment_id], max_results=1000,
                    )
                    exp_info["run_count"] = len(run_count)
                else:
                    exp_info["run_count"] = 0
            except Exception as e:
                logger.warning(
                    f"Error getting run count for experiment {exp.experiment_id}: {e!s}",
                )
                exp_info["run_count"] = "Error getting count"

            experiments_info.append(exp_info)

        result = {
            "total_experiments": len(experiments_info),
            "experiments": experiments_info,
        }

        return json.dumps(result, indent=2)

    except Exception as e:
        error_msg = f"Error listing experiments: {e!s}"
        logger.error(error_msg, exc_info=True)
        return json.dumps({"error": error_msg})

list_models(name_contains='', max_results=MLFLOW_MAX_RESULTS)

List all registered models in the MLflow model registry, with optional filtering.

Parameters:

Name Type Description Default
name_contains str

Optional filter to only include models whose names contain this string

''
max_results int

Maximum number of results to return (default: 100)

MLFLOW_MAX_RESULTS

Returns:

Type Description
str

A JSON string containing all registered models matching the criteria.

Source code in src/mlflow_assistant/engine/tools.py
@tool
def list_models(name_contains: str = "", max_results: int = MLFLOW_MAX_RESULTS) -> str:
    """List all registered models in the MLflow model registry, with optional filtering.

    Args:
        name_contains: Optional filter to only include models whose names contain this string
        max_results: Maximum number of results to return (default: 100)

    Returns:
        A JSON string containing all registered models matching the criteria.

    """
    logger.debug(
        f"Fetching registered models (filter: '{name_contains}', max: {max_results})",
    )

    try:
        # Get all registered models
        registered_models = client.search_registered_models(max_results=max_results)

        # Filter by name if specified
        if name_contains:
            registered_models = [
                model
                for model in registered_models
                if name_contains.lower() in model.name.lower()
            ]

        # Create a list to hold model information
        models_info = []

        # Extract relevant information for each model
        for model in registered_models:
            model_info = {
                "name": model.name,
                "creation_timestamp": MLflowTools.format_timestamp(
                    model.creation_timestamp,
                ),
                "last_updated_timestamp": MLflowTools.format_timestamp(
                    model.last_updated_timestamp,
                ),
                "description": model.description or "",
                "tags": {tag.key: tag.value for tag in model.tags}
                if hasattr(model, "tags")
                else {},
                "latest_versions": [],
            }

            # Add the latest versions if available
            if model.latest_versions and len(model.latest_versions) > 0:
                for version in model.latest_versions:
                    version_info = {
                        "version": version.version,
                        "status": version.status,
                        "stage": version.current_stage,
                        "creation_timestamp": MLflowTools.format_timestamp(
                            version.creation_timestamp,
                        ),
                        "run_id": version.run_id,
                    }
                    model_info["latest_versions"].append(version_info)

            models_info.append(model_info)

        result = {"total_models": len(models_info), "models": models_info}

        return json.dumps(result, indent=2)

    except Exception as e:
        error_msg = f"Error listing models: {e!s}"
        logger.error(error_msg, exc_info=True)
        return json.dumps({"error": error_msg})

workflow

Core LangGraph-based workflow engine for processing user queries and generating responses using an AI provider.

This workflow supports tool-augmented generation: tool calls are detected and executed in a loop until a final AI response is produced.

State

Bases: TypedDict

State schema for the workflow engine.

create_workflow()

Create and return a compiled LangGraph workflow.

Source code in src/mlflow_assistant/engine/workflow.py
def create_workflow():
    """Create and return a compiled LangGraph workflow."""
    graph_builder = StateGraph(State)

    def call_model(state: State) -> State:
        """Call the AI model and return updated state with response."""
        messages = state[STATE_KEY_MESSAGES]
        provider_config = state.get(STATE_KEY_PROVIDER_CONFIG, {})
        try:
            provider = AIProvider.create(provider_config)
            model = provider.langchain_model().bind_tools(tools)
            response = model.invoke(messages)
            return {**state, STATE_KEY_MESSAGES: [response]}
        except Exception as e:
            logger.error(f"Error generating response: {e}", exc_info=True)
            return {**state, STATE_KEY_MESSAGES: messages}

    # Add nodes
    graph_builder.add_node("tools", ToolNode(tools))
    graph_builder.add_node("model", call_model)

    # Define graph transitions
    graph_builder.add_edge("tools", "model")
    graph_builder.add_conditional_edges("model", tools_condition)
    graph_builder.set_entry_point("model")

    return graph_builder.compile()