The MMD-Critic Technique, Defined. A strong but under-the-radar methodology… | by Matthew Chak | Aug, 2024

A strong but under-the-radar methodology for information summarization and explainable AI

Regardless of being a strong instrument for information summarization, the MMD-Critic methodology has a shocking lack of each utilization and “protection”. Maybe it is because easier and extra established strategies for information summarization exist (e.g. Okay-medoids, see [1] or, extra merely, the Wikipedia web page), or maybe it is because no Python bundle for the tactic existed (prior to now). Regardless, the outcomes introduced within the authentic paper [2] warrant extra use than MMD-Critic has presently. As such, I’ll clarify the MMD-Critic methodology right here with as a lot readability as attainable. I’ve additionally printed an open-source Python bundle with an implementation of the method so you should utilize it simply.

Earlier than leaping into the MMD-Critic methodology itself, it’s value discussing what precisely we’re making an attempt to perform. Finally, we want to take a dataset and discover examples which are consultant of the info (prototypes), in addition to edge-case examples which will confound our machine studying fashions (criticisms).

Prototypes and criticisms for the MNIST dataset, taken from [2].

There are a lot of the explanation why this can be helpful:

  • We are able to get a really good summarized view of our dataset by seeing each stereotypical and atypical examples
  • We are able to take a look at fashions on the criticisms to see how they deal with edge instances (that is, for apparent causes, essential)
  • Although maybe not as helpful, we are able to use prototypes to create a naturally explainable Okay-means-esque algorithm whereby the closest prototype to the brand new information level is used to label it. Then explanations are easy since we simply present the consumer essentially the most comparable information level.
  • Extra

You possibly can see part 6.3 in this e-book for more information on the purposes of this (and for an honest clarification of MMD-Critic as effectively), nevertheless it suffices to say that discovering these examples is helpful for all kinds of causes. MMD-Critic permits us to do that.

I sadly can’t declare to have a hyper-rigorous understanding of Maximal Imply Discrepancy (MMD), as such an understanding would require a robust background in practical evaluation. You probably have such a background, you’ll find the paper that launched the measure right here.

In easy phrases although, MMD is a option to decide the distinction between two chance distributions. Formally, for 2 chance distributions P and Q, we outline the MMD of the 2 as

The components for the MMD of two distributions P, Q

Right here, F is any perform area — that’s, any set of capabilities with the identical area and codomain. Be aware additionally that the notation x~P signifies that we’re treating x as if it’s a random variable drawn from the distribution P — that’s, x is described by P. This components thus finds the best distinction within the anticipated values of X and Y when they’re reworked by some perform from our area F.

This can be a bit of exhausting to wrap your head round, however right here’s an instance. Suppose that X is Uniform(0, 1) (i.e. a distribution that’s equal to selecting a random quantity from 0 to 1), and Y is Uniform(-1, 1) . Let’s additionally let F be a reasonably easy household containing three capabilities — f(x) = 0, f(x) = x, and f(x) = x². Iterating over every perform in our area, we get:

  1. Within the f(x) = 0 case, E[f(x)] when x ~ P is 0 since it doesn’t matter what x we select, f(x) will probably be 0. The identical holds for when x ~ Q. Thus, we get a imply discrepancy of 0
  2. Within the f(x) = x case, we now have E[f(x)] = 0.5 for the P case and 0 for the Q case, so our imply discrepancy is 0.5
  3. Within the f(x) = x² case, we observe that
System for the anticipated worth of a random variable x reworked by a perform f

thus within the P case, we get

Anticipated worth of f(x) underneath the distribution P

and within the Q case, we get

Anticipated worth of f(x) underneath the distribution Q

thus our discrepancy on this case can also be 0. The supremum over our perform area is thus 0.5, in order that’s our MMD.

It’s possible you’ll now discover a couple of issues with our MMD. It appears extremely depending on our selection of perform area and likewise seems extremely costly (and even unimaginable) to compute for a big or infinite perform area. Not solely that, nevertheless it additionally requires us to know our distributions P and Q, which isn’t life like.

The latter drawback is definitely solvable, as we are able to rewrite our MMD metric to make use of estimates of P and Q primarily based on our dataset:

MMD utilizing estimates of P and Q

Right here, our x’s are our samples from the dataset drawing from P, and the y’s are the samples drawn from Q.

The primary two issues are solvable with a bit of additional math. With out going into an excessive amount of element, it seems that if F is one thing referred to as a Reproducing Kernel Hilbert Area (RKHS), we all know what perform goes to present us our MMD prematurely. Specifically, it’s the next perform, referred to as the witness perform:

Our optimum f(x) in an RKHS

the place okay is the kernel (interior product) related to the RKHS¹. Intuitively, this perform “witnesses” the discrepancy between P and Q on the level x.

We thus solely want to decide on a sufficiently expressive RKHS/kernel — often, the RBF kernel is used which has the kernel perform

The RBF kernel, the place sigma is a hyperparameter

This typically will get pretty intuitive outcomes. Right here, as an illustration, is the plot of the witness perform with the RBF kernel when estimated (in the identical manner as talked about earlier than — that’s, changing expectations with a sum) on two datasets drawn from Uniform(-0.5, 0.5) and Uniform(-1, 1) :

Values of the witness perform at totally different factors for 2 uniform distributions

The code for producing the above graph is right here:

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def rbf(v1, v2, sigma=0.5):
return np.exp(-(v2 - v1) ** 2/(2 * sigma**0.5))

def comp_wit_fn(x, d1, d2):
return 1/len(d1) * sum([rbf(x, dp) for dp in d1]) - 1/len(d2) * sum([rbf(x, dp) for dp in d2])

low1, high1 = -0.5, 0.5 # Vary for the primary uniform distribution
low2, high2 = -1, 1 # Vary for the second uniform distribution

# Generate information for the uniform distributions
data1 = np.random.uniform(low1, high1, 10000)
data2 = np.random.uniform(low2, high2, 10000)

# Generate a variety of x values for which to compute comp_wit_fn
x_values = np.linspace(min(low1 * 2, low2 * 2), max(high1 * 2, high2 * 2), 100)

comp_wit_values = [comp_wit_fn(x, data1, data2) for x in x_values]
sns.kdeplot(data1, label=f'Uniform({low1}, {high1})', shade='blue', fill=True)
sns.kdeplot(data2, label=f'Uniform({low2}, {high2})', shade='purple', fill=True)
plt.plot(x_values, comp_wit_values, label='Witness Operate', shade='inexperienced')

plt.xlabel('Worth')
plt.ylabel('Density / Wit Fn')
plt.legend()
plt.present()

The concept behind MMD-Critic is now pretty easy — if we wish to discover okay prototypes, we have to discover the set of prototypes that finest matches the distribution of the unique dataset given by their squared MMD. In different phrases, we want to discover a subset P of cardinality okay of our dataset that minimizes MMD²(F, X, P). With out going into an excessive amount of element about why, the sq. MMD is given by

The sq. MMD metric, with X ~ P, Y ~ Q, and okay the kernel for our RKHS F

After discovering these prototypes, we then choose the factors the place the hypothetical distribution of our prototypes is most totally different from our dataset distribution as criticisms. As we’ve seen earlier than, the distinction between two distributions at a degree could be measured by our witness perform, so we simply discover factors that maximize its absolute worth within the context of X and P. In different phrases, we outline our criticism “rating” as

The “rating” for a criticism c

Or, within the extra usable approximate kind,

The approximated S(c) for a criticism c

Then, to seek out our desired quantity of criticisms, say m of them, we merely want to discover the set C of dimension m that maximizes

To advertise selecting extra diverse criticisms, the paper additionally suggests including a regularizer time period that encourages chosen criticisms to be as far aside as attainable. The prompt regularizer within the paper is the log determinant regularizer, although this isn’t required. I received’t go into a lot element right here because it’s not important, however the paper suggests studying [6]².

We are able to thus implement an extraordinarily naive MMD-Critic with out criticism regularization as follows (do NOT use this):

import math
import itertools

def euc_distance(p1, p2):
return math.sqrt(sum((x - y) ** 2 for x, y in zip(p1, p2)))

def rbf(v1, v2, sigma=0.5):
return math.exp(-euc_distance(v1, v2) ** 2/(2 * sigma**0.5))

def mmd_sq(X, Y, sigma=0.5):
sm_xx = 0
for x in X:
for x2 in X:
sm_xx += rbf(x, x2, sigma)

sm_xy = 0
for x in X:
for y in Y:
sm_xy += rbf(x, y, sigma)

sm_yy = 0
for y in Y:
for y2 in Y:
sm_yy += rbf(y, y2, sigma)

return 1/(len(X) ** 2) * sm_xx
- 2/(len(X) * len(Y)) * sm_xy
+ 1/(len(Y) ** 2) * sm_yy

def select_protos(X, n, sigma=0.5):
min_score, min_sub = math.inf, None
for subset in itertools.mixtures(X, n):
new_mmd = mmd_sq(X, subset, sigma)
if new_mmd < min_score:
min_score = new_mmd
min_sub = subset
return min_sub

def criticism_score(criticism, prototypes, X, sigma=0.5):
return abs(1/len(X) * sum([rbf(criticism, x, sigma) for x in X])
- 1/len(prototypes) * sum([rbf(criticism, p, sigma) for p in prototypes]))

def select_criticisms(X, P, n, sigma=0.5):
candidates = [c for c in X if c not in P]
max_score, crits = -math.inf, []
for subset in itertools.mixtures(candidates, n):
new_score = sum([criticism_score(c, P, X, sigma) for c in subset])
if new_score > max_score:
max_score = new_score
crits = subset

return crits

The above implementation is so impractical that, after I ran it, I failed to seek out 5 prototypes in a dataset with 25 factors in an affordable time. It’s because our MMD calculation is O(max(|X|, |Y|)²), and iterating over each length-n subset is O(C(|X|, n)) (the place C is the select perform), which provides us a horrendous runtime complexity.

Disregarding utilizing extra environment friendly computation strategies (e.g. utilizing pure numpy/numexpr/matrix calculations as a substitute of loops/no matter) and caching repeated calculations, there are a couple of optimizations we are able to make on the theoretical degree. Firstly, the obvious slowdown we now have is looping over the C(|X|, n) subsets in our prototype and criticism strategies. As an alternative of that, we are able to use an approximation that loops n occasions, greedily choosing the right prototype every time. This enables us to alter our prototype choice code to

def select_protos(X, n, sigma=0.5):
protos = []
for _ in vary(n):
min_score, min_proto = math.inf, None
for cand in X:
if cand in protos:
proceed
new_score = mmd_sq(X, protos + [cand], sigma)
if new_score < min_score:
min_score = new_score
min_proto = cand
protos.append(min_proto)
return protos

and comparable for the criticisms.

There’s one different necessary lemma that makes this drawback way more optimizable. It seems that by altering our prototype choice right into a minimization drawback and including a regularization time period to the fee, we are able to compute the fee perform very effectively with matrix operations. I received’t go into a lot element right here, however you’ll be able to take a look at the unique paper for particulars.

Now that we perceive the MMD-Critic methodology, we are able to lastly play with it! You possibly can set up it by operating

pip set up mmd-critic

The implementation within the bundle itself is way quicker than the one introduced right here, so don’t fear.

We are able to run a reasonably easy instance utilizing blobs as such:

from sklearn.datasets import make_blobs
from mmd_critic import MMDCritic
from mmd_critic.kernels import RBFKernel

n_samples = 50 # Whole variety of samples
facilities = 4 # Variety of clusters
cluster_std = 1 # Normal deviation of the clusters

X, _ = make_blobs(n_samples=n_samples, facilities=facilities, cluster_std=cluster_std, n_features=2, random_state=42)
X = X.tolist()

# MMD critic with the kernel used for the prototypes being an RBF with sigma=1,
# for the criticisms one with sigma=0.025
critic = MMDCritic(X, RBFKernel(1), RBFKernel(0.025))
protos, _ = critic.select_prototypes(facilities)
criticisms, _ = critic.select_criticisms(10, protos)

Then plotting the factors and criticisms will get us

Plotting the discovered prototypes (inexperienced) and criticisms (purple)

You’ll discover that I offered the choice to make use of a separate kernel for prototype and criticism choice. It’s because I’ve discovered that outcomes for criticisms particularly could be extraordinarily delicate to the sigma hyperparameter. That is an unlucky limitation of the MMD Critic methodology and kernel strategies generally. General, I’ve discovered good outcomes utilizing a big sigma for prototypes and a smaller one for criticisms.

We are able to additionally, after all, use a extra sophisticated dataset. Right here, as an illustration, is the tactic used on MNIST³:

from sklearn.datasets import fetch_openml
import numpy as np
from mmd_critic import MMDCritic
from mmd_critic.kernels import RBFKernel

# Load MNIST information
mnist = fetch_openml('mnist_784', model=1)
photos = (mnist['data'].astype(np.float32)).to_numpy() / 255.0
labels = mnist['target'].astype(np.int64)

critic = MMDCritic(photos[:15000], RBFKernel(2.5), RBFKernel(0.025))
protos, _ = critic.select_prototypes(40)
criticisms, _ = critic.select_criticisms(40, protos)

which will get us the next prototypes

Prototypes discovered by MMD critic for MNIST. MNIST is free for business use underneath the GPL-3.0 License.

and criticisms

Criticisms discovered by the MMD Critic methodology

Fairly neat, huh?

And that’s about it for the MMD-Critic methodology. It’s fairly easy on the core, and it’s good to make use of save for having to fiddle with the Sigma hyperparameter. I hope that the newly launched Python bundle offers it extra use.

Please contact [email protected] for any inquiries. All photos by writer until said in any other case.