JuliaCon 2022
CounterfactualExplanations.jl
From human to data-driven decision-making …
… where black boxes are recipe for disaster.
“You cannot appeal to (algorithms). They do not listen. Nor do they bend.”
— Cathy O’Neil in Weapons of Math Destruction, 2016
Ground Truthing
Probabilistic Models
Counterfactual Reasoning
Ground Truthing
Probabilistic Models
Counterfactual Reasoning
We typically want to maximize the likelihood of observing \(\mathcal{D}_n\) under given parameters (Murphy 2022):
\[ \theta^* = \arg \max_{\theta} p(\mathcal{D}_n|\theta) \qquad(1)\]
Compute an MLE (or MAP) point estimate \(\hat\theta = \mathbb{E} \theta^*\) and use plugin approximation for prediction:
\[ p(y|x,\mathcal{D}_n) \approx p(y|x,\hat\theta) \qquad(2)\]
Ground Truthing
Probabilistic Models
Counterfactual Reasoning
. . .
[…] deep neural networks are typically very underspecified by the available data, and […] parameters [therefore] correspond to a diverse variety of compelling explanations for the data. (Wilson 2020)
In this setting it is often crucial to treat models probabilistically!
\[ p(y|x,\mathcal{D}_n) = \int p(y|x,\theta)p(\theta|\mathcal{D}_n)d\theta \qquad(3)\]
. . .
Probabilistic models covered briefly today. More in my other talk on Laplace Redux …
Ground Truthing
Probabilistic Models
Counterfactual Reasoning
We can now make predictions – great! But do we know how the predictions are actually being made?
With the model trained for its task, we are interested in understanding how its predictions change in response to input changes.
\[ \nabla_x p(y|x,\mathcal{D}_n;\hat\theta) \qquad(4)\]
. . .
Important to realize that we are keeping \(\hat\theta\) constant!
Even though […] interpretability is of great importance and should be pursued, explanations can, in principle, be offered without opening the “black box”. (Wachter, Mittelstadt, and Russell 2017)
. . .
Objective originally proposed by Wachter, Mittelstadt, and Russell (2017) is as follows
\[ \min_{x\prime \in \mathcal{X}} h(x\prime) \ \ \ \mbox{s. t.} \ \ \ M(x\prime) = t \qquad(5)\]
where \(h\) relates to the complexity of the counterfactual and \(M\) denotes the classifier.
. . .
Typically this is approximated through regularization:
\[ x\prime = \arg \min_{x\prime} \ell(M(x\prime),t) + \lambda h(x\prime) \qquad(6)\]
. . .
Yes and no!
While both are methodologically very similar, adversarial examples are meant to go undetected while CEs ought to be meaningful.
Effective counterfactuals should meet certain criteria ✅
NO!
Causal inference: counterfactuals are thought of as unobserved states of the world that we would like to observe in order to establish causality.
Counterfactual Explanations: involves perturbing features after some model has been trained.
The number of ostensibly pro data scientists confusing themselves into believing that "counterfactual explanations" capture real-world causality is just staggering🤦♀️. Where do we go from here? How can a community that doesn't even understand what's already known make advances?
— Zachary Lipton (@zacharylipton) June 20, 2022
When people say that counterfactuals should look realistic or plausible, they really mean that counterfactuals should be generated by the same Data Generating Process (DGP) as the factuals:
\[ x\prime \sim p(x) \]
But how do we estimate \(p(x)\)? Two probabilistic approaches …
Schut et al. (2021) note that by maximizing predictive probabilities \(\sigma(M(x\prime))\) for probabilistic models \(M\in\mathcal{\widetilde{M}}\) one implicitly minimizes epistemic and aleotoric uncertainty.
\[ x\prime = \arg \min_{x\prime} \ell(M(x\prime),t) \ \ \ , \ \ \ M\in\mathcal{\widetilde{M}} \qquad(7)\]
Instead of perturbing samples directly, some have proposed to instead traverse a lower-dimensional latent embedding learned through a generative model (Joshi et al. 2019).
\[ z\prime = \arg \min_{z\prime} \ell(M(dec(z\prime)),t) + \lambda h(x\prime) \qquad(8)\]
and
\[x\prime = dec(z\prime)\]
where \(dec(\cdot)\) is the decoder function.
Work currently scattered across different GitHub repositories …
CounterfactualExplanations.jl
📦… until now!
Julia has an edge with respect to Trustworthy AI: it’s open-source, uniquely transparent and interoperable 🔴🟢🟣
Modular, composable, scalable!
Figure 6: Overview of package architecture. Modules are shown in red, structs in green and functions in blue.
Figure 7: Type tree for AbstractGenerator
.
Figure 8: Type tree for AbstractFittedModel
.
. . .
We begin by instantiating the fitted model …
. . .
… then based on its prediction for \(x\) we choose the opposite label as our target …
. . .
… and finally generate the counterfactual.
. . .
… et voilà!
GenericGenerator
. The contour (left) shows the predicted probabilities of the classifier (Logistic Regression).
. . .
This time we use a Bayesian classifier …
. . .
… and once again choose our target label as before …
. . .
… to then finally use greedy search to find a counterfactual.
. . .
In this case the Bayesian approach yields a similar outcome.
GreedyGenerator
. The contour (left) shows the predicted probabilities of the classifier (Bayesian Logistic Regression).
Using the same classifier as before we can either use the specific REVISEGenerator
…
. . .
… or realize that that REVISE (Joshi et al. 2019) just boils down to generic search in a latent space:
. . .
We have essentially combined latent search with a probabilistic classifier (as in Antorán et al. (2020)).
REVISEGenerator
.
. . .
Loading pre-trained classifiers and VAE …
. . .
… instantiating model and attaching VAE.
. . .
The results in Figure 13 look great!
. . .
But things can also go wrong …
The VAE used to generate the counterfactual in Figure 14 is not expressive enough.
. . .
The counterfactual in Figure 15 is also valid … what to do?
Step 1: add composite type as subtype of AbstractFittedModel
.
Step 2: dispatch logits
and probs
methods for new model type.
using Statistics
import CounterfactualExplanations.Models: logits, probs
logits(M::FittedEnsemble, X::AbstractArray) = mean(Flux.stack([nn(X) for nn in M.ensemble],3), dims=3)
probs(M::FittedEnsemble, X::AbstractArray) = mean(Flux.stack([softmax(nn(X)) for nn in M.ensemble],3),dims=3)
M = FittedEnsemble(ensemble)
Results for a simple deep ensemble also look convincing!
Adding support for torch
models was easy! Here’s how I implemented it for torch
classifiers trained in R.
. . .
Step 1: add composite type as subtype of AbstractFittedModel
Implemented here.
Step 2: dispatch logits
and probs
methods for new model type.
Implemented here.
. . .
Step 3: add gradient access.
Implemented here.
. . .
GenericGenerator
and RTorchModel
.
Idea 💡: let’s implement a generic generator with dropout!
. . .
Step 1: create a subtype of AbstractGradientBasedGenerator
(adhering to some basic rules).
# Constructor:
abstract type AbstractDropoutGenerator <: AbstractGradientBasedGenerator end
struct DropoutGenerator <: AbstractDropoutGenerator
loss::Symbol # loss function
complexity::Function # complexity function
mutability::Union{Nothing,Vector{Symbol}} # mutibility constraints
λ::AbstractFloat # strength of penalty
ϵ::AbstractFloat # step size
τ::AbstractFloat # tolerance for convergence
p_dropout::AbstractFloat # dropout rate
end
. . .
Step 2: implement logic for generating perturbations.
import CounterfactualExplanations.Generators: generate_perturbations, ∇
using StatsBase
function generate_perturbations(generator::AbstractDropoutGenerator, counterfactual_state::State)
𝐠ₜ = ∇(generator, counterfactual_state.M, counterfactual_state) # gradient
# Dropout:
set_to_zero = sample(1:length(𝐠ₜ),Int(round(generator.p_dropout*length(𝐠ₜ))),replace=false)
𝐠ₜ[set_to_zero] .= 0
Δx′ = - (generator.ϵ .* 𝐠ₜ) # gradient step
return Δx′
end
. . .
DropoutGenerator
and RTorchModel
.
Develop package, register and submit to JuliaCon 2022.
Native support for deep learning models (Flux
, torch
).
Add latent space search.
. . .
. . .
MLJ
, GLM
, …. . .
. . .
. . .
Flux
optimizers.. . .
Explaining Black-Box Models through Counterfactuals – JuliaCon 2022 – Patrick Altmeyer