diff --git a/src/mcore_bridge/model/gpts/deepseek_v4.py b/src/mcore_bridge/model/gpts/deepseek_v4.py index 14eb047..ed8a305 100644 --- a/src/mcore_bridge/model/gpts/deepseek_v4.py +++ b/src/mcore_bridge/model/gpts/deepseek_v4.py @@ -70,7 +70,7 @@ def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) self.layer_type = self.config.hf_config.layer_types[self.layer_number - 1] self.rope_layer_type = 'main' if self.layer_type == 'sliding_attention' else 'compress' - if getattr(config, 'fp8_param', False): + if config.fp8_param: group_proj_in_size = self.query_projection_size // config.o_groups del self.linear_o_group_proj self.linear_o_group_proj = te.GroupedLinear( @@ -442,7 +442,7 @@ def _set_o_group_proj_grouped(self, mg_attn, hf_state_dict, to_mcore): scale_invs.append(param._rowwise_scale_inv) else: weights.append(param.data) - hf_state_dict['wo_a.weight'] = torch.cat(weights, dim=0) + hf_state_dict['wo_a.weight'] = torch.cat(weights, dim=0).view(torch.float8_e4m3fn) if scale_invs: hf_state_dict['wo_a.weight_scale_inv'] = torch.cat(scale_invs, dim=0)