Fix get_lora_param_count with model-specific per-component table#695
Conversation
Mirrors `create_lora_training_client`'s signature so callers can predict
the exact trainable parameter count for any Tinker base model:
```
get_lora_param_count(model_name, lora_rank=32,
train_mlp=True, train_attn=True, train_unembed=True)
```
Removes the previous implementation, which parsed HuggingFace
safetensors headers and was off by 10x or more for many models, plus
required network access and HF auth on gated repos.
```
total = lora_rank * (mlp * train_mlp + attn * train_attn
+ unembed * train_unembed
+ mlp_attn_extra * (train_mlp and train_attn))
```
The `mlp_attn_extra` term is NemotronH-specific: SSM/Mamba layers are
LoRA-adapted only when both `train_mlp` and `train_attn` are True.
Removed kwargs: `detailed`, `include_experts`,
`shared_expert_outer_loras`.
Tests verify the function reproduces an independently-maintained
reference table for all 31 Tinker base models across 7 ranks and 7
valid flag combinations (217 parametrizations × 7 ranks).
Next up: updating docs to reflect the new signature.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
get_lora_param_count with hardcoded per-component tableget_lora_param_count with model-specific per-component table
|
Maybe add a test to validate that all models in model_info support get_lora_param_count? |
hmm I like this idea! but the problem is it has a bunch of models that aren't actually on Tinker: |
For the two Nemotron-H models, derive per-component (mlp / attn / unembed) param counts by partitioning the all-True adapter's tensors into groups by sub-layer name, rather than measuring each flag combination directly. The result is additive across all 7 valid flag combinations, matching every other model in the table. This changes the values returned by `get_lora_param_count` for partial flag combinations on the two Nemotron-H models; the default (all-True) total is unchanged. Drops the now-unused `mlp_attn_extra` branch and its corresponding key in the lookup dict. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
LGTM after @nealwu walked me through his measurement methodology (I also ran some to validate) and mirroring of the class signature is quite neat. So we need to update this list for every model release? Also that Do note if developers and researchers want to compute the number of params themselves and cross reference, we could let them know LoRA counts for MoEs are using shared-outer-LoRA for experts. |
| @pytest.mark.parametrize("model_name", sorted(_REFERENCE_PARAMS_PER_RANK.keys())) | ||
| def test_get_lora_param_count_matches_measurements( | ||
| self, model_name: str, flag_combo: tuple[bool, bool, bool] | ||
| ) -> None: |
There was a problem hiding this comment.
nit: this test and _LORA_PARAMS_PER_RANK_BY_COMPONENT seem to come from same measurement, so mostly testing adding up the components, is that the goal of the test? that won't cover any actual lora count measurement issues.
There was a problem hiding this comment.
the difference is the test doesn't assume addition of components and actually measures each configuration, so if there are any inconsistencies (like we saw here) the test would help pick it up
(the and on "shared-outer-LoRA for experts" sounds good, I'll update the docs to explain this and have you help double check it! |
Removes the previous implementation, which was off by 10x or more for many newer models, plus required network access and HF auth on gated repos and was often extremely slow.
Now mirrors
create_lora_training_client's signature so callers can predict the exact trainable parameter count for any Tinker base model:Removed kwargs:
detailed,include_experts,shared_expert_outer_loras.Tests verify the function reproduces an independently-maintained reference table for all 31 Tinker base models across 7 ranks and 7 valid flag combinations (217 parametrizations × 7 ranks).
Next up: updating docs to reflect the new signature.