Trustworthy AI in JuliA

%%{
  init: {
    'theme': 'base',
    'themeVariables': {
      'primaryColor': '#BB2528',
      'primaryTextColor': '#fff',
      'primaryBorderColor': '#7C0000',
      'lineColor': '#F8B229',
      'secondaryColor': '#006100',
      'tertiaryColor': '#e9edfb',
      'fontFamily': "avenir"
    }
  }
}%%

flowchart TB

    classDef taija fill:#389836,stroke:#333,color:#fff;
    classDef core fill:#CB3C33,stroke:#333,color:#fff;
    classDef base fill:#9558B2,stroke:#333,color:#fff;

    %% Base
    base["TaijaBase.jl"]

    %% Meta
    interop["TaijaInteroperability.jl"]
    data["TaijaData.jl"]
    parallel["TaijaParallel.jl"]
    plotting["TaijaPlotting.jl"]

    %% Core
    ce["CounterfactualExplanations.jl"]
    cp["ConformalPrediction.jl"]
    lr["LaplaceRedux.jl"]
    jem["JointEnergyModels.jl"]

    %% External
    mlj["MLJ.jl"]
    flux["Flux.jl"]

    class base base;
    class interop,data,parallel,plotting taija;
    class ce,cp,lr,jem core;

    %% Graph
    base --> ce & cp & lr & jem

    subgraph "Core Packages"
        ce & cp & lr & jem 
    end

    subgraph "Meta Packages"
        data & plotting & parallel & interop
    end 

    subgraph "External Packages"
        mlj & flux
    end
Figure 1: Overview of the Taija ecosystem. Early-stage packages ommitted.

TaijaParallel.jl

Stable Dev Build Status Coverage Code Style: Blue Aqua QA

This package adds custom support for parallelization for certain Taija packages.

  • Intuitive user interface.
  • Support for multi-threading and multi-processing.
  • Easy to extend to new parallelization backends and functions.

Why Supercomputing?

Efforts towards trustworthy AI tend to increase the computational burden involved in training or inference:

  • Explainable AI: we are often required to generate many local explanations for many individuals.
  • Conformal Prediction: many techniques involve cross-validation or bootstrapping.
  • (Quasi) Bayesian DL like deep ensembles.

Good news: all of these tasks can be parallelized!

Motivating Example

Generate ALL the Counterfactual Explanations!

  • The package was developed to power research presented at AAAI 2024 (Altmeyer et al. 2024).
  • The project involved large benchmarks of counterfactual explanations that had to be run on a supercomputer.

Benchmarking Explanations

Goal: Generate faithful counterfactual explanations that reflect model quality.

  • Final benchmark: total of ~10 million counterfactuals across 8 datasets and different DL models.
  • Parallelized across 50 to 300 CPUs on DelftBlue using combination of multi-threading and -processing.

Counterfactual explanations for different models. Source: Altmeyer et al. (2024)

Source Code

Code is open-sourced and available on GitHub.

The Package

User Interface

We aim to minimize the burden on users.

  • Users will mostly interact with custom macro @with_parallelizer.
  • In Figure 2, mpi is an instance of type MPIParallelizer.
Figure 2: Generating counterfactuals in parallel using MPI. See docs for details.

High-level Architecture

  • TaijaBase.jl ships basic parallelization functions and symbols to make them available to all Taija packages.
    • CounterfactualExplanations.benchmark, for example, accepts a parallelizer argument of type Union{Nothing,AbstractParallelizer}.
  • TaijaParallel.jl adds out-of-the-box support for parallelization through multi-threading.
  • Multi-processing through MPI.jl handled through an extension.

Backend

  • The @with_parallelizer macro defined here parses inputs and calls the parallelize function.
  • parallelize is dispatched on the type of parallelizer and the function to be parallelized.
  • Easy to add support for new parallelization backends and functions by overloading parallelize.
  • Possible to combine different forms of parallelization, e.g., multi-threading and multi-processing (see here for an example).

Caveats and Future Work

  • The functions to be parallelized must be broadcastable: generate_counterfactual, for example, can be broadcasted over a batch of inputs.
  • Currently, the package only supports CounterfactualExplanations.jl.
  • Work on ConformalPrediction.jl is in progress, and hinges on the ability to parallelize cross-validation. Requires changes to ConformalPrediction.jl.

ALL the Counterfactuals!

Trustworthy AI may be slow but …

Julia go vroom vroom!

Generating 10,000 counterfactuals for MNIST in parallel in under 2s on a MacBook.

Questions?

References

Altmeyer, Patrick, Mojtaba Farmanbar, Arie van Deursen, and Cynthia CS Liem. 2024. “Faithful Model Explanations Through Energy-Constrained Conformal Counterfactuals.” In Proceedings of the AAAI Conference on Artificial Intelligence, 38:10829–37. 10.