GPT from “scratch” in Flax#

TLDR:

  • This single notebook contains a full GPT that you can train at home.

  • The things to pay attention to are:

    • The tokenizer

    • The model structure and definition

    • How the model performance changes over various checkpoints

The trained model and the initial model#

Before we dive in, here’s the final model. It was trained on the Tiny Shakespeare dataset.

# A sample generation
idx = bglm.apply(state.params, 600, method='generate')
generation = ts.decode(idx.tolist()[0])
print(generation.strip())
MENENIUS:
COMINIUS:
Why, sir, or hast been so what I did;
Respect is his, and go with holds,
And backs hangment tomb, Camillo true,
To speak no away in this son of roopiness;
Whithe that is all for you was, kind how I am
Hope for you as time our place.

LARTIUS:
Alas, only to love, to ho,
Him it with the cause to one this, bitter her
May a shepherd haze. Angelo.
But one till the officers to this true,
Thus expect us was unto the; for I am every begue;
And come unto our bild, man, then joy in his Pome.

BENVOLIO:
How down seems be my voiciant is upon himself,
And since him out of a diath. whom

And here’s where we started, a bunch of random characters.

idx = bglm.apply(initial_params, 600, method='generate')
print(ts.decode(idx.tolist()[0]).strip())
boysZBZMAhCxFaSKeNyDs,WPHVxVWLoH hus
dV,XbV
Dhwha'rJFolY;rRlhRsL;PQsih$w-jYIWO$KT iyFUhrlYT,! xHZ
gUBslha &mHcYO.KmUDaCLjGqg? eVTYSOlllqJh$?$MlSTwj;jAACkPMNsJ-;!NCzrIhgiSwkb;M!VrQLQ:?gN.m:vNFlVf&OK,X-ygBaXYLkxWSu:o3fkrxBd;y:CMQHRuZWbs
VA,jkUuYzq::xXOeCst;qjLLdEKAAZHOU?:pmnzp:rEZRI&v&lPjJ,BVbmc N:wxAZ$AX-LnawOzE; GoyDOHg&xSuFCKTTkwJCxTMh:oFo,&3?izWwD wh!yEVokSP.qesfhfoIV!&.cwmL:KbRofnrdmXDETA'-'GWPURTzwX.ubveT.XMHY;f-wmz:-EqBHFab;gfJVJL-XIQe'WdCc$nguDpF3pTsB:3jI.bCKvYlJPCRCp,o?3atqGI.NPjRbkvBSex,QPdIekjhx! MUjbSYor:BcDDkff&HcbQN,gkv3BvNMdr-kddiRmubfc.XEGsImf'uqqZJqrLzCAgsUKVQnU$xpAgTVbalnVbKUpM

The rest of this notebook shows how we get from our random initial LLM to our trained one.

Why Flax and Jax#

Similar to PyTorch, Flax and Jax are production-grade tools that are used from individual research to massive multi-cluster LLMs.

For this guidebook, there are some additional reasons:

  • I find that the separation of state, model definition, and training loops in Flax make it easier to disambiguate the concepts.

  • Being able to print the layer outputs as well helped with understanding and debugging.

  • Deterministic by default makes debugging simpler.

  • Now you have both a PyTorch and Flax/Jax reference!

This notebook dives straight into a neural network model. If you need a primer on Neural Networks, start with NN Quickstart.

Additional Notes#

Here are some additional details for easy reference:

B, T, C Meaning#

Karpathy uses the shorthand symbols B, T, C in his code and videos. He explains their meanings here, which I’ll write out for easy reference:

  • B - Batch - How many training examples we’re using at once. This is just for computational efficiency.

  • T - Time - In a text input array, either encoded or decoded, the index that is being referenced.

  • C - Channel - The specific token “number”, or index in the embedding array. This matches the…

Block Size is Context Window#

The length of the input sequence the model gets for making the next character prediction.

Suggested Exercises#

For additional hands-on learning, I’ve included two additional datasets here, along with all the code for training and text generation.

abcdef A quick brown fox jumped over the lazy dog

These strings have specific properties. The first has no repeating characters, and it’s easy to encode and decode to integers in your head. This means it’s easy to get an overfit model with perfect predictions for debugging, as well as inspecting the parameters. The second has no repeating alphabetic characters, but it does have spaces, adding just one additional layer of complexity.


For additional debugging and understanding, you can perform an ablation study. That means removing parts of the model, making certain parameters smaller or larger, and seeing what happens. Visualize parts of the learned parameters using heatmaps and other tools to see what the model is learning.

qb = QuickBrownFox(block_size = block_size, batch_size=batch_size)
ab = Alphabet(block_size = block_size, batch_size=batch_size)

Imports#

!pip install flax==0.7.5 jax==0.4.24 optax

seed = 1337
import jax
import jax.numpy as jnp

import flax
from flax import linen as nn
from flax.training import train_state, checkpoints
import optax

from typing import Sequence
import matplotlib.pyplot as plt
import numpy as np

from tqdm.auto import tqdm
import logging
import requests

def create_logger():
    logger = logging.getLogger("notebook")
    logger.setLevel(logging.INFO)

    if not logger.hasHandlers():
        consolehandler = logging.StreamHandler()
        consolehandler.setLevel(logging.DEBUG)
        formatter = logging.Formatter("%(levelname)s - %(message)s")
        consolehandler.setFormatter(formatter)
        logger.addHandler(consolehandler)

    return logger

logger = create_logger()
# Checking if Cuda is loaded
devices = jax.devices()
print(devices)
[cuda(id=0)]

Data Loaders#

This text-based class implements the tokenizer, as well as the string encoder and decoder. This is not to be confused with the encoder and decoder portions of a Transformer, which are different things.

import abc
import tensorflow_datasets as tfds
class BaseTextProcesser:
    """Load text, with tokenizer encode and decode"""

    def __init__(self, block_size, batch_size):
        self.block_size = block_size
        self.batch_size = batch_size
        self._data = None

        self.text = self.set_text()
        self.chars = sorted(list(set(self.text)))
        self.vocab_size = len(self.chars)
        self.stoi = {ch: i for i, ch in enumerate(self.chars)}
        self.itos = {i: ch for i, ch in enumerate(self.chars)}

    @abc.abstractmethod
    def set_text(self) -> str:
        """Sets the text corpus that going to be used"""
        return

    def encode(self, input_string):
        return [self.stoi[c] for c in input_string]

    def decode(self, token_iter):
        # Int for jax.ArrayImpl coercion
        return "".join([self.itos[int(c)] for c in token_iter])
        
    def batch_decoder(self, tokens_lists):
        return [self.decode(tokens) for tokens in tokens_lists]

    @property
    def data(self):
        if self._data is None:
            self._data = jnp.array(self.encode(self.text))
        return self._data

    def train_test_split(self, split=.9):
        """Return train test split of data.

        We'll do it the same as the tutorial without any fancy reshuffling or anything like that
        """

        # Let's now split up the data into train and validation sets
        n = int(split * len(self.data))  # first 90% will be training data, rest val
        train_data = self.data[:n]
        val_data = self.data[n:]
        return train_data, val_data

    def get_batch(self, key):
        """Depending on what's passed in it'll get batches

        Parameters
        ----------
        key: jax PRNG Key
        data: Jax array
        block_size int:
        batch_size int:

        Returns
        -------
        x: Training examples
        y: Target array
        ix: Indices
        """
        # Take batch size random samples of starting positions of text from Tiny Shakespeare
        # Jax require more arguments than pytorch for random sample
        ix = jax.random.randint(key=key, minval=0, maxval=len(self.data) - self.block_size,
                                shape=(self.batch_size,))

        # Each starting position of text take a snippet and stack all the text snippets together
        x = jnp.stack([self.data[i:i + self.block_size] for i in ix])

        # The training data is the same just one position offset
        y = jnp.stack([self.data[i + 1:i + self.block_size + 1] for i in ix])
        return x, y, ix

Three datasets are included:

  • Tiny Shakespeare - This is the original from the tutorial.

  • Alphabet - Just the alphabet, no character repeats so the LLM should be able to memorize the next token.

  • QuickBrownFox - A sentence where letters don’t repeat, but spaces do, leading to a tiny bit of variability.

The second two datasets are included for debugging purposes. One of the easiest ways to debug an LLM, or any deep learning model, is to overfit a simple dataset. If your model can’t do that, you have problems.

class TinyShakespeare(BaseTextProcesser):

    def set_text(self):
        """Sets the text corpus that going to be used"""
        text = None
        if text is None:
            text = requests.get("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt").text
        return text
class QuickBrownFox(BaseTextProcesser):
    @abc.abstractmethod
    def set_text(self):
        """Sets the text corpus that going to be used"""
        text = "The quick brown fox jumped over the lazy dog"
        return text
class Alphabet(BaseTextProcesser):
    @abc.abstractmethod
    def set_text(self):
        """Sets the text corpus that going to be used"""
        text = "abcdefg"
        return text

Random Keys#

JAX, on which Flax is based, by design requires explicit keys for its randomness. This is counterintuitive but smart for many reasons. Read the docs to learn more.

root_key = jax.random.PRNGKey(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)

Model Constants#

batch_size = 128 # How many independent sequences will we process in parallel?
block_size = 64 # What is the maximum context length for predictions?

Load and Encode Data#

ts = TinyShakespeare(block_size = block_size, batch_size=batch_size)
# ts = utils.QuickBrownFox(block_size = block_size, batch_size=batch_size)
# ts = utils.Alphabet(block_size = block_size, batch_size=batch_size)

xb, yb, ix = ts.get_batch(main_key)
ts.text[:50]
'First Citizen:\nBefore we proceed any further, hear'
ts.vocab_size
65
xb[:2]
Array([[ 1, 61, 43, 43, 42, 57,  6,  1, 61, 47, 58, 46,  1, 53, 60, 43,
        56, 61, 46, 43, 50, 51, 47, 52, 45,  1, 40, 56, 53, 61, 57,  6,
         0, 15, 59, 50, 50, 47, 52, 45,  1, 53, 44,  1, 57, 47, 51, 54,
        50, 43, 57, 11,  1, 51, 43, 39, 45, 56, 43,  1, 61, 43, 56, 43],
       [ 8,  0, 14, 59, 58,  1, 39, 57,  1, 58, 46, 47, 57,  1, 58, 47,
        58, 50, 43,  1, 46, 53, 52, 53, 59, 56, 57,  1, 51, 43,  1, 39,
        52, 42,  1, 51, 47, 52, 43,  6,  0, 31, 53,  1, 63, 53, 59, 56,
         1, 42, 47, 57, 50, 47, 49, 43,  6,  1, 58, 53,  1, 61, 46, 53]],      dtype=int32)
ts.batch_decoder(xb[:2])
[' weeds, with overwhelming brows,\nCulling of simples; meagre were',
 '.\nBut as this title honours me and mine,\nSo your dislike, to who']

Model Parameters#

# Model Parameters
vocab_size = ts.vocab_size
n_embd = 120
n_head = 6
n_layer = 6
dropout_rate = .4

# n_embd = 300
# n_head = 6
# n_layer = 6
#dropout = 0.2
global_mask = nn.make_causal_mask(xb)
global_mask.shape
(128, 1, 64, 64)

Multi-Head Attention#

Now, this is not from scratch, but I chose to use the pre-canned version to show what a production implementation would look like. Andrej already does a great job explaining the internals of attention, so I didn’t feel that repeating it here would add any extra value at the moment.

class MultiHeadAttention(nn.Module):
    """Combine single Attention Head into one here"""
    num_heads: int
    n_embd: int

    @nn.compact
    def __call__(self, x, training):
        x = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            dropout_rate=dropout_rate,
            deterministic=not training
        )(x, mask=global_mask)
        x = nn.Dense(self.n_embd)(x)
        return x

Test Parameter Initialization#

mha = MultiHeadAttention(
    # head_size=head_size,
    num_heads=n_head,
    n_embd=n_embd,
    )
input_x = jnp.ones((batch_size, block_size, n_embd))
logger.debug(f"{input_x.shape}")

params = mha.init({'params':params_key, "dropout": dropout_key}, input_x, training=True)

print(mha.apply(params, input_x, training=False).shape)
(128, 64, 120)
input_x = jnp.ones((batch_size, block_size, n_embd))
params = mha.init(root_key, input_x, training=False)

mha.apply(params, input_x, training=False);

Feedforward Layer#

(batch_size, block_size) -> (batch_size, block_size, n_embd)

class FeedForward(nn.Module):
    
    @nn.compact      
    def __call__(self, x, training):
        x = nn.Sequential([
            nn.Dense(4 * n_embd),
            nn.relu,
            nn.Dense(n_embd),
            nn.Dropout(dropout_rate, deterministic = not training)
        ])(x)
        return x
ff = FeedForward()

input_x = jnp.ones((batch_size, block_size, n_embd))
logger.debug(f"{input_x.shape}")

params = ff.init({'params':params_key, "dropout": dropout_key}, input_x, training=True)

print(ff.apply(params, input_x, training=False).shape)
(128, 64, 120)
print(ff.tabulate(root_key, input_x, training=False,
      console_kwargs={'force_terminal': False, 'force_jupyter': True}))
                                          FeedForward Summary                                          
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ path          module       inputs                 outputs              params                   ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│              │ FeedForward │ - float32[128,64,120] │ float32[128,64,120] │                          │
│              │             │ - training: False     │                     │                          │
├──────────────┼─────────────┼───────────────────────┼─────────────────────┼──────────────────────────┤
│ Sequential_0 │ Sequential  │ float32[128,64,120]   │ float32[128,64,120] │                          │
├──────────────┼─────────────┼───────────────────────┼─────────────────────┼──────────────────────────┤
│ Dense_0      │ Dense       │ float32[128,64,120]   │ float32[128,64,480] │ bias: float32[480]       │
│              │             │                       │                     │ kernel: float32[120,480] │
│              │             │                       │                     │                          │
│              │             │                       │                     │ 58,080 (232.3 KB)        │
├──────────────┼─────────────┼───────────────────────┼─────────────────────┼──────────────────────────┤
│ Dense_1      │ Dense       │ float32[128,64,480]   │ float32[128,64,120] │ bias: float32[120]       │
│              │             │                       │                     │ kernel: float32[480,120] │
│              │             │                       │                     │                          │
│              │             │                       │                     │ 57,720 (230.9 KB)        │
├──────────────┼─────────────┼───────────────────────┼─────────────────────┼──────────────────────────┤
│ Dropout_0    │ Dropout     │ float32[128,64,120]   │ float32[128,64,120] │                          │
├──────────────┼─────────────┼───────────────────────┼─────────────────────┼──────────────────────────┤
│                                                                 Total  115,800 (463.2 KB)       │
└──────────────┴─────────────┴───────────────────────┴─────────────────────┴──────────────────────────┘
                                                                                                       
                                 Total Parameters: 115,800 (463.2 KB)                                  

Block#

class Block(nn.Module):
    @nn.compact         
    def __call__(self, x, training):

        sa = MultiHeadAttention(        
                            n_embd=n_embd,
                            num_heads=n_head,
                            )
        ff = FeedForward()
        
        x = x + sa(nn.LayerNorm(n_embd)(x), training=training)
        x = x + ff(nn.LayerNorm(n_embd)(x), training=training)
        return dict(x=x, training=training)
block = Block()

input_x = jnp.ones((batch_size, block_size, n_embd))
logger.debug(f"{input_x.shape=}")

block_params = block.init({'params':params_key, "dropout": dropout_key}, input_x, training=True)
logger.info(f"{block.apply(block_params, input_x, training=False)['x'].shape=}")
INFO - block.apply(block_params, input_x, training=False)['x'].shape=(128, 64, 120)
print(block.tabulate(root_key, input_x, training=False,
      console_kwargs={'force_terminal': True, 'force_jupyter': True},
                     column_kwargs = {'width': 400}))
                                                   Block Summary                                                   
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ path                  module                inputs                outputs              params               ┃
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│                      │ Block                │ -                    │ training: False     │                      │
│                      │                      │ float32[128,64,120]  │ x:                  │                      │
│                      │                      │ - training: False    │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ LayerNorm_0          │ LayerNorm            │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ scale: float32[120]  │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 240 (960 B)          │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ MultiHeadAttention_0 │ MultiHeadAttention   │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - training: False    │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ MultiHeadAttention_… │ MultiHeadDotProduct… │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - mask:              │                     │                      │
│                      │                      │ float32[128,1,64,64] │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ MultiHeadAttention_… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ MultiHeadAttention_… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ MultiHeadAttention_… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ MultiHeadAttention_… │ DenseGeneral         │ float32[128,64,6,20] │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[6,20,120]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ MultiHeadAttention_… │ Dense                │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,120]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ LayerNorm_1          │ LayerNorm            │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ scale: float32[120]  │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 240 (960 B)          │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ FeedForward_0        │ FeedForward          │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - training: False    │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ FeedForward_0/Seque… │ Sequential           │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ FeedForward_0/Dense… │ Dense                │ float32[128,64,120]  │ float32[128,64,480] │ bias: float32[480]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,480]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 58,080 (232.3 KB)    │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ FeedForward_0/Dense… │ Dense                │ float32[128,64,480]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[480,120]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 57,720 (230.9 KB)    │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ FeedForward_0/Dropo… │ Dropout              │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│                                                                                 Total  188,880 (755.5 KB)   │
└──────────────────────┴──────────────────────┴──────────────────────┴─────────────────────┴──────────────────────┘
                                                                                                                   
                                       Total Parameters: 188,880 (755.5 KB)                                        

Full Model#

vocab_size = ts.vocab_size

class BigramLanguageModel(nn.Module):
    
    @nn.compact
    def __call__(self, idx, training):
        logger.debug(f"In call {idx.shape=}")
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = nn.Embed(vocab_size, n_embd, name="TokenEmbedding")(idx) # (B,T,C)
        pos_emb = nn.Embed(block_size, n_embd, name="Position Embedding")(jnp.arange(T)) # (T,C)
        
        x = tok_emb + pos_emb # (B,T,C)
        x = nn.Sequential([Block() for _ in range(int(n_layer))])(x, training=training)["x"] # (B,T,C)
        x = nn.LayerNorm(n_embd, name="LayerNorm")(x) # (B,T,C)
        logits = nn.Dense(vocab_size, name="Final Dense")(x) # (B,T,vocab_size)
        return logits
    
    def generate(self, max_new_tokens):
        idx = jnp.zeros((1, block_size), dtype=jnp.int32)*4
        
        # We need to get this to enable correct random behavior later
        key = jax.random.PRNGKey(0)

        
        for i in range(max_new_tokens):
            logger.debug(f"In generate {i=}")
            
            # Get the predictions
            logger.debug(f"In generate {idx=}==========")
            logits = self.__call__(idx[:, -block_size:], training=False)
            logger.debug(f"In generate {logits.size=}")

            
            ## Focus only on the logits last time step
            # logits_last_t = logits[:, -1, :] # becomes (T)
            logits_last_t = logits[0, -1]
            
            # Due to the way randomness works in jax we have to generate subkeys to get new values
            # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#jax-prng
            _, subkey = jax.random.split(key)

            # Update: Jax categorical wants unnormalized logprobabilities so we don't need to softmax
            # https://jax.readthedocs.io/en/latest/_autosummary/jax.random.categorical.html
            # sample from the distribution. 
            idx_next = jax.random.categorical(subkey, logits_last_t)
            
            # Rotate the key
            key = subkey

            # append sampled index to the running sequence
            idx = jnp.atleast_2d(jnp.append(idx, idx_next))
            logger.debug(f"In generate after append {idx=}")

        return idx

bglm = BigramLanguageModel()
logger.setLevel(logging.DEBUG)

bglm = BigramLanguageModel()

input_x = jnp.ones((batch_size, block_size), dtype=jnp.int16)
logger.debug(f"{input_x.shape=}")
initial_params = bglm.init({'params':params_key, "dropout": dropout_key}, input_x, training=True)
DEBUG - input_x.shape=(128, 64)
DEBUG - In call idx.shape=(128, 64)
initial_params["params"].keys()
dict_keys(['TokenEmbedding', 'Position Embedding', 'Block_0', 'Block_1', 'Block_2', 'Block_3', 'Block_4', 'Block_5', 'LayerNorm', 'Final Dense'])
print(bglm.tabulate({'params':params_key, "dropout": dropout_key}, input_x, training=False,
      console_kwargs={'force_terminal': False, 'force_jupyter': True},
                     column_kwargs = {'width': 400}))
DEBUG - In call idx.shape=(128, 64)
                                            BigramLanguageModel Summary                                            
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ path                  module                inputs                outputs              params               ┃
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│                      │ BigramLanguageModel  │ - int16[128,64]      │ float32[128,64,65]  │                      │
│                      │                      │ - training: False    │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ TokenEmbedding       │ Embed                │ int16[128,64]        │ float32[128,64,120] │ embedding:           │
│                      │                      │                      │                     │ float32[65,120]      │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 7,800 (31.2 KB)      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Position Embedding   │ Embed                │ int32[64]            │ float32[64,120]     │ embedding:           │
│                      │                      │                      │                     │ float32[64,120]      │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 7,680 (30.7 KB)      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Sequential_0         │ Sequential           │ -                    │ training: False     │                      │
│                      │                      │ float32[128,64,120]  │ x:                  │                      │
│                      │                      │ - training: False    │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_0              │ Block                │ -                    │ training: False     │                      │
│                      │                      │ float32[128,64,120]  │ x:                  │                      │
│                      │                      │ - training: False    │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_0/LayerNorm_0  │ LayerNorm            │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ scale: float32[120]  │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 240 (960 B)          │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_0/MultiHeadAt… │ MultiHeadAttention   │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - training: False    │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_0/MultiHeadAt… │ MultiHeadDotProduct… │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - mask:              │                     │                      │
│                      │                      │ float32[128,1,64,64] │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_0/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_0/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_0/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_0/MultiHeadAt… │ DenseGeneral         │ float32[128,64,6,20] │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[6,20,120]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_0/MultiHeadAt… │ Dense                │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,120]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_0/LayerNorm_1  │ LayerNorm            │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ scale: float32[120]  │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 240 (960 B)          │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_0/FeedForward… │ FeedForward          │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - training: False    │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_0/FeedForward… │ Sequential           │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_0/FeedForward… │ Dense                │ float32[128,64,120]  │ float32[128,64,480] │ bias: float32[480]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,480]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 58,080 (232.3 KB)    │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_0/FeedForward… │ Dense                │ float32[128,64,480]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[480,120]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 57,720 (230.9 KB)    │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_0/FeedForward… │ Dropout              │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_1              │ Block                │ training: False      │ training: False     │                      │
│                      │                      │ x:                   │ x:                  │                      │
│                      │                      │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_1/LayerNorm_0  │ LayerNorm            │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ scale: float32[120]  │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 240 (960 B)          │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_1/MultiHeadAt… │ MultiHeadAttention   │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - training: False    │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_1/MultiHeadAt… │ MultiHeadDotProduct… │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - mask:              │                     │                      │
│                      │                      │ float32[128,1,64,64] │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_1/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_1/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_1/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_1/MultiHeadAt… │ DenseGeneral         │ float32[128,64,6,20] │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[6,20,120]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_1/MultiHeadAt… │ Dense                │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,120]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_1/LayerNorm_1  │ LayerNorm            │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ scale: float32[120]  │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 240 (960 B)          │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_1/FeedForward… │ FeedForward          │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - training: False    │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_1/FeedForward… │ Sequential           │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_1/FeedForward… │ Dense                │ float32[128,64,120]  │ float32[128,64,480] │ bias: float32[480]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,480]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 58,080 (232.3 KB)    │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_1/FeedForward… │ Dense                │ float32[128,64,480]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[480,120]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 57,720 (230.9 KB)    │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_1/FeedForward… │ Dropout              │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_2              │ Block                │ training: False      │ training: False     │                      │
│                      │                      │ x:                   │ x:                  │                      │
│                      │                      │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_2/LayerNorm_0  │ LayerNorm            │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ scale: float32[120]  │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 240 (960 B)          │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_2/MultiHeadAt… │ MultiHeadAttention   │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - training: False    │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_2/MultiHeadAt… │ MultiHeadDotProduct… │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - mask:              │                     │                      │
│                      │                      │ float32[128,1,64,64] │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_2/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_2/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_2/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_2/MultiHeadAt… │ DenseGeneral         │ float32[128,64,6,20] │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[6,20,120]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_2/MultiHeadAt… │ Dense                │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,120]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_2/LayerNorm_1  │ LayerNorm            │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ scale: float32[120]  │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 240 (960 B)          │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_2/FeedForward… │ FeedForward          │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - training: False    │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_2/FeedForward… │ Sequential           │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_2/FeedForward… │ Dense                │ float32[128,64,120]  │ float32[128,64,480] │ bias: float32[480]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,480]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 58,080 (232.3 KB)    │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_2/FeedForward… │ Dense                │ float32[128,64,480]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[480,120]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 57,720 (230.9 KB)    │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_2/FeedForward… │ Dropout              │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_3              │ Block                │ training: False      │ training: False     │                      │
│                      │                      │ x:                   │ x:                  │                      │
│                      │                      │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_3/LayerNorm_0  │ LayerNorm            │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ scale: float32[120]  │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 240 (960 B)          │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_3/MultiHeadAt… │ MultiHeadAttention   │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - training: False    │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_3/MultiHeadAt… │ MultiHeadDotProduct… │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - mask:              │                     │                      │
│                      │                      │ float32[128,1,64,64] │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_3/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_3/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_3/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_3/MultiHeadAt… │ DenseGeneral         │ float32[128,64,6,20] │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[6,20,120]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_3/MultiHeadAt… │ Dense                │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,120]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_3/LayerNorm_1  │ LayerNorm            │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ scale: float32[120]  │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 240 (960 B)          │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_3/FeedForward… │ FeedForward          │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - training: False    │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_3/FeedForward… │ Sequential           │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_3/FeedForward… │ Dense                │ float32[128,64,120]  │ float32[128,64,480] │ bias: float32[480]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,480]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 58,080 (232.3 KB)    │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_3/FeedForward… │ Dense                │ float32[128,64,480]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[480,120]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 57,720 (230.9 KB)    │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_3/FeedForward… │ Dropout              │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_4              │ Block                │ training: False      │ training: False     │                      │
│                      │                      │ x:                   │ x:                  │                      │
│                      │                      │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_4/LayerNorm_0  │ LayerNorm            │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ scale: float32[120]  │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 240 (960 B)          │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_4/MultiHeadAt… │ MultiHeadAttention   │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - training: False    │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_4/MultiHeadAt… │ MultiHeadDotProduct… │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - mask:              │                     │                      │
│                      │                      │ float32[128,1,64,64] │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_4/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_4/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_4/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_4/MultiHeadAt… │ DenseGeneral         │ float32[128,64,6,20] │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[6,20,120]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_4/MultiHeadAt… │ Dense                │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,120]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_4/LayerNorm_1  │ LayerNorm            │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ scale: float32[120]  │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 240 (960 B)          │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_4/FeedForward… │ FeedForward          │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - training: False    │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_4/FeedForward… │ Sequential           │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_4/FeedForward… │ Dense                │ float32[128,64,120]  │ float32[128,64,480] │ bias: float32[480]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,480]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 58,080 (232.3 KB)    │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_4/FeedForward… │ Dense                │ float32[128,64,480]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[480,120]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 57,720 (230.9 KB)    │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_4/FeedForward… │ Dropout              │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_5              │ Block                │ training: False      │ training: False     │                      │
│                      │                      │ x:                   │ x:                  │                      │
│                      │                      │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_5/LayerNorm_0  │ LayerNorm            │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ scale: float32[120]  │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 240 (960 B)          │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_5/MultiHeadAt… │ MultiHeadAttention   │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - training: False    │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_5/MultiHeadAt… │ MultiHeadDotProduct… │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - mask:              │                     │                      │
│                      │                      │ float32[128,1,64,64] │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_5/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_5/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_5/MultiHeadAt… │ DenseGeneral         │ float32[128,64,120]  │ float32[128,64,6,2… │ bias: float32[6,20]  │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,6,20]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_5/MultiHeadAt… │ DenseGeneral         │ float32[128,64,6,20] │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[6,20,120]    │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_5/MultiHeadAt… │ Dense                │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,120]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 14,520 (58.1 KB)     │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_5/LayerNorm_1  │ LayerNorm            │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ scale: float32[120]  │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 240 (960 B)          │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_5/FeedForward… │ FeedForward          │ -                    │ float32[128,64,120] │                      │
│                      │                      │ float32[128,64,120]  │                     │                      │
│                      │                      │ - training: False    │                     │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_5/FeedForward… │ Sequential           │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_5/FeedForward… │ Dense                │ float32[128,64,120]  │ float32[128,64,480] │ bias: float32[480]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,480]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 58,080 (232.3 KB)    │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_5/FeedForward… │ Dense                │ float32[128,64,480]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[480,120]     │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 57,720 (230.9 KB)    │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Block_5/FeedForward… │ Dropout              │ float32[128,64,120]  │ float32[128,64,120] │                      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ LayerNorm            │ LayerNorm            │ float32[128,64,120]  │ float32[128,64,120] │ bias: float32[120]   │
│                      │                      │                      │                     │ scale: float32[120]  │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 240 (960 B)          │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│ Final Dense          │ Dense                │ float32[128,64,120]  │ float32[128,64,65]  │ bias: float32[65]    │
│                      │                      │                      │                     │ kernel:              │
│                      │                      │                      │                     │ float32[120,65]      │
│                      │                      │                      │                     │                      │
│                      │                      │                      │                     │ 7,865 (31.5 KB)      │
├──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────┼──────────────────────┤
│                                                                                 Total  1,156,865 (4.6 MB)   │
└──────────────────────┴──────────────────────┴──────────────────────┴─────────────────────┴──────────────────────┘
                                                                                                                   
                                       Total Parameters: 1,156,865 (4.6 MB)                                        

Sample Generation#

This is with initial parameters that are totally random.

logger.propagate = False
bglm = BigramLanguageModel()

idx = bglm.apply(initial_params, 50, method='generate')
ts.decode(idx.tolist()[0])
"\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nboysZBZMAhCxFaSKeNyDs,WPHVxVWLoH hus\ndV,XbV\nDhwha'"

Training#

learning_rate = 1e-2

class TrainState(train_state.TrainState):
  key: jax.Array

state = TrainState.create(
  apply_fn=bglm.apply,
  params=initial_params,
  key=dropout_key,
  tx=optax.adam(learning_rate)
)
@jax.jit

def train_step(state: TrainState, inputs, labels):
  dropout_train_key = jax.random.fold_in(key=dropout_key, data=state.step)

  def cross_entropy_loss(params):
        logits = bglm.apply(params,
                            inputs,
                            training=True,
                            rngs={'dropout': dropout_key})
        logger.debug(logits.shape)
        
        # We use with integer labels method so we don't have to bother with one hot encoding
        loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels))
        return loss, logits
    
  grad_fn = jax.value_and_grad(cross_entropy_loss, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state, loss
state.params["params"].keys()
dict_keys(['TokenEmbedding', 'Position Embedding', 'Block_0', 'Block_1', 'Block_2', 'Block_3', 'Block_4', 'Block_5', 'LayerNorm', 'Final Dense'])
state.step
0

Training Loop#

This model was trained on my RTX 4090. It took about 30 minutes to train.

logger.setLevel(logging.DEBUG)

eval_interval = 100
CKPT_DIR = 'ckpts'
epochs = 10000
_loss = []

train_key, dropout_key = jax.random.split(key=root_key, num=2)

for epoch in tqdm(range(epochs)):    

    # Generate a new random key
    train_key = jax.random.fold_in(key=root_key, data=state.step)

    # Get a new batch
    xb, yb, ix = ts.get_batch(train_key)
    
    # Calculate the gradient
    state, loss = train_step(state, xb, yb)
    _loss.append(loss)
    
    # every once in a while evaluate the loss on train and val sets
    if epoch % eval_interval == 0 or epoch == epochs - 1:
        print(f"step {epoch}: train loss {loss:.4f}")
    
    # # Update the model parameters
    checkpoints.save_checkpoint(ckpt_dir=CKPT_DIR,
                               target=state,
                               keep_every_n_steps=100,
                               prefix='nn_compact',
                               overwrite=True,
                               step=epoch)
DEBUG - In call idx.shape=(128, 64)
DEBUG - (128, 64, 65)
step 0: train loss 4.1745
step 100: train loss 2.5395
step 200: train loss 2.3642
step 300: train loss 3.0977
step 400: train loss 2.6543
step 500: train loss 2.4802
step 600: train loss 2.4041
step 700: train loss 2.3522
step 800: train loss 2.2822
step 900: train loss 2.2196
step 1000: train loss 2.1838
step 1100: train loss 2.1544
step 1200: train loss 2.1250
step 1300: train loss 2.0753
step 1400: train loss 2.0177
step 1500: train loss 2.0272
step 1600: train loss 1.9575
step 1700: train loss 1.9646
step 1800: train loss 1.9174
step 1900: train loss 1.9337
step 2000: train loss 1.8997
step 2100: train loss 1.8733
step 2200: train loss 1.8532
step 2300: train loss 1.8371
step 2400: train loss 1.8265
step 2500: train loss 1.7930
step 2600: train loss 1.7741
step 2700: train loss 1.7919
step 2800: train loss 1.7622
step 2900: train loss 1.7659
step 3000: train loss 1.7208
step 3100: train loss 1.6991
step 3200: train loss 1.6876
step 3300: train loss 1.7133
step 3400: train loss 1.6870
step 3500: train loss 1.6644
step 3600: train loss 1.6809
step 3700: train loss 1.6525
step 3800: train loss 1.6607
step 3900: train loss 1.6528
step 4000: train loss 1.6055
step 4100: train loss 1.6374
step 4200: train loss 1.6325
step 4300: train loss 1.6066
step 4400: train loss 1.6246
step 4500: train loss 1.6030
step 4600: train loss 1.5821
step 4700: train loss 1.5942
step 4800: train loss 1.5454
step 4900: train loss 1.5827
step 5000: train loss 1.5408
step 5100: train loss 1.5866
step 5200: train loss 1.5466
step 5300: train loss 1.4982
step 5400: train loss 1.5322
step 5500: train loss 1.5738
step 5600: train loss 1.5510
step 5700: train loss 1.5762
step 5800: train loss 1.5388
step 5900: train loss 1.5330
step 6000: train loss 1.5429
step 6100: train loss 1.5503
step 6200: train loss 1.5298
step 6300: train loss 1.4796
step 6400: train loss 1.5489
step 6500: train loss 1.5252
step 6600: train loss 1.4758
step 6700: train loss 1.5277
step 6800: train loss 1.4985
step 6900: train loss 1.5336
step 7000: train loss 1.5061
step 7100: train loss 1.4715
step 7200: train loss 1.5023
step 7300: train loss 1.5088
step 7400: train loss 1.5028
step 7500: train loss 1.4952
step 7600: train loss 1.4748
step 7700: train loss 1.4821
step 7800: train loss 1.4490
step 7900: train loss 1.4799
step 8000: train loss 1.4497
step 8100: train loss 1.4566
step 8200: train loss 1.4322
step 8300: train loss 1.4875
step 8400: train loss 1.4908
step 8500: train loss 1.4873
step 8600: train loss 1.4771
step 8700: train loss 1.4786
step 8800: train loss 1.4422
step 8900: train loss 1.4572
step 9000: train loss 1.4584
step 9100: train loss 1.4253
step 9200: train loss 1.4107
step 9300: train loss 1.4151
step 9400: train loss 1.4546
step 9500: train loss 1.4369
step 9600: train loss 1.4540
step 9700: train loss 1.4451
step 9800: train loss 1.4437
step 9900: train loss 1.4155
step 9999: train loss 1.4538

Training Loss#

fig, ax = plt.subplots()
ax.plot(np.arange(epochs), _loss)

ax.set_xlabel("Training Step")
ax.set_ylabel("Training Loss")
Text(0, 0.5, 'Training Loss')
../_images/84d0ed948a0151aba78d0ca80e54abc37517ecad5c2b7ec1e5d23bf23ba35524.png

Final Results#

We can now use our weights to generate text. At a glance, it’s not bad.

logger.setLevel(logging.INFO)

idx = bglm.apply(state.params, 600, method='generate')
generation = ts.decode(idx.tolist()[0])
print(generation.strip())
MENENIUS:
COMINIUS:
Why, sir, or hast been so what I did;
Respect is his, and go with holds,
And backs hangment tomb, Camillo true,
To speak no away in this son of roopiness;
Whithe that is all for you was, kind how I am
Hope for you as time our place.

LARTIUS:
Alas, only to love, to ho,
Him it with the cause to one this, bitter her
May a shepherd haze. Angelo.
But one till the officers to this true,
Thus expect us was unto the; for I am every begue;
And come unto our bild, man, then joy in his Pome.

BENVOLIO:
How down seems be my voiciant is upon himself,
And since him out of a diath. whom
logger.setLevel(logging.INFO)

idx = bglm.apply(state.params, 600, method='generate')
generation = ts.decode(idx.tolist()[0])
print(generation.strip())
MENENIUS:
COMINIUS:
Why, sir, or hast been so what I did;
Respect is his, and go with holds,
And backs hangment tomb, Camillo true,
To speak no away in this son of roopiness;
Whithe that is all for you was, kind how I am
Hope for you as time our place.

LARTIUS:
Alas, only to love, to ho,
Him it with the cause to one this, bitter her
May a shepherd haze. Angelo.
But one till the officers to this true,
Thus expect us was unto the; for I am every begue;
And come unto our bild, man, then joy in his Pome.

BENVOLIO:
How down seems be my voiciant is upon himself,
And since him out of a diath. whom

Comparing Checkpoints#

What’s interesting is going back in time and comparing generations from different checkpoints. Since we saved earlier checkpoints, we can load them and see what the LLM had learned up to that point.

Checkpoint 501#

ckpt_301 = checkpoints.restore_checkpoint(ckpt_dir=f'{CKPT_DIR}/nn_compact301', target=None)

idx = bglm.apply(ckpt_301["params"], 600, method='generate')
generation = ts.decode(idx.tolist()[0])
print(generation.strip())
o sand Ah tiat,eNyh , t silenod has
d Aobe
Dh,hacr lolipr lh s teisihh mosaaOgn niyaUhrl T,! In

r Bsl a Ime Hot mor CLhtqIw tre Smellosh
t nleawe IiACtl Nssor NCtroh iteke;sacr Ie tg bi:slalrf,sarerygdaaYiku S noS konddiy:iM Hhusote
VA,iktrodh: toweisteuRsodl
AAuHtstopm r :rloRIshtla o,raboc N:w t to
-Ln tOate mo s
Hgr Sub.
TTk h xnehaoao,hehaiAme wheyouok n
ursowfed r .cw s Kboof r ms toA'tiGWtoRfawon
beeTm Mnaufowazirrse tab;gfyosor Ite
Wdochagu pe r se itIr t vol wy Cp omaatit leher
tiBSen,iPd ek tor  Uubheoryan r
f ser m borcses Mdi-k diumubfciOElsamfeunrcisiLun gsh Von REaAITratle bh?e

Checkpoint 1101#

ckpt_1101 = checkpoints.restore_checkpoint(ckpt_dir=f'{CKPT_DIR}/nn_compact1101', target=None)

idx = bglm.apply(ckpt_1101["params"], 600, method='generate')
generation = ts.decode(idx.tolist()[0])
print(generation.strip())
Thostou phithat, Nous, theey, ou has deave sow, a'r couth now sthensing mustatgn nind
Irlls, wencir Beloa I you the gard me weere bolllon bothe awngeis
Wonksser! Cordowitike;
What to greimstant for frighe ming andoo kend staioud us bear tikirought
Heeinte; sodllleantst:
my to love nour h, abe it wit to band,
at omonst gravubstink hext have sheais we wheave, andsesth, I mu.

As sooof warstt yould tould thue ep.
Morclowaze-
Andt beavy sor I hat ochesuspe pese it dame Allll my omantinge.

TROFBUCHANP IESTHARD UV:
You me hand me ma, you, midisk dighur:
O that fout cist trecs: Von heave hat now spa

Checkpoint 5901#

ckpt_5901 = checkpoints.restore_checkpoint(ckpt_dir=f'{CKPT_DIR}/nn_compact5901', target=None)

idx = bglm.apply(ckpt_5901["params"], 600, method='generate')
generation = ts.decode(idx.tolist()[0])
print(generation.strip())
KING EDWARD IV:
You should be our us;n all some acculous;
Respemions his jotal no in shrl beawen to Bilianca
Angromer Captque true of soldiers, away is
Wondomen to roogion by mertaint, being the sakery
Backing and of thy son. Though say the poover
His stoublt loooks stop not: love he fir,
Him it with seen at a toloner gracuble,
That I have should be when woman.

BENVOLIO:
Why, sir, no must you pray! I thus explumply was:
The tabegot sorrow wedow again; and compract our
Company at their match exputine.


GLOUCESTER:
How farewell,
Is it is to dismublcients, for this man could
Thee a dranged by M