LogoPierre Lepagnol, PhD

Notes on NLP, ML and applied data science

Back to all articles

Cross-LoRA - Partie 2 : l'implémentation

De l'algorithme du papier à un adaptateur PEFT qui se recharge dans le modèle cible.

·21 min read
Au sommaire

Dans la Partie 1 nous avons posé les briques théoriques de Cross-LoRA. Ici nous passons à la pratique ! En prenant l'Algorithm 1, pour en faire un pipeline complet qui produit un adaptateur PEFT chargeable dans le modèle cible.

Tout au long de l'article nous prendrons l'exemple de 2 modèles: Qwen2.5-1.5B \rightarrow LLaMA-3.2-3B.

Tout le code est disponible sur Lien du GitHub de l'implémentation.

Ce que le papier ne précise pas

L'article Cross-LoRA, aussi bien écrit soit-il, ne précise pas des composants très opérationnels à savoir:

  • le nommage et matching des modules entre familles. En effet, les modèles Qwen, LLaMA, Gemma n'ont pas les mêmes conventions;
  • le mapping inter-couches quand les profondeurs diffèrents (24 vs 32 couches) ;
  • le surcharge du symbole rr : rLoRA=16\rlora = 16 (rang LoRA) vs rSVD=320\rsvd = 320 (rang de sous-espace). Il y a deux quantités sans rapport ;
  • le format exact du checkpoint PEFT à produire en sortie.

Ainsi, dès le départ, nous prendrons les décisions suivantes pour ce pipeline :

  • séparer rLoRA\rlora et rSVD\rsvd partout dans le code (deux variables nommées) ;
  • Algorithm 1 comme référence d'implémentation (et non l'équation globale) ;
  • matching par index de couche + type de module, fallback proportionnel sur la profondeur si les nombres de couches diffèrent;
  • sérialisation au format peft (Hugging Face) standard, avec adapter_config.json + adapter_model.safetensors.

Pipeline en 7 étapes

Pipeline Cross-LoRA en 7 étapes
Pipeline Cross-LoRA en 7 étapes

Chaque étape est isolée : un échec à l'étape matcher doit produire un rapport avec une erreur expliquée et non un crash trois étapes plus loin.

Étape 1: I/O des checkpoints

Inputs, Indexation & Outputs

Nous devons charger en mémoire, pour chaque couche, les poids source Ws\Ws et cible Wt\Wt, ainsi que l'adaptateur source lora_A, lora_B, α\alpha, rLoRA\rlora, et la liste des modules visés.

On indexe chaque tenseur par un triplet (layer_id, module_type, side)side prend la valeur "out" (côté BB, espace de sortie) ou "in" (côté AA, espace d'entrée). Ce schéma permet ensuite de joindre source et cible sans dépendre du nommage propre à chaque famille de modèles.

@dataclass(frozen=True)
class TensorKey:
    layer_id: int
    module_type: str   # "q_proj", "k_proj", ..., "down_proj"
    side: str          # "in" ou "out"

En sortie nous avons besoin d'un répertoire au format PEFT compatible from_pretrained :

adapter/
├── adapter_config.json      # rang, alpha, target_modules, base_model_name_or_path
└── adapter_model.safetensors # tenseurs nommés selon la convention PEFT cible

Étape 2: Matching des modules & Normalisation

Comme cité précédemment, les conventions de nommage diffèrent entre les modèles.

FamilleAttentionMLP
LLaMAq_proj, k_proj, v_proj, o_projgate_proj, up_proj, down_proj
Qwenq_proj, k_proj, v_proj, o_proj (avec biais sur QKV)gate_proj, up_proj, down_proj
Gemmaq_proj, k_proj, v_proj, o_projgate_proj, up_proj, down_proj

Bonne nouvelle : pour les sept modules visés, les noms convergent. Les pièges sont ailleurs : présence/absence de biais (Qwen QKV), RMSNorm vs LayerNorm, têtes groupées (GQA) qui changent les dimensions de kprojk_{\text{proj}} et vprojv_{\text{proj}} sans changer le nom.

Sept modules cibles

TARGET_MODULES = (
    "q_proj", "k_proj", "v_proj", "o_proj",
    "gate_proj", "up_proj", "down_proj",
)

Stratégie de matching

def match_layers(n_src: int, n_tgt: int) -> list[tuple[int, int]]:
    """Apparie les couches source et cible.
 
    Cas 1 : profondeurs égales -> appariement direct.
    Cas 2 : profondeurs différentes -> appariement proportionnel.
    """
    if n_src == n_tgt:
        return [(i, i) for i in range(n_src)]
    return [
        (i, round(i * (n_tgt - 1) / max(n_src - 1, 1)))
        for i in range(n_src)
    ]

Rapport de couverture

Règle : ne jamais publier un adaptateur partiel sans signaler le taux de couverture. Un module manquant est un trou silencieux dans le transfert.

Le rapport inclut :

  • nombre de modules transférés / attendus ;
  • liste des modules ignorés et raison (dimension incohérente, module absent, etc.) ;
  • résidus moyens et max d'alignement (cf. § 6) ;
  • énergie SVD conservée par couche (cf. § 5).

Extraction des sous-espaces (SVD)

SVD tronquée

def truncated_svd(W: torch.Tensor, r_svd: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Renvoie (U, S, V) tronqués au rang r_svd."""
    U, S, Vh = torch.linalg.svd(W.float(), full_matrices=False)
    r = min(r_svd, S.numel())
    return U[:, :r].contiguous(), S[:r].contiguous(), Vh[:r, :].T.contiguous()

Notes :

  • on caste en float32 avant la SVD : bfloat16 est trop bruité pour les petites valeurs singulières ;
  • on renvoie VV (et non VV^\top) pour rester cohérent avec la convention WUΣVW \approx U \Sigma V^\top du papier.

Cache par poids de base

Sans cache, on recalcule la même SVD pour les côtés AA et BB du même module - et plus généralement à chaque appel. Avec un dictionnaire indexé par id(W) ou par hash du tenseur, on divise le temps de transfert par 2 environ.

@functools.lru_cache(maxsize=None)
def svd_cached(weight_id: int, r_svd: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    W = WEIGHT_REGISTRY[weight_id]
    return truncated_svd(W, r_svd)

Diagnostics SVD

Avant d'utiliser une SVD, vérifier :

  • énergie conservée : i=1rSVDSi2/iSi2\sum_{i=1}^{\rsvd} S_i^2 \,/\, \sum_i S_i^2 - typiquement >0.95> 0.95 sur les couches d'attention pour rSVD=320\rsvd = 320, modèles 1.5B–3B ;
  • orthonormalité : UUIF\lVert U^\top U - I \rVert_F proche de zéro ;
  • shapes : URm×rSVDU \in \R^{m \times \rsvd}, VRn×rSVDV \in \R^{n \times \rsvd}.
Énergie SVD cumulée par rang
Énergie SVD cumulée par rang

Alignement (moindres carrés)

Résolution

On cherche les matrices d'alignement PURmt×ms\PU \in \R^{m_t \times m_s} et PVRnt×ns\PV \in \R^{n_t \times n_s} solutions de :

PU=argminPPUsUtF2,PV=argminPPVsVtF2.\PU = \operatorname*{argmin}_{P} \, \lVert P\, \Us - \Ut \rVert_F^2, \qquad \PV = \operatorname*{argmin}_{P} \, \lVert P\, \Vs - \Vt \rVert_F^2.
def solve_alignment(U_s: torch.Tensor, U_t: torch.Tensor) -> torch.Tensor:
    """Résout argmin_P ||P U_s - U_t||_F^2.
 
    U_s ∈ R^(m_s × r_svd), U_t ∈ R^(m_t × r_svd).
    Renvoie P ∈ R^(m_t × m_s).
    """
    # lstsq résout argmin_X ||A X - B||_F^2 ; on pose A = U_sᵀ, B = U_tᵀ.
    sol = torch.linalg.lstsq(U_s.T, U_t.T)
    return sol.solution.T.contiguous()

On applique la même fonction à (Vs,Vt)(\Vs, \Vt) pour obtenir PV\PV.

Construire les bases alignées

Les bases source ré-exprimées dans les dimensions cible sont :

U~s=PUUs    Rmt×rSVD,V~s=PVVs    Rnt×rSVD.\widetilde{U}_s = \PU\, \Us \;\in\; \R^{m_t \times \rsvd}, \qquad \widetilde{V}_s = \PV\, \Vs \;\in\; \R^{n_t \times \rsvd}.
U_tilde_s = P_U @ U_s          # ∈ R^(m_t × r_svd)
V_tilde_s = P_V @ V_s          # ∈ R^(n_t × r_svd)

Monitorer les résidus

Deux quantités à logguer pour chaque couche :

resU  =  PUUsUtFUtF,resV  =  PVVsVtFVtF.\mathrm{res}_U \;=\; \frac{\lVert \PU\, \Us - \Ut \rVert_F}{\lVert \Ut \rVert_F}, \qquad \mathrm{res}_V \;=\; \frac{\lVert \PV\, \Vs - \Vt \rVert_F}{\lVert \Vt \rVert_F}.
res_U = torch.linalg.norm(P_U @ U_s - U_t) / torch.linalg.norm(U_t)
res_V = torch.linalg.norm(P_V @ V_s - V_t) / torch.linalg.norm(V_t)

Empiriquement (Qwen2.5-1.5B \rightarrow LLaMA-3.2-3B) :

  • resU\mathrm{res}_U typique : 0.40.40.70.7 sur les couches d'attention, plus haut sur les MLP ;
  • une couche au-dessus de 0.950.95 en résidu relatif est suspecte - le transfert sur cette couche n'apportera rien.
Heatmap des résidus d'alignement
Heatmap des résidus d'alignement

Projection LoRA - Algorithm 1 complet

Le cœur de la projection s'écrit, pour chaque module aligné :

Bt  =  U~sU~sBs,At  =  AsV~sV~s.\Bt \;=\; \widetilde{U}_s\, \widetilde{U}_s^{\top}\, \Bs, \qquad \At \;=\; \As\, \widetilde{V}_s\, \widetilde{V}_s^{\top}.

avec AsRrLoRA×ns\As \in \R^{\rlora \times n_s}, BsRms×rLoRA\Bs \in \R^{m_s \times \rlora}, et en sortie AtRrLoRA×nt\At \in \R^{\rlora \times n_t}, BtRmt×rLoRA\Bt \in \R^{m_t \times \rlora}.

Pseudo-code

Entrées :
  - W_s, W_t           # poids des deux modèles, indexés par (layer_id, module_type)
  - (A_s, B_s, alpha)  # adaptateur source
  - r_svd              # rang de troncature SVD
Sortie :
  - (A_t, B_t, alpha)  # adaptateur projeté pour le modèle cible

Pour chaque (layer_src, layer_tgt) dans le matching de couches :
  Pour chaque module_type dans TARGET_MODULES :
      W_s_lm = W_s[layer_src, module_type]
      W_t_lm = W_t[layer_tgt, module_type]
      A_s_lm = A_s[layer_src, module_type]    # ∈ R^(r_lora × n_s)
      B_s_lm = B_s[layer_src, module_type]    # ∈ R^(m_s × r_lora)

      # 1. SVD tronquée (avec cache)
      U_s, _, V_s = svd_cached(W_s_lm, r_svd)
      U_t, _, V_t = svd_cached(W_t_lm, r_svd)

      # 2. Alignement
      P_U = solve_alignment(U_s, U_t)
      P_V = solve_alignment(V_s, V_t)
      U_tilde_s = P_U @ U_s
      V_tilde_s = P_V @ V_s

      # 3. Projection (Algorithm 1)
      B_t_lm = U_tilde_s @ (U_tilde_s.T @ B_s_lm)
      A_t_lm = (A_s_lm @ V_tilde_s) @ V_tilde_s.T

      # 4. Vérification de shapes
      assert B_t_lm.shape == (m_t, r_lora)
      assert A_t_lm.shape == (r_lora, n_t)

      A_t[layer_tgt, module_type] = A_t_lm
      B_t[layer_tgt, module_type] = B_t_lm

Implémentation PyTorch (cœur)

def project_lora_pair(
    W_s: torch.Tensor, W_t: torch.Tensor,
    A_s: torch.Tensor, B_s: torch.Tensor,
    r_svd: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    U_s, _, V_s = svd_cached(id(W_s), r_svd)
    U_t, _, V_t = svd_cached(id(W_t), r_svd)
 
    P_U = solve_alignment(U_s, U_t)
    P_V = solve_alignment(V_s, V_t)
 
    U_tilde_s = P_U @ U_s
    V_tilde_s = P_V @ V_s
 
    B_t = U_tilde_s @ (U_tilde_s.T @ B_s.float())
    A_t = (A_s.float() @ V_tilde_s) @ V_tilde_s.T
 
    return A_t, B_t

Orchestrateur complet

Voici l'enchaînement complet, en pseudo-Python lisible. Il suffit de coller les fonctions précédentes pour obtenir une version exécutable.

def cross_lora_transfer(
    base_src: PreTrainedModel,
    base_tgt: PreTrainedModel,
    adapter_src: PeftModel,
    r_svd: int = 320,
    target_modules: tuple[str, ...] = TARGET_MODULES,
) -> dict[TensorKey, torch.Tensor]:
 
    layer_pairs = match_layers(
        n_src=base_src.config.num_hidden_layers,
        n_tgt=base_tgt.config.num_hidden_layers,
    )
 
    out: dict[TensorKey, torch.Tensor] = {}
    coverage = {"matched": 0, "skipped": []}
 
    for layer_src, layer_tgt in layer_pairs:
        for mod in target_modules:
            try:
                W_s = get_weight(base_src, layer_src, mod)
                W_t = get_weight(base_tgt, layer_tgt, mod)
                A_s = get_lora_A(adapter_src, layer_src, mod)
                B_s = get_lora_B(adapter_src, layer_src, mod)
            except KeyError as e:
                coverage["skipped"].append((layer_src, mod, str(e)))
                continue
 
            A_t, B_t = project_lora_pair(W_s, W_t, A_s, B_s, r_svd)
 
            out[TensorKey(layer_tgt, mod, "in")]  = A_t.to(base_tgt.dtype)
            out[TensorKey(layer_tgt, mod, "out")] = B_t.to(base_tgt.dtype)
            coverage["matched"] += 1
 
    log_coverage(coverage, total=len(layer_pairs) * len(target_modules))
    return out

Trois choses à remarquer :

  • la boucle est plate : couches ×\times modules. Toute logique propre à une famille (biais Qwen, GQA) doit être absorbée par get_weight / get_lora_* - pas injectée dans la boucle ;
  • les exceptions de matching sont comptées, pas levées, pour produire un rapport de couverture au lieu d'un crash ;
  • le cast vers base_tgt.dtype est fait en sortie uniquement.

Encadré - biais Qwen sur QKV

Qwen2.5 attache des biais à q_proj, k_proj, v_proj que LLaMA-3.2 et Gemma-2 n'ont pas. Trois options pratiques :

  1. Ignorer les biais source quand la cible n'en a pas (option par défaut, simple et sûre). Conséquence : on perd la composante de translation apprise côté source - mineur en général.
  2. Les projeter comme des vecteurs dans l'espace de sortie via U~sU~s\widetilde{U}_s \widetilde{U}_s^{\top} appliqué côté gauche, et les attacher si la cible accepte un biais.
  3. Les rejeter avec un avertissement quand la cible n'en a pas et que le norme du biais source est non négligeable.

Ce choix doit apparaître dans le rapport de couverture.

Packaging final

  • caster vers le dtype du modèle cible (bfloat16 typiquement) après les calculs en float32 ;
  • réinjecter α\alpha (par défaut on garde l'α\alpha source - voir Table 5 du papier qui suggère parfois α=64\alpha = 64 côté cible) ;
  • écrire adapter_config.json avec target_modules = TARGET_MODULES, r = rLoRA\rlora, lora_alpha = α\alpha, base_model_name_or_path = <modèle cible> ;
  • sérialiser les tenseurs au format safetensors.

Validation à trois niveaux

À faire dans cet ordre - chaque niveau est moins coûteux à corriger que le suivant.

Niveau 1 - Algébrique

Pour chaque module transféré :

  • shapes attendues respectées ;
  • dtype final correct ;
  • pas de NaN / Inf ;
  • normes AtF\lVert \At \rVert_F et BtF\lVert \Bt \rVert_F du même ordre que côté source (un facteur 100100 est suspect).

Niveau 2 - Compatibilité

  • l'adaptateur se recharge via PeftModel.from_pretrained(base_model, path) sans avertissement de tenseurs manquants ;
  • tous les modules attendus sont présents dans la liste retournée par PEFT ;
  • une passe avant sur un batch jouet ne plante pas.

Niveau 3 - Papier

Reproduire au moins Qwen2.5-1.5B \rightarrow LLaMA-3.2-3B sur ARC-C ou OBQA, et comparer à :

  • le modèle cible nu (sans adaptateur) ;
  • une interpolation naïve des poids LoRA (zéro-padding quand les dimensions permettent).

Cibles indicatives sur ARC-C : Cross-LoRA doit être strictement au-dessus du modèle nu et de l'interpolation naïve, sans atteindre un fine-tuning natif sur la cible - c'est cohérent avec le papier.

Ablations utiles

  • rSVD{80,160,320}\rsvd \in \{80, 160, 320\} - pour voir où la qualité plafonne ;
  • attention seule vs attention + MLP ;
  • avec / sans cache SVD (impact temps, pas qualité).
ARC-C en fonction de r_svd
ARC-C en fonction de r_svd

Coût et reproductibilité

Hardware testé dans le papier

  • V100 32 GB et RTX 4090 24 GB ;
  • CUDA 11.8, Ubuntu 22.04.1.

Ordres de grandeur

MétriqueV100RTX 4090
Temps transfert~349 s~564 s
Pic mémoire2.3–5.5 GB2.3–5.5 GB

Recette LoRA source (Table 4 du papier)

À reproduire si on entraîne le LoRA source de zéro :

  • batch =16= 16, optimiseur AdamW ;
  • lr =105= 10^{-5}, weight_decay =0.1= 0.1, max_grad_norm =1.0= 1.0 ;
  • rLoRA=16\rlora = 16, α=32\alpha = 32, lora_dropout =0.1= 0.1, bias = None ;
  • steps =600= 600 (ARC-C / ARC-E / OBQA), 300300 (HellaSwag).

Config Cross-LoRA (Table 5)

  • rSVD=320\rsvd = 320 ;
  • α=64\alpha = 64 côté cible ;
  • modules cibles : q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj.
Hyperparamètres LoRA source vs Cross-LoRA
Hyperparamètres LoRA source vs Cross-LoRA

Pièges à éviter

  • Lire r=320r = 320 comme un rang LoRA : c'est le rang de sous-espace (rSVD\rsvd), à ne pas confondre avec rLoRA\rlora.
  • Implémenter l'équation globale au lieu d'Algorithm 1 : équivalent en théorie, fragile dès qu'il y a un mismatch dimensionnel.
  • Oublier de cacher les SVD : on recalcule deux fois la même décomposition pour AA et BB du même module.
  • Publier un adaptateur partiel sans signaler le taux de couverture.
  • Mélanger les conventions gauche/droite entre lora_A et lora_B - symptôme typique : shapes correctes mais qualité dégradée à zéro.
  • SVD en bfloat16 : les petites valeurs singulières deviennent du bruit pur. Toujours caster en float32 avant.
  • Comparer à un mauvais baseline : le LoRA source ne tourne pas sur la cible ; les seules baselines honnêtes sont base nue et interpolation naïve.

Résultats indicatifs (fil rouge Qwen2.5-1.5B \rightarrow LLaMA-3.2-3B)

Ordres de grandeur à attendre une fois le pipeline en place. À titre indicatif, inspiré des chiffres du papier - vos résultats varieront selon rSVD\rsvd, la recette LoRA source, et le dtype.

ConfigurationARC-CCommentaire
LLaMA-3.2-3B nuréférenceborne basse
LLaMA-3.2-3B + LoRA Qwen interpolé naïf\approx référenceles dimensions ne coïncident pas, l'interpolation est essentiellement bruit
LLaMA-3.2-3B + Cross-LoRA (rSVD=320\rsvd = 320)au-dessusgain modeste mais reproductible
LLaMA-3.2-3B + LoRA natif (réentraîné)borne hautenécessite des données

Cross-LoRA n'est pas magique : il préserve une part du gain du fine-tuning source. C'est suffisant quand les données ne sont plus accessibles et qu'on veut éviter une journée de calcul.

Résultats ARC-C par configuration
Résultats ARC-C par configuration

Checklist finale

À cocher avant de publier un adaptateur transféré :

  • rLoRA\rlora et rSVD\rsvd séparés explicitement dans le code et dans la config.
  • SVD calculée en float32, cache activé.
  • Énergie SVD conservée >0.9> 0.9 sur les couches utilisées (sinon : augmenter rSVD\rsvd ou exclure la couche).
  • Résidus d'alignement loggés ; couches au-dessus de 0.950.95 en résidu relatif signalées dans le rapport.
  • Couverture des modules 100%\geq 100\,\% attendu, ou écart documenté.
  • Biais source traités explicitement (cf. § 7.4).
  • Adaptateur rechargeable via PeftModel.from_pretrained sans warning.
  • Évaluation sur au moins un benchmark, comparée à base nue et interpolation naïve.
  • Hyperparamètres (rSVD\rsvd, α\alpha, target_modules, dtype) inscrits dans adapter_config.json ou dans une note adjacente.

Exemple d'appel end-to-end

Une fois le pipeline implémenté, l'usage côté utilisateur tient en quelques lignes :

from transformers import AutoModelForCausalLM
from peft import PeftModel
from cross_lora import cross_lora_transfer, save_peft_adapter
 
base_src = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-1.5B")
base_tgt = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B")
adapter_src = PeftModel.from_pretrained(base_src, "path/to/qwen-arc-lora")
 
projected = cross_lora_transfer(
    base_src=base_src,
    base_tgt=base_tgt,
    adapter_src=adapter_src,
    r_svd=320,
)
 
save_peft_adapter(
    tensors=projected,
    out_dir="./llama-3b-arc-from-qwen",
    base_model_name="meta-llama/Llama-3.2-3B",
    r_lora=16,
    alpha=64,
    target_modules=("q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"),
)

À partir de là, l'adaptateur ./llama-3b-arc-from-qwen/ est chargeable directement comme n'importe quel adaptateur PEFT - y compris dans des inférences vLLM ou text-generation-inference.

Quand ne pas utiliser Cross-LoRA

Cross-LoRA n'est pas la bonne réponse à toutes les situations :

  • Vous avez encore les données : ré-entraînez. C'est plus rapide, plus fidèle, et ça évite l'intégralité du pipeline ci-dessus.
  • Source et cible sont identiques : un simple chargement PEFT suffit. Aucune projection nécessaire.
  • Cible d'architecture exotique (SSM, Mamba, MoE non standard) : l'hypothèse de sous-espaces partagés ne tient plus de la même manière. Mieux vaut re-fine-tuner même sur peu de données.
  • Adaptateur source DoRA / rsLoRA sans avoir adapté le pipeline à leur structure : vous obtiendrez un adaptateur qui se charge mais perd la composante de magnitude / le facteur d'échelle correct.
  • Vous cherchez la qualité maximale absolue : la borne haute reste un fine-tuning natif. Cross-LoRA vise la portabilité, pas la SOTA.

Conclusion

Ce que Cross-LoRA achète : la portabilité de l'investissement de fine-tuning, sans données ni ré-entraînement. Le coût d'un transfert se compte en minutes sur GPU grand public, pas en heures de calcul ni en jeux de données à reconstituer.

Ce qui reste ouvert :

  • familles très éloignées - l'hypothèse de sous-espaces partagés faiblit ;
  • profondeurs très différentes - le matching proportionnel est une heuristique, pas une garantie ;
  • modules non standards (MoE, GQA agressifs, attention factorisée) - il faut étendre la liste TARGET_MODULES au cas par cas et adapter les vérifications de dimensions.

Pour la suite : on peut imaginer combiner Cross-LoRA avec un léger ajustement post-transfert (quelques centaines de pas sur un petit corpus) pour récupérer le gain manquant - mais ce n'est plus zero-data, c'est un autre régime.