Gaurav's Blog

Efficient AI: KV Caching and KV Sharing

Firstly, it feels good to blog again after a seven year hiatus. Secondly, starting with this post, I am starting a series on Efficient AI techniques (architecture, optimization, data, etc.). Most of these posts are going to be focused on techniques for improving autoregressive decoder-only models (LLMs) while also being generally applicable to other models with some tweaks. I also assume that you are familiar with the basics of the transformer architecture (if you are not, this or this might be good first steps.).

Summary

Doing inference on transformers can be expensive. Inference latency scales linearly with the model depth / number of layers ($l$) in the model. Efforts like Early Exits aim to reduce the number of layers used while processing easier tokens (for some definition of ‘easier’), but aren’t trivial to implement. We will cover it in a future post.

In this post, we cover two different, but related techniques:

  1. KV Caching: cache and reuse the K, V representations of tokens that are already computed in previous steps. You might already be familiar with his.
  2. KV Sharing: share the key and value representations ($K$ and $V$) of tokens across the last half of the layers of a transformer model. Therefore avoiding re-computing them across the last half of the layers. Other weight tensors such as query, MLP, etc. remain non-shared. This is a relatively newer technique.

Discussion

One of the reasons for a naive transformer implemention’s expensive inference is the need to compute the key, and value representations for all the tokens in the given sequence in the Self-Attention module. It looks something like the figure below.

Self Attention block.
A typical Self-Attention block. Source: Sebastian Raschka's blog.

Since transformers are typically operating in the auto-regressive (AR) setting, this works as follows:

  1. Assume we have a sequence of tokens $S = [s_0, s_1, s_2, …, s_{n-1}]$, that we have already generated.
  2. To predict the next token, $s_n$, we need the K, V representations of all the tokens in the sequence seen so far. And we need to do this for the $l$ layers.
  3. This means we need to compute $l$ matrices of size $(n-1) \times d$, where $d$ is the model dimension / embedding dimension, via a matrix multiplication of the form $X . W_i$, where $X$ is the input at a particular layer, and $W$ is either the key or value weight matrix ($W_K$ or $W_V$) at that layer.
  4. This is expensive because we will incur $l$ matrix multiplications, each costing $O(nd^2)$, for a total cost of $O(lnd^2)$ per predicted token! It’s growing faster than the US National Debt!

In the above calculation, we assume a single attention head. Thus the total cost of computing the K,V representations per head ($O(lnd^2)$) is made up of three components:

  1. Sequence Dimension: The number of tokens seen so far ($n-1$).
  2. Depth Dimension: The number of layers ($l$).
  3. Model Dimension: The width of the input stream ($d$).

We can’t do much about $d$ yet, but let’s see how we can tackle the other two for now.

1. KV Caching: Optimizing the sequence dimension.

KV Caching suggests there are two things happening during inference:

  • Model weights, including $W_K$ and $W_V$ are fixed.
  • K, V representations for a given token $s_i$ only depends on that token and $W_K$ and $W_V$.

Since $W_K$ and $W_V$ are fixed, once we compute the K, V representations for a given (token, layer, head) tuple, it can be reused when predicted any subsequent tokens by caching those representations in memory and reusing them for the next step.

If we can do this, we would only need to compute the K, V representations of the $s_{n-1}$ token when predicting the $n$-th token, since that’s the only token for which we don’t have the K, V vectors. Therefore, it is easy to show that the new total cost of computing the representations is $O(ld^2)$, an $n$-times speedup! This is a significant win, especially if the sequence is very long.

Question: Why don’t we cache the query representation?
Answer: We only compute and use the Q vector of the last token in the Self-Attention block. Thus, there is no need for caching the Q representations for the previous tokens.

2. KV Sharing: Optimizing the depth dimension.

KV Sharing reduces the cost of computing the K, V representations in the depth-dimension ($l$). Concretely, the proposal is that the actual K, V representations are the same between the last half (or any other fraction) of the layers.

Note that we are referring to the actual $K$, $V$ tensors being the same, not just the $W_K$ and $W_V$ matrices being shared. What this means is that the last layer which doesn’t share the K, V representations computes them once, and they are used as is across the remaining half of the layers (regardless of what the inputs are to these layers).

Said in an even simpler way, there is a KV cache across the last half of the layers. This is illustrated in the figure below.

KV Sharing
KV Sharing. Source: You Only Cache Once paper.

The figure above illustrates the KV Sharing by showing a shared KV cache between the last $l/2$ layers. It is easy to see that if this works, we can simply not compute the K, V representations for $l/2$ of the layers. More generally, we save $l/k$ of the FLOPS, if the last $l/k$ layers are shared.

However, to make this work we need to ensure that the model is trained with the KV-Sharing behavior. This is detailed in the You Only Cache Once paper.

Some of the intuition behind why this even works in the first place, comes from works like these, which show empirically that in a deep transformer-like model, the last layers are correlated with each other. What this means is that the last few layers are not necessarily adding a lot of new information, but just tweaking the output so far. This redundancy can potentially be exploited (on another note, how can we make these layers do more heavy lifting?).

Additionally, note that we are only sharing the K, V representations, so it only affects the representations of the tokens seen in the past in the Self-Attention block, and is allowing cheap some degrees of freedom to the model.

Another bonus of this technique:

  1. You also save a lot of memory, since you don’t have to store $W_K$ and $W_V$ at all.
  2. It is applicable during training as well, so you save on inference and memory during training too.

Conclusion

In this post we saw that we can significantly reduce the costs associated with computing the K,V representations in the Self-Attention block using KV Caching and KV Sharing. Concretely, we reduced it by a factor of:

  1. $n$ by implementing KV Caching.
  2. $l/2$ by implementing KV Sharing across the last $l/2$ layers.

The total cost is now $O(ld^2)$, but with a significantly smaller constant due to KV Sharing. Additionally, KV Sharing eliminates the $W_K$ and $W_V$ matrices for half the layers, which is another huge gain.

That brings us to a the end of this post. Please feel free to drop in any comments if I missed something.

Read more

Scaling SGD

I’ve been reading a few papers related to scaling Stochastic Gradient Descent for large datasets, and wanted to summarize them here.

  • One of the popular papers in this domain, talks about work on a new distributed training framework called DistBelief. Pre-cursor to the distributed training support in Tensorflow.
  • Before this work, ideas for doing SGD in a distributed setting restricted the kind of models (convex / sparse gradient updates / smaller models on GPUs with gradient averaging).
  • This works describes how to do distributed asynchronous SGD.

Model-Level Parallelism: Works with large models by splitting the model graph itself into several parts. Each part of the model is assigned to a different machine. If there is an edge between two nodes in different parts, the two machines hosting those parts would need to communicate. This is to get around the problem of fitting a large model on a single GPU.

Splitting the Model Graph
Figure 2: Splitting the Model Graph.

Downpour SGD: To be able to scale to large datasets, DistBelief also runs several replicas of the model itself. The training data is split into several subsets, and each replica works on a single subset. Each of the replica sends the updates of its params to a Parameter Server. The parameter server itself is sharded, and is responsible for getting updates for a subset of params.

Whenever a new replica starts a new minibatch, it gets the relevant params from the parameter server shards, and then sends its updates when its done with its minibatch.

Parameter Server
Figure 2: Parameter Server.

The authors found Adagrad to be useful in the asynchrous SGD setting, since it uses an adaptive learning rate for each parameter, which makes it easy to implement locally per parameter shard.

  • This paper describes how the authors trained ImageNet using synchronous SGD. However, given the synchronous nature of SGD, the idea is to use large batches (of the order of thousands of samples), instead of mini-batches (which are typically in the tens of samples), to avoid the communication overhead.
  • They demonstrate that with their method, they are able to use large batch sizes (up to 8192) without hurting accuracy with a ResNet-50 model (as compared to the baseline model with a batch-size of 256). Using 256 Tesla P100 GPUs, their model trains on the ImageNet dataset within 1 hour.
  • Linear Scaling Rule for Learning Rate: “When the minibatch size is multiplied by $k$, multiply the learning rate by $k$.”. One way to think about this is, if the batch size is increased by $k$ times, there are $k$ times fewer updates to weights (since there $k$ times fewer iterations per epoch). Another intuition is, with smaller batches the stochasticity (randomness) of the gradient is higher. With bigger batches, you can confidently take bigger steps.
  • The authors do a gradual warm-up of the learning rate from a small value, to the target learning rate, per the linear scaling rule. The authors hypothesize that the linear scaling rule breaks down for large batches in the initial stages of the training, where a gradual warm-up helps with better training.
  • Another paper that is similar to the paper by Goyal et al. They use a bigger batch-size (32k instead of 8k).
  • As per the numbers reported in the paper, with a 32k batch size, they get accuracy comparable to smaller batches. The training finishes in 14 minutes using unspecified number of Intel Knights Landing CPUs (possibly 1024 or 2048).
  • They use the gradual warm-up reported in Goyal et al., along with an algorithm that tweaks the learning-rate on a layer-wise basis (LARS algorithm - You et al., 2017). The LARS algorithm is similar to Adagrad (which works on a per-param level), which was useful in Dean et al.’s work.
Read more

Dynamic Programming: You Can Do It Half Asleep!

That was a click-baity title. :)

But seriously, people make a big deal out of ‘Dynamic Programming’, in the context of software engineering interviews. Also, the name sounds fancy but for most problems in such interviews, you can go from a naive recursive solution to an efficient solution, pretty easily.

Any problem that has the following properties can be solved with Dynamic Programming:

  1. Overlapping sub-problems: When you might need the solution to the sub-problems again.
  2. Optimal substructure: Optimizing the sub-problems can help you get the optimal solution to the bigger problem.

You just have to do two things here:

  1. Check if you can apply the above criteria.
  2. Get a recursive solution to the problem.

That’s it.

Usually the second part is harder. After that, it is like clockwork, and the steps remain the same almost all the time.

Example

Assume, your recursive solution to say, compute the n-th fibonacci number, is:

\[F(n) = F(n - 1) + F(n - 2)\] \[F(0) = F(1) = 1\]

Step 1: Write this as a recursive solution first

int fib(int n) {
  if (n == 0 || n == 1) {
    return 1;
  } else {
    return fib(n-1) + fib(n-2);
  }
}

Now, this is an exponential time solution. Most of the inefficiency comes in because we recompute the solutions again and again. Draw the recursion tree as an exercise, and convince yourself that this is true.

Also, when you do this, you at least get a naive solution out of the way. The interviewer at least knows that you can solve the problem (perhaps, not efficiently, yet).

Step 2: Let’s just simply cache everything

Store every value ever computed.

int cache[20];
int fib(int n) {
  // Pre-fill all the values of cache as -1.
  memset(cache, -1, sizeof(cache)); 
  return fibDP(n);
}

int fibDP(int n) {
  // Check if we have already computed this value before.
  if (cache[n] != -1) {
    // Yes, we have. Return that.
    return cache[n];
  }
  
  // This part is just identical to the solution before.
  // Just make sure that we store the value in the cache after computing
  if (n == 0 || n == 1) {
    cache[n] = 1;
  } else {
    cache[n] = fibDP(n-1) + fibDP(n-2);
  }
  
  return cache[n]; 
}

Let us compute how many unique calls can we make to fibDP?

  • There is one parameter, n.
  • Hence, n unique values of n can be passed to fibDP.
  • Hence, n unique calls.

Now, realize two things:

  1. We would never compute the value of the function twice for the same value, ever.
    • So, given $n$, we would call the function $n$ times, as seen above.
    • Each time with $O(1)$ work in each function call.
    • Total = $n$ calls with $O(1)$ work each => $O(n)$ total time complexity.
  2. We are using up extra space.
    • We use as much extra space as:
    • Number of possible unique calls to the recursive function * space required for each value.
    • Since there are $n$ unique calls possible with an int value, space would be $O(n)$.
    • I have hard-coded a limit of 20 in my code. We can also use a Vector etc.

That’s it. We just optimized the recursive code from a $O(2^n)$ time complexity, $O(n)$ space complexity (recursive call stack space) to an $O(n)$ time, $O(n)$ space (recursive + extra space).

Example with a higher number of parameters

int foo(int n, int m) {
  if (n <= 0 || m <= 0) {
   return 1;
  }
    
  return foo(n-1, m) + foo(n, m-1) + foo(n-1, m-1);
}

Time complexity: $O(3^{n+m})$ [Work it out on paper, why this would be the complexity, if you are not sure.]

DP Code

int cache[100][100];
int foo(int n, int m) {
  memset(cache, -1, sizeof(cache));
  return fooDP(n, m);
}

int fooDP(int n, int m) {
  if (n <= 0 || m <= 0) {
   return 1;
  }
  
  if (cache[n][m] == -1) {
    cache[n][m] = fooDP(n-1, m) + fooDP(n, m-1) + fooDP(n-1, m-1);
  }
  return cache[n][m];
}
  • Number of unique calls: $O(nm)$
  • Space Complexity: $O(nm)$
  • Time Complexity: $O(nm)$

Assume I tweak foo and add an $O(n \log m)$ work inside each call, that would just be multiplied for the time complexity, i.e.,

Time complexity = O(unique calls) * O(work-per-call)

\[\implies O(nm) \times O(n \log m)\] \[\implies O(n^2 m \log m)\]

$Space Complexity = O(unique calls) * O(space per call)

\[\implies O(nm) \times O(1)\] \[\implies O(nm)\]

Now just reinforce these ideas with this question

  • Given a rectangular grid of size N x M,
  • What is the length of the longest path from bottom-left corner (0, 0) to top-right corner (N - 1, M - 1), assuming you can go up, right, diagonally?

Extra Credit

What we saw is called top-down DP, because we are taking a bigger problem, breaking it down into sub-problems and solving them first. This is basically recursion with memoization (we ‘memoize’ (fancy word for caching) the solutions of the sub-problems).

When you absolutely, totally nail the recursive solution, some interviewers might want a solution without recursion. Or, probably want to optimize the space complexity even further (which is not often possible in the recursive case). In this case, we want a bottom-up DP, which is slightly complicated. It starts by solving the smallest problems iteratively, and builds the solution to bigger problems from that.

Only if you have time, go in this area. Otherwise, even if you mention to the interviewer that you know there is something called bottom-up DP which can be used to do this iteratively, they should be at least somewhat okay. I did a short blog-post on converting a top-down DP to a bottom-up DP if it sounds interesting.

Read more