Introduction: What Problems Does Weight Loading Solve?

Before diving into vLLM's weight loading implementation, it's essential to understand the core challenges it addresses.

Large language model weights are typically stored on disk as checkpoint files. The weight loading task seems straightforward: read these files, match tensors by name, and copy data into the model's parameters. However, three critical complexities make this far from simple.

Challenge 1: Tensor Sharding and Memory Control in Tensor Parallelism

vLLM supports splitting a model across multiple GPUs for parallel inference, known as Tensor Parallelism (TP). The core idea: divide large matrices into rows or columns, with each GPU holding only a portion, computing independently, then merging results through communication (AllReduce/AllGather).

Example with a [4096, 4096] linear layer weight and TP=2:

  • Column Parallel: Weights split by columns. GPU-0 holds [4096, 2048], GPU-1 holds the other half. Each GPU multiplies complete input by its half weight, producing half output, then AllGather concatenates results.
  • Row Parallel: Weights split by rows. GPU-0 holds [2048, 4096], GPU-1 holds the other half. Input is also split, results computed independently then AllReduce sums them.

The Challenge: During weight loading, you can't simply copy entire tensors to parameters. Instead, you must extract the appropriate slice from complete weights based on the current GPU's rank. How is this "slicing" implemented? Additionally, since checkpoints store complete weights while each GPU only needs 1/TP of the slice, how do you prevent memory/GPU memory OOM during loading?

Challenge 2: QKV Fusion and Gate-Up Fusion

To reduce kernel launch overhead and improve GPU utilization, vLLM fuses multiple logically independent weights into a single physical parameter:

QKV Fusion: The Transformer attention layer has three projection matrices: Q, K, and V. In checkpoints, they exist as separate weights (q_proj.weight, k_proj.weight, v_proj.weight), but vLLM concatenates them into qkv_proj.weight. This enables computing Q, K, V in a single GEMM operation, eliminating two kernel launches.

Gate-Up Fusion: Similarly, gate_proj and up_proj in the FFN layer are fused into gate_up_proj, replacing two GEMM operations with one.

The Challenge: Checkpoints don't contain a qkv_proj key—only q_proj, k_proj, and v_proj. How is this mapping handled during loading?

Challenge 3: Meta Device Initialization and Delayed Materialization

PyTorch provides a special meta device (device="meta"): tensors created on meta devices only record metadata like shape, dtype, and stride without allocating actual memory. This is crucial for large models—initializing empty parameters for a 500B parameter model directly on GPU would require approximately 1000GB VRAM (FP16), far exceeding single-card capacity.

vLLM uses meta devices for delayed memory allocation in scenarios like online quantization and Transformers Backend.

The Challenge: When parameters reside on meta devices, you can't directly copy_ data into them (meta tensors have no actual storage). How does weight loading handle these "virtual" parameters?

With these three questions in mind, let's examine vLLM's actual implementation.

Weight Loading System Overview

This section systematically introduces vLLM's weight loading workflow.

Overall Workflow

vLLM's weight loading consists of four stages: Model Initialization → Weight Reading → Weight Distribution → Post-Processing.

┌─────────────────────────────────────────────────────────────────────┐
│                    BaseModelLoader.load_model()                      │
│                                                                      │
│  ① initialize_model()          Build model structure (empty params)  │
│         │                                                            │
│  ② load_weights(model, ...)    Read checkpoint and distribute        │
│         │                                                            │
│  ③ process_weights_after_loading()  Quantization post-processing     │
│         │                                                            │
│  ④ model.eval()                Return inference-ready model          │
└─────────────────────────────────────────────────────────────────────┘

Weight Reading: From Files to Iterator

DefaultModelLoader is the most commonly used loader. It converts checkpoint files (safetensors/PyTorch bin) into an Iterable[tuple[str, torch.Tensor]] iterator—each element is a (weight_name, tensor) pair.

The get_all_weights() function internally calls safetensors_weights_iterator() and similar functions, yielding (name, tensor) pairs file by file, key by key. This streaming iterator (yield) avoids loading the entire checkpoint into memory at once—only one tensor is read at a time, released after processing. CPU memory peaks at just the size of the largest single tensor.

At this stage, yielded tensors reside in CPU memory, maintaining the checkpoint's original key naming and complete shape.

Weight Distribution: Two Coexisting Modes

The weight iterator is passed to model.load_weights(), where the model level decides how to distribute each (name, tensor) to corresponding parameters. Currently, two distribution modes exist:

Mode A: Manual Traversal (Traditional, Gradually Being Replaced)

Top-level model classes (inheriting from nn.Module, like QWenLMHeadModel) manually traverse the iterator in their load_weights method, processing key renaming, fusion mapping, and shard_id injection line by line, ultimately calling param.weight_loader(param, loaded_weight, shard_id) to complete loading.

Mode B: Automatic Recursion (AutoWeightsLoader Mode, Current Mainstream Direction)

Top-level model classes create an AutoWeightsLoader instance, which automatically distributes weights according to the module tree. AutoWeightsLoader receives the top-level model instance (the root of the entire module tree), splits weight names by ., matches submodules or parameters level by level, adopting a three-level priority strategy:

  1. Module-Level Priority: If module has load_weights method → delegate to it
  2. Submodule Recursion: Match by prefix → recursively call _load_module
  3. Parameter-Level Processing: Match by prefix → call param.weight_loader

Trend: The "routing distribution" part of manual traversal is being replaced by AutoWeightsLoader. Comparing the evolution of the same model series clearly shows this trend: early qwen.py (Qwen-1) used about 30 lines of manual traversal code handling both routing and fusion, while subsequent qwen3.py (Qwen-3) delegated routing responsibilities to AutoWeightsLoader, requiring only 4 lines at the top level.

Fusion Key Mapping: stacked_params_mapping Mechanism

As mentioned in Challenge 2, vLLM fuses multiple logically independent weights into one physical parameter (e.g., q_proj + k_proj + v_proj → qkv_proj). However, checkpoints only contain the original separate keys, not the fused keys. stacked_params_mapping solves this mapping problem—it tells the loader "which position in the fused parameter this checkpoint key should fill."

Mapping Table Structure

Each mapping is a triplet (param_name, shard_name, shard_id):

  • param_name: The fused parameter name (actually existing in the model)
  • shard_name: Original key fragment in the checkpoint
  • shard_id: Position identifier of this original key in the fused parameter
stacked_params_mapping = [
    # (param_name, shard_name, shard_id)
    ("qkv_proj", "q_proj", "q"),      # q_proj → Q region of qkv_proj
    ("qkv_proj", "k_proj", "k"),      # k_proj → K region of qkv_proj
    ("qkv_proj", "v_proj", "v"),      # v_proj → V region of qkv_proj
    ("gate_up_proj", "gate_proj", 0), # gate_proj → slice 0 of gate_up_proj
    ("gate_up_proj", "up_proj", 1),   # up_proj → slice 1 of gate_up_proj
]

Loading Process

When encountering checkpoint key model.layers.0.self_attn.q_proj.weight:

  1. Match shard_name="q_proj", replace q_proj with qkv_proj in the key, getting model.layers.0.self_attn.qkv_proj.weight
  2. Call weight_loader(param, loaded_weight, shard_id="q") with shard_id="q"
  3. Inside weight_loader, calculate offset based on shard_id, writing data to the Q region of qkv_proj parameter

Fusion mapping is used in both Mode A and Mode B. Whether manual traversal (Mode A) or AutoWeightsLoader recursive distribution (Mode B), fusion mapping processing logic is implemented by each model file itself—defining stacked_params_mapping in the load_weights method and traversing for matches.

Parameter-Level Loading: weight_loader Responsibilities

Regardless of distribution mode, the final step always calls the weight_loader on the parameter to complete actual data copying. weight_loader handles TP sharding (narrowing complete weights to current rank's slice) and fusion offset (concatenating multiple sub-weights into different regions of the same parameter).

Understanding nn.Parameter

Before diving into two generations of parameter systems, understand nn.Parameter itself. nn.Parameter is essentially torch.Tensor—it directly inherits from Tensor, only doing two additional things:

  1. Default requires_grad=True: Ordinary Tensors don't participate in gradient calculation by default, while Parameters do. This is its semantic marker as a "learnable parameter."
  2. Automatic registration to nn.Module: When a Parameter is assigned as a Module attribute (e.g., self.weight = nn.Parameter(...)), Module's setattr automatically registers it to the _parameters dictionary, making it discoverable by named_parameters(), optimizers, and state_dict() serialization.

Beyond this, nn.Parameter has no additional data storage or methods.

Two Generations of Parameter Systems in vLLM

vLLM has two generations of parameter systems, attaching weight loading capabilities to this "pure Tensor subclass" in different ways:

Aspectv1 (nn.Parameter + Dynamic Attributes)v2 (BasevLLMParameter Subclass)
TypePyTorch native nn.ParameterBasevLLMParameter and subclasses
weight_loader SourceDynamically attached via set_weight_attrs or direct assignmentPassed as constructor parameter, exposed as formal class attribute
TP Sharding LogicManual narrow + copy_ inside weight_loader functionEncapsulated in parameter subclass methods like load_column_parallel_weight()
RepresentativeColumnParallelLinear.weight_loader (v1)ModelWeightParameter.load_column_parallel_weight() (v2)

v1: nn.Parameter + Dynamic Attributes

v1's approach leverages Python's dynamic attribute mechanism, bypassing type system constraints. As mentioned, nn.Parameter is essentially just a Tensor, lacking a weight_loader attribute itself. v1 forcibly injects weight_loader onto nn.Parameter instances via setattr or direct assignment.

v2: BasevLLMParameter Subclass System

v2's BasevLLMParameter is a better design. It inherits nn.Parameter, treating weight_loader as a formal constructor parameter, exposing it via @property as a class attribute with complete type constraints.

Additionally, v2 encapsulates TP sharding logic as the parameter's own methods (like load_column_parallel_weight(), load_merged_column_weight()), rather than scattering them in external weight_loader functions, achieving better cohesion.

Post-Processing: process_weights_after_loading

process_weights_after_loading converts weights from storage format to the format required by runtime kernels, completing quantization weight repacking, scale calculation, format conversion, etc. Its call timing depends on the loading scenario:

Default Scenario (Non-Online Quantization): Called uniformly after all model weights are loaded. The flow from BaseModelLoader.load_model clearly shows this sequence: first load all weights, then unified post-processing.

Online Quantization Scenario (Layerwise Reload): Post-processing is performed layer by layer—immediately after each layer's weights are loaded, executing that layer's process_weights_after_loading, converting full-precision weights to low-precision format then releasing, before processing the next layer. This way, GPU only needs to hold one layer's full-precision weights at a time, significantly reducing peak memory.

Key Participants Summary

Throughout the workflow, two core participants require重点 understanding: the module-level load_weights method and the parameter-level weight_loader attribute. They respectively undertake "scheduling" and "execution" responsibilities for weight loading.

Module-Level: Module.load_weights

load_weights is a convention method defined by vLLM model classes themselves. The framework detects whether a module implements this method via hasattr(module, "load_weights") and calls it if present. It's responsible for weight routing and scheduling—deciding which parameter should handle each checkpoint weight. It appears at two levels:

  • Top-level model class (e.g., Qwen3ForCausalLM): Serves as the entry point for entire weight loading, called by BaseModelLoader. Top-level load_weights either manually traverses the iterator (Mode A) or creates AutoWeightsLoader to delegate recursive distribution (Mode B).
  • Intermediate submodule (e.g., Qwen3NextModel): When AutoWeightsLoader recurses to a submodule, if that submodule has a load_weights method, it's prioritized for delegation. Submodule load_weights typically handles fusion mapping (stacked_params_mapping) and other layer-specific logic.

Core responsibilities include:

  1. Key renaming: Mapping checkpoint keys to model parameter names
  2. Fusion mapping: Using stacked_params_mapping to map separate checkpoint keys (q_proj, k_proj, v_proj) to fused parameters (qkv_proj) and inject shard_id
  3. Routing distribution: Handing processed (name, tensor) to corresponding parameter's weight_loader for actual loading

Parameter-Level: param.weight_loader

weight_loader is a callable attribute mounted on nn.Parameter (or its subclass BasevLLMParameter), responsible for actual weight writing—correctly filling a checkpoint tensor into the parameter's data storage. It's the last link in the weight loading chain, handling two key things:

  • TP Sharding: Narrow out the 1/TP slice belonging to current rank from complete weights based on current rank
  • Fusion Offset: Calculate offset based on shard_id, writing data to the correct region of fused parameters

Typical weight_loader call patterns:

# Non-fused weights: 2-parameter call
weight_loader(param, loaded_weight)

# Fused weights: 3-parameter call with shard_id
weight_loader(param, loaded_weight, shard_id)

weight_loader is a general parameter-level loading protocol, not limited to linear layers. Any layer needing custom weight writing logic can provide weight_loader for its parameters. Common sources include:

  • Linear layers: ColumnParallelLinear.weight_loader, MergedColumnParallelLinear.weight_loader, QKVParallelLinear.weight_loader, etc., handling TP sharding and fusion offset
  • Embedding layers: VocabParallelEmbedding.weight_loader, handling vocabulary sharding
  • Mamba layers: mamba_v2_sharded_weight_loader, handling interleaved sharding of SSM projections
  • MoE layers: weight_loader in FusedMoE, handling expert weight distribution
  • v2 parameter subclasses: BasevLLMParameter and subclasses carry weight_loader attribute themselves, internalizing loading logic into parameter types

These weight_loaders are "mounted" onto parameters, indirectly called by the external framework through parameters. This design allows the external framework to call param.weight_loader without knowing which type of layer the parameter belongs to.

Collaboration Relationship

Module.load_weights (Scheduler Layer)
  │  "Which parameter should handle this checkpoint key? What's the shard_id?"
  │
  ▼
param.weight_loader (Execution Layer)
  │  "I have the data and shard_id, write to correct position per TP sharding rules"
  │
  ▼
Parameter data update complete

In short: load_weights solves "who loads," weight_loader solves "how to load." The former is scheduling logic, the latter is execution logic.

Solutions to the Three Challenges

This section returns to the three challenges from Chapter 1, combining the system introduced in Chapter 2 to provide vLLM's solutions one by one.

Solution 1: Weight Sharding and Memory Control under TP

Sharding Mechanism: During weight loading, the narrow (slicing) operation extracts the slice belonging to current rank from complete weights, then copy_ to parameters. This "slicing" operation is one of weight_loader's core responsibilities.

Specifically, ColumnParallelLinear.weight_loader calculates current rank's starting position and shard size based on tp_rank and tp_size, then executes narrow on the CPU tensor.

GPU Memory Side: Parameters only allocate 1/TP space

During model initialization, vLLM creates the model within GPU device context. At this point, parallel layers like ColumnParallelLinear and RowParallelLinear calculate sharded dimensions based on tp_size, only allocating [4096, 4096/TP] sized parameters on GPU, not complete [4096, 4096]. Therefore, GPU memory occupies only 1/TP from the start.

CPU Memory Side: Per-tensor reading + narrow on CPU

Checkpoint weight reading adopts streaming iteration mode. safetensors' safe_open uses mmap mechanism, with get_tensor() reading only the currently requested single tensor from disk to CPU memory, not loading the entire file at once. Subsequently, in weight_loader, the narrow operation executes on CPU tensor, slicing out the 1/TP slice needed by current rank, then cross-device copying to GPU via param_data.copy_(loaded_weight).

This way, CPU memory peak ≈ complete size of single largest tensor (typically hundreds of MB), not entire model size; GPU memory always holds only 1/TP of parameters.

Note: Each rank independently reads the complete checkpoint file. Although each rank ultimately needs only 1/TP of data, they all traverse all checkpoint files, read each complete tensor, then各自 narrow out their own slice. This means disk I/O is TP-times redundant—a trade-off in current design, using I/O redundancy for implementation simplicity (no inter-rank coordination needed for read distribution). fastsafetensors and instanttensor loaders attempt to optimize this through distributed I/O.

Solution 2: QKV Fusion and Gate-Up Fusion Loading

This is precisely why stacked_params_mapping and shard_id mechanisms exist—they tell the loader "which position in the fused parameter this checkpoint key should fill."

Briefly, during loading:

  1. Identify that q_proj should map to slice 0 of qkv_proj (shard_id="q")
  2. Write q_proj data to corresponding region of qkv_proj parameter
  3. Repeat above process for k_proj and v_proj, writing to their respective regions

Solution 3: Meta Device Initialization and Delayed Materialization

vLLM uses meta devices in two scenarios, handling them through different materialization strategies. Taking Online Quantization as an example:

When users specify online quantization (e.g., FP8 per-tensor), the model loading goal is: read full-precision checkpoint → online quantize to low-precision → store quantized weights. If full-precision parameters are allocated on GPU first, then quantized to FP8, GPU would need to hold both full-precision and quantized weights simultaneously before quantization completes, doubling peak memory.

To solve this, online quantization methods (like Fp8OnlineLinearMethod) create weights on meta devices, then handle through layerwise reload mechanism:

  1. During weight loading, first buffer checkpoint data in CPU memory without immediately writing to parameters (parameters are on meta device, cannot write yet). Specifically, online_process_loader intercepts weight_loader calls, caching call parameters (containing CPU tensor references from checkpoint iterator) to LayerReloadingInfo.loaded_weights list.
  2. When all weights for a layer are buffered, materialize that layer—allocate real memory on GPU.
  3. Load buffered weights into materialized parameters.
  4. Immediately execute quantization processing (process_weights_after_loading), converting full-precision weights to FP8.
  5. Release full-precision weights, retaining only quantized results.

This way, GPU only needs to hold one layer's full-precision weights at a time, releasing immediately after quantization, significantly reducing peak memory.

Design Defect Analysis

This section analyzes unreasonable designs in vLLM's weight loading system one by one. While I call them "design defects," they have no impact on vLLM system stability or performance—most of the time they only affect human programmers, requiring extra mental gymnastics when reading and additional attention to details during development.

Defect 4.1: [Unnecessary Separation Design Increases Development Burden] AutoWeightsLoader's Anti-Recursion and Bidirectional Dependency

AutoWeightsLoader is a tool class independent from models, created by model's load_weights, but it反过来 calls submodules' load_weights, forming bidirectional dependency.

Case A: Defensive Code for Anti-Recursion

The check module != self.module exists because the framework recognized recursion risk—if top-level module's load_weights creates AutoWeightsLoader, and AutoWeightsLoader calls the same module's load_weights, infinite recursion occurs. This is a symptom of design defect; good design shouldn't have such seemingly mysterious defenses.

Case B: Bidirectional Dependency Call Chain

Model.load_weights()
  └─ Creates AutoWeightsLoader(self)
       └─ AutoWeightsLoader._load_module()
            └─ Calls child_module.load_weights()   ← Reverse call

AutoWeightsLoader is created in multiple model files, each model's load_weights is both AutoWeightsLoader's creator and its potential call target. This bidirectional dependency increases understanding and maintenance burden.

Defect 4.2: [Not Cohesive] Fusion Key Mapping Scattered in Model Layer, Not Internalized in Fusion Operators

Fusion layers (like MergedColumnParallelLinear, QKVParallelLinear) merge multiple checkpoint keys into one parameter, but mapping relationships are defined by each model file itself, not declared by fusion operators themselves, resulting in nearly identical stacked_params_mapping definitions across multiple model files.

Essence of Problem: Fusion operators (like MergedColumnParallelLinear) know which sub-weights compose them, but they don't declare this information. Instead, every model file using them repeatedly declares it. This violates the cohesion principle that "information should be managed by the owner."

Defect 4.3: [Core Defect] nn.Parameter Bears Responsibilities Not Its Own, Making Parameter Objects Impure

As described in section 2.4, nn.Parameter is essentially just a torch.Tensor with requires_grad flag—a pure data container. But vLLM dynamically mounts weight loading scheduling logic (weight_loader) etc. onto nn.Parameter via dynamic attributes, making it bear responsibilities not its own. This root cause leads to three-level problems: dynamic mounting bypasses type system (4.3.1), version split of weight_loader v1/v2 coexistence (4.3.2), and meta device materialization被迫 using class hack (4.3.3).

Manifestation 1: Dynamically Mounting weight_loader on Native nn.Parameter (Bypassing Type System)

nn.Parameter is PyTorch native type, lacking weight_loader attribute. vLLM leverages Python's dynamic language features, forcibly injecting this attribute through two methods.

Manifestation 2: weight_loader v1/v2 Coexistence (Version Split)

vllm/model_executor/layers/linear.py maintains whitelist WEIGHT_LOADER_V2_SUPPORTED. Quantization methods in whitelist use v2 (BasevLLMParameter subclass's load_column_parallel_weight() etc.), otherwise use v1 (external function manual narrow + copy_). Two versions coexisting means: adding new quantization methods requires deciding which version to support and manually adding to whitelist, with two styles mixed in existing code.

Manifestation 3: Meta Device Materialization Relies on class Hack (Chain Reaction)

Another chain reaction of impure parameter objects appears in meta device materialization. When parameters reside on meta devices, equivalent parameter objects need creating on real devices. Since nn.Parameter has dynamically mounted attributes like weight_loader and output_dim via setattr, stored in instance's dict__, they cannot be reconstructed through standard nn.Parameter(data, requires_grad) constructor. Therefore, materialize_meta_tensor() can only bypass normal object construction flow, using __class + dict copy hack.

Ideal Design

Based on defect analysis from Chapter 4, this section expands ideal design directions closely following defects. Core idea: introduce nn.Module base class undertaking recursive loading responsibilities (eliminating AutoWeightsLoader), internalize fusion mapping to fusion operators, remove weight_loader and other dynamic attributes from nn.Parameter, with all custom loading logic implemented by parameter owners (nn.Module derivatives) through load_weights implementation.

5.1 Eliminating AutoWeightsLoader: Introduce nn.Module Base Class (Addressing Defect 4.1)

Problem Review: As in 4.1, AutoWeightsLoader is a tool class independent from models, created by model's load_weights, but it反过来 calls submodules' load_weights, forming bidirectional dependency. The existence of module != self.module anti-recursion check itself indicates unnatural design.

Root Cause: Recursive traversal of module tree and weight distribution should inherently be the module system's own capability, not undertaken by an external tool class.

Ideal Design: vLLMModule Base Class

Introduce a base class vLLMModule inheriting from nn.Module, internalizing AutoWeightsLoader's recursive distribution logic as the base class's default load_weights implementation:

class vLLMModule(nn.Module):
    """vLLM module base class, providing default implementation for recursive weight loading."""
    
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        self._maybe_materialize()           # ★ Before loading: materialize meta parameters
        weights = self._apply_fused_routing(weights)  # ★ Fusion routing
        
        for child_prefix, child_weights in self._groupby_prefix(weights):
            if child_prefix in child_modules:
                child_module.load_weights(child_weights)   # Delegate to submodule
            elif child_prefix in child_params:
                self._load_single_param(param, child_weights)  # Leaf parameter
        
        self._maybe_post_process()          # ★ After loading: quantization post-processing

Effect After Transformation

Before (bidirectional dependency, anti-recursion hack):

Qwen3ForCausalLM.load_weights()
  └─ Creates AutoWeightsLoader(self)          ← External tool class
       └─ if module != self.module:            ← Anti-recursion hack

After (single-direction inheritance, natural recursion):

Qwen3ForCausalLM(vLLMModule).load_weights()   ← Inherit base class, no override needed
  └─ Base class recursive distribution → child_module.load_weights() → natural polymorphism

Key Changes:

  1. Eliminated AutoWeightsLoader external tool class—recursive distribution is module system's own capability
  2. Eliminated anti-recursion hack—base class's load_weights only recurses to submodules, won't call itself
  3. Top-level model classes become extremely simple—most don't need to override load_weights, just inherit base class; only fusion linear layers, MoE layers etc. need override

5.2 Internalizing Fusion Mapping to Fusion Operators (Addressing Defect 4.2)

Problem Review: As in 4.2, fusion layers merge multiple checkpoint keys into one parameter, but mapping relationships are defined by each model file through stacked_params_mapping, causing nearly identical mapping table declarations across multiple model files.

Ideal Design: Fusion Operators Declare Mapping Relationships Themselves

Fusion operators know which sub-weights compose them and should declare this information themselves. Fusion layers override load_weights method, completing checkpoint key to shard_id mapping internally.

Fusion Routing Before Distribution: Fusion Routing in Recursive Scheduler Layer

The weight keys received by fusion layers in 5.2.2 are still checkpoint original names (like gate_proj.weight), but when base class matches submodules by prefix, only gate_up_proj exists, not gate_proj—routing would fail.

Solution: Before routing, base class's recursive scheduling logic automatically scans submodules' shard_names attributes, building fusion routing table—when checkpoint key prefix (like gate_proj) hits routing table, directly route that weight to corresponding fusion submodule (like gate_up_proj). Scheduling remains scheduling, processing remains processing—routing logic stays in recursive scheduler layer, fusion layer only receives weights and processes shard_id.

Effect After Transformation

Before (mapping scattered in model layer):

# Every model file must repeatedly define
stacked_params_mapping = [
    ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ...
    ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1),
]
# Model layer manually traverses, replaces keys, injects shard_id
for param_name, weight_name, shard_id in stacked_params_mapping:
    name = name.replace(weight_name, param_name)
    param.weight_loader(param, loaded_weight, shard_id)

After (mapping internalized in fusion operators, routing automatically completed by base class recursive scheduler layer):

# Model layer no longer needs stacked_params_mapping
# Base class recursive scheduler layer automatically scans submodules' shard_names,
# builds fusion routing table
# Checkpoint keys routed to fusion layers via fusion routing table
# Fusion layer infers shard_id itself based on shard_names
# Attention/MLP and other upper modules need no override load_weights

Key Changes:

  1. Eliminated repeated stacked_params_mapping definitions in model files—mapping relationships declared by fusion layers themselves
  2. shard_id no longer externally injected—fusion layer infers it based on weight names
  3. Routing problems automatically solved by base class recursive scheduler layer—base class scans submodules' shard_names to build fusion routing table, automatically routing checkpoint keys to correct fusion submodules, upper modules need no extra code

5.3 Eliminating weight_loader on nn.Parameter (Addressing Defect 4.3)

Problem Review: As in 4.3, nn.Parameter is essentially a pure data container, but vLLM mounts weight_loader and other dynamic attributes onto native nn.Parameter via setattr, bypassing type system. This causes three-level problems: type unsafety (4.3.1), v1/v2 version split (4.3.2), meta materialization relying on class hack (4.3.3).

Ideal Design: Custom Logic Implemented by Parameter Owners

Core Principle: nn.Parameter should not have dynamic attributes bypassing type system. If a parameter needs custom loading logic, its owner (the nn.Module derivative holding that parameter) implements it through load_weights.

This naturally cooperates with vLLMModule base class introduced in 5.1—base class provides default recursive distribution and simple copy_ loading, subclasses implement custom logic through override load_weights:

Linear Layer's load_weights (illustrative):

class ColumnParallelLinear(vLLMModule):
    def load_weights(self, weights):
        for name, loaded_weight in weights:
            param = params[name]
            if isinstance(param, BasevLLMParameter):
                param.load_column_parallel_weight(loaded_weight)  # Parameter self-service sharding
            else:
                # Native nn.Parameter, module handles TP sharding (narrow + copy_)

Other modules needing custom loading similarly override load_weights, implementing their own loading logic internally, rather than hanging weight_loader on parameters.

Panoramic View of Eliminating Dynamic Attributes

After eliminating dynamic attributes, all weight_loader and other attributes dynamically mounted via setattr on native nn.Parameter are removed, with these responsibilities transferred to owner module's load_weights. Native nn.Parameter returns to pure data container (__dict__ empty). BasevLLMParameter subclass's _output_dim, _input_dim, tp_rank and other sharding metadata are retained—these are parameter's own inherent attributes, defined through formal constructors and @property, different in nature from eliminated weight_loader (external scheduling logic "hung" on parameters).

Chain Benefit 1: v1/v2 Version Split Naturally Disappears

When all modules schedule weight loading through load_weights, neither v1 (external function manual narrow + copy_) nor v2 (BasevLLMParameter subclass methods mounted on parameters) are needed anymore—module's load_weights directly calls param.load_column_parallel_weight() and other self-service methods, WEIGHT_LOADER_V2_SUPPORTED whitelist naturally disappears.

Chain Benefit 2: Meta Materialization Simplified

After eliminating dynamic attributes on native nn.Parameter, both class assignment and dict copy hacks are no longer needed. Native nn.Parameter's dict is empty, directly using nn.Parameter(real_data) constructor suffices; BasevLLMParameter subclasses construct formally through newly added materialize_on method (creating instance on real device via new__, skipping __init side effects) and inherit sharding metadata. Materialization logic transforms from hack to formal constructor call.

Chain Benefit 3: Materialization Logic Built into Recursive Flow, Unifying All Loading Scenarios

As shown in base class skeleton code in 5.1.2, load_weights calls _maybe_materialize() before recursive distribution and _maybe_post_process() after distribution completion. Each module materializes its direct parameters at load_weights beginning (checking if on meta device), with submodule materialization handled by submodules themselves. Recursion naturally ensures "materialize first, then load" order. This enables normal loading, Layerwise Reload, and Transformers Backend scenarios to follow same entry point and recursive flow, automatically branching internally based on parameter state—no external orchestrator needed, no weight_loader interception mechanism needed, with any() check in _maybe_materialize() immediately returning False in non-Layerwise scenarios, nearly zero overhead.

Transformation Path Summary

The entire ideal transformation closely follows three defects one by one, forming an organic whole:

Defect 4.1: AutoWeightsLoader's anti-recursion and bidirectional dependency
└── Transformation 5.1: Introduce vLLMModule base class, internalizing recursive distribution as module system's own capability

    ├── Eliminate AutoWeightsLoader external tool class
    └── Eliminate anti-recursion hack

Defect 4.2: Fusion key mapping scattered in model layer
└── Transformation 5.2: Fusion layer overrides load_weights, declaring mapping relationships itself

    ├── Eliminate repeated stacked_params_mapping definitions in model files
    ├── shard_id inferred by fusion layer itself
    └── Fusion routing automatically completed by base class recursive scheduler layer

Defect 4.3: nn.Parameter bears responsibilities not its own
└── Transformation 5.3: Custom logic implemented by parameter owners (nn.Module derivatives) through load_weights, replacing weight_loader

    ├── Eliminate dynamic attributes (weight_loader, output_dim, etc.) on native nn.Parameter
    ├── Chain: v1/v2 version split naturally disappears
    ├── Chain: Meta materialization simplified, eliminating __class__ hack
    └── Chain: Materialization logic built into recursive flow, unifying normal loading and Layerwise Reload

Core Principle: All weight loading logic is undertaken by nn.Module derivatives—base class provides default implementation for recursive distribution, subclasses implement custom logic (TP sharding, fusion mapping, etc.) through override load_weights. nn.Parameter returns to pure data container, bearing no loading scheduling logic. BasevLLMParameter system as type-safe parameter subclasses, with their self-service sharding capabilities (load_column_parallel_weight etc.), represents reasonable design, not targeted for elimination. Materialization logic as organic part of recursive flow enables normal loading, Layerwise Reload, and Transformers Backend scenarios to follow same code path, eliminating dependency on weight_loader interception mechanism.

Appendix: SGLang Weight Loading System Comparison

SGLang's weight loading system directly derives from vLLM, with highly consistent architectural design: four-stage loading flow identical, core defects like nn.Parameter dynamic mounting of weight_loader, stacked_params_mapping scattered in model layer, and v1/v2 version split all exist. Therefore, the aforementioned ideal transformation plan holds value for SGLang as well—and since SGLang hardly uses AutoWeightsLoader (only 1 file in transformers.py uses it), with 43+ model files all adopting manual weight traversal, introducing base class load_weights (Transformation 5.1) offers the biggest benefit. load_weights methods in these model files are highly similar (each 30~130 lines), enabling substantial code reduction.

A significant difference from vLLM lies in meta device usage: SGLang's mainstream path (DefaultModelLoader) creates models directly on GPU devices, not involving meta devices; meta devices only appear in two non-mainstream paths. Therefore, SGLang doesn't have vLLM's class hack problem. LayeredModelLoader uses PyTorch's native to_empty() for per-module materialization, delegating weight filling to model's own load_weights_to_module method, but currently only torch_native_llama.py implements this interface, with logic duplicating load_weights. Adopting the ideal base class approach could unify normal loading and layer-by-layer loading code paths, eliminating this extra interface burden.