How Tiny Neural Networks Symbolize Fundamental Features | by Amir Taubenfeld | Sep, 2024

A delicate introduction to mechanistic interpretability by means of easy algorithmic examples

This text exhibits how small Synthetic Neural Networks (NN) can characterize fundamental features. The objective is to offer basic instinct about how NNs work and to function a mild introduction to Mechanistic Interpretability — a subject that seeks to reverse engineer NNs.

I current three examples of elementary features, describe every utilizing a easy algorithm, and present how the algorithm will be “coded” into the weights of a neural community. Then, I discover if the community can be taught the algorithm utilizing backpropagation. I encourage readers to consider every instance as a riddle and take a minute earlier than studying the answer.

This text makes an attempt to interrupt NNs into discrete operations and describe them as algorithms. An alternate strategy, maybe extra frequent and pure, is trying on the steady topological interpretations of the linear transformations in numerous layers.

The next are some nice sources for strengthening your topological instinct:

In all the next examples, I take advantage of the terminology “neuron” for a single node within the NN computation graph. Every neuron can be utilized solely as soon as (no cycles; e.g., not RNN), and it performs 3 operations within the following order:

  1. Inside product with the enter vector.
  2. Including a bias time period.
  3. Operating a (non-linear) activation operate.

I present solely minimal code snippets in order that studying will probably be fluent. This Colab pocket book consists of your entire code.

What number of neurons are required to be taught the operate “x < 10”? Write an NN that returns 1 when the enter is smaller than 10 and 0 in any other case.

Answer

Let’s begin by creating pattern dataset that follows the sample we need to be taught

X = [[i] for i in vary(-20, 40)]
Y = [1 if z[0] < 10 else 0 for z in X]
Creating and visualizing the coaching knowledge for “< operator”

This classification job will be solved utilizing logistic regression and a Sigmoid because the output activation. Utilizing a single neuron, we will write the operate as Sigmoid(ax+b). b, the bias time period, will be regarded as the neuron’s threshold. Intuitively, we will set b = 10 and a = -1 and get F=Sigmoid(10-x)

Let’s implement and run F utilizing PyTorch

mannequin = nn.Sequential(nn.Linear(1,1), nn.Sigmoid())
d = mannequin.state_dict()
d["0.weight"] = torch.tensor([[-1]]).float()
d['0.bias'] = torch.tensor([10]).float()
mannequin.load_state_dict(d)
y_pred = mannequin(x).detach().reshape(-1)
Sigmoid(10-x)

Looks as if the proper sample, however can we make a tighter approximation? For instance, F(9.5) = 0.62, we favor it to be nearer to 1.

For the Sigmoid operate, because the enter approaches -∞ / ∞ the output approaches 0 / 1 respectively. Due to this fact, we have to make our 10 — x operate return massive numbers, which will be achieved by multiplying it by a bigger quantity, say 100, to get F=Sigmoid(100(10-x)), now we’ll get F(9.5) =~1.

Sigmoid(100(10-x))

Certainly, when coaching a community with one neuron, it converges to F=Sigmoid(M(10-x)), the place M is a scalar that retains rising throughout coaching to make the approximation tighter.

Tensorboard graph — the X-axis represents the variety of coaching epochs and the Y-axis represents the worth of the bias and the load of the community. The bias and the load enhance/lower in reverse proportion. That’s, the community will be written as M(10-x) the place M is a parameter that retains rising throughout coaching.

To make clear, our single-neuron mannequin is simply an approximation of the “<10” operate. We are going to by no means be capable to attain a lack of zero, as a result of the neuron is a steady operate whereas “<10” just isn’t a steady operate.

Write a neural community that takes two numbers and returns the minimal between them.

Answer

Like earlier than, let’s begin by making a check dataset and visualizing it

X_2D = [
[random.randrange(-50, 50),
random.randrange(-50, 50)]
for i in vary(1000)
]
Y = [min(a, b) for a, b in X_2D]
Visualizing the coaching knowledge for Min(a, b). The 2 horizontal axes characterize the coordinates of the enter. The vertical axis labeled as “Floor Reality” is the anticipated output — i.e., the minimal of the 2 enter coordinates

On this case, ReLU activation is an efficient candidate as a result of it’s basically a most operate (ReLU(x) = max(0, x)). Certainly, utilizing ReLU one can write the min operate as follows

min(a, b) = 0.5 (a + b -|a - b|) = 0.5 (a + b - ReLU(b - a) - ReLU(a - b))

[Equation 1]

Now let’s construct a small community that’s able to studying Equation 1, and attempt to prepare it utilizing gradient descent

class MinModel(nn.Module):
def __init__(self):
tremendous(MinModel, self).__init__()

# For ReLU(a-b)
self.fc1 = nn.Linear(2, 1)
self.relu1 = nn.ReLU()
# For ReLU(b-a)
self.fc2 = nn.Linear(2, 1)
self.relu2 = nn.ReLU()
# Takes 4 inputs
# [a, b, ReLU(a-b), ReLU(b-a)]
self.output_layer = nn.Linear(4, 1)

def ahead(self, x):
relu_output1 = self.relu1(self.fc1(x))
relu_output2 = self.relu2(self.fc2(x))
return self.output_layer(
torch.cat(
(x, Relu_output1, relu_output2),
dim=-1
)
)

Visualization of the MinModel computation graph. Drawing was achieved utilizing the Torchview library

Coaching for 300 epochs is sufficient to converge. Let’s have a look at the mannequin’s parameters

>> for ok, v in mannequin.state_dict().objects():
>> print(ok, ": ", torch.spherical(v, decimals=2).numpy())

fc1.weight : [[-0. -0.]]
fc1.bias : [0.]
fc2.weight : [[ 0.71 -0.71]]
fc2.bias : [-0.]
output_layer.weight : [[ 1. 0. 0. -1.41]]
output_layer.bias : [0.]

Many weights are zeroing out, and we’re left with the properly trying

mannequin([a,b]) = a - 1.41 * 0.71 ReLU(a-b) ≈ a - ReLU(a-b)

This isn’t the answer we anticipated, however it’s a legitimate resolution and even cleaner than Equation 1! By trying on the community we realized a brand new properly trying formulation! Proof:

Proof:

  • If a <= b: mannequin([a,b]) = a — ReLU(a-b) = a — 0 = a
  • If a > b: a — ReLU(a-b) = a — (a-b) = b

Create a neural community that takes an integer x as an enter and returns x mod 2. That’s, 0 if x is even, 1 if x is odd.

This one seems to be fairly easy, however surprisingly it’s unattainable to create a finite-size community that accurately classifies every integer in (-∞, ∞) (utilizing a normal non-periodic activation operate equivalent to ReLU).

Theorem: is_even wants at the very least log neurons

A community with ReLU activations requires at the very least n neurons to accurately classify every of two^n consecutive pure numbers as even or odd (i.e., fixing is_even).

Proof: Utilizing Induction

Base: n == 2: Intuitively, a single neuron (of the shape ReLU(ax + b)), can’t resolve S = [i + 1, i + 2, i + 3, i + 4] as it’s not linearly separable. For instance, with out lack of generality, assume a > 0 and i + 2 is even. If ReLU(a(i + 2) + b) = 0, then additionally ReLU(a(i + 1) + b) = 0 (monotonic operate), however i + 1 is odd.
Extra particulars are included within the basic Perceptrons guide.

Assume for n, and have a look at n+1: Let S = [i + 1, …, i + 2^(n + 1)], and assume, for the sake of contradiction, that S will be solved utilizing a community of dimension n. Take an enter neuron from the primary layer f(x) = ReLU(ax + b), the place x is the enter to the community. WLOG a > 0. Primarily based on the definition of ReLU there exists a j such that:
S’ = [i + 1, …, i + j], S’’ = [i + j + 1, …, i + 2^(n + 1)]
f(x ≤ i) = 0
f(x ≥ i) = ax + b

There are two instances to contemplate:

  • Case |S’| ≥ 2^n: dropping f and all its edges gained’t change the classification outcomes of the community on S’. Therefore, there’s a community of dimension n-1 that solves S’. Contradiction.
  • Case |S’’|≥ 2^n: For every neuron g which takes f as an enter g(x) = ReLU(cf(x) + d + …) = ReLU(c ReLU(ax + b) + d + …), Drop the neuron f and wire x on to g, to get ReLU(cax + cb + d + …). A community of dimension n — 1 solves S’’. Contradiction.

Logarithmic Algorithm

What number of neurons are ample to categorise [1, 2^n]? I’ve confirmed that n neurons are vital. Subsequent, I’ll present that n neurons are additionally ample.

One easy implementation is a community that continually provides/subtracts 2, and checks if sooner or later it reaches 0. This can require O(2^n) neurons. A extra environment friendly algorithm is so as to add/subtract powers of two, which would require solely O(n) neurons. Extra formally:
f_i(x) := |x — i|
f(x) := f_1∘ f_1∘ f_2 ∘ f_4∘ … ∘ f_(2^(n-1)) (|x|)

Proof:

  • By definition:∀ x ϵ[0, 2^i]: f_(2^(i-1)) (x) ≤ 2^(i-1).
    I.e., cuts the interval by half.
  • Recursively f_1∘ f_1∘ f_2 ∘ … ∘ f_(2^(n-1)) (|x|) ≤ 1
  • For each even i: is_even(f_i(x)) = is_even(x)
  • Equally is_even(f_1( f_1(x))) = is_even(x)
  • We bought f(x) ϵ {0,1} and is_even(x) =is_even(f(x)). QED.

Implementation

Let’s attempt to implement this algorithm utilizing a neural community over a small area. We begin once more by defining the information.

X = [[i] for i in vary(0, 16)]
Y = [z[0] % 2 for z in X]
is_even knowledge and labels on a small area [0, 15]

As a result of the area comprises 2⁴ integers, we have to use 6 neurons. 5 for f_1∘ f_1∘ f_2 ∘ f_4∘ f_8, + 1 output neuron. Let’s construct the community and hardwire the weights

def create_sequential_model(layers_list = [1,2,2,2,2,2,1]):
layers = []
for i in vary(1, len(layers_list)):
layers.append(nn.Linear(layers_list[i-1], layers_list[i]))
layers.append(nn.ReLU())
return nn.Sequential(*layers)

# This weight matrix implements |ABS| utilizing ReLU neurons.
# |x-b| = Relu(-(x-b)) + Relu(x-b)
abs_weight_matrix = torch_tensor([[-1, -1],
[1, 1]])
# Returns the pair of biases used for every of the ReLUs.
get_relu_bias = lambda b: torch_tensor([b, -b])

d = mannequin.state_dict()
d['0.weight'], d['0.bias'] = torch_tensor([[-1],[1]]), get_relu_bias(8)
d['2.weight'], d['2.bias'] = abs_weight_matrix, get_relu_bias(4)
d['4.weight'], d['4.bias'] = abs_weight_matrix, get_relu_bias(2)
d['6.weight'], d['6.bias'] = abs_weight_matrix, get_relu_bias(1)
d['8.weight'], d['8.bias'] = abs_weight_matrix, get_relu_bias(1)
d['10.weight'], d['10.bias'] = torch_tensor([[1, 1]]), torch_tensor([0])
mannequin.load_state_dict(d)
mannequin.state_dict()

As anticipated we will see that this mannequin makes an ideal prediction on [0,15]

And, as anticipated, it doesn’t generalizes to new knowledge factors

We noticed that we will hardwire the mannequin, however would the mannequin converge to the identical resolution utilizing gradient descent?

The reply is — not so simply! As an alternative, it’s caught at a neighborhood minimal — predicting the imply.

It is a recognized phenomenon, the place gradient descent can get caught at a neighborhood minimal. It’s particularly prevalent for non-smooth error surfaces of extremely nonlinear features (equivalent to is_even).

Extra particulars are past the scope of this text, however to get extra instinct one can have a look at the numerous works that investigated the basic XOR drawback. Even for such a easy drawback, we will see that gradient descent can wrestle to discover a resolution. Particularly, I like to recommend Richard Bland’s quick guide “Studying XOR: exploring the area of a basic drawback” — a rigorous evaluation of the error floor of the XOR drawback.

Remaining Phrases

I hope this text has helped you perceive the essential construction of small neural networks. Analyzing Massive Language Fashions is far more complicated, however it’s an space of analysis that’s advancing quickly and is stuffed with intriguing challenges.

When working with Massive Language Fashions, it’s straightforward to deal with supplying knowledge and computing energy to attain spectacular outcomes with out understanding how they function. Nevertheless, interpretability affords essential insights that may assist tackle points like equity, inclusivity, and accuracy, which have gotten more and more very important as we rely extra on LLMs in decision-making.

For additional exploration, I like to recommend following the AI Alignment Discussion board.

*All the photographs had been created by the creator. The intro picture was created utilizing ChatGPT and the remaining had been created utilizing Python libraries.