Skip to content

Commit b831118

Browse files
forforever73lvyichenCISC
authored
model : support Step3.5-Flash (ggml-org#19283)
* Support Step3.5-Flash * fix: norm.weight + 1 (HF zero_centered=true) * step35: simplify GGUF conversion + drop redundant rope KVs * Address review feedback * rename limits -> clamp * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Apply suggestion from @CISC Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * rename swiglu limits -> swiglu clamp in LLM_KV * avoid CI fail * Apply suggestions from code review * Apply suggestions from code review * disabled KV shifting for LLM_ARCH_STEP35 * Apply suggestions from code review * mistakenly removed cmath * add model size && apply missed suggestion * assert partial_rotary_factors * fix CI errors: * load freq_base_swa --------- Co-authored-by: lvyichen <lvyichen@stepfun.com> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent 3228e77 commit b831118

15 files changed

+576
-38
lines changed

convert_hf_to_gguf.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ def set_gguf_parameters(self):
920920
self.gguf_writer.add_expert_group_used_count(n_group_used)
921921
logger.info(f"gguf: expert groups used count = {n_group_used}")
922922

923-
if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func", "moe_router_activation_func"], optional=True)) is not None:
923+
if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func", "moe_router_activation", "moe_router_activation_func"], optional=True)) is not None:
924924
if score_func == "sigmoid":
925925
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
926926
elif score_func == "softmax":
@@ -7912,6 +7912,135 @@ def prepare_tensors(self):
79127912
raise ValueError(f"Unprocessed experts: {experts}")
79137913

79147914

7915+
@ModelBase.register("Step3p5ForCausalLM")
7916+
class Step35Model(TextModel):
7917+
model_arch = gguf.MODEL_ARCH.STEP35
7918+
7919+
def set_gguf_parameters(self):
7920+
rope_theta = self.hparams.get("rope_theta")
7921+
if isinstance(rope_theta, list):
7922+
self.hparams["rope_theta"] = float(rope_theta[0])
7923+
self.hparams["local_rope_theta"] = float(rope_theta[1])
7924+
self.rope_parameters["rope_theta"] = self.hparams["rope_theta"]
7925+
self.rope_parameters["sliding_attention"] = {"rope_theta": self.hparams["local_rope_theta"]}
7926+
7927+
super().set_gguf_parameters()
7928+
7929+
layer_types = self.hparams.get("layer_types") or []
7930+
partial_rotary_factors = self.hparams.get("partial_rotary_factors") or []
7931+
attn_other = self.hparams.get("attention_other_setting") or {}
7932+
7933+
n_head_base = self.hparams["num_attention_heads"]
7934+
n_kv_base = self.hparams["num_attention_groups"]
7935+
7936+
n_head_swa = attn_other.get("num_attention_heads", n_head_base)
7937+
n_kv_swa = attn_other.get("num_attention_groups", n_kv_base)
7938+
7939+
layer_types = layer_types[: self.block_count]
7940+
partial_rotary_factors = partial_rotary_factors[: self.block_count]
7941+
assert [1.0 if lt == "sliding_attention" else 0.5 for lt in layer_types] == partial_rotary_factors
7942+
head_arr = [n_head_swa if lt == "sliding_attention" else n_head_base for lt in layer_types]
7943+
kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types]
7944+
swa_pat = [lt == "sliding_attention" for lt in layer_types]
7945+
7946+
self.gguf_writer.add_head_count(head_arr)
7947+
self.gguf_writer.add_head_count_kv(kv_arr)
7948+
7949+
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
7950+
self.gguf_writer.add_sliding_window_pattern(swa_pat)
7951+
7952+
self.gguf_writer.add_value_length(self.hparams["head_dim"])
7953+
7954+
# MoE params
7955+
self.gguf_writer.add_expert_count(self.hparams["moe_num_experts"])
7956+
self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"])
7957+
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
7958+
self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["share_expert_dim"])
7959+
7960+
if (moe_router_scaling_factor := self.hparams.get("moe_router_scaling_factor")) is not None:
7961+
self.gguf_writer.add_expert_weights_scale(moe_router_scaling_factor)
7962+
if (norm_expert_weight := self.hparams.get("norm_expert_weight")) is not None:
7963+
self.gguf_writer.add_expert_weights_norm(norm_expert_weight)
7964+
7965+
# leading dense blocks
7966+
leading_dense = 0
7967+
moe_layers_enum = self.hparams.get("moe_layers_enum")
7968+
if isinstance(moe_layers_enum, str) and moe_layers_enum.strip():
7969+
moe_layers = sorted(int(i) for i in moe_layers_enum.strip().split(","))
7970+
if moe_layers:
7971+
leading_dense = max(0, moe_layers[0])
7972+
self.gguf_writer.add_leading_dense_block_count(leading_dense)
7973+
self.gguf_writer.add_moe_every_n_layers(int(self.hparams.get("moe_every_n_layer", 1)))
7974+
7975+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
7976+
7977+
# Optional per-layer SwiGLU clamps.
7978+
if (limits := self.hparams.get("swiglu_limits")) is not None:
7979+
limits_f = [0.0 if v is None else float(v) for v in limits[: self.block_count]]
7980+
self.gguf_writer.add_swiglu_clamp_exp(limits_f)
7981+
if (limits_shared := self.hparams.get("swiglu_limits_shared")) is not None:
7982+
limits_shared_f = [0.0 if v is None else float(v) for v in limits_shared[: self.block_count]]
7983+
self.gguf_writer.add_swiglu_clamp_shexp(limits_shared_f)
7984+
7985+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
7986+
# remove mtp layers
7987+
if (m := re.match(r"model\.layers\.(\d+)\.", name)) is not None:
7988+
il = int(m.group(1))
7989+
n_main = int(self.hparams.get("num_hidden_layers", self.block_count))
7990+
if il >= n_main:
7991+
return
7992+
if name.endswith("norm.weight"):
7993+
data_torch += 1.0
7994+
# Map router bias (expert selection bias) to a GGUF bias tensor
7995+
if name.endswith(".moe.router_bias"):
7996+
name += ".bias"
7997+
7998+
if name.endswith((".self_attn.g_proj.weight", ".moe.gate.weight", ".moe.up_proj.weight", ".moe.gate_proj.weight", ".moe.down_proj.weight")):
7999+
data_torch = data_torch.squeeze().contiguous()
8000+
8001+
yield from super().modify_tensors(data_torch, name, bid)
8002+
8003+
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
8004+
# Step35 can optionally use Llama-3 style RoPE scaling (HF: rope_scaling.rope_type == "llama3").
8005+
# llama.cpp represents this via a single extra tensor: "rope_freqs.weight" (aka MODEL_TENSOR.ROPE_FREQS).
8006+
rope_params = self.rope_parameters.get("full_attention", self.rope_parameters)
8007+
rope_type = rope_params.get("rope_type") or ""
8008+
if rope_type.lower() != "llama3":
8009+
return
8010+
8011+
# Step35 configs can carry per-layer rope_theta as a list; for llama3 rope factors we use the base value.
8012+
rope_theta = self.hparams.get("rope_theta", 10000.0)
8013+
if isinstance(rope_theta, list):
8014+
rope_theta = rope_theta[0]
8015+
base = float(rope_theta)
8016+
if (dim := self.hparams.get("head_dim")) is None:
8017+
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
8018+
dim = int(dim)
8019+
8020+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
8021+
8022+
factor = float(rope_params.get("factor", 8.0))
8023+
low_freq_factor = float(rope_params.get("low_freq_factor", 1.0))
8024+
high_freq_factor = float(rope_params.get("high_freq_factor", 4.0))
8025+
old_context_len = int(rope_params.get("original_max_position_embeddings", self.hparams.get("original_max_position_embeddings", 8192)))
8026+
8027+
low_freq_wavelen = old_context_len / low_freq_factor
8028+
high_freq_wavelen = old_context_len / high_freq_factor
8029+
8030+
rope_factors: list[float] = []
8031+
for freq in freqs:
8032+
wavelen = 2 * math.pi / float(freq)
8033+
if wavelen < high_freq_wavelen:
8034+
rope_factors.append(1.0)
8035+
elif wavelen > low_freq_wavelen:
8036+
rope_factors.append(factor)
8037+
else:
8038+
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
8039+
rope_factors.append(1.0 / ((1.0 - smooth) / factor + smooth))
8040+
8041+
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
8042+
8043+
79158044
@ModelBase.register("PanguEmbeddedForCausalLM")
79168045
class PanguEmbeddedModel(TextModel):
79178046
model_arch = gguf.MODEL_ARCH.PANGU_EMBED

gguf-py/gguf/constants.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ class LLM:
146146
ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
147147
ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
148148
EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input"
149+
SWIGLU_CLAMP_EXP = "{arch}.swiglu_clamp_exp"
150+
SWIGLU_CLAMP_SHEXP = "{arch}.swiglu_clamp_shexp"
149151
DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in"
150152
DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
151153

@@ -179,20 +181,20 @@ class Attention:
179181
TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"
180182

181183
class Rope:
182-
DIMENSION_COUNT = "{arch}.rope.dimension_count"
183-
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
184-
FREQ_BASE = "{arch}.rope.freq_base"
185-
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
186-
SCALING_TYPE = "{arch}.rope.scaling.type"
187-
SCALING_FACTOR = "{arch}.rope.scaling.factor"
188-
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
189-
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
190-
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
191-
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
192-
SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor"
193-
SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor"
194-
SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast"
195-
SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow"
184+
DIMENSION_COUNT = "{arch}.rope.dimension_count"
185+
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
186+
FREQ_BASE = "{arch}.rope.freq_base"
187+
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
188+
SCALING_TYPE = "{arch}.rope.scaling.type"
189+
SCALING_FACTOR = "{arch}.rope.scaling.factor"
190+
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
191+
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
192+
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
193+
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
194+
SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor"
195+
SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor"
196+
SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast"
197+
SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow"
196198

197199
class Split:
198200
LLM_KV_SPLIT_NO = "split.no"
@@ -462,6 +464,7 @@ class MODEL_ARCH(IntEnum):
462464
PANGU_EMBED = auto()
463465
MISTRAL3 = auto()
464466
MIMO2 = auto()
467+
STEP35 = auto()
465468
LLAMA_EMBED = auto()
466469
MAINCODER = auto()
467470
KIMI_LINEAR = auto()
@@ -892,6 +895,7 @@ class MODEL_TENSOR(IntEnum):
892895
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
893896
MODEL_ARCH.MISTRAL3: "mistral3",
894897
MODEL_ARCH.MIMO2: "mimo2",
898+
MODEL_ARCH.STEP35: "step35",
895899
MODEL_ARCH.LLAMA_EMBED: "llama-embed",
896900
MODEL_ARCH.MAINCODER: "maincoder",
897901
MODEL_ARCH.KIMI_LINEAR: "kimi-linear",
@@ -3364,6 +3368,32 @@ class MODEL_TENSOR(IntEnum):
33643368
MODEL_TENSOR.FFN_UP_EXP,
33653369
MODEL_TENSOR.FFN_EXP_PROBS_B,
33663370
],
3371+
MODEL_ARCH.STEP35: [
3372+
MODEL_TENSOR.TOKEN_EMBD,
3373+
MODEL_TENSOR.OUTPUT_NORM,
3374+
MODEL_TENSOR.OUTPUT,
3375+
MODEL_TENSOR.ROPE_FREQS,
3376+
MODEL_TENSOR.ATTN_NORM,
3377+
MODEL_TENSOR.ATTN_Q,
3378+
MODEL_TENSOR.ATTN_Q_NORM,
3379+
MODEL_TENSOR.ATTN_K,
3380+
MODEL_TENSOR.ATTN_K_NORM,
3381+
MODEL_TENSOR.ATTN_V,
3382+
MODEL_TENSOR.ATTN_GATE,
3383+
MODEL_TENSOR.ATTN_OUT,
3384+
MODEL_TENSOR.FFN_NORM,
3385+
MODEL_TENSOR.FFN_GATE,
3386+
MODEL_TENSOR.FFN_DOWN,
3387+
MODEL_TENSOR.FFN_UP,
3388+
MODEL_TENSOR.FFN_GATE_INP,
3389+
MODEL_TENSOR.FFN_GATE_EXP,
3390+
MODEL_TENSOR.FFN_DOWN_EXP,
3391+
MODEL_TENSOR.FFN_UP_EXP,
3392+
MODEL_TENSOR.FFN_UP_SHEXP,
3393+
MODEL_TENSOR.FFN_GATE_SHEXP,
3394+
MODEL_TENSOR.FFN_DOWN_SHEXP,
3395+
MODEL_TENSOR.FFN_EXP_PROBS_B,
3396+
],
33673397
MODEL_ARCH.LLAMA_EMBED: [
33683398
MODEL_TENSOR.TOKEN_EMBD,
33693399
MODEL_TENSOR.OUTPUT_NORM,
@@ -3753,12 +3783,12 @@ class VisionProjectorType:
37533783
KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS
37543784

37553785
# RoPE
3756-
KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT
3757-
KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE
3758-
KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE
3759-
KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR
3760-
KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
3761-
KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
3786+
KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT
3787+
KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE
3788+
KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE
3789+
KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR
3790+
KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
3791+
KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
37623792

37633793
# SSM
37643794
KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL

gguf-py/gguf/gguf_writer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,12 @@ def add_expert_weights_norm(self, value: bool) -> None:
824824
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
825825
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
826826

827+
def add_swiglu_clamp_exp(self, values: Sequence[float]) -> None:
828+
self.add_array(Keys.LLM.SWIGLU_CLAMP_EXP.format(arch=self.arch), values)
829+
830+
def add_swiglu_clamp_shexp(self, values: Sequence[float]) -> None:
831+
self.add_array(Keys.LLM.SWIGLU_CLAMP_SHEXP.format(arch=self.arch), values)
832+
827833
def add_expert_group_scale(self, value: float) -> None:
828834
self.add_float32(Keys.LLM.EXPERT_GROUP_SCALE.format(arch=self.arch), value)
829835

gguf-py/gguf/tensor_mapping.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ class TensorNameMap:
359359

360360
MODEL_TENSOR.ATTN_GATE: (
361361
"model.layers.{bid}.self_attn.gate_proj", # afmoe
362+
"model.layers.{bid}.self_attn.g_proj", # step3.5 head-wise attention gate
362363
),
363364

364365
# Feed-forward norm
@@ -423,6 +424,7 @@ class TensorNameMap:
423424
"model.layers.{bid}.mlp.router.gate", # afmoe
424425
"layers.{bid}.gate", # mistral-large
425426
"backbone.layers.{bid}.mixer.gate", # nemotron-h-moe
427+
"model.layers.{bid}.moe.gate", # step3.5
426428
),
427429

428430
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@@ -439,6 +441,7 @@ class TensorNameMap:
439441
"backbone.layers.{bid}.mixer.gate.e_score_correction", # nemotron-h-moe
440442
"model.layers.{bid}.mlp.e_score_correction", # exaone-moe
441443
"model.layers.{bid}.block_sparse_moe.gate.e_score_correction", # kimi
444+
"model.layers.{bid}.moe.router_bias", # step3.5 expert selection bias
442445
),
443446

444447
# Feed-forward up
@@ -493,6 +496,7 @@ class TensorNameMap:
493496
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
494497
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
495498
"model.layers.{bid}.block_sparse_moe.experts.up", # smallthinker
499+
"model.layers.{bid}.moe.up_proj", # step3.5
496500
),
497501

498502
MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -504,6 +508,7 @@ class TensorNameMap:
504508
"layers.{bid}.shared_experts.w3", # mistral-large
505509
"backbone.layers.{bid}.mixer.shared_experts.up_proj", # nemotron-h-moe
506510
"model.layers.{bid}.block_sparse_moe.shared_experts.up_proj", # kimi
511+
"model.layers.{bid}.share_expert.up_proj", # step3.5
507512
),
508513

509514
MODEL_TENSOR.FFN_UP_CHEXP: (
@@ -543,6 +548,7 @@ class TensorNameMap:
543548
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
544549
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
545550
"model.layers.{bid}.block_sparse_moe.experts.gate", # smallthinker
551+
"model.layers.{bid}.moe.gate_proj", # step3.5
546552
),
547553

548554
MODEL_TENSOR.FFN_GATE_SHEXP: (
@@ -552,6 +558,7 @@ class TensorNameMap:
552558
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
553559
"layers.{bid}.shared_experts.w1", # mistral-large
554560
"model.layers.{bid}.block_sparse_moe.shared_experts.gate_proj", # kimi
561+
"model.layers.{bid}.share_expert.gate_proj", # step3.5
555562
),
556563

557564
MODEL_TENSOR.FFN_GATE_CHEXP: (
@@ -606,6 +613,7 @@ class TensorNameMap:
606613
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
607614
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
608615
"model.layers.{bid}.block_sparse_moe.experts.down", # smallthinker
616+
"model.layers.{bid}.moe.down_proj", # step3.5
609617
),
610618

611619
MODEL_TENSOR.FFN_DOWN_SHEXP: (
@@ -617,6 +625,7 @@ class TensorNameMap:
617625
"layers.{bid}.shared_experts.w2", # mistral-large
618626
"backbone.layers.{bid}.mixer.shared_experts.down_proj", # nemotron-h-moe
619627
"model.layers.{bid}.block_sparse_moe.shared_experts.down_proj", # kimi
628+
"model.layers.{bid}.share_expert.down_proj", # step3.5
620629
),
621630

622631
MODEL_TENSOR.FFN_DOWN_CHEXP: (

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ add_library(llama
135135
models/stablelm.cpp
136136
models/starcoder.cpp
137137
models/starcoder2.cpp
138+
models/step35-iswa.cpp
138139
models/t5-dec.cpp
139140
models/t5-enc.cpp
140141
models/wavtokenizer-dec.cpp

0 commit comments

Comments
 (0)