Neural Network Quickstart#

TLDR:

  • All statistical models essentially have three parts:

    • A model that estimates things.

    • A way to measure how “bad” their estimates are.

    • A way to update the model parameters to make better estimates (hopefully).

  • These days, most of this is done in software.

    • Enables both scaling and iteration speeds never before possible.

    • Enables techniques for “guessing” new parameters through tools like autograd.

  • Neural networks are a type of model; at their core, they are a lot of linear regression added together.

    • They’re put together in a way that can flexibly model many sorts of data.

(Over)Simplifying Neural Networks#

Neural network models are at the core of modern GenAI systems. To really understand what’s going on, you have to grasp the fundamentals of the model, their training routines, and how they’re not all that complex at their core. This takes the hype out of the crazy claims, but that’s what we’re looking for here, a grounded first principles understanding of what’s happening in these systems.

Starting with Linear Regression#

In this notebook, we’re going to implement a basic regression in Flax. By using a simple model, though, you can learn the other parts of the code and model, such as the model representation in Flax, the things outside of the model such as the optimizer and loss function, and be able to iterate quickly without burning too much compute.

I encourage you to run this notebook to get hands-on practice and reinforce your learning.

So before we dive in, here are the three takeaways from this notebook:

  1. The fundamental components of linear regression and neural networks are the same.

  2. Modern LLMs scale due to great code and computational architecture.

  3. It’s the ability to scale to massive models and datasets that is driving the modern LLM revolution, even if the concepts are simple.

Words, Then Code#

I’ll talk through the fundamental concepts in words first. Then we’ll go through the same concepts end-to-end in code, with the actual libraries that are used in production-grade GenAI systems. Quite literally, if you took the code in this notebook and scaled it up, you would have the makings of a ChatGPT/Bard-style model. Now there are some additional complexities, but generally speaking, that’s what’s happening.

Mathematical Models and “Self-Learning” Models#

The purpose of all mathematical models is to output a number. That might be the temperature tomorrow, the number of people visiting a store on the weekend, or a number that represents the next word in this sentence.

Now, you could say, “What’s the big deal, I can produce numbers by just randomly rolling dice.” and you’d be right. A random number generator is a model, but frankly, it’s not one that most folks are impressed by these days.

Linear regression is another model that’s often taught in grade school. Most people aren’t impressed if you draw a line through dots, though.

In the last decade, the class of models typically referred to as “Artificial Intelligence” or “Machine Learning” models are the ones that capture people’s imagination. These are models that can “automagically” predict what you want to watch next on Netflix, play Atari games by themselves, or beat anyone at Go.

The biggest models in the world right now are these chatbots. These models are insanely big and keep getting bigger and more capable.

They may feel like magic, but at their core, they are not much more than basic multiplication and addition. The magic comes from their architecture, that is how their parameters and mathematical operations are structured, how their predictions are scored, and how those parameters are updated.

Model Architecture and Parameter Estimates#

Here’s an example of a mathematical model.

\[ y = mx + b \]

This model takes \(x\) as input and outputs \(y\). \(m\) and \(b\) could literally be any number; in fact, write down two numbers right now. There are no wrong answers here. What people care about, though, is figuring out which \(m\) and which \(b\) produce a nice reasonable \(y\) when given a reasonable \(x\).

Another thing we care about is having a computer learn this relationship itself, rather than us telling it. To make this a self-learning model:

  • Get many examples of \(x_{observed}\) and \(y_{observed}\).

  • Ask the model to estimate \(y_{estimate}\) after we input a value of \(x_{observed}\).

  • Note the difference between \(y_{estimate}\) and \(y_{observed}\).

  • Update the parameters \(m\) and \(b\) to lessen the difference between \(y_{estimate}\) and \(y_{observed}\).

../_images/SelfLearningModels.png

Fig. 1 How all AI models work#

This is literally how all fancy AI/ML/DS models, including chatbot models, are trained. They’re very VERY efficient guess-and-check routines.

Now, with a model as simple as \( y = mx + b \), it can’t by itself predict much, but if you do this literally millions of times, specifically structure it into types of models now referred to as Neural Nets, you can do crazy things like classify dog breeds with superhuman accuracy, or make cars self-driving.

../_images/LinearReg.png

Fig. 2 The core of all AI models#

If you structure them into this thing called a Transformer, a special type of Neural Networks, you get the T in ChatGPT.

It’s actually kind of mindboggling how such basic ideas can be scaled (to ridiculous levels) and the end result is Artificial Intelligence.

Reimplementing Everything Above in Code#

So now that you’ve seen it in words, below we’ll do everything again in code. The libraries and code below are the same ones that are used to train massive image generators and chat models that exploded in popularity.

In the code, we do the following:

  1. Generate some toy data to work with

  2. Create a linear regression model in Flax (a production Neural Network library that’s been used to train massive chat models)

  3. Implement the parameter learning loop

  4. Train the model

Then at the bottom, we show how Neural Networks are able to flexibly fit data by adding more parameters.

Simple Linear Regression in Code#

import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats

Our data generating function#

In this notebook we’re going to generate data to fit against.

We do this because in real life we’ll never know the “true” answers, there might not even be one. By generating data ourselves we “know” the answers and we can verify that our model got it right.

We’re going use a simple linear regression. You’ve probably seen this in grade schol before.

\(m\) and \(b\) are the parameters we’re going to need to estimate. \(x\) and \(y\) are what we observed and we’re going to use these to figure out what \(m\) and \(b\) should be.

\[ y = mx + b \]

In this sample we’re going to expand this a bit and add one more coefficient.

\[ y = m_0 * x_0 + m_1 * x_1 + b \]

Set Coefficients and Constants#

These are the numbers that are the “ground truth” for our simulation. In this simple case we can verify that our learning process worked correctly by seeing it the model infers these parameters. Again In real life we never really know these parameters, but we can use parameter recvoery like this to debug our models.

m_0, m_1 = 1.1, 2.1
intercept = bias = 3.1
# The shape needs to be (1,2). You'll see why below
coefficients = np.array([[m_0, m_1]])

Generate some observed y data#

So now with \(m_0\), \(m_1\) and \(b\) set we can calculate what \(y\) will be if we pick any value of \(x_0\) and \(x_1\). Let’s pick some easy values first

x_0, x_1 = 0, 1
coefficients[0][0] * x_0 + coefficients[0][1] * x_1 + bias
5.2
x_0, x_1 = 1, 0
coefficients[0][0] * x_0 + coefficients[0][1] * x_1 + bias
4.2

Let’s pick some other values now to see what we get. Note how the values of the coefficients don’t change, it’s just x.

x_0, x_1 = 1.3, -2.5
coefficients[0][0] * x_0 + coefficients[0][1] * x_1 + bias
-0.7199999999999998

Plot our hyperplane#

Because we have two values of X we can plot a hyperplane of all valuees of Y at all values of X

from mpl_toolkits.mplot3d import Axes3D
# Define grid range
x_0_range = np.linspace(-5, 5, 20)
x_1_range = np.linspace(-5, 5, 20)
X_0, x_1 = np.meshgrid(x_0_range, x_1_range)

# Calculate corresponding z coordinates

Y = coefficients[0][0] * X_0 + coefficients[0][1] * x_1 + bias

# Create 3D plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Plot plane
ax.plot_surface(X_0, x_1, Y)

# Set axis labels
ax.set_xlabel('X_0')
ax.set_ylabel('X_1')
ax.set_zlabel('Y')

ax.mouse_init()
../_images/3054bdbae45bd9a9675988284b52ac0dc9ee7b72642ff78cbbef2594af971325.png

Generate Random X_0, X_1 points#

In real life we won’t always see numbers on a grid. To simulate this we generate x values at random places. We’ll use these x values to estimate our y values.

rng = np.random.default_rng(12345)
x_obs = rng.uniform(-10, 10, size=(200,2))
x_obs[:10]
array([[-5.45327955, -3.66483321],
       [ 5.94730915,  3.52509342],
       [-2.17780899, -3.34372144],
       [ 1.96617507, -6.26531629],
       [ 3.45512088,  8.83605731],
       [-5.03508571,  8.97762304],
       [ 3.34474906, -8.08204129],
       [-1.16320668,  7.72959839],
       [ 3.94907   , -3.47054272],
       [ 4.67856327, -5.59730089]])

Side Track: Fancy matrix multiplication#

So if we want to this with our first pair of x values and coefficients we could do it manually like this. But this is both ugly, error prone, and inefficient.

y_obs = coefficients[0][0] * x_obs[0][0] + coefficients[0][1] * x_obs[0][1] + bias
y_obs
-10.594757237912647

Matrix Multiplication#

So here’s a practical tip. If you’re going to work with Neural Nets you’re going to check shapes of things quite often. Let’s that here.

coefficients.shape, x_obs.shape
((1, 2), (200, 2))

With these two shapes we can do a matrix multiplication But if you remember from Linear Algebra class the dimensions have to match like this.

(1,2) @ (2, 200) = (1, 200)

Let’s go ahead and do that in code

# Typically done this way in NN literature so we get a column vector 
y_obs = (x_obs @ coefficients.T) + bias
y_obs[:5]
array([[-10.59475724],
       [ 17.04473623],
       [ -6.31740492],
       [ -7.89437163],
       [ 25.45635331]])

Einsum is really nice#

We can also do the calculation above using einsum. This is an underrated features in most tensor libraries, so want to take this moment to show it to you here.

y_obs = np.einsum('ij,kj->ki', coefficients, x_obs) + bias
y_obs[:5]
array([[-10.59475724],
       [ 17.04473623],
       [ -6.31740492],
       [ -7.89437163],
       [ 25.45635331]])

Let’s make add some noise to Y#

In real life we’ll never perfectly get our exact measurements. We should add to add some noise to our observations

noise_sd = .5
y_obs_noisy = y_obs + stats.norm(0, noise_sd).rvs(y_obs.shape)

What we’ve done up to this point#

So far we haveb’t done any modeling yet. We’ve just generated some simulated data to work with.

  • We decided our data generated function is a two coefficient regression with a bias

    • \( y = m_0 * x_0 + m_1 * x_1 + b \)

  • Picked some arbitrary parameters

    • m_0, m_1 = 1.1, 2.1

    • intercept = bias = 3.1

  • We picked some random x_0 and x_1 observations and calculated our y1_observations

    • In the real world we’d have observed this

    • weight = m1*height + m2+width

  • We generated some random x_0 and x_1 points

    • Used that to calculate Y

  • Added some random noise to Y

    • The real world is messy

Actually Building a Model in Jax#

Let’s now make a model The goal here is to figure out the coefficients m_0, m_1, bias using our observed x_0, x_1, and y data.

Let’s do that using Jax and gradient descent. Gradient descent is important so let’s explain it here.

Estimating our parameters using gradient descent#

The way this works is basically

  1. Start with some parameters that are probasbly nad

  2. Make a prediction

  3. Measure how bad that prediction is.

  4. Figure out which way to adjust the parameters so the prediction is less bad

  5. Change the parameters in direction

  6. Go to Step 2 again

Repeas this until your your ability to make your guesses less bad stops.

Implementation in Jax#

The code that does all the above is in this cell. This block contains

  1. A model

  2. A loss function

  3. A gradient computation function

import jax
import jax.numpy as jnp

# Generate sample data
X = jnp.array(x_obs)
Y = jnp.array(y_obs)

# Model definition
def model(params, X):
    weights = params["weights"]
    bias = params["bias"]
    return jnp.dot(X, weights) + bias

# Loss function
def mean_squared_error(params, X, Y):
    predictions = model(params, X)
    return jnp.mean((predictions - Y) ** 2)

# Initialize parameters
params = {'weights': jnp.zeros([2,1]), 'bias': 0.}

# Gradient descent optimizer
@jax.jit
def update(params, X, Y, learning_rate=0.001):
    loss_val, grads = jax.value_and_grad(mean_squared_error)(params, X, Y)
    print(X.shape)
    params = jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)
    return loss_val, params

# Train the model
loss_vals = []
for _ in range(100):
    loss_val, params = update(params, X, Y)
    loss_vals.append(loss_val)

# Print final parameters
print("Learned weights:", params['weights'])
print("Learned bias:", params['bias'])
(200, 2)
Learned weights: [[1.0531881]
 [2.0895174]]
Learned bias: 0.5365327
fig, ax = plt.subplots()
ax.plot(range(100), loss_vals)
ax.set_ylabel("Loss Value")
ax.set_xlabel("Epoch or Step");
../_images/ccc51b4acf9eee0176ef7ddd93cca20e2085f5c7233a0e9bd4e530d270e3b556.png

And that’s it. Let’s do it again in Flax.

Neural Nets (in Flax)#

With simple things like linear regression writing code by hand is fine. But when writing larger neural networks doing this by hand becomes tedious and error prone. So in this tutorial we’ll use Flax which is a neural network library. It’s overkill but it will prepare us to read be able to read larger models with ease.

import flax.linen as nn
class LinearRegression(nn.Module):
    # Define the neural network architecture
    def setup(self):
        """Single output, I dont need to specify shape"""
        self.dense = nn.Dense(features=1)

    # Define the forward pass
    def __call__(self, x):
        y_pred = self.dense(x)
        return y_pred
model = LinearRegression()
key = jax.random.PRNGKey(0) 
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Initialize Random Parameters#

params = model.init(key, x_obs)
params
{'params': {'dense': {'kernel': Array([[ 0.23232014],
          [-1.5953745 ]], dtype=float32),
   'bias': Array([0.], dtype=float32)}}}

Estimate y from our random parameters#

Let’s just take a random x and y and see what we get

x1, x2 = 1, 0
model.apply(params, [1,0])
Array([0.23232014], dtype=float32)

L2 Loss by Hand#

We’re going to use L2 loss, also know as mean squared error. This is a common error metric, but definitely not the only one.

m = params["params"]["dense"]["kernel"]
m
Array([[ 0.23232014],
       [-1.5953745 ]], dtype=float32)
bias = params["params"]["dense"]["bias"]
m = params["params"]["dense"]["kernel"]
y_0_pred = m[0] * x_obs[0][0] + m[1] * x_obs[0][1] + bias
y_0_pred
Array([4.5798745], dtype=float32)
y_obs[0]
array([-10.59475724])
(y_0_pred - y_obs[0])**2 / 2
Array([115.13471], dtype=float32)

Use Optax instead#

Common loss functions like L2 are already defined in libraries like optax, so rather than having to handcode them we can just import and run them.

import optax
y_pred = model.apply(params, x_obs[0])
y_pred
Array([4.5798745], dtype=float32)
y_pred = model.apply(params, x_obs[0])

optax.l2_loss(y_pred[0], y_obs[0][0])
Array(115.13471, dtype=float32)

Training#

Which basically means repeating our guess, check, update routine thousands of times until our estimates are not so bad.

from flax.training import train_state  # Useful dataclass to keep train state
# Note so jax needs to differentiate this 
@jax.jit
def flax_l2_loss(params, x, y_true): 
    y_pred = model.apply(params, x)
    total_loss = optax.l2_loss(y_pred, y_true).sum()
    
    return total_loss

flax_l2_loss(params, x_obs[0], y_obs[0])
Array(115.13471, dtype=float32)

How far do we step each time?#

optimizer = optax.adam(learning_rate=0.001)

Storing training state#

state = train_state.TrainState.create(apply_fn=model, params=params, tx=optimizer)
type(state)
flax.training.train_state.TrainState
state
TrainState(step=0, apply_fn=LinearRegression(), params={'params': {'dense': {'kernel': Array([[ 0.23232014],
       [-1.5953745 ]], dtype=float32), 'bias': Array([0.], dtype=float32)}}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7ff70c54a340>, update=<function chain.<locals>.update_fn at 0x7ff70c54a200>), opt_state=(ScaleByAdamState(count=Array(0, dtype=int32), mu={'params': {'dense': {'bias': Array([0.], dtype=float32), 'kernel': Array([[0.],
       [0.]], dtype=float32)}}}, nu={'params': {'dense': {'bias': Array([0.], dtype=float32), 'kernel': Array([[0.],
       [0.]], dtype=float32)}}}), EmptyState()))

Actual training#

epochs = 10000
_loss = []

for epoch in range(epochs):
    # Calculate the gradient
    loss, grads = jax.value_and_grad(flax_l2_loss)(state.params, x_obs, y_obs_noisy)
    _loss.append(loss)
    # Update the model parameters
    state = state.apply_gradients(grads=grads)

Training Loss#

fig, ax = plt.subplots()
ax.plot(np.arange(epochs), _loss)
ax.set_xlabel("Step or Epoch")
ax.set_ylabel("Trainng loss");
../_images/3ab9ed8d68e6edc666a92e18ff262042b2760fd7599bee82bf1f6fd5bbd87167.png

Final Parameters#

state.params
{'params': {'dense': {'bias': Array([3.1378474], dtype=float32),
   'kernel': Array([[1.1043512],
          [2.0921109]], dtype=float32)}}}

What if data is non linear?#

So in our previous case we knew the relationship between x and y was linear because we generated it. In reality though most relationships are quite non linear, for example the relation between the pixels of an image, and the image being a cat. Or the relation between words in a sentence.

x_non_linear = np.linspace(-10, 10, 100)
m = 2
y_non_linear = m*x_non_linear**2
fig, ax = plt.subplots()
ax.plot(x_non_linear,y_non_linear);
../_images/92af733849a8767a59eecd5228f956111bc4ffe33c2dcf1567144518da20ab56.png

Initialize Params for our model#

params = model.init(key, x_non_linear)
params["params"]["dense"]["kernel"].shape
(100, 1)

We need to reshape X#

x_non_linear[..., None][:5]
array([[-10.        ],
       [ -9.7979798 ],
       [ -9.5959596 ],
       [ -9.39393939],
       [ -9.19191919]])
params = model.init(key, x_non_linear[..., None])
params
{'params': {'dense': {'kernel': Array([[-0.92962414]], dtype=float32),
   'bias': Array([0.], dtype=float32)}}}
state = train_state.TrainState.create(apply_fn=model, params=params, tx=optimizer)
type(state)
flax.training.train_state.TrainState
epochs = 10000
_loss = []

for epoch in range(epochs):
    # Calculate the gradient, shapes are really annoying
    loss, grads = jax.value_and_grad(flax_l2_loss)(state.params, x_non_linear[..., None], y_non_linear[..., None])
    _loss.append(loss)
    
    # Update the model parameters
    state = state.apply_gradients(grads=grads)
state.params
{'params': {'dense': {'bias': Array([9.870098], dtype=float32),
   'kernel': Array([[-7.460363e-08]], dtype=float32)}}}
y_pred = model.apply(state.params, x_non_linear[..., None])
y_pred[:5]
Array([[9.870099],
       [9.870099],
       [9.870099],
       [9.870099],
       [9.870099]], dtype=float32)

This is a terrible fit#

If we plot our original data, and what out model is able to estiate, we see that’s it’s not close at all. Our model can only really fit a line, but it’s too simplisitic to fit a parabola.

fig, ax = plt.subplots()
ax.plot(x_non_linear, y_non_linear, label="Actual")
ax.plot(x_non_linear, y_pred, label="Predicted")
ax.legend();
../_images/a0de33ed50d3a7d0bb5e995824e0036f369d878454d8be50ef854ebaf1bbd99f.png

Nonlinear Regression#

Let’s expand our neural network so it can more flexibly fit data. We do this by adding a second layer. We also incrase the number of parameters using the features keyword. We also include the a relu layer. This is known as an activation function in a Neural Network. It’s what let’s individual neurons turn themselves “on” or “off”

class NonLinearRegression(nn.Module):
    # Define the neural network architecture
    def setup(self):
        """Single output, I dont need to specify shape"""
        self.hidden_layer_1 = nn.Dense(features=4)
        self.hidden_layer_2 = nn.Dense(features=4)
        self.dense_out = nn.Dense(features=1)

    # Define the forward pass
    def __call__(self, x):
        hidden_x_1 = self.hidden_layer_1(x)
        hidden_x_2 = self.hidden_layer_2(hidden_x_1)
        x = nn.relu(hidden_x_2)
        y_pred = self.dense_out(x)
        return y_pred
model = NonLinearRegression()
params = model.init(key, x_non_linear[..., None])
params
{'params': {'hidden_layer_1': {'kernel': Array([[-1.3402965 ,  0.60816777, -0.06039568, -0.73402834]], dtype=float32),
   'bias': Array([0., 0., 0., 0.], dtype=float32)},
  'hidden_layer_2': {'kernel': Array([[ 0.4993827 , -0.15674856,  0.19190995, -0.7276566 ],
          [-0.16007023,  0.4871689 , -0.32242206,  0.20559888],
          [ 1.0811065 , -0.07486992, -0.2228159 ,  0.18825608],
          [-1.068188  , -0.01827471, -0.34845108, -0.7157423 ]],      dtype=float32),
   'bias': Array([0., 0., 0., 0.], dtype=float32)},
  'dense_out': {'kernel': Array([[-0.4820224 ],
          [-0.11561313],
          [-1.0410513 ],
          [-0.37862116]], dtype=float32),
   'bias': Array([0.], dtype=float32)}}}
state = train_state.TrainState.create(apply_fn=model, params=params, tx=optimizer)
epochs = 3000
_loss = []

for epoch in range(epochs):
    # Calculate the gradient. Also shapes are really annoying
    loss, grads = jax.value_and_grad(flax_l2_loss)(state.params, x_non_linear[..., None], y_non_linear[..., None])
    _loss.append(loss)
    # Update the model parameters
    state = state.apply_gradients(grads=grads)
fig, ax = plt.subplots()
ax.plot(np.arange(epochs), _loss)
ax.set_xlabel("Step or Epoch")
ax.set_ylabel("Trainng loss");
../_images/957cbb155c41bb400b68a73909b84cc35a463a959dfbabfdd550acedb01b2aa4.png

Calculate Predictions#

y_pred_non_linear = model.apply(state.params, x_non_linear[..., None])

Plot Predictions#

So here’s our fit. It’s better but it’s not great.

fig, ax = plt.subplots()
ax.plot(x_non_linear, y_non_linear, label="Actual");
ax.plot(x_non_linear, y_pred, label="Simple Model Predictions", ls='--');
ax.plot(x_non_linear, y_pred_non_linear, label="More Complex Predictions", ls='--', lw=2)
plt.legend();
../_images/27f2541525cae36744b912afaa848acf1e8287190ba76d3be49e033bedf35da7.png

Non Linear (but more complex)#

Let’s now fit another model but we’ll increase the complexity by adding many more features, 128 in this case. Notice how the model is now able to fit the parabola “better”. It’s not perfect but this is the idea. Get a really big model with a lot of parameters and it theoretically find the structure in anything, even the words on the internet in every language, including code.

class NonLinearRegression_2(nn.Module):
    
    # Define the neural network architecture
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(128)(x)              # create inline Flax Module submodules
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)       # shape inference
        return x
model = NonLinearRegression_2()
params = model.init(key, x_non_linear[..., None])
state = train_state.TrainState.create(apply_fn=model, params=params, tx=optimizer)
epochs = 3000
_loss = []

for epoch in range(epochs):
    # Calculate the gradient. Also shapes are really annoying
    loss, grads = jax.value_and_grad(flax_l2_loss)(state.params, x_non_linear[..., None], y_non_linear[..., None])
    _loss.append(loss)
    # Update the model parameters
    state = state.apply_gradients(grads=grads)
y_pred_non_linear_2 = model.apply(state.params, x_non_linear[..., None])
fig, ax = plt.subplots()
ax.plot(x_non_linear, y_non_linear, label="Actual");
ax.plot(x_non_linear, y_pred, label="Simple Model Predictions", ls='--');
ax.plot(x_non_linear, y_pred_non_linear, label="More Complex Predictions", ls='--', lw=2)
ax.plot(x_non_linear, y_pred_non_linear_2, label="Most Complex Predicted", ls='--', lw=2)

plt.legend();
../_images/c8fd6de39fb8a38b2b19c02a016a8ac340e240b66e422d05f9a3f0fb2fa8c858.png

That’s it!#

That’s all there is to create Neural Nets and LLMs. With that you’re ready for the next notebook where we recreate a Shakespeare LLM from scratch using the same tools here.

References#