Goal: implement my own inference engine for the Qwen3 models that matches the performance of vllm.
The Qwen3 dense models have the following variants: 0.6B, 1.7B, 4B, 8B, 14B, and 32B.
Some architectural details:
- Qwen 3 uses grouped query attention, so the number of kv heads is less than the number of query heads. The kv heads are always set to 8.
- The attention head dim is always 128
- The larger models primarily increase hidden_size, num_hidden_layers, num_attention_heads
- It uses full attention on all layers, no variants like sliding window or linear attention
- They all have a max context length of 32,768
- The 4B models and up can have their context length extended to 131072 with YaRN
- All use SwiGLU feed forward
First I took the Qwen3 modeling code from the transformers hugging face library and stripped out all the unnecessary config logic. I also took the KVCache implementation from transformers and stripped out the parts that Qwen3 wasn’t using. So now I’m left with a clean implementation of the Qwen3 model (eager attention) and a simple KV cache that dynamically grows.
--------
One thing I noticed was that the model initialization takes a few seconds. This is unnecessary overhead because we are loading in pretrained weights so the initialized weights are immediately aren’t used anyway. To speed this, up I initialize the model on the meta device so that the initialization code doesn’t run and then move the model to cuda with empty weight tensors.
However, this caused a subtle bug which caused RoPE positional embeddings to load incorrectly. On initialization, RoPE embeddings pre-compute an inverse frequency buffer that we reuse during the forward pass. However, when initializing on the meta device, this buffer doesn’t get initialized and when we move to cuda it just creates a tensor of zeros.
The fix is to reinitialize the RoPE buffer after moving the model to cuda. Without doing this, I was observing a severe accuracy drop where the model would primarily just repeat the last token over and over instead of outputting coherent text.
# get pretrained model from hf
pretrained_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="cpu",
)
config = pretrained_model.config
# load custom model onto meta device
with torch.device("meta"):
custom_model = Qwen3ForCausalLM(config)
custom_model.to(dtype=torch.bfloat16)
custom_model.to_empty(device=device)
# Reinitialize rotary embeddings after moving from meta device
custom_model.model.rotary_emb.reinitialize_on_device(device)
# move weights from pretrained model to custom model
...
I realized later that cleaner alternative is to just reimplement the Linear layers to automatically initialize empty tensors, but alas.
--------
Now onto performance benchmarking.
I'm testing an offline inference workload with the following config:
Batch size: 1
new_tokens: 10,000
prompt: "hi there"
ignore_eos: true
vllm v0.11.0 does this in 69 seconds.
My custom implementation does this in 235 seconds.
Not great, but I have some ideas to speed this up:
- Use flash attention
- Use torch.compile
Trying these gave the following results:
- custom w/ flash attention: 231 sec
- custom w/ torch.compile: 154 sec (3 min warmup)
- custom w/ flash attention + torch.compile: 94 sec (3 min warmup)
--------
Nice! The low hanging fruit helped a lot. But now I was stuck.
I spent some time reading through the vllm codebase to find some inspiration and noted down the following ideas:
- Preallocate kv cache memory. Currently I'm dynamically growing the kv cache on each new token. Vllm let's you set how much memory you want to it to use and it will allocate all of that memory on initialization.
- Preallocate the buffer for the output tokens. This also dynamically grows on each new token.
- Pack QKV computation
- Pack gate and up projection in SwiGLU
- Merge RMS Norm and residual add steps together
- Write more optimized kernels? vllm uses Paged Attention, but that seems to be solve a different problem from the one I'm dealing with.
- Also saw something about CUDA graphs
I ended up trying the following combinations:
- Preallocated kv cache: 78 sec
- Preallocated kv cache + merged rms norm and residual add: 78 sec
- Preallocated kv cache + preallocated output tokens: 78 sec
- Preallocated kv cache + packed gate and up proj: 74 sec
- Preallocated kv cache + packed gate and up proj + packed qkv: 68 sec!!!
Implementation notes: implementing packed qkv and gate/up proj was quite tricky. You have to change the modeling code, forward pass, and make sure you load the weights to the spots that the forward pass expects.
--------
Summary: I was able to get the hugging face implementation to match vllm inference performance on a single batch workload. The tricks I used were flash attention, torch.compile, packed qkv + gate/up proj computations, and preallocated kv cache.
This was all done on a single GPU device to simplify the implementation. Some potential next steps would be to get tensor parallel versions working, try other workloads like larger batch sizes or online inference, and learn more about the lower level optimizations like kernels and CUDA graphs.