Skip to content

Commit 76133f8

Browse files
nealwuclaude
andauthored
Fix get_lora_param_count with model-specific per-component table (#695)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent d881771 commit 76133f8

2 files changed

Lines changed: 425 additions & 55 deletions

File tree

tests/downstream_compat/test_cli_and_hyperparam.py

Lines changed: 340 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,306 @@
33
Validates that CLI utilities and hyperparameter functions remain stable.
44
"""
55

6+
import pytest
7+
68
from tinker_cookbook.cli_utils import check_log_dir
79
from tinker_cookbook.hyperparam_utils import (
810
get_lora_lr_over_full_finetune_lr,
911
get_lora_param_count,
1012
get_lr,
1113
)
1214

15+
# Independently-measured rank=1 LoRA parameter counts for every Tinker base
16+
# model under every valid combination of train_mlp / train_attn / train_unembed.
17+
# Keys are (train_mlp, train_attn, train_unembed). The all-False case is
18+
# excluded because Tinker rejects it.
19+
#
20+
# To refresh these values (e.g., when a new model ships), ask a Tinker team
21+
# member to re-run the measurement script.
22+
_REFERENCE_PARAMS_PER_RANK: dict[str, dict[tuple[bool, bool, bool], int]] = {
23+
"Qwen/Qwen3-235B-A22B-Instruct-2507": {
24+
(True, True, True): 59_931_008,
25+
(True, True, False): 59_774_976,
26+
(True, False, True): 56_754_560,
27+
(True, False, False): 56_598_528,
28+
(False, True, True): 3_332_480,
29+
(False, True, False): 3_176_448,
30+
(False, False, True): 156_032,
31+
},
32+
"Qwen/Qwen3-30B-A3B": {
33+
(True, True, True): 15_440_256,
34+
(True, True, False): 15_286_272,
35+
(True, False, True): 14_604_672,
36+
(True, False, False): 14_450_688,
37+
(False, True, True): 989_568,
38+
(False, True, False): 835_584,
39+
(False, False, True): 153_984,
40+
},
41+
"Qwen/Qwen3-30B-A3B-Base": {
42+
(True, True, True): 15_440_256,
43+
(True, True, False): 15_286_272,
44+
(True, False, True): 14_604_672,
45+
(True, False, False): 14_450_688,
46+
(False, True, True): 989_568,
47+
(False, True, False): 835_584,
48+
(False, False, True): 153_984,
49+
},
50+
"Qwen/Qwen3-30B-A3B-Instruct-2507": {
51+
(True, True, True): 15_440_256,
52+
(True, True, False): 15_286_272,
53+
(True, False, True): 14_604_672,
54+
(True, False, False): 14_450_688,
55+
(False, True, True): 989_568,
56+
(False, True, False): 835_584,
57+
(False, False, True): 153_984,
58+
},
59+
"Qwen/Qwen3-32B": {
60+
(True, True, True): 8_545_664,
61+
(True, True, False): 8_388_608,
62+
(True, False, True): 6_055_296,
63+
(True, False, False): 5_898_240,
64+
(False, True, True): 2_647_424,
65+
(False, True, False): 2_490_368,
66+
(False, False, True): 157_056,
67+
},
68+
"Qwen/Qwen3-4B-Instruct-2507": {
69+
(True, True, True): 2_218_880,
70+
(True, True, False): 2_064_384,
71+
(True, False, True): 1_481_600,
72+
(True, False, False): 1_327_104,
73+
(False, True, True): 891_776,
74+
(False, True, False): 737_280,
75+
(False, False, True): 154_496,
76+
},
77+
"Qwen/Qwen3-8B": {
78+
(True, True, True): 2_883_968,
79+
(True, True, False): 2_727_936,
80+
(True, False, True): 1_925_504,
81+
(True, False, False): 1_769_472,
82+
(False, True, True): 1_114_496,
83+
(False, True, False): 958_464,
84+
(False, False, True): 156_032,
85+
},
86+
"Qwen/Qwen3-8B-Base": {
87+
(True, True, True): 2_883_968,
88+
(True, True, False): 2_727_936,
89+
(True, False, True): 1_925_504,
90+
(True, False, False): 1_769_472,
91+
(False, True, True): 1_114_496,
92+
(False, True, False): 958_464,
93+
(False, False, True): 156_032,
94+
},
95+
"Qwen/Qwen3-VL-235B-A22B-Instruct": {
96+
(True, True, True): 59_931_008,
97+
(True, True, False): 59_774_976,
98+
(True, False, True): 56_754_560,
99+
(True, False, False): 56_598_528,
100+
(False, True, True): 3_332_480,
101+
(False, True, False): 3_176_448,
102+
(False, False, True): 156_032,
103+
},
104+
"Qwen/Qwen3-VL-30B-A3B-Instruct": {
105+
(True, True, True): 15_440_256,
106+
(True, True, False): 15_286_272,
107+
(True, False, True): 14_604_672,
108+
(True, False, False): 14_450_688,
109+
(False, True, True): 989_568,
110+
(False, True, False): 835_584,
111+
(False, False, True): 153_984,
112+
},
113+
"Qwen/Qwen3.5-27B": {
114+
(True, True, True): 7_544_320,
115+
(True, True, False): 7_290_880,
116+
(True, False, True): 4_578_816,
117+
(True, False, False): 4_325_376,
118+
(False, True, True): 3_218_944,
119+
(False, True, False): 2_965_504,
120+
(False, False, True): 253_440,
121+
},
122+
"Qwen/Qwen3.5-35B-A3B": {
123+
(True, True, True): 17_545_728,
124+
(True, True, False): 17_295_360,
125+
(True, False, True): 16_531_968,
126+
(True, False, False): 16_281_600,
127+
(False, True, True): 1_264_128,
128+
(False, True, False): 1_013_760,
129+
(False, False, True): 250_368,
130+
},
131+
"Qwen/Qwen3.5-397B-A17B": {
132+
(True, True, True): 99_124_736,
133+
(True, True, False): 98_872_320,
134+
(True, False, True): 96_283_136,
135+
(True, False, False): 96_030_720,
136+
(False, True, True): 3_094_016,
137+
(False, True, False): 2_841_600,
138+
(False, False, True): 252_416,
139+
},
140+
"Qwen/Qwen3.5-4B": {
141+
(True, True, True): 2_278_400,
142+
(True, True, False): 2_027_520,
143+
(True, False, True): 1_381_376,
144+
(True, False, False): 1_130_496,
145+
(False, True, True): 1_147_904,
146+
(False, True, False): 897_024,
147+
(False, False, True): 250_880,
148+
},
149+
"Qwen/Qwen3.6-27B": {
150+
(True, True, True): 7_544_320,
151+
(True, True, False): 7_290_880,
152+
(True, False, True): 4_578_816,
153+
(True, False, False): 4_325_376,
154+
(False, True, True): 3_218_944,
155+
(False, True, False): 2_965_504,
156+
(False, False, True): 253_440,
157+
},
158+
"Qwen/Qwen3.6-35B-A3B": {
159+
(True, True, True): 17_545_728,
160+
(True, True, False): 17_295_360,
161+
(True, False, True): 16_531_968,
162+
(True, False, False): 16_281_600,
163+
(False, True, True): 1_264_128,
164+
(False, True, False): 1_013_760,
165+
(False, False, True): 250_368,
166+
},
167+
"deepseek-ai/DeepSeek-V3.1": {
168+
(True, True, True): 96_883_776,
169+
(True, True, False): 96_747_328,
170+
(True, False, True): 94_443_776,
171+
(True, False, False): 94_307_328,
172+
(False, True, True): 2_576_448,
173+
(False, True, False): 2_440_000,
174+
(False, False, True): 136_448,
175+
},
176+
"deepseek-ai/DeepSeek-V3.1-Base": {
177+
(True, True, True): 96_883_776,
178+
(True, True, False): 96_747_328,
179+
(True, False, True): 94_443_776,
180+
(True, False, False): 94_307_328,
181+
(False, True, True): 2_576_448,
182+
(False, True, False): 2_440_000,
183+
(False, False, True): 136_448,
184+
},
185+
"meta-llama/Llama-3.1-70B": {
186+
(True, True, True): 13_079_808,
187+
(True, True, False): 12_943_360,
188+
(True, False, True): 8_983_808,
189+
(True, False, False): 8_847_360,
190+
(False, True, True): 4_232_448,
191+
(False, True, False): 4_096_000,
192+
(False, False, True): 136_448,
193+
},
194+
"meta-llama/Llama-3.1-8B": {
195+
(True, True, True): 2_753_792,
196+
(True, True, False): 2_621_440,
197+
(True, False, True): 1_901_824,
198+
(True, False, False): 1_769_472,
199+
(False, True, True): 984_320,
200+
(False, True, False): 851_968,
201+
(False, False, True): 132_352,
202+
},
203+
"meta-llama/Llama-3.1-8B-Instruct": {
204+
(True, True, True): 2_753_792,
205+
(True, True, False): 2_621_440,
206+
(True, False, True): 1_901_824,
207+
(True, False, False): 1_769_472,
208+
(False, True, True): 984_320,
209+
(False, True, False): 851_968,
210+
(False, False, True): 132_352,
211+
},
212+
"meta-llama/Llama-3.2-1B": {
213+
(True, True, True): 834_816,
214+
(True, True, False): 704_512,
215+
(True, False, True): 621_824,
216+
(True, False, False): 491_520,
217+
(False, True, True): 343_296,
218+
(False, True, False): 212_992,
219+
(False, False, True): 130_304,
220+
},
221+
"meta-llama/Llama-3.2-3B": {
222+
(True, True, True): 1_650_944,
223+
(True, True, False): 1_519_616,
224+
(True, False, True): 1_077_504,
225+
(True, False, False): 946_176,
226+
(False, True, True): 704_768,
227+
(False, True, False): 573_440,
228+
(False, False, True): 131_328,
229+
},
230+
"meta-llama/Llama-3.3-70B-Instruct": {
231+
(True, True, True): 13_079_808,
232+
(True, True, False): 12_943_360,
233+
(True, False, True): 8_983_808,
234+
(True, False, False): 8_847_360,
235+
(False, True, True): 4_232_448,
236+
(False, True, False): 4_096_000,
237+
(False, False, True): 136_448,
238+
},
239+
"moonshotai/Kimi-K2-Thinking": {
240+
(True, True, True): 146_694_976,
241+
(True, True, False): 146_523_968,
242+
(True, False, True): 144_754_688,
243+
(True, False, False): 144_583_680,
244+
(False, True, True): 2_111_296,
245+
(False, True, False): 1_940_288,
246+
(False, False, True): 171_008,
247+
},
248+
"moonshotai/Kimi-K2.5": {
249+
(True, True, True): 146_694_976,
250+
(True, True, False): 146_523_968,
251+
(True, False, True): 144_754_688,
252+
(True, False, False): 144_583_680,
253+
(False, True, True): 2_111_296,
254+
(False, True, False): 1_940_288,
255+
(False, False, True): 171_008,
256+
},
257+
"moonshotai/Kimi-K2.6": {
258+
(True, True, True): 146_694_976,
259+
(True, True, False): 146_523_968,
260+
(True, False, True): 144_754_688,
261+
(True, False, False): 144_583_680,
262+
(False, True, True): 2_111_296,
263+
(False, True, False): 1_940_288,
264+
(False, False, True): 171_008,
265+
},
266+
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16": {
267+
(True, True, True): 12_064_768,
268+
(True, True, False): 11_931_008,
269+
(True, False, True): 11_479_936,
270+
(True, False, False): 11_346_176,
271+
(False, True, True): 718_592,
272+
(False, True, False): 584_832,
273+
(False, False, True): 133_760,
274+
},
275+
"nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16": {
276+
(True, True, True): 113_160_192,
277+
(True, True, False): 113_025_024,
278+
(True, False, True): 111_484_928,
279+
(True, False, False): 111_349_760,
280+
(False, True, True): 1_810_432,
281+
(False, True, False): 1_675_264,
282+
(False, False, True): 135_168,
283+
},
284+
"openai/gpt-oss-120b": {
285+
(True, True, True): 41_074_624,
286+
(True, True, False): 40_870_656,
287+
(True, False, True): 40_328_128,
288+
(True, False, False): 40_124_160,
289+
(False, True, True): 950_464,
290+
(False, True, False): 746_496,
291+
(False, False, True): 203_968,
292+
},
293+
"openai/gpt-oss-20b": {
294+
(True, True, True): 7_544_512,
295+
(True, True, False): 7_340_544,
296+
(True, False, True): 7_046_848,
297+
(True, False, False): 6_842_880,
298+
(False, True, True): 701_632,
299+
(False, True, False): 497_664,
300+
(False, False, True): 203_968,
301+
},
302+
}
303+
304+
_TEST_RANKS = (1, 2, 4, 8, 16, 32, 64)
305+
13306

14307
class TestCliUtils:
15308
def test_check_log_dir_signature(self):
@@ -30,11 +323,55 @@ def test_get_lora_lr_over_full_finetune_lr_signature(self):
30323
assert_params(get_lora_lr_over_full_finetune_lr, ["model_name", "lora_alpha"])
31324

32325
def test_get_lora_param_count_signature(self):
33-
from tests.downstream_compat.sig_helpers import assert_params_subset
326+
from tests.downstream_compat.sig_helpers import assert_params
34327

35-
assert_params_subset(get_lora_param_count, ["model_name", "lora_rank"])
328+
assert_params(
329+
get_lora_param_count,
330+
["model_name", "lora_rank", "train_mlp", "train_attn", "train_unembed"],
331+
)
36332

37333
def test_get_lr_returns_float(self):
38-
lr = get_lr("Qwen/Qwen3-8B", is_lora=True)
334+
lr = get_lr("Qwen/Qwen3.6-27B", is_lora=True)
39335
assert isinstance(lr, float)
40336
assert lr > 0
337+
338+
def test_get_lora_param_count_rejects_all_false(self):
339+
with pytest.raises(ValueError):
340+
get_lora_param_count(
341+
"Qwen/Qwen3.6-27B",
342+
lora_rank=32,
343+
train_mlp=False,
344+
train_attn=False,
345+
train_unembed=False,
346+
)
347+
348+
@pytest.mark.parametrize(
349+
"flag_combo",
350+
sorted(
351+
{combo for params in _REFERENCE_PARAMS_PER_RANK.values() for combo in params},
352+
reverse=True,
353+
),
354+
)
355+
@pytest.mark.parametrize("model_name", sorted(_REFERENCE_PARAMS_PER_RANK.keys()))
356+
def test_get_lora_param_count_matches_measurements(
357+
self, model_name: str, flag_combo: tuple[bool, bool, bool]
358+
) -> None:
359+
"""Function output must equal the measured rank=1 value times the rank,
360+
for every (model, train_mlp, train_attn, train_unembed) combination.
361+
"""
362+
train_mlp, train_attn, train_unembed = flag_combo
363+
params_at_rank_1 = _REFERENCE_PARAMS_PER_RANK[model_name][flag_combo]
364+
for rank in _TEST_RANKS:
365+
expected = params_at_rank_1 * rank
366+
actual = get_lora_param_count(
367+
model_name,
368+
lora_rank=rank,
369+
train_mlp=train_mlp,
370+
train_attn=train_attn,
371+
train_unembed=train_unembed,
372+
)
373+
assert actual == expected, (
374+
f"{model_name} rank={rank} "
375+
f"(mlp={train_mlp}, attn={train_attn}, unembed={train_unembed}): "
376+
f"expected {expected:,}, got {actual:,} (off by {actual - expected:+,})"
377+
)

0 commit comments

Comments
 (0)