Add Qwen3.5 export support for 0.8B/2B/4B#17800
Add Qwen3.5 export support for 0.8B/2B/4B#17800Phineas1500 wants to merge 6 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17800
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 2 Unrelated FailuresAs of commit 1084695 with merge base 153adbf ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @Phineas1500! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
|
@pytorchbot label "release notes: none" |
There was a problem hiding this comment.
Pull request overview
This PR adds initial Qwen3.5 (0.8B/2B/4B) support to the ExecuTorch Llama export pipeline, including checkpoint conversion, model/attention wiring, and basic unit tests.
Changes:
- Added Qwen3.5 model types to the export config/registry and HF auto-download/convert path.
- Introduced Qwen3.5 weight conversion + model configs (0.8B/2B/4B) and an XNNPACK fp32 export YAML.
- Implemented Qwen3.5 full-attention + gated DeltaNet linear-attention in the Llama attention stack, plus unit tests.
Reviewed changes
Copilot reviewed 20 out of 20 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| extension/llm/export/config/llm_config.py | Adds ModelType enum entries for Qwen3.5 sizes. |
| examples/models/llama/export_llama_lib.py | Registers Qwen3.5 model names, HF repo IDs, and hooks in Qwen3.5 weight conversion on download. |
| examples/models/llama/attention.py | Adds qwen3_5_full attention and gated_deltanet linear attention implementations; extends RMSNorm usage for q/k norm. |
| examples/models/llama/model_args.py | Adds Qwen3.5-specific linear-attention dimension args + rms_norm_add_unit_offset, and defaulting logic in __post_init__. |
| examples/models/llama/norm.py | Extends RMSNorm to optionally use (1 + weight) scaling. |
| examples/models/llama/llama_transformer.py | Wires rms_norm_add_unit_offset into norms and supports "linear_attention" layers via gated_deltanet. |
| examples/models/llama/tests/test_qwen3_5_attention.py | Adds unit tests for Qwen3.5 full-attention shape and DeltaNet state reset behavior. |
| examples/models/llama/tests/BUCK | Adds BUCK target for the new Qwen3.5 attention unit test. |
| examples/models/llama/init.py | Switches to lazy import pattern for Llama2Model. |
| examples/models/qwen3_5/convert_weights.py | Adds HF → meta checkpoint key conversion for Qwen3.5 (including legacy packed tensor splitting). |
| examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml | Adds fp32 + static-shape XNNPACK export config for Qwen3.5. |
| examples/models/qwen3_5/config/0_8b_config.json | Adds 0.8B model parameter config (hybrid layer layout). |
| examples/models/qwen3_5/config/2b_config.json | Adds 2B model parameter config (hybrid layer layout). |
| examples/models/qwen3_5/config/4b_config.json | Adds 4B model parameter config (hybrid layer layout). |
| examples/models/qwen3_5/tests/test_convert_weights.py | Adds unit test validating key mapping for both full- and linear-attention weights. |
| examples/models/qwen3_5/tests/init.py | Adds package marker/license header for tests. |
| examples/models/qwen3_5/init.py | Adds lazy Qwen3_5Model wrapper + exports convert_weights. |
| examples/models/qwen3_5/README.md | Documents export + runner usage for Qwen3.5 sizes and current limitations. |
| examples/models/qwen3_5/BUCK | Adds BUCK target for the Qwen3.5 python library and deps. |
| examples/models/BUCK | Adds Qwen3.5 to the examples/models python_library deps list. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
examples/models/llama/attention.py
Outdated
| self.wo = nn.Linear( | ||
| self.n_heads * self.head_dim, | ||
| self.dim, | ||
| bias=self.attention_qkv_bias, |
There was a problem hiding this comment.
AttentionQwen3_5Full sets the output projection (self.wo) bias based on args.attention_qkv_bias. In the rest of the llama attention implementations, the output projection is bias-free, and HF checkpoints typically don’t include an o_proj.bias. If attention_qkv_bias is enabled, this will introduce an extra bias parameter that won’t be loaded from the checkpoint (and will change numerics). Consider forcing wo to bias=False (or gating it on a dedicated flag rather than attention_qkv_bias).
| bias=self.attention_qkv_bias, | |
| bias=False, |
| for i in range(sequence_length): | ||
| q_t = query[:, :, i] | ||
| k_t = key[:, :, i] | ||
| v_t = value[:, :, i] | ||
| g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) | ||
| beta_t = beta[:, :, i].unsqueeze(-1) | ||
|
|
||
| last_recurrent_state = last_recurrent_state * g_t | ||
| kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) | ||
| delta = (v_t - kv_mem) * beta_t | ||
| last_recurrent_state = ( | ||
| last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) | ||
| ) | ||
| core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum( | ||
| dim=-2 | ||
| ) | ||
|
|
There was a problem hiding this comment.
_recurrent_gated_delta_rule uses a Python for loop over sequence_length, which will be very slow for long prefill sequences and hard for compilers/export backends to optimize. Consider rewriting this recurrence using vectorized/scan-style tensor ops (or constraining this path to decode-only seq_len==1 and providing a separate efficient prefill implementation).
| for i in range(sequence_length): | |
| q_t = query[:, :, i] | |
| k_t = key[:, :, i] | |
| v_t = value[:, :, i] | |
| g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) | |
| beta_t = beta[:, :, i].unsqueeze(-1) | |
| last_recurrent_state = last_recurrent_state * g_t | |
| kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) | |
| delta = (v_t - kv_mem) * beta_t | |
| last_recurrent_state = ( | |
| last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) | |
| ) | |
| core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum( | |
| dim=-2 | |
| ) | |
| # This recurrent implementation is intended for decode-only usage, | |
| # where sequence_length == 1. Longer sequences should use a separate | |
| # prefill implementation that can be efficiently vectorized. | |
| if sequence_length != 1: | |
| raise NotImplementedError( | |
| "_recurrent_gated_delta_rule only supports decode-only " | |
| "(sequence_length == 1). Use a dedicated prefill path for " | |
| "longer sequences." | |
| ) | |
| # Unrolled single-step recurrence for sequence_length == 1. | |
| q_t = query[:, :, 0] | |
| k_t = key[:, :, 0] | |
| v_t = value[:, :, 0] | |
| g_t = g[:, :, 0].exp().unsqueeze(-1).unsqueeze(-1) | |
| beta_t = beta[:, :, 0].unsqueeze(-1) | |
| last_recurrent_state = last_recurrent_state * g_t | |
| kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) | |
| delta = (v_t - kv_mem) * beta_t | |
| last_recurrent_state = ( | |
| last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) | |
| ) | |
| core_attn_out[:, :, 0] = (last_recurrent_state * q_t.unsqueeze(-1)).sum( | |
| dim=-2 | |
| ) |
| raise ValueError( | ||
| f"Invalid packed in_proj_qkvz shape for {key}: {tuple(value.shape)}" | ||
| ) | ||
| key_dim = (conv_dim - value_dim) // 2 |
There was a problem hiding this comment.
key_dim is computed but never used, which makes the legacy in_proj_qkvz split logic harder to follow. Either remove it or use it to validate the packed tensor layout (e.g., sanity-check expected q/k/v sizing).
| key_dim = (conv_dim - value_dim) // 2 |
|
@Phineas1500 thank you for adding this! Have you tried to export with the xnnpack recipe and test out the .pte file? |
I exported with the XNNPACK recipe for Qwen3.5-0.8B, and I exported 4B in no-backend mode (my computer had memory constraints). I'm in the process of testing 0.8B. Should have done that before making the PR, and I'll get back to you soon. |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 20 out of 20 changed files in this pull request and generated 4 comments.
Comments suppressed due to low confidence (1)
examples/models/llama/norm.py:24
- RMSNorm.init now accepts
add_unit_offset, but the docstring’s Args section still only documentsdimandeps. Please update the docstring to includeadd_unit_offsetand describe how it changes the scaling (e.g.,output * (1 + weight)when enabled), so callers know when to set it.
def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| input_pos = kwargs.get("input_pos") | ||
| batch_size, seq_len, _ = x.shape | ||
| assert ( | ||
| batch_size <= self.max_batch_size | ||
| ), f"batch_size ({batch_size}) exceeds max_batch_size ({self.max_batch_size})" | ||
|
|
||
| self._maybe_reset_state(input_pos, batch_size) | ||
|
|
There was a problem hiding this comment.
AttentionGatedDeltaNet maintains internal conv_state/recurrent_state, but when input_pos is omitted it never resets, so outputs can depend on prior forward calls (state leakage across sequences) when the model is run without kv-cache / without passing input_pos. If this attention is only valid in kv-cache mode, consider asserting input_pos is not None (or alternatively resetting state when input_pos is None) to avoid silently incorrect results.
| checkpoint_shards = sorted(set(weight_map.values())) | ||
|
|
||
| shard_to_weights = {} | ||
| for shard in checkpoint_shards: | ||
| shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) | ||
|
|
||
| merged_state_dict = {} | ||
| for weight_name, shard in weight_map.items(): | ||
| merged_state_dict[weight_name] = shard_to_weights[shard][weight_name] |
There was a problem hiding this comment.
_load_checkpoint_from_safetensors() loads every shard into shard_to_weights before merging, which can significantly increase peak memory for sharded Qwen3.5 checkpoints. Consider switching to the streaming pattern used in examples/models/qwen2_5/convert_weights.py (group keys per shard, load one shard at a time, then del shard_data) to reduce peak RAM usage.
| checkpoint_shards = sorted(set(weight_map.values())) | |
| shard_to_weights = {} | |
| for shard in checkpoint_shards: | |
| shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) | |
| merged_state_dict = {} | |
| for weight_name, shard in weight_map.items(): | |
| merged_state_dict[weight_name] = shard_to_weights[shard][weight_name] | |
| # Group parameter names by shard so we can load each shard only once. | |
| shard_to_weight_names: Dict[str, list[str]] = {} | |
| for weight_name, shard in weight_map.items(): | |
| if shard not in shard_to_weight_names: | |
| shard_to_weight_names[shard] = [] | |
| shard_to_weight_names[shard].append(weight_name) | |
| merged_state_dict: Dict[str, torch.Tensor] = {} | |
| # Stream shards: load one shard at a time, copy required tensors, then free it. | |
| for shard, weight_names in shard_to_weight_names.items(): | |
| shard_path = os.path.join(input_dir, shard) | |
| shard_data = load_file(shard_path) | |
| for weight_name in weight_names: | |
| merged_state_dict[weight_name] = shard_data[weight_name] | |
| del shard_data |
| try: | ||
| new_key = get_mapped_key(normalized_key, _QWEN_3_5_TO_META) | ||
| except Exception: | ||
| # Ignore non-text weights and training-only extras (e.g., MTP). | ||
| if ( | ||
| key.startswith("mtp.") | ||
| or key.startswith("model.visual.") | ||
| or ".vision_" in key | ||
| or key.startswith("visual.") | ||
| ): | ||
| continue | ||
| # Ignore unsupported keys that are not required by the export model. | ||
| continue | ||
| converted_state_dict[new_key] = value |
There was a problem hiding this comment.
qwen_3_5_to_meta() currently catches the exception from get_mapped_key(...) and then unconditionally continues for all unmapped keys. This makes conversion silently succeed even if a required text weight is missing from _QWEN_3_5_TO_META (e.g., due to a naming change), which can lead to hard-to-debug runtime failures later. Suggestion: explicitly whitelist/skip known non-text prefixes (visual/MTP/etc.), but for any other unexpected key either raise (default) or at least collect and print a summary of skipped keys behind a --verbose/--strict flag.
| # Legacy packed tensors (older checkpoints): | ||
| # in_proj_qkvz -> split into in_proj_qkv and in_proj_z | ||
| # in_proj_ba -> split into in_proj_b and in_proj_a | ||
| if normalized_key.endswith(".linear_attn.in_proj_qkvz.weight"): | ||
| pending_qkvz[normalized_key] = value | ||
| continue | ||
| if normalized_key.endswith(".linear_attn.in_proj_ba.weight"): | ||
| pending_ba[normalized_key] = value | ||
| continue |
There was a problem hiding this comment.
The legacy packed tensor handling (*.linear_attn.in_proj_qkvz.weight and *.linear_attn.in_proj_ba.weight splitting) is new behavior but isn’t covered by the added unit test (which only uses the already-split in_proj_qkv/in_proj_z and in_proj_b/in_proj_a keys). Please add a small test case exercising both packed keys and validating the produced in_proj_qkv, in_proj_z, in_proj_b, and in_proj_a outputs (including shape checks).
Just ran a forward pass successfully. Now I'm recreating the pte file and testing full generation (initially set max sequence length to 1 to save resources). |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 20 out of 20 changed files in this pull request and generated 3 comments.
Comments suppressed due to low confidence (1)
examples/models/llama/norm.py:24
RMSNorm.__init__now acceptsadd_unit_offset, but the docstring doesn’t describe what this flag does or how it changes the scaling (e.g., using(1 + weight)instead ofweight). Please update the docstring/Args section so callers understand the semantics.
def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Ignore non-language-model keys up front. | ||
| if not ( | ||
| normalized_key.startswith("model.") or normalized_key.startswith("lm_head.") | ||
| ): | ||
| if _should_ignore_unmapped_key(key, normalized_key): | ||
| continue | ||
| continue |
There was a problem hiding this comment.
In qwen_3_5_to_meta, keys that don’t start with model. or lm_head. are currently always skipped (even when they’re not in the explicit ignore list). This can silently drop unexpected checkpoint prefixes and produce a partially-converted/empty state dict without error. Consider raising a ValueError for non-text keys unless _should_ignore_unmapped_key(...) returns true (or at least logging them in non-verbose mode).
| # Ignore non-language-model keys up front. | |
| if not ( | |
| normalized_key.startswith("model.") or normalized_key.startswith("lm_head.") | |
| ): | |
| if _should_ignore_unmapped_key(key, normalized_key): | |
| continue | |
| continue | |
| # Ignore non-language-model keys up front, but fail on unexpected prefixes. | |
| if not ( | |
| normalized_key.startswith("model.") or normalized_key.startswith("lm_head.") | |
| ): | |
| if _should_ignore_unmapped_key(key, normalized_key): | |
| continue | |
| raise ValueError( | |
| f"Unexpected non-language-model checkpoint key for Qwen3.5 export: {key}" | |
| ) |
| "model.layers.{}.linear_attn.in_proj_b.weight": "layers.{}.attention.in_proj_b.weight", | ||
| "model.layers.{}.linear_attn.in_proj_a.weight": "layers.{}.attention.in_proj_a.weight", | ||
| "model.layers.{}.linear_attn.conv1d.weight": "layers.{}.attention.conv1d.weight", | ||
| "model.layers.{}.linear_attn.conv1d.bias": "layers.{}.attention.conv1d.bias", |
There was a problem hiding this comment.
The mapping includes model.layers.*.linear_attn.conv1d.bias -> layers.*.attention.conv1d.bias, but the implemented DeltaNet conv (nn.Conv1d(..., bias=False)) has no bias parameter. If HF checkpoints contain this bias, it will become an unexpected key and be ignored by the loader (strict=False), potentially degrading correctness. Either remove/ignore the conv1d.bias mapping or enable a bias in the model and load it consistently.
| "model.layers.{}.linear_attn.conv1d.bias": "layers.{}.attention.conv1d.bias", |
| # Legacy packed tensors (older checkpoints): | ||
| # in_proj_qkvz -> split into in_proj_qkv and in_proj_z | ||
| # in_proj_ba -> split into in_proj_b and in_proj_a | ||
| if normalized_key.endswith(".linear_attn.in_proj_qkvz.weight"): | ||
| pending_qkvz[normalized_key] = value | ||
| continue | ||
| if normalized_key.endswith(".linear_attn.in_proj_ba.weight"): | ||
| pending_ba[normalized_key] = value | ||
| continue | ||
|
|
||
| try: | ||
| new_key = get_mapped_key(normalized_key, _QWEN_3_5_TO_META) | ||
| except Exception as err: | ||
| if _should_ignore_unmapped_key(key, normalized_key): | ||
| continue | ||
| raise ValueError( | ||
| f"Unexpected checkpoint key not mapped for Qwen3.5 export: {key}" | ||
| ) from err | ||
| converted_state_dict[new_key] = value | ||
|
|
||
| for key, value in pending_qkvz.items(): | ||
| layer_match = re.search(r"model\.layers\.(\d+)\.", key) | ||
| if layer_match is None: | ||
| raise ValueError(f"Failed to parse layer id from key: {key}") | ||
| layer_id = layer_match.group(1) | ||
| out_proj_key = f"layers.{layer_id}.attention.out_proj.weight" | ||
| if out_proj_key not in converted_state_dict: | ||
| raise ValueError( | ||
| f"Cannot split {key}: missing {out_proj_key} to infer value dimension." | ||
| ) | ||
|
|
||
| value_dim = converted_state_dict[out_proj_key].shape[1] | ||
| total_dim = value.shape[0] | ||
| conv_dim = total_dim - value_dim | ||
| if conv_dim <= 0 or (conv_dim - value_dim) % 2 != 0: | ||
| raise ValueError( | ||
| f"Invalid packed in_proj_qkvz shape for {key}: {tuple(value.shape)}" | ||
| ) | ||
| qkv, z = torch.split(value, [conv_dim, value_dim], dim=0) | ||
| converted_state_dict[f"layers.{layer_id}.attention.in_proj_qkv.weight"] = qkv | ||
| converted_state_dict[f"layers.{layer_id}.attention.in_proj_z.weight"] = z | ||
| print(f"Split legacy packed key {key} -> in_proj_qkv + in_proj_z") | ||
|
|
||
| for key, value in pending_ba.items(): | ||
| layer_match = re.search(r"model\.layers\.(\d+)\.", key) | ||
| if layer_match is None: | ||
| raise ValueError(f"Failed to parse layer id from key: {key}") | ||
| layer_id = layer_match.group(1) | ||
| if value.shape[0] % 2 != 0: | ||
| raise ValueError( | ||
| f"Invalid packed in_proj_ba shape for {key}: {tuple(value.shape)}" | ||
| ) | ||
| half = value.shape[0] // 2 | ||
| b, a = torch.split(value, [half, half], dim=0) | ||
| converted_state_dict[f"layers.{layer_id}.attention.in_proj_b.weight"] = b | ||
| converted_state_dict[f"layers.{layer_id}.attention.in_proj_a.weight"] = a | ||
| print(f"Split legacy packed key {key} -> in_proj_b + in_proj_a") | ||
|
|
There was a problem hiding this comment.
Legacy packed key support (*.linear_attn.in_proj_qkvz.weight and *.linear_attn.in_proj_ba.weight) is new behavior here, but the unit tests don’t cover these split paths. Adding a small test case that includes packed tensors and asserts the expected split keys/shapes would help prevent regressions.
| backend: | ||
| xnnpack: | ||
| enabled: True | ||
| extended_ops: True |
There was a problem hiding this comment.
Hi @Phineas1500 thanks for putting up this PR! I'm wondering if you were able to successfully export to executorch using this config?
There was a problem hiding this comment.
Hey! I validated qwen3_5_0_8b with qwen3_5_xnnpack_fp32.yaml, and I exported successfully to .pte. I asked the model what 2+2 was, and it said 4 😂
I also successfully exported qwen3_5_4b with xnnpack not enabled (my MacBook Air isn't powerful enough to fully do it).
I exported qwen3_5_4b with xnnpack successfully on Google Colab with max_seq_length and max_context_length equalling 128, and the resulting qwen3_5_4b_no_backend_smoke.pte was 15.5GB, while the qwen3_5_4b_xnnpack_fp32_128.pte was 18GB.
Currently trying to export qwen3_5_4b with xnnpack and q8da4w so I can try running it on a phone. Hopefully should work 🤞
Note, I was using the code from this PR and from #17801
There was a problem hiding this comment.
Wow, that's great to hear! Thank you
examples/models/llama/attention.py
Outdated
| return x * inv_norm | ||
|
|
||
|
|
||
| class RMSNormGated(nn.Module): |
examples/models/llama/attention.py
Outdated
|
|
||
|
|
||
| @register_attention("qwen3_5_full") | ||
| class AttentionQwen3_5Full(Attention): |
There was a problem hiding this comment.
This looks very similar to AttentionMHA. Can you refactor the new features out into the existing AttentionMHA? E.g.
class AttentionMHA(Attention):
def __init__(self, args, ...):
self.use_q_gate = args.use_q_gate # from config
q_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1)
...
self.wq = nn.Linear(self.dim, q_dim, bias=...)
There was a problem hiding this comment.
I made a new commit addressing your recommendations on here and on #17801
| ) | ||
| last_recurrent_state = self.recurrent_state[:batch_size].to(value.dtype) | ||
|
|
||
| for i in range(sequence_length): |
There was a problem hiding this comment.
so for dynamic shape we will need this to be a torch.scan, but really we should be swapping this to a custom op and doing a prefix sumish thing if we want to hope for any decent perf on prefill right?
There was a problem hiding this comment.
I agree. I didn't add this to make the PR simple (just wanted to add basic compatibility). Do you think torch.scan or a custom operation should be added to this PR or a subsequent one?
Summary
Test Plan
cc @mergennachin @iseeyuan @lucylq @helunwencser @tarun292 @kimishpatel @jackzhxng