Home Machine Learning Mamba: SSM, Concept, and Implementation in Keras and TensorFlow | by Vedant Jumle | Mar, 2024

Mamba: SSM, Concept, and Implementation in Keras and TensorFlow | by Vedant Jumle | Mar, 2024

0
Mamba: SSM, Concept, and Implementation in Keras and TensorFlow | by Vedant Jumle | Mar, 2024

[ad_1]

Understanding how SSMs and Mamba work, together with the best way to get began with implementing it in Keras and TensorFlow.

Supply: AI Generate (SDXL)

Submitted on 1st December, 2023 on arXiv, the paper titled “Mamba: Linear-Time Sequence Modeling with Selective State Areas” proposed an attention-grabbing method to sequence modeling. The authors — Albert Gu, Tri Dao — launched, ‘Mamba’ that utilized ‘selective’ state house fashions (SSM) to realize outcomes that compete with the efficiency of the, now ubiquitous, Transformer mannequin.

Transformers have seen latest reputation with the rise of Giant Language Fashions (LLMs) like LLaMa-2, GPT-4, Claude, Gemini, and many others., nevertheless it suffers from the issue of context window. The difficulty with transformers lies in it’s core, the multi head-attention mechanism.

The primary difficulty with multi-head consideration sprouts from the truth that for enter sequence size n, the time complexity and house complexity scales by O(n²). This limits the size of the context window of an LLM. As a result of, to extend it by 10x, we have to scale the {hardware} requirement (most notably GPU VRAM) by 100x.

Mamba, alternatively, scales by O(n)!, i.e., Linearly.

Plot taken from the Mamba paper evaluating FlashAttention and Mamba method (indicated by scan(ours) within the legends)[1]

This linear scaling is what has taken wind for researchers to take a position that Mamba is perhaps the way forward for sequence modeling.

The core of the Mamba mannequin comes from the idea of State Area Fashions. State Area Fashions, like Transformers and RNN, course of sequences of data, like textual content, audio indicators, video frames, DNA sequences, and many others.

State Area Fashions come from an thought of describing a bodily system as a set of enter, outputs, and variables. These variables are: A, B, C, D. The method of SSM includes calculation of an inner state vector h(t), given an enter x(t). Then, we do a weighted sum of h(t) and x(t) the place the weights are A, B, C, & D. Within the easiest kind (steady time-invariant), the method formulation appears like:

supply: wikipedia[6]

h(t) is commonly known as the ‘hidden’ or the ‘latent’ state, I shall be sticking to calling it the ‘hidden’ state for higher readability. You will need to be aware that A, B, C, and D are learnt parameters in SSM.

What are the variables?

The variables, A, B, C & D, are learnt parameters, and they are often described as:

  • A: How a lot ought to the earlier hidden state (h) be thought-about to calculate the brand new hidden state
  • B: How a lot ought to the enter (x) be contemplate to calculate the brand new hidden state.
  • C: How a lot ought to the brand new hidden state be thought-about in calculating the output (y).
  • D: How a lot ought to the enter (x) be contemplate in calculating the output (y).

D comes ultimately of the computations and doesn’t have an effect on how the hidden state is calculated. Therefore, it’s normally thought-about exterior of ssm, and will be regarded as a skip connection.

Going from steady areas to discrete areas

The above formulation applies to a system the place the enter and output belong to a steady house. However in circumstances, like language modeling, the place the enter and output belong to discrete areas (token values in a vocabulary). Additionally, discovering h(t) is analytically difficult. This may be achieved by performing a Zero-order maintain.

In a zero-order maintain, each time an enter is acquired, the mannequin holds its worth until the following enter is acquired. This results in a steady enter house.

How Zero order maintain works

This size of ‘maintain’ is decided by a brand new parameter known as, step measurement ∆. It may be regarded as the decision of the enter. Ideally, ∆ needs to be infinitesimal.

Mathematically, Zero-order maintain will be described as:

Lastly, we will create a discrete SSM, as:

Since, D is used with a skip connection exterior of SSM, the output will be lowered to:

Involvement of DX(t) is taken into account as a skip connection, therefore is goes from exterior of SSM

In SSMs, the hidden state is carried over to when the following enter is acquired. That is just like how Recurrent Neural Networks operate.

Comparability of RNN and SSM

This recurrent format of SSM will be unwrapped, similar to RNNs. However in contrast to RNNs, that are iterative and sluggish, SSM can course of the enter sequence in parallel (similar to transformers) and this makes the coaching processes quicker.

Unrolled type of SSM

Observe that ‘D’ is utilized in a skip connection, which is exterior of SSM.

The important thing perception in how SSM make coaching quick is to make use of the variables A, B, C in a pre-computed convolutional kernel. Maarten Grootendorst wrote a extremely good clarification on how this canonical ‘convolutional’ kernel is constructed. However right here’s a easy mathematical clarification.

Contemplate the output y. For a sequence size of ok, the output for y(ok) shall be represented (assuming h0 = zero):

Equally, y3 will be represented as:

Extrapolating the sample, yk will be represented as:

This formulation will be additional lowered to:

The humorous trying multiplication image represents a convolution operation, the place the convolution kernel is Okay. Discover that Okay will not be depending on x, therefore Okay will be pre-computed right into a convolutional kernel, which makes the method quicker.

Nearly as good because the computational capability of SSM sounds, it seems to be fairly meh in metrics like accuracy in comparison with Transformers.

The core difficulty lies with the variables, ∆, A, B, & C. Seems that since we apply the identical matrices to each enter, they can’t actually course of the context of the sequence.

SSMs are rigid in the way in which they course of information[4]

So what’s so particular about Mamba? In mamba, we use a course of known as ‘selective’ SSM, the place the variables, ∆, B, & C, are computed primarily based on the enter. 🤔. We do that by passing the present enter by means of Linear layers, and take the output to be the ∆, B, & C.

However then this makes ∆, B, & C enter dependent, therefore that means that they can’t be pre-computed 😢, quick convolution isn’t going to work right here. However, the authors focus on a way, which is predicated on parallel associative scan.

Parallel Associative Scan

Parallel associative scan is a robust method utilized in parallel computing to carry out a prefix sum operation, which is a cumulative operation on a sequence of numbers. This operation is “associative”, that means the way in which numbers are grouped within the operation doesn’t change the outcome.

Parallel prefix sum is an instance of associative scanning. (supply: Nvidia)[7]

Within the context of the Mamba mannequin, by defining an associative operator, components and associative operators for a parallel associative scan operation are obtained. This enables for fixing issues on the entire time interval in parallel, leading to logarithmic time complexity within the variety of sub-intervals.

{Hardware} conscious algorithm

Together with associative scan, the authors additionally suggest a {hardware} conscious algorithm, the place they use the quirks inside Nvidia GPUs associated to the pace of HBM and SRAM. They argue that the computation of SSM states will be sped up by:

  • preserving the hidden state and A within the quicker however much less capability SRAM,
  • whereas computing ∆, B, & C, within the slower however bigger capability HBM.
  • They then switch ∆, B, & C to the SRAM, compute the brand new hidden state inside SRAM.
  • After which write ∆, B & C again to HBM.
Illustration taken from the Mamba paper, it reveals how the {hardware} conscious algorithm works[1]

Within the implementation part, I can’t be discussing on the best way to work with the {hardware} conscious algorithm, reasonably I shall be solely utilizing parallel associative scan.

With all of this in thoughts, let’s discover and implement the Mamba structure utilizing Keras and TensorFlow.

The Mamba structure, after studying the paper and evaluation of the code, will be damaged into just a few key elements that are related as:

Breakdown of a mamba block

The Mamba structure consists of a number of stacked layers of ‘Mamba blocks’. Which, judging from the above illustration, consists of fairly just a few elements. One other necessary factor to notice is that the authors add the output from Selective SSM to the unique enter after which apply a normalization layer to it. This normalization will be both a Layer normalization or an RMS normalization.

Lets begin with coding a part of Mamba. We are going to utilizing the next dependencies:

tensorflow[and-cuda]==2.15.0.post1 # if you wish to use GPU or
tensorflow==2.15.0.post1 # if you wish to solely use CPU
transformers==4.36.2 # for utilizing the bert tokenizer
einops==0.7.0 # helpful to make matrix manipulation quicker
datasets==2.16.1 # to load datasets
# all different modules (like numpy) shall be auto put in

Imports:

import tensorflow_datasets as tfds
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers, Mannequin

from dataclasses import dataclass
from einops import rearrange, repeat
from typing import Union

from transformers import AutoTokenizer

import datasets
import math
import numpy as np

To make the modeling argument processing simpler, let’s create a easy ModelArgs dataclass as a config class. This enables us to simply move the dataclass variable within the arguments once we are initializing the mannequin.

@dataclass
class ModelArgs:
model_input_dims: int = 64
model_states: int = 64
projection_expand_factor: int = 2
conv_kernel_size: int = 4
delta_t_min: float = 0.001
delta_t_max: float = 0.1
delta_t_scale: float = 0.1
delta_t_init_floor: float = 1e-4
conv_use_bias: bool = True
dense_use_bias: bool = False
layer_id: int = -1
seq_length: int = 128
num_layers: int = 5
dropout_rate: float = 0.2
use_lm_head: float = False
num_classes: int = None
vocab_size: int = None
final_activation = None
loss:Union[str, keras.losses.Loss] = None
optimizer: Union[str, keras.optimizers.Optimizer] = keras.optimizers.AdamW()
metrics = ['accuracy']

def __post_init__(self):
self.model_internal_dim: int = int(self.projection_expand_factor * self.model_input_dims)

self.delta_t_rank = math.ceil(self.model_input_dims/16)
if self.layer_id == -1:
self.layer_id = np.spherical(np.random.randint(0, 1000), 4)

if self.vocab_size == None:
increase ValueError("vocab measurement can't be none")

if self.use_lm_head:
self.num_classes=self.vocab_size
else:
if self.num_classes == None:
increase ValueError(f'num courses can't be {self.num_classes}')

if self.num_classes == 1:
self.final_activation = 'sigmoid'
else:
self.final_activation = 'softmax'

if self.loss == None:
increase ValueError(f"loss can't be {self.loss}")

Load the bert-base-uncased tokenizer:

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
vocab_size = tokenizer.vocab_size

Earlier than we implement our Mamba and SSM courses, we have to implement the parallel associative scan, the code appears like this:

def selective_scan(u, delta, A, B, C, D):
# first step of A_bar = exp(ΔA), i.e., ΔA
dA = tf.einsum('bld,dn->bldn', delta, A)
dB_u = tf.einsum('bld,bld,bln->bldn', delta, u, B)

dA_cumsum = tf.pad(
dA[:, 1:], [[0, 0], [1, 1], [0, 0], [0, 0]])[:, 1:, :, :]

dA_cumsum = tf.reverse(dA_cumsum, axis=[1]) # Flip alongside axis 1

# Cumulative sum alongside all of the enter tokens, parallel prefix sum,
# calculates dA for all of the enter tokens parallely
dA_cumsum = tf.math.cumsum(dA_cumsum, axis=1)

# second step of A_bar = exp(ΔA), i.e., exp(ΔA)
dA_cumsum = tf.exp(dA_cumsum)
dA_cumsum = tf.reverse(dA_cumsum, axis=[1]) # Flip again alongside axis 1

x = dB_u * dA_cumsum
# 1e-12 to keep away from division by 0
x = tf.math.cumsum(x, axis=1)/(dA_cumsum + 1e-12)

y = tf.einsum('bldn,bln->bld', x, C)

return y + u * D

With this, we will implement the MambaBlock:

class MambaBlock(layers.Layer):
def __init__(self, modelargs: ModelArgs, *args, **kwargs):
tremendous().__init__(*args, **kwargs)
self.args = modelargs
args = modelargs
self.layer_id = modelargs.layer_id

self.in_projection = layers.Dense(
args.model_internal_dim * 2,
input_shape=(args.model_input_dims,), use_bias=False)

self.conv1d = layers.Conv1D(
filters=args.model_internal_dim,
use_bias=args.conv_use_bias,
kernel_size=args.conv_kernel_size,
teams=args.model_internal_dim,
data_format='channels_first',
padding='causal'
)

# this layer takes in present token 'x'
# and outputs the input-specific Δ, B, C (based on S6)
self.x_projection = layers.Dense(args.delta_t_rank + args.model_states * 2, use_bias=False)

# this layer initiatives Δ from delta_t_rank to the mamba inner
# dimension
self.delta_t_projection = layers.Dense(args.model_internal_dim,
input_shape=(args.delta_t_rank,), use_bias=True)

self.A = repeat(
tf.vary(1, args.model_states+1, dtype=tf.float32),
'n -> d n', d=args.model_internal_dim)

self.A_log = tf.Variable(
tf.math.log(self.A),
trainable=True, dtype=tf.float32,
identify=f"SSM_A_log_{args.layer_id}")

self.D = tf.Variable(
np.ones(args.model_internal_dim),
trainable=True, dtype=tf.float32,
identify=f"SSM_D_{args.layer_id}")

self.out_projection = layers.Dense(
args.model_input_dims,
input_shape=(args.model_internal_dim,),
use_bias=args.dense_use_bias)

def name(self, x):
"""Mamba block ahead. This appears the identical as Determine 3 in Part 3.4 within the Mamba pape.
Official Implementation:
class Mamba, https://github.com/state-spaces/mamba/blob/major/mamba_ssm/modules/mamba_simple.py#L119
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/major/mamba_ssm/ops/selective_scan_interface.py#L311
"""

(batch_size, seq_len, dimension) = x.form

x_and_res = self.in_projection(x) # form = (batch, seq_len, 2 * model_internal_dimension)
(x, res) = tf.cut up(x_and_res,
[self.args.model_internal_dim,
self.args.model_internal_dim], axis=-1)

x = rearrange(x, 'b l d_in -> b d_in l')
x = self.conv1d(x)[:, :, :seq_len]
x = rearrange(x, 'b d_in l -> b l d_in')

x = tf.nn.swish(x)
y = self.ssm(x)
y = y * tf.nn.swish(res)
return self.out_projection(y)

def ssm(self, x):
"""Runs the SSM. See:
- Algorithm 2 in Part 3.2 within the Mamba paper
- run_SSM(A, B, C, u) in The Annotated S4
Official Implementation:
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/major/mamba_ssm/ops/selective_scan_interface.py#L311
"""
(d_in, n) = self.A_log.form

# Compute ∆ A B C D, the state house parameters.
# A, D are enter impartial (see Mamba paper [1] Part 3.5.2 "Interpretation of A" for why A is not selective)
# ∆, B, C are input-dependent (it is a key distinction between Mamba and the linear time invariant S4,
# and is why Mamba known as **selective** state areas)

A = -tf.exp(tf.forged(self.A_log, tf.float32)) # form -> (d_in, n)
D = tf.forged(self.D, tf.float32)

x_dbl = self.x_projection(x) # form -> (batch, seq_len, delta_t_rank + 2*n)

(delta, B, C) = tf.cut up(
x_dbl,
num_or_size_splits=[self.args.delta_t_rank, n, n],
axis=-1) # delta.form -> (batch, seq_len) & B, C form -> (batch, seq_len, n)

delta = tf.nn.softplus(self.delta_t_projection(delta)) # form -> (batch, seq_len, model_input_dim)

return selective_scan(x, delta, A, B, C, D)

Lastly, a residual block to implement the exterior skip connection.

class ResidualBlock(layers.Layer):
def __init__(self, modelargs: ModelArgs, *args, **kwargs):
tremendous().__init__(*args, **kwargs)
self.args = modelargs
self.mixer = MambaBlock(modelargs)
self.norm = layers.LayerNormalization(epsilon=1e-5)

def name(self, x):
"""
Official Implementation:
Block.ahead(), https://github.com/state-spaces/mamba/blob/major/mamba_ssm/modules/mamba_simple.py#L297

Observe: the official repo chains residual blocks that appear like
[Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
the place the primary Add is a no-op. That is purely for efficiency causes as this
permits them to fuse the Add->Norm.

We as an alternative implement our blocks because the extra acquainted, less complicated, and numerically equal
[Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....

"""
return self.mixer(self.norm(x)) + x

With this, we will initialize our mannequin. On this instance, I shall be demonstrating the best way to use the Mamba block to create a easy classification mannequin, however it may be simply modified to turn out to be a language mannequin. Let’s load the IMDB critiques dataset for a easy sentiment classifier.

from datasets import load_dataset
from tqdm import tqdm

dataset = load_dataset("ajaykarthick/imdb-movie-reviews")

First we create a operate that can take the mannequin args and return a mannequin.

def init_model(args: ModelArgs):
input_layer = layers.Enter(form=(args.seq_length,), identify='input_ids')
x = layers.Embedding(
args.vocab_size,
args.model_input_dims,
input_length=args.seq_length)(input_layer)

for i in vary(args.num_layers):
x = ResidualBlock(args, identify=f"Residual_{i}")(x)
x = layers.Dropout(args.dropout_rate)(x) # for regularization

x = layers.LayerNormalization(epsilon=1e-5)(x) # normalization layer

# use flatten provided that we're not utilizing the mannequin as an LM
if not args.use_lm_head:
x = layers.Flatten()(x)
x = layers.Dense(1024, activation=tf.nn.gelu)(x)
output_layer = layers.Dense(
args.num_classes,
activation=args.final_activation)(x)

mannequin = Mannequin(
inputs=input_layer,
outputs=output_layer, identify='Mamba_ka_Mamba')
mannequin.compile(
loss=args.loss,
optimizer=args.optimizer,
metrics=args.metrics
)

return mannequin

Now we will initialize our mannequin, and summarize it:

args = ModelArgs(
model_input_dims=128,
model_states=32,
num_layers=12,
dropout_rate=0.2,
vocab_size=vocab_size,
num_classes=1,
loss='binary_crossentropy',
)
mannequin = init_model(args)
mannequin.abstract()
Mannequin: "Mamba_ka_Mamba"
_________________________________________________________________
Layer (sort) Output Form Param #
=================================================================
input_ids (InputLayer) [(None, 128)] 0

embedding_2 (Embedding) (None, 128, 128) 3906816

Residual_0 (ResidualBlock) (None, 128, 128) 129024

dropout_24 (Dropout) (None, 128, 128) 0

Residual_1 (ResidualBlock) (None, 128, 128) 129024

dropout_25 (Dropout) (None, 128, 128) 0

... (I've shrinked this to make it extra readable)

dropout_35 (Dropout) (None, 128, 128) 0

layer_normalization_38 (La (None, 128, 128) 256
yerNormalization)

flatten_2 (Flatten) (None, 16384) 0

dense_148 (Dense) (None, 1024) 16778240

dense_149 (Dense) (None, 1) 1025

=================================================================
Complete params: 22234625 (84.82 MB)
Trainable params: 22234625 (84.82 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

For simpler processing, lets pre-tokenize our information right into a numpy arrays, then convert them into tf.information.Dataset objects:

train_labels, test_labels = [], []
train_ids = np.zeros((len(dataset['train']), args.seq_length))
test_ids = np.zeros((len(dataset['test']), args.seq_length))

for i, merchandise in enumerate(tqdm(dataset['train'])):
textual content = merchandise['review']
train_ids[i, :] = tokenizer.encode_plus(
textual content,
max_length=args.seq_length,
padding='max_length',
return_tensors='np')['input_ids'][0][:args.seq_length]

train_labels.append(merchandise['label'])

for i, merchandise in enumerate(tqdm(dataset['test'])):
textual content = merchandise['review']
test_ids[i, :] = tokenizer.encode_plus(
textual content,
max_length=args.seq_length,
padding='max_length',
return_tensors='np')['input_ids'][0][:args.seq_length]

test_labels.append(merchandise['label'])

del dataset # delete the unique dataset to avoid wasting reminiscence

BATCH_SIZE = 32
train_dataset = tf.information.Dataset.from_tensor_slices((train_ids, train_labels)).batch(BATCH_SIZE).shuffle(1000)
test_dataset = tf.information.Dataset.from_tensor_slices((test_ids, test_labels)).batch(BATCH_SIZE).shuffle(1000)

Now the mannequin will be educated:

historical past = mannequin.match(train_dataset, validation_data=test_dataset, epochs=10)

You may mess around with the inference algorithm:

def infer(textual content: str, mannequin: Mannequin, tokenizer):
tokens = tokenizer.encode(
"Whats up what's up",
max_length=args.seq_length,
padding='max_length', return_tensors='np')
output = mannequin(tokens)[0, 0]
return output

This mannequin will be transformed right into a language mannequin and algorithms like beam search, top-k sampling, grasping sampling, and many others. can be utilized to generate language.

This code will be discovered on my Github.

Lots of the code is impressed from the mamba’s official implementation[2] and one other pytorch implementation known as ‘mamba-tiny’[3]

Thanks for studying.

[ad_2]