State Space Models For Sequence Modeling

Atufa Shireen
4 min readMar 1, 2024

--

In control systems, state space models are probabilistic frameworks used to model systems with unobservable states. Consider the kalman filter as an example., it represents the evolution of hidden states over time through transition distributions and connects these states to observed data through observation distributions.

In this blog we are discussing the structured and the selective state space models for sequence modeling. The structured state space models as proposed earlier works successfully in domains involving continuous data such as audios while selective state space models are their successors which shows great performance on discrete data like texts.

And before that, let’s recap our knowledge of a few sequence models:

After going through the pros, cons and structure of the above sequence models, the SSMs aims to address the linear computational inefficiency of these models on long sequences.

STRUCTURED STATE SPACE MODEL (S4):

The SSM generates the output based on the current hidden state and the input of the model.

General state space model

Similarly, the structured SSM consists of the following equations:

a. Input processing

b. Dynamic state representation

c. The output generation

and the following phases:

a. Discretisation of the parameters

b. Model Computation

A. Input processing:

In the above equations, the matrices A,B,C are the learnable parameters here and x(t) is the input sequence of characters.

* The matrix A captures the information from all the previous states and input history to generate the output. Think of it like a fourier transform which is approximating the input signal and storing it as a vector of coefficients, but instead of a fourier transform the authors use hippo method.

* The matrix C describes how the state can be translated to an output.

* The matrix B gets multiplied by the input signal x(t) describing how the inputs influence the system.

B. Dynamic state representation:

The state is an equation described by h’(t) = Ah(t) + Bx(t), with the goal to find h(t).The state defines how the current state evolves through matrix A based on how the input is influenced by matrix B.

B. Output generation:

The output is generated by the equation y(t) = Ch(t)

The output equation describes how the current state translates to output (through matrix C).

Together, these two equations aim to predict the state of a system from observed data. Since the input is expected to be continuous, the main representation of the SSM is a continuous-time representation.

A. The Discretisation with delta:

The state system is continuous right now, and to make it discretized, SSMs use another parameter delta such that instead of a function-to-function, x(t) → y(t), is now a sequence-to-sequence mapping. xt → yt using ZoH (zero order Hold).

The smaller step size ∆ results in ignoring specific words and instead using the previous context more whilst a larger step size ∆ focuses on the input words more than the context.

2. Model Computation:

* The Recurrence mode:

When you closely look at the state equation, you can notice it resembles the Rnns. i.e the current hidden state is dependent on the previous hidden state.This recurrence relation makes the inference time of the rnn models to be linear, since the output is dependent only on the previous state, reducing the number of computations.

However, the recurrence formulation is expensive to backprop and update the parameters since it cannot be parallelised, which is addressed by the convolution mode of the ssms.

* The convolution mode:

Given that the state and output equations are linear rnn, the ssm unrolls the rnn into wide cnns for training.

eg: y2 = Ch(2)

y2 = C(Ah(1) + Bx(2))

y2 = C(A(Ah(0) +Bx(1)) + Bx(2))

y2 = C(A(ABx(0) +Bx(1)) + Bx(2))

y2 = C(A2Bx(0) + ABx(1) + Bx(2))

y2 = CA2Bx(0) + CABx(1) + CBx(2)

So we can substitute,

K = (CB, CAB, CA2B..)

x = (x0, x1, x2,..)

and obtain the following

y = x*K (x convolve K) {From image 2b, eq 3a and 3b}

Therefore, the output can be calculated with both the rnn method and the convolution method since it is not dependent on the previous outputs and hence, the ssm uses the convolution method during training as they can be parallelised and the rnn method for inferencing as they have linear time complexity.

SELECTIVE STATE SPACE MODELS:

If you noticed, the structured ssm lacks the ability to look at only relevant information while maintaining the order in the state since the matrices, B , C and delta are static across inputs (Time Invariant). The selective state space models adds a selection mechanism in the SSMs and makes these matrices input dependent (Time varying).

Selection Mechanism:

A Linear layer is added for each of the parameters, such that we get a different B, C or delta value at each input. But now that the output cannot be calculated using the convolution method the authors address this by introducing the selective scan algorithm and hardware-aware algorithms.

The selective SSM that we have explored thus far can be implemented as a block, the same way we can represent attention in a transformer by stacking multiple ssm blocks and use their output as the input for the next block.

--

--