+34% throughput in vLLM

See GitHub issue, PR

vLLM has taken the AI world by storm a few weeks ago by showcasing a technique to speed up LLM serving at scale. We investigated whether we could further accelerate vLLM by profiling its performance with GPU counters. We ended up achieving a speed-up of 1.34x and found several interesting directions which are still open.

This issue has 3 sections:

  1. An optimization in the main attention kernel (single_query_cached_kv_attention)
  2. An optimization in the python code serving the models
  3. Further open directions + ideas which did not work out (yet)

This post assumes decent knowledge of how LLM inference works in general and some basic CUDA knowledge – if you would like us to write a post describing either of these in more detail, please reach out!

Benchmark

The main benchmark that vLLM uses to measure throughput is using LLaMA13B to complete 1000 randomly sampled prompts from ShareGPT, and creating a single completion for each prompt. To run the benchmark, begin by cloning vLLM, downloading the dataset from the project website, and running the following command.

1
python benchmarks/benchmark_throughput.py --backend vllm --dataset ./ShareGPT_V3_unfiltered_cleaned_split.json --model openlm-research/open_llama_13b --tokenizer hf-internal-testing/llama-tokenizer --num-prompts=1000

We begin by running the above on a clean clone of vLLM on an A100 (80GB), to receive the following output.

1
2
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1000/1000 [04:08<00:00,  4.02it/s]
Throughput: 4.02 requests/s, 1921.31 tokens/s

This rate of 4.02 sequences completed per second translates to 241.2 seq/min. On the project website, a throughput of 154.2 seq/min is reported for running the same model, yet on an A100 (40GB). For this issue, we are using an A100 (80GB), and so we set the reference point at 4.02 seq/sec. By the end of this post, we get 5.41 seq/sec, achieving an improvement of 1.34x.

Analyzing single_query_cached_kv_attention

The main kernel in vLLM is single_query_cached_kv_attention, which is used to compute the forward pass of an attention layer, using the KV cache designed in vLLM. We begin by profiling this kernel using NVIDIA Nsight Compute to check for potential improvements.

A preliminary look through NVIDIA Nsight Compute reveals several points to tackle. Nsight Compute Overview of `single_query_cached_kv_attention`. As seen above, the kernel underutilizes the SM resources both in terms of compute and memory – uses only roughly 15% of the compute and 50% of the memory bandwidth.

Nsight Compute shows that in most clock cycles, there are no warps issued.
As seen here, each SM does not schedule any warp in 5 out of every 6 times cycles. Thus, we begin by trying to identify the culprit for why the warps are stalling.

Let’s start by understanding roughly how single_query_cached_kv_attention works. Each CUDA block (128 threads by default) is responsible for computing the entire attention mechanism for the last token of one specific sequence and one specific head in that token.

  1. Each thread begins by reading its appropriate query values into its registers: Q_vec q_vecs[NUM_VECS_PER_THREAD];. The threads of the block are split into ‘groups’ such that each group loads the entire query. On our configuration (default configuration + running LLaMA13B on an A100, 80GB), each thread group has 2 threads. That means that every 2 threads in the warp will read the entire head of the query into their own registers/local memory – i.e., each thread holds half of the query head.
  2. The code then proceeds to iterate over the entire sequence. Each thread group fetches a single key from the KV cache, and computes the dot product of the query head that was loaded with the appropriate key head. We continue in this fashion until the dot product between the query head and the corresponding key head in every key in the sequence is computed. This for loop basically loads keys from global memory into the registers, computes the dot product, and then aggregates within each thread group. Throughout, the values for the softmax are also computed; the logits are stored in shared memory.
  3. After the loop is done, each warp aggregates the results from all the thread groups inside the warp. Following this, aggregation happens between the warps in the block.
  4. Now, the value vectors are fetched from global memory and are summed according to the computed logits. Each thread stores in its registers an accumulator for the sums it computes.
  5. Finally, aggregation of the summed value-dot-logits values happens within each warp and then within the entire block.

To find which stage is holding the warps back, we observe the assembly analysis in Nsight Compute. The warps wait a lot of time on the commands in this screenshot.
SASS instructions where warps stall frequently.
As we can see, there is a global load happening, and then roughly 4% of the time stalls happen there (a value is loaded from global memory into register R78, and then warps halt before executing the instruction highlighted as to run the instruction they must wait for the load into R78 to finish). Notably, further below, this code repeats 14 times in total (due to loop unrolling), which causes most of the stalls in this kernel.

These commands are part of the first two steps above, where the threads load the query head and key heads. Specifically, these commands are the load of the key heads. As not much can be done about the loading of the key heads, we focus on the query heads which are also loaded from global memory.

The query heads are read multiple times from global memory – specifically, in our case (default configs, LLaMA13B, A100, 80GB), every byte of the query is read by 64 different threads (each block is 128 threads, and every two threads read the entire query). Therefore, we begin by optimizing this such that each byte in the query head is read only by exactly 1 thread, and then stored in shared memory for other threads in the block to access.

We replace this code:

1
2
3
4
5
6
7
  const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
Q_vec q_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}

with this code (see this PR):

1
2
3
4
5
6
7
8
9
10
11
constexpr int NUM_THREAD_GROUPS_LOWER_BOUND = NUM_THREADS / THREAD_GROUP_SIZE;
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
if (thread_group_idx <= NUM_THREAD_GROUPS_LOWER_BOUND) {
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS_LOWER_BOUND) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
}
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs

Running the benchmark gives:

1
2
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:45<00:00,  4.43it/s]
Throughput: 4.43 requests/s, 2118.42 tokens/s

As the reference point above is 4.02 seq/sec, the result of 4.43 seq/sec that we get here is a 1.10x improvement.

At this point, we rerun the nsight compute analysis above.
The Nsight Compute Overview of `single_query_cached_kv_attention` now shows that the memory bandwidth is well-utilized.
As we can see, the kernel is now at a rather high memory bandwidth utilization (86%). We tried several other improvements (see this section) to squeeze a bit more performance out of this kernel, yet they did not improve the overall runtime of the benchmark. Therefore, as the memory bandwidth utilization is rather high and it appears that the kernel is loading the minimal amount of data it needs to from global memory (it has to load the keys and values), then we decided to stop looking at the kernel itself and began looking elsewhere.

Overall Program Analysis

Observe the following report generated by using NVIDIA Nsight Systems to profile the entire program execution.
Nsight Systems overview of vLLM.

As we can see, roughly half the time the program does not use the GPU at all (observe that DRAM Bandwidth, SM Warp Occupancy, etc, are practically zero half the time). This is time which is spent in the CPU, running the python code which surrounds the model. We investigate and find that the culprit is the sampling of the generated tokens. Observe the forward code of the class LlamaForCausalLM.

1
2
3
4
5
6
7
8
9
10
11
12
13
class LlamaForCausalLM(nn.Module):
...
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states, input_metadata)
return next_tokens

It turns out that half the program time is spent in the above call to self.model and half in the call to self.sampler (note: this is not possible to see by timing the Python only, as the kernels are run on the GPU asyncronously and the CPU waits for them later on).

Specifically, the sampler performs the following for each sequence being completed (link).

1
2
# Sample 1 token for each sequence in the group.
next_token_ids = torch.multinomial(probs, num_samples=1, replacement=True)

That is, for each sequence, probs is the generated probabilities for the next token. The above code focuses on a specific sequence and samples just for that sequence. We replace this by feeding the entire matrix (num_sequences x token_space) into torch.multinomial
to perform the sampling for all sequences at once. The following is a POC-level snippet which does this for the current benchmark (greedily sampling just 1 token for each sequence, no beam-search or any other decoding technique).

As the code change is rather long, we do not write it out here – please refer to the following commit to see the change.

Rerunning the benchmark gives the following.

1
2
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:12<00:00,  5.18it/s]
Throughput: 5.18 requests/s, 2478.31 tokens/s

As the reference point above is 4.02 seq/sec, the result of 5.18 seq/sec that we get here is a 1.28x improvement so far!

We rerun nsight systems and observe the following.
Nsight Systems now shows that the gaps between the forward passes have shrunk, but there is still more room to go.

Indeed, the time between GPU calls drastically shrunk. We zoom in to see what remains there.
Nsight Systems shows that the remaining gaps are due to many small memory reads from the GPU to the CPU.
It seems that there are many small 4 byte reads from the GPU to the CPU. The culprit is the following line, where the logprobs of the chosen tokens are read from the GPU to the CPU one-by-one.

1
output_logprobs[next_token_id] = logprob[j, next_token_id].item()

These many small reads have a huge overhead and incur high sync costs. Fixing this by coalesing the reads requires some manuvering in the code (it turns out that there is another small read in another place). See this commit for a POC. We rerun the baseline and get the following.

1
2
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:04<00:00,  5.41it/s]
Throughput: 5.41 requests/s, 2587.22 tokens/s

As the reference point above is 4.02 seq/sec, the result of 5.41 seq/sec that we get here is a 1.34x improvement!

Further potential directions & Ideas which did not pan out

Even though we managed to improve vLLM’s throughput by roughly 34%, we tried many things along the way which did not end up helping. Let’s list a few of them here, as it’s always good to learn from failure :)

Potential idea: Parallel Tokenization

There is a potential for a further improvement of roughly 10% by parallelizing tokenization after sampling. Specifically, this line get called sequentially for every sequence when we sample convert each sampled token into text. This takes roughly 10% of the execution time – time where the GPU sits completely idle.

Edit: This used to be hard to do before GIL-free Python was introduced. This is now much easier to implement, gg to the community!

Failed idea: Batch reading from the block tables

At the start of the for-loop fetching the keys, the physical block number is read from the block table (global memory). This line stalls many threads. It turns out that all the threads in a warp read the same position in the block table (which is ok, since, iirc, as only 1 read is sent to the memory and its results are broadcast to the threads automatically). To try to reduce the stall, we can have each thread read a different value, and then only once in every 32 for-loop iterations would we go out for a global memory read.

We implemented this, yet it had no affect on the runtime. We believe it’s due to both the fact that the kernel is memory-bound anyhow, and, it seemed that the stall just moved to being concentrated on the key reads (which come right after).

Failed idea: atomicAdd for the last aggregation

We tried replacing the final aggregation in the kernel with atomicAdd between the 4 warps in the block. This degraded the results we observed.

Final Thoughts

vLLM is truly a thought-provoking and intriguing concept! We very much enjoyed delving into this code and are very eager to see how far this can be optimized.

We are now beginning a journey of “democratizing LLM inference” so that everyone can run them at scale! LLMs are becoming a key part of everyday life, so let’s ensure that not only Google/OpenAI have the know-how to run them at scale :)