From 13406a2832dd4ea76c84a35260a01c3b85271528 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Sun, 30 Mar 2025 12:24:04 +0800 Subject: [PATCH 01/41] update lynxnet2 backbone --- configs/acoustic.yaml | 3 +- configs/templates/config_acoustic.yaml | 3 +- configs/templates/config_variance.yaml | 38 +++++----- configs/variance.yaml | 16 ++-- modules/backbones/__init__.py | 4 +- modules/backbones/lynxnet.py | 20 +---- modules/backbones/lynxnet2.py | 101 +++++++++++++++++++++++++ modules/backbones/wavenet.py | 8 +- modules/commons/common_layers.py | 16 ++++ 9 files changed, 151 insertions(+), 58 deletions(-) create mode 100644 modules/backbones/lynxnet2.py diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 9f27733f7..994e49923 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -70,13 +70,12 @@ sampling_steps: 20 diff_accelerator: ddim diff_speedup: 10 hidden_size: 256 -backbone_type: 'lynxnet' +backbone_type: 'lynxnet2' backbone_args: num_channels: 1024 num_layers: 6 kernel_size: 31 dropout_rate: 0.0 - strong_cond: true main_loss_type: l2 main_loss_log_norm: false schedule_type: 'linear' diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 59778df99..8d99b01e5 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -76,13 +76,12 @@ T_start: 0.4 T_start_infer: 0.4 K_step: 300 K_step_infer: 300 -backbone_type: 'lynxnet' +backbone_type: 'lynxnet2' backbone_args: num_channels: 1024 num_layers: 6 kernel_size: 31 dropout_rate: 0.0 - strong_cond: true #backbone_type: 'wavenet' #backbone_args: # num_channels: 512 diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 7d5b211aa..908deba4a 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -94,31 +94,29 @@ pitch_prediction_args: pitd_clip_min: -12.0 pitd_clip_max: 12.0 repeat_bins: 64 - backbone_type: 'wavenet' - backbone_args: - num_layers: 20 - num_channels: 256 - dilation_cycle_length: 5 -# backbone_type: 'lynxnet' +# backbone_type: 'wavenet' # backbone_args: -# num_layers: 6 -# num_channels: 512 -# dropout_rate: 0.0 -# strong_cond: true +# num_layers: 20 +# num_channels: 256 +# dilation_cycle_length: 5 + backbone_type: 'lynxnet2' + backbone_args: + num_layers: 6 + num_channels: 512 + dropout_rate: 0.0 variances_prediction_args: total_repeat_bins: 48 - backbone_type: 'wavenet' - backbone_args: - num_layers: 10 - num_channels: 192 - dilation_cycle_length: 4 -# backbone_type: 'lynxnet' +# backbone_type: 'wavenet' # backbone_args: -# num_layers: 6 -# num_channels: 384 -# dropout_rate: 0.0 -# strong_cond: true +# num_layers: 10 +# num_channels: 192 +# dilation_cycle_length: 4 + backbone_type: 'lynxnet2' + backbone_args: + num_layers: 6 + num_channels: 384 + dropout_rate: 0.0 lambda_dur_loss: 1.0 lambda_pitch_loss: 1.0 diff --git a/configs/variance.yaml b/configs/variance.yaml index 61c508a1b..c99b13381 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -65,11 +65,11 @@ pitch_prediction_args: pitd_clip_min: -12.0 pitd_clip_max: 12.0 repeat_bins: 64 - backbone_type: 'wavenet' + backbone_type: 'lynxnet2' backbone_args: - num_layers: 20 - num_channels: 256 - dilation_cycle_length: 5 + num_layers: 6 + num_channels: 512 + dropout_rate: 0.0 energy_db_min: -96.0 energy_db_max: -12.0 @@ -88,11 +88,11 @@ tension_smooth_width: 0.12 variances_prediction_args: total_repeat_bins: 48 - backbone_type: 'wavenet' + backbone_type: 'lynxnet2' backbone_args: - num_layers: 10 - num_channels: 192 - dilation_cycle_length: 4 + num_layers: 6 + num_channels: 384 + dropout_rate: 0.0 lambda_dur_loss: 1.0 lambda_pitch_loss: 1.0 diff --git a/modules/backbones/__init__.py b/modules/backbones/__init__.py index 8fce796ab..ebd903456 100644 --- a/modules/backbones/__init__.py +++ b/modules/backbones/__init__.py @@ -1,11 +1,13 @@ import torch.nn from modules.backbones.wavenet import WaveNet from modules.backbones.lynxnet import LYNXNet +from modules.backbones.lynxnet2 import LYNXNet2 from utils import filter_kwargs BACKBONES = { 'wavenet': WaveNet, - 'lynxnet': LYNXNet + 'lynxnet': LYNXNet, + 'lynxnet2': LYNXNet2, } diff --git a/modules/backbones/lynxnet.py b/modules/backbones/lynxnet.py index 5dbd1d0a1..88b2be348 100644 --- a/modules/backbones/lynxnet.py +++ b/modules/backbones/lynxnet.py @@ -6,26 +6,10 @@ import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU +from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Conv1d, Transpose from utils.hparams import hparams -class Conv1d(torch.nn.Conv1d): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - nn.init.kaiming_normal_(self.weight) - - -class Transpose(nn.Module): - def __init__(self, dims): - super().__init__() - assert len(dims) == 2, 'dims must be a tuple of two dimensions' - self.dims = dims - - def forward(self, x): - return x.transpose(*self.dims) - - class LYNXConvModule(nn.Module): @staticmethod def calc_same_padding(kernel_size): @@ -150,7 +134,7 @@ def forward(self, spec, diffusion_step, cond): # post-norm x = self.norm(x.transpose(1, 2)).transpose(1, 2) - # MLP and GLU + # output_projection x = self.output_projection(x) # [B, 128, T] if self.n_feats == 1: diff --git a/modules/backbones/lynxnet2.py b/modules/backbones/lynxnet2.py new file mode 100644 index 000000000..5a10a856b --- /dev/null +++ b/modules/backbones/lynxnet2.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Conv1d, Transpose +from utils.hparams import hparams + + +class LYNXNet2Block(nn.Module): + def __init__(self, dim, expansion_factor, kernel_size=31, dropout=0.): + super().__init__() + inner_dim = int(dim * expansion_factor) + if float(dropout) > 0.: + _dropout = nn.Dropout(dropout) + else: + _dropout = nn.Identity() + self.net = nn.Sequential( + nn.LayerNorm(dim), + Transpose((1, 2)), + nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim), + Transpose((1, 2)), + nn.Linear(dim, inner_dim * 2), + SwiGLU(), + nn.Linear(inner_dim, inner_dim * 2), + SwiGLU(), + nn.Linear(inner_dim, dim), + _dropout + ) + + def forward(self, x): + return x + self.net(x) + + +class LYNXNet2(nn.Module): + def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=1, kernel_size=31, + dropout=0.0): + """ + LYNXNet2(Linear Gated Depthwise Separable Convolution Network Version 2) + """ + super().__init__() + self.in_dims = in_dims + self.n_feats = n_feats + self.input_projection = nn.Linear(in_dims * n_feats, num_channels) + self.conditioner_projection = nn.Linear(hparams['hidden_size'], num_channels) + self.diffusion_embedding = nn.Sequential( + SinusoidalPosEmb(num_channels), + nn.Linear(num_channels, num_channels * 4), + nn.GELU(), + nn.Linear(num_channels * 4, num_channels), + ) + self.residual_layers = nn.ModuleList( + [ + LYNXNet2Block( + dim=num_channels, + expansion_factor=expansion_factor, + kernel_size=kernel_size, + dropout=dropout + ) + for i in range(num_layers) + ] + ) + self.norm = nn.LayerNorm(num_channels) + self.output_projection = nn.Linear(num_channels, in_dims * n_feats) + nn.init.kaiming_normal_(self.input_projection.weight) + nn.init.kaiming_normal_(self.conditioner_projection.weight) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, spec, diffusion_step, cond): + """ + :param spec: [B, F, M, T] + :param diffusion_step: [B, 1] + :param cond: [B, H, T] + :return: + """ + + if self.n_feats == 1: + x = spec[:, 0] # [B, M, T] + else: + x = spec.flatten(start_dim=1, end_dim=2) # [B, F x M, T] + + x = self.input_projection(x.transpose(1, 2)) # [B, T, F x M] + x = x + self.conditioner_projection(cond.transpose(1, 2)) + x = x + self.diffusion_embedding(diffusion_step).unsqueeze(1) + + for layer in self.residual_layers: + x = layer(x) + + # post-norm + x = self.norm(x) + + # output projection + x = self.output_projection(x).transpose(1, 2) # [B, 128, T] + + if self.n_feats == 1: + x = x[:, None, :, :] + else: + # This is the temporary solution since PyTorch 1.13 + # does not support exporting aten::unflatten to ONNX + # x = x.unflatten(dim=1, sizes=(self.n_feats, self.in_dims)) + x = x.reshape(-1, self.n_feats, self.in_dims, x.shape[2]) + return x diff --git a/modules/backbones/wavenet.py b/modules/backbones/wavenet.py index 08e57eff4..58724e5aa 100644 --- a/modules/backbones/wavenet.py +++ b/modules/backbones/wavenet.py @@ -5,16 +5,10 @@ import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb +from modules.commons.common_layers import SinusoidalPosEmb, Conv1d from utils.hparams import hparams -class Conv1d(torch.nn.Conv1d): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - nn.init.kaiming_normal_(self.weight) - - class ResidualBlock(nn.Module): def __init__(self, encoder_hidden, residual_channels, dilation): super().__init__() diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index bf4a2822c..2938d7bb5 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -117,6 +117,22 @@ def forward(self, x): return out * F.silu(gate) +class Conv1d(torch.nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + nn.init.kaiming_normal_(self.weight) + + +class Transpose(nn.Module): + def __init__(self, dims): + super().__init__() + assert len(dims) == 2, 'dims must be a tuple of two dimensions' + self.dims = dims + + def forward(self, x): + return x.transpose(*self.dims) + + class TransformerFFNLayer(nn.Module): def __init__(self, hidden_size, filter_size, kernel_size=1, dropout=0., act='gelu'): super().__init__() From 4a4ee3defb4cbbbf985c1064d1c33cb4661f2625 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Thu, 3 Apr 2025 23:59:04 +0800 Subject: [PATCH 02/41] support muon optimizer --- basics/base_task.py | 2 +- configs/acoustic.yaml | 9 +- configs/templates/config_acoustic.yaml | 10 +- configs/templates/config_variance.yaml | 14 ++- configs/variance.yaml | 13 ++- modules/fastspeech/tts_modules.py | 39 +++++--- modules/fastspeech/variance_encoder.py | 3 +- modules/optimizer/chained_optimizer.py | 122 +++++++++++++++++++++++ modules/optimizer/muon.py | 129 +++++++++++++++++++++++++ utils/__init__.py | 3 +- 10 files changed, 316 insertions(+), 28 deletions(-) create mode 100644 modules/optimizer/chained_optimizer.py create mode 100644 modules/optimizer/muon.py diff --git a/basics/base_task.py b/basics/base_task.py index 065f8273a..656893d96 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -307,7 +307,7 @@ def build_optimizer(self, model): optimizer = build_object_from_class_name( optimizer_args['optimizer_cls'], torch.optim.Optimizer, - model.parameters(), + model if optimizer_args['optimizer_cls'] == 'modules.optimizer.muon.Muon_AdamW' else model.parameters(), **optimizer_args ) return optimizer diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 5c94c23de..cdc7c754c 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -104,10 +104,15 @@ lambda_aux_mel_loss: 0.2 # train and eval num_sanity_val_steps: 1 optimizer_args: + optimizer_cls: modules.optimizer.muon.Muon_AdamW lr: 0.0006 + muon_args: + weight_decay: 0.1 + adamw_args: + weight_decay: 0.0 lr_scheduler_args: - step_size: 10000 - gamma: 0.75 + step_size: 5000 + gamma: 0.8 max_batch_frames: 50000 max_batch_size: 64 dataset_size_key: 'lengths' diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 8d99b01e5..65e276dc7 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -101,11 +101,15 @@ shallow_diffusion_args: lambda_aux_mel_loss: 0.2 optimizer_args: + optimizer_cls: modules.optimizer.muon.Muon_AdamW lr: 0.0006 + muon_args: + weight_decay: 0.1 + adamw_args: + weight_decay: 0.0 lr_scheduler_args: - scheduler_cls: torch.optim.lr_scheduler.StepLR - step_size: 10000 - gamma: 0.75 + step_size: 5000 + gamma: 0.8 max_batch_frames: 50000 max_batch_size: 64 max_updates: 160000 diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 908deba4a..7022f2000 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -67,8 +67,8 @@ enc_ffn_kernel_size: 3 use_rope: true hidden_size: 256 dur_prediction_args: - arch: fs2 - hidden_size: 512 + arch: resnet + hidden_size: 256 dropout: 0.1 num_layers: 5 kernel_size: 3 @@ -123,11 +123,15 @@ lambda_pitch_loss: 1.0 lambda_var_loss: 1.0 optimizer_args: + optimizer_cls: modules.optimizer.muon.Muon_AdamW lr: 0.0006 + muon_args: + weight_decay: 0.1 + adamw_args: + weight_decay: 0.0 lr_scheduler_args: - scheduler_cls: torch.optim.lr_scheduler.StepLR - step_size: 10000 - gamma: 0.75 + step_size: 5000 + gamma: 0.8 max_batch_frames: 80000 max_batch_size: 48 max_updates: 160000 diff --git a/configs/variance.yaml b/configs/variance.yaml index c99b13381..5c0048411 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -40,8 +40,8 @@ rel_pos: true hidden_size: 256 dur_prediction_args: - arch: fs2 - hidden_size: 512 + arch: resnet + hidden_size: 256 dropout: 0.1 num_layers: 5 kernel_size: 3 @@ -114,10 +114,15 @@ diff_speedup: 10 # train and eval num_sanity_val_steps: 1 optimizer_args: + optimizer_cls: modules.optimizer.muon.Muon_AdamW lr: 0.0006 + muon_args: + weight_decay: 0.1 + adamw_args: + weight_decay: 0.0 lr_scheduler_args: - step_size: 10000 - gamma: 0.75 + step_size: 5000 + gamma: 0.8 max_batch_frames: 80000 max_batch_size: 48 dataset_size_key: 'lengths' diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index 391de11ab..16b358e3c 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -62,7 +62,7 @@ class DurationPredictor(torch.nn.Module): """ def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3, - dropout_rate=0.1, offset=1.0, dur_loss_type='mse'): + dropout_rate=0.1, offset=1.0, dur_loss_type='mse', arch='resnet'): """Initialize duration predictor module. Args: in_dims (int): Input dimension. @@ -76,16 +76,29 @@ def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3, self.offset = offset self.conv = torch.nn.ModuleList() self.kernel_size = kernel_size + self.use_resnet = (arch == 'resnet') for idx in range(n_layers): in_chans = in_dims if idx == 0 else n_chans - self.conv.append(torch.nn.Sequential( - torch.nn.Identity(), # this is a placeholder for ConstantPad1d which is now merged into Conv1d - torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=kernel_size // 2), - torch.nn.ReLU(), - LayerNorm(n_chans, dim=1), - torch.nn.Dropout(dropout_rate) - )) - + if self.use_resnet: + self.conv.append(nn.Sequential( + LayerNorm(in_chans, dim=1), + nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=kernel_size // 2), + nn.ReLU(), + nn.Conv1d(n_chans, n_chans, 1), + nn.Dropout(dropout_rate) + )) + else: + self.conv.append(nn.Sequential( + nn.Identity(), # this is a placeholder for ConstantPad1d which is now merged into Conv1d + nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=kernel_size // 2), + nn.ReLU(), + LayerNorm(n_chans, dim=1), + nn.Dropout(dropout_rate) + )) + if self.use_resnet and in_dims != n_chans: + self.res_conv = nn.Conv1d(in_dims, n_chans, 1) + else: + self.res_conv = None self.loss_type = dur_loss_type if self.loss_type in ['mse', 'huber']: self.out_dims = 1 @@ -121,8 +134,12 @@ def forward(self, xs, x_masks=None, infer=True): xs = xs.transpose(1, -1) # (B, idim, Tmax) masks = 1 - x_masks.float() masks_ = masks[:, None, :] - for f in self.conv: - xs = f(xs) # (B, C, Tmax) + for idx, f in enumerate(self.conv): + if self.use_resnet: + residual = self.res_conv(xs) if idx == 0 and self.res_conv is not None else xs + xs = residual + f(xs) + else: + xs = f(xs) if x_masks is not None: xs = xs * masks_ xs = self.linear(xs.transpose(1, -1)) # [B, T, C] diff --git a/modules/fastspeech/variance_encoder.py b/modules/fastspeech/variance_encoder.py index deab9ee84..557ee6ea1 100644 --- a/modules/fastspeech/variance_encoder.py +++ b/modules/fastspeech/variance_encoder.py @@ -46,7 +46,8 @@ def __init__(self, vocab_size): dropout_rate=dur_hparams['dropout'], kernel_size=dur_hparams['kernel_size'], offset=dur_hparams['log_offset'], - dur_loss_type=dur_hparams['loss_type'] + dur_loss_type=dur_hparams['loss_type'], + arch=dur_hparams['arch'] ) def forward( diff --git a/modules/optimizer/chained_optimizer.py b/modules/optimizer/chained_optimizer.py new file mode 100644 index 000000000..b123f58e8 --- /dev/null +++ b/modules/optimizer/chained_optimizer.py @@ -0,0 +1,122 @@ +from torch import Tensor +from torch.optim import Optimizer +from torch.optim.optimizer import ParamsT +from dataclasses import dataclass +from typing import Any, Dict, List, Type, Callable, Optional, Iterable + + +@dataclass +class OptimizerSpec: + """Spec for creating an optimizer that is part of a `ChainedOptimizer`.""" + + class_type: Type[Optimizer] + init_args: Dict[str, Any] + param_filter: Optional[Callable[[Tensor], bool]] + + +class ChainedOptimizer(Optimizer): + """ + A wrapper around multiple optimizers that allows for chaining them together. + The optimizers are applied in the order they are passed in the constructor. + Each optimizer is responsible for updating a subset of the parameters, which + is determined by the `param_filter` function. If no optimizer is found for a + parameter group, an exception is raised. + """ + + def __init__( + self, + params: ParamsT, + optimizer_specs: List[OptimizerSpec], + lr: float, + weight_decay: float = 0.0, + optimizer_selection_callback: Optional[Callable[[Tensor, int], None]] = None, + **common_kwargs, + ): + self.optimizer_specs = optimizer_specs + self.optimizer_selection_callback = optimizer_selection_callback + self.optimizers: List[Optimizer] = [] + defaults = dict(lr=lr, weight_decay=weight_decay) + super().__init__(params, defaults) + + # Split the params for each optimzier + params_for_optimizers = [[] for _ in optimizer_specs] + for param_group in self.param_groups: + params = param_group["params"] + indices = param_group["optimizer_and_param_group_indices"] = set() + for param in params: + assert isinstance(param, Tensor), f"Expected a Tensor, got {type(param)}" + for index, spec in enumerate(optimizer_specs): + if spec.param_filter is None or spec.param_filter(param): + if self.optimizer_selection_callback is not None: + self.optimizer_selection_callback(param, index) + params_for_optimizers[index].append(param) + indices.add((index, 0)) + break + + # Initialize the optimizers + for spec, selected_params in zip(optimizer_specs, params_for_optimizers): + optimizer_args = { + 'lr': lr, + 'weight_decay': weight_decay, + } + optimizer_args.update(common_kwargs) + optimizer_args.update(spec.init_args) + optimizer = spec.class_type(selected_params, **optimizer_args) + self.optimizers.append(optimizer) + + def state_dict(self) -> Dict[str, Any]: + return { + "optimizers": [opt.state_dict() for opt in self.optimizers], + **super().state_dict(), + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + optimizers = state_dict.pop("optimizers") + super().load_state_dict(state_dict) + for i in range(len(self.optimizers)): + self.optimizers[i].load_state_dict(optimizers[i]) + + def zero_grad(self, set_to_none: bool = True) -> None: + for opt in self.optimizers: + opt.zero_grad(set_to_none=set_to_none) + + def _copy_lr_to_optimizers(self) -> None: + for param_group in self.param_groups: + indices = param_group["optimizer_and_param_group_indices"] + for optimizer_idx, param_group_idx in indices: + self.optimizers[optimizer_idx].param_groups[param_group_idx]["lr"] = param_group["lr"] + + def step(self, closure=None) -> None: + self._copy_lr_to_optimizers() + for opt in self.optimizers: + opt.step(closure) + + def add_param_group(self, param_group: Dict[str, Any]) -> None: + super().add_param_group(param_group) + + # If optimizer has not been initialized, skip adding the param groups + if not self.optimizers: + return + + # Split the params for each optimzier + params_for_optimizers = [[] for _ in self.optimizer_specs] + params = param_group["params"] + indices = param_group["optimizer_and_param_group_indices"] = set() + for param in params: + assert isinstance(param, Tensor), f"Expected a Tensor, got {type(param)}" + found_optimizer = False + for index, spec in enumerate(self.optimizer_specs): + if spec.param_filter is None or spec.param_filter(param): + if self.optimizer_selection_callback is not None: + self.optimizer_selection_callback(param, index) + params_for_optimizers[index].append(param) + indices.add((index, len(self.optimizers[index].param_groups))) + found_optimizer = True + break + if not found_optimizer: + raise ValueError("No valid optimizer found for the given parameter group") + + # Add the selected param group to the optimizers + for optimizer, selected_params in zip(self.optimizers, params_for_optimizers): + if selected_params: + optimizer.add_param_group({"params": selected_params}) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py new file mode 100644 index 000000000..9e59c9a5d --- /dev/null +++ b/modules/optimizer/muon.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Module, Parameter, Embedding +from typing import List +from .chained_optimizer import ChainedOptimizer, OptimizerSpec + + +def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.float() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) + + # Perform the NS iterations + for _ in range(steps): + A = X @ X.mT + B = torch.baddbmm(A, A, A, beta=b, alpha=c) + X = torch.baddbmm(X, B, X, beta=a, alpha=1) + + if G.size(-2) > G.size(-1): + X = X.mT + return X.to(G) + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. + + Arguments: + lr: The learning rate used by the internal SGD. + momentum: The momentum used by the internal SGD. + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iteration steps to use. + """ + + def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + for group in self.param_groups: + shape_groups = {} + for p in filter(lambda p: p.grad is not None, group["params"]): + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf: Tensor = state["momentum_buffer"] + key = (p.shape, p.device, p.dtype) + if key not in shape_groups: + shape_groups[key] = {"params": [], "grads": [], "buffers": []} + shape_groups[key]["params"].append(p) + shape_groups[key]["grads"].append(g) + shape_groups[key]["buffers"].append(buf) + for key in shape_groups: + group_data = shape_groups[key] + g = torch.stack(group_data["grads"]) + buf = torch.stack(group_data["buffers"]) + buf.lerp_(g, 1 - group["momentum"]) + g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf + if g.ndim >= 4: # for the case of conv filters + g = g.view(g.size(0), g.size(1), -1) + g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + for i, p in enumerate(group_data["params"]): + if group["weight_decay"] > 0: + p.data.mul_(1 - group["lr"] * group["weight_decay"]) + p.data.add_(g[i].view_as(p), alpha=-group["lr"] * max(g[i].size()) ** 0.5) + self.state[p]["momentum_buffer"] = buf[i].clone() + + +def get_params_for_muon(model) -> List[Parameter]: + """ + Filter parameters of a module into two groups: those that can be optimized by Muon, + and those that should be optimized by a standard optimizer. + Args: + module: The module to filter parameters for. + Returns: + A list of parameters that should be optimized with muon. + """ + muon_params = [] + for module in model.modules(): + for param in module.parameters(recurse=False): + if not param.requires_grad: + continue + if not isinstance(module, nn.Embedding) and param.ndim >= 2: + muon_params.append(param) + return muon_params + + +class Muon_AdamW(ChainedOptimizer): + def __init__(self, model, lr=0.0005, weight_decay=0.0, muon_args={}, adamw_args={}, verbose=False): + muon_params_id_set = set(id(p) for p in get_params_for_muon(model)) + spec_muon = OptimizerSpec(Muon, muon_args, lambda param: id(param) in muon_params_id_set) + spec_adamw = OptimizerSpec(torch.optim.AdamW, adamw_args, None) + specs = [spec_muon, spec_adamw] + callback = None + if verbose: + callback = lambda p, spec_idx: print( + f"Adding param {p.shape} to optimizer{spec_idx} {str(specs[spec_idx].class_type)}" + ) + super().__init__(model.parameters(), specs, lr=lr, weight_decay=weight_decay, optimizer_selection_callback=callback) \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py index abb5df151..1f4c17c04 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -315,8 +315,9 @@ def helper(params): def simulate_lr_scheduler(optimizer_args, scheduler_args, step_count, num_param_groups=1): + optimizer_cls = optimizer_args['optimizer_cls'] optimizer = build_object_from_class_name( - optimizer_args['optimizer_cls'], + 'torch.optim.AdamW' if optimizer_cls == 'modules.optimizer.muon.Muon_AdamW' else optimizer_cls, torch.optim.Optimizer, [{'params': torch.nn.Parameter(), 'initial_lr': optimizer_args['lr']} for _ in range(num_param_groups)], **optimizer_args From f9fda2781414fe14a293fc116bd74126c1832c25 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Fri, 4 Apr 2025 00:06:59 +0800 Subject: [PATCH 03/41] optimize --- configs/acoustic.yaml | 6 +++--- configs/templates/config_acoustic.yaml | 6 +++--- configs/templates/config_variance.yaml | 4 ++-- configs/variance.yaml | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index cdc7c754c..73648dd1d 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -119,10 +119,10 @@ dataset_size_key: 'lengths' val_with_vocoder: true val_check_interval: 2000 num_valid_plots: 10 -max_updates: 160000 +max_updates: 100000 num_ckpt_keep: 5 -permanent_ckpt_start: 80000 -permanent_ckpt_interval: 20000 +permanent_ckpt_start: 60000 +permanent_ckpt_interval: 10000 finetune_enabled: false finetune_ckpt_path: null diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 65e276dc7..2ea62b33a 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -112,13 +112,13 @@ lr_scheduler_args: gamma: 0.8 max_batch_frames: 50000 max_batch_size: 64 -max_updates: 160000 +max_updates: 100000 num_valid_plots: 10 val_with_vocoder: true val_check_interval: 2000 num_ckpt_keep: 5 -permanent_ckpt_start: 120000 -permanent_ckpt_interval: 20000 +permanent_ckpt_start: 60000 +permanent_ckpt_interval: 10000 pl_trainer_devices: 'auto' pl_trainer_precision: '16-mixed' diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 7022f2000..1fa3be3a1 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -134,12 +134,12 @@ lr_scheduler_args: gamma: 0.8 max_batch_frames: 80000 max_batch_size: 48 -max_updates: 160000 +max_updates: 100000 num_valid_plots: 10 val_check_interval: 2000 num_ckpt_keep: 5 -permanent_ckpt_start: 80000 +permanent_ckpt_start: 60000 permanent_ckpt_interval: 10000 pl_trainer_devices: 'auto' pl_trainer_precision: '16-mixed' diff --git a/configs/variance.yaml b/configs/variance.yaml index 5c0048411..6ed3d8b94 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -128,9 +128,9 @@ max_batch_size: 48 dataset_size_key: 'lengths' val_check_interval: 2000 num_valid_plots: 10 -max_updates: 160000 +max_updates: 100000 num_ckpt_keep: 5 -permanent_ckpt_start: 80000 +permanent_ckpt_start: 60000 permanent_ckpt_interval: 10000 finetune_enabled: false From eb3b606de6d0b26299dfd525d170b57c0f906f32 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Fri, 4 Apr 2025 20:14:13 +0800 Subject: [PATCH 04/41] stabilize fp16 training --- modules/commons/common_layers.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index 2938d7bb5..7a72b1555 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -114,7 +114,15 @@ def forward(self, x): # out, gate = x.chunk(2, dim=self.dim) # Using torch.split instead of chunk for ONNX export compatibility. out, gate = torch.split(x, x.size(self.dim) // 2, dim=self.dim) - return out * F.silu(gate) + gate = F.silu(gate) + if x.dtype == torch.float16: + out_min, out_max = torch.aminmax(out.detach()) + gate_min, gate_max = torch.aminmax(gate.detach()) + max_abs_out = torch.max(-out_min, out_max).float() + max_abs_gate = torch.max(-gate_min, gate_max).float() + if max_abs_out * max_abs_gate > 65504: + return (out.float() * gate.float()).clamp(-65504, 65504).half() + return out * gate class Conv1d(torch.nn.Conv1d): From 300676a8c13036ee31a53ff11d5d2cdf4106c9d8 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Sat, 5 Apr 2025 02:13:07 +0800 Subject: [PATCH 05/41] stabilize fp16 training --- modules/commons/common_layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index 7a72b1555..77381e2de 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -120,8 +120,8 @@ def forward(self, x): gate_min, gate_max = torch.aminmax(gate.detach()) max_abs_out = torch.max(-out_min, out_max).float() max_abs_gate = torch.max(-gate_min, gate_max).float() - if max_abs_out * max_abs_gate > 65504: - return (out.float() * gate.float()).clamp(-65504, 65504).half() + if max_abs_out * max_abs_gate > 1000: + return (out.float() * gate.float()).clamp(-1000, 1000).half() return out * gate From 14c360938d815d0ad4f0b2f001d7329924aa13a0 Mon Sep 17 00:00:00 2001 From: Kakaru <97896816+KakaruHayate@users.noreply.github.com> Date: Tue, 22 Apr 2025 22:45:18 +0800 Subject: [PATCH 06/41] Fix some issue about Initialization (#250) issue 249(2/3/4) --- modules/backbones/lynxnet.py | 3 ++- modules/backbones/lynxnet2.py | 6 +++++- modules/backbones/wavenet.py | 3 ++- modules/commons/common_layers.py | 13 ++++++++++--- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/modules/backbones/lynxnet.py b/modules/backbones/lynxnet.py index 88b2be348..6d229aff8 100644 --- a/modules/backbones/lynxnet.py +++ b/modules/backbones/lynxnet.py @@ -6,7 +6,8 @@ import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Conv1d, Transpose +from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Transpose +from modules.commons.common_layers import KaimingNormalConv1d as Conv1d from utils.hparams import hparams diff --git a/modules/backbones/lynxnet2.py b/modules/backbones/lynxnet2.py index 5a10a856b..9ce522199 100644 --- a/modules/backbones/lynxnet2.py +++ b/modules/backbones/lynxnet2.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Conv1d, Transpose +from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Transpose from utils.hparams import hparams @@ -42,6 +42,8 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio self.n_feats = n_feats self.input_projection = nn.Linear(in_dims * n_feats, num_channels) self.conditioner_projection = nn.Linear(hparams['hidden_size'], num_channels) + # It may need to be modified at some point to be compatible with the condition cache + # self.conditioner_projection = nn.Conv1d(hparams['hidden_size'], num_channels, 1) self.diffusion_embedding = nn.Sequential( SinusoidalPosEmb(num_channels), nn.Linear(num_channels, num_channels * 4), @@ -80,6 +82,8 @@ def forward(self, spec, diffusion_step, cond): x = self.input_projection(x.transpose(1, 2)) # [B, T, F x M] x = x + self.conditioner_projection(cond.transpose(1, 2)) + # It may need to be modified at some point to be compatible with the condition cache + # x = x + self.conditioner_projection(cond.transpose(1, 2)) x = x + self.diffusion_embedding(diffusion_step).unsqueeze(1) for layer in self.residual_layers: diff --git a/modules/backbones/wavenet.py b/modules/backbones/wavenet.py index 58724e5aa..2cbff961d 100644 --- a/modules/backbones/wavenet.py +++ b/modules/backbones/wavenet.py @@ -5,7 +5,8 @@ import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb, Conv1d +from modules.commons.common_layers import SinusoidalPosEmb +from modules.commons.common_layers import KaimingNormalConv1d as Conv1d from utils.hparams import hparams diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index 77381e2de..e24d4488e 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -121,11 +121,11 @@ def forward(self, x): max_abs_out = torch.max(-out_min, out_max).float() max_abs_gate = torch.max(-gate_min, gate_max).float() if max_abs_out * max_abs_gate > 1000: - return (out.float() * gate.float()).clamp(-1000, 1000).half() + return (out.float() * gate.float()).clamp(-1000, 1000).half() return out * gate -class Conv1d(torch.nn.Conv1d): +class KaimingNormalConv1d(torch.nn.Conv1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) nn.init.kaiming_normal_(self.weight) @@ -190,10 +190,17 @@ def __init__(self, embed_dim, num_heads, dropout=0.1, bias=False, rotary_embed=N # Dropout layer self.dropout = nn.Dropout(dropout) - + # Rotary Embeddings self.rotary_embed = rotary_embed + # Initialization parameters + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.xavier_uniform_(self.out_proj.weight) + if bias: + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + def forward(self, x, key_padding_mask=None): # x: (B, L, C) # key_padding_mask: (B, L) From 954e41c8909bbada5538de05fada3fc56141721e Mon Sep 17 00:00:00 2001 From: yxlllc Date: Mon, 19 May 2025 23:24:51 +0800 Subject: [PATCH 07/41] variance scaling --- configs/acoustic.yaml | 1 + configs/templates/config_acoustic.yaml | 1 + configs/templates/config_variance.yaml | 3 ++- configs/variance.yaml | 3 ++- modules/fastspeech/acoustic_encoder.py | 30 ++++++++++++++++++++++++-- modules/fastspeech/variance_encoder.py | 18 +++++++++++----- 6 files changed, 47 insertions(+), 9 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 73648dd1d..555e7310f 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -64,6 +64,7 @@ timesteps: 1000 max_beta: 0.02 enc_ffn_kernel_size: 3 use_rope: true +use_variance_scaling: true rel_pos: true sampling_algorithm: euler sampling_steps: 20 diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 2ea62b33a..9745a3fbe 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -71,6 +71,7 @@ augmentation_args: diffusion_type: reflow enc_ffn_kernel_size: 3 use_rope: true +use_variance_scaling: true use_shallow_diffusion: true T_start: 0.4 T_start_infer: 0.4 diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 1fa3be3a1..e29bf9f7c 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -65,6 +65,7 @@ tension_logit_max: 10.0 enc_ffn_kernel_size: 3 use_rope: true +use_variance_scaling: true hidden_size: 256 dur_prediction_args: arch: resnet @@ -78,7 +79,7 @@ dur_prediction_args: lambda_wdur_loss: 1.0 lambda_sdur_loss: 3.0 -use_melody_encoder: false +use_melody_encoder: true melody_encoder_args: hidden_size: 128 enc_layers: 4 diff --git a/configs/variance.yaml b/configs/variance.yaml index 6ed3d8b94..fbec36634 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -36,6 +36,7 @@ predict_tension: false enc_ffn_kernel_size: 3 use_rope: true +use_variance_scaling: true rel_pos: true hidden_size: 256 @@ -51,7 +52,7 @@ dur_prediction_args: lambda_wdur_loss: 1.0 lambda_sdur_loss: 3.0 -use_melody_encoder: false +use_melody_encoder: true melody_encoder_args: hidden_size: 128 enc_layers: 4 diff --git a/modules/fastspeech/acoustic_encoder.py b/modules/fastspeech/acoustic_encoder.py index b6f986bb0..90da9cf80 100644 --- a/modules/fastspeech/acoustic_encoder.py +++ b/modules/fastspeech/acoustic_encoder.py @@ -49,6 +49,26 @@ def __init__(self, vocab_size): for v_name in self.variance_embed_list }) + self.use_variance_scaling = hparams.get('use_variance_scaling', False) + if self.use_variance_scaling: + self.variance_scaling_factor = { + 'energy': 1. / 96, + 'breathiness': 1. / 96, + 'voicing': 1. / 96, + 'tension': 0.1, + 'key_shift': 1. / 12, + 'speed': 1. + } + else: + self.variance_scaling_factor = { + 'energy': 1., + 'breathiness': 1., + 'voicing': 1., + 'tension': 1., + 'key_shift': 1., + 'speed': 1. + } + self.use_key_shift_embed = hparams.get('use_key_shift_embed', False) if self.use_key_shift_embed: self.key_shift_embed = Linear(1, hparams['hidden_size']) @@ -64,17 +84,20 @@ def __init__(self, vocab_size): def forward_variance_embedding(self, condition, key_shift=None, speed=None, **variances): if self.use_variance_embeds: variance_embeds = torch.stack([ - self.variance_embeds[v_name](variances[v_name][:, :, None]) + self.variance_embeds[v_name](variances[v_name][:, :, None]) + * self.variance_scaling_factor[v_name] for v_name in self.variance_embed_list ], dim=-1).sum(-1) condition += variance_embeds if self.use_key_shift_embed: key_shift_embed = self.key_shift_embed(key_shift[:, :, None]) + key_shift_embed *= self.variance_scaling_factor['key_shift'] condition += key_shift_embed if self.use_speed_embed: speed_embed = self.speed_embed(speed[:, :, None]) + speed_embed *= self.variance_scaling_factor['speed'] condition += speed_embed return condition @@ -87,7 +110,10 @@ def forward( ): txt_embed = self.txt_embed(txt_tokens) dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1]).float() - dur_embed = self.dur_embed(dur[:, :, None]) + if self.use_variance_scaling: + dur_embed = self.dur_embed(torch.log(1 + dur[:, :, None])) + else: + dur_embed = self.dur_embed(dur[:, :, None]) if self.use_lang_id: lang_embed = self.lang_embed(languages) extra_embed = dur_embed + lang_embed diff --git a/modules/fastspeech/variance_encoder.py b/modules/fastspeech/variance_encoder.py index 557ee6ea1..70edcebcb 100644 --- a/modules/fastspeech/variance_encoder.py +++ b/modules/fastspeech/variance_encoder.py @@ -17,7 +17,7 @@ def __init__(self, vocab_size): self.predict_dur = hparams['predict_dur'] self.linguistic_mode = 'word' if hparams['predict_dur'] else 'phoneme' self.use_lang_id = hparams['use_lang_id'] - + self.use_variance_scaling = hparams.get('use_variance_scaling', False) self.txt_embed = Embedding(vocab_size, hparams['hidden_size'], PAD_INDEX) if self.use_lang_id: self.lang_embed = Embedding(hparams['num_lang'] + 1, hparams['hidden_size'], padding_idx=0) @@ -80,9 +80,11 @@ def forward( word_dur = torch.gather(F.pad(word_dur, [1, 0], value=0), 1, ph2word) # [B, T_w] => [B, T_ph] word_dur_embed = self.word_dur_embed(word_dur.float()[:, :, None]) extra_embed = onset_embed + word_dur_embed + elif self.use_variance_scaling: + extra_embed = self.ph_dur_embed(torch.log(1 + ph_dur.float())[:, :, None]) else: - ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None]) - extra_embed = ph_dur_embed + extra_embed = self.ph_dur_embed(ph_dur.float()[:, :, None]) + if self.use_lang_id: lang_embed = self.lang_embed(languages) extra_embed += lang_embed @@ -109,6 +111,7 @@ def get_hparam(key): # MIDI inputs hidden_size = get_hparam('hidden_size') + self.use_variance_scaling = hparams.get('use_variance_scaling', False) self.note_midi_embed = Linear(1, hidden_size) self.note_dur_embed = Linear(1, hidden_size) @@ -136,8 +139,13 @@ def forward(self, note_midi, note_rest, note_dur, glide=None): :param glide: int64 [B, T_n] :return: [B, T_n, H] """ - midi_embed = self.note_midi_embed(note_midi[:, :, None]) * ~note_rest[:, :, None] - dur_embed = self.note_dur_embed(note_dur.float()[:, :, None]) + if self.use_variance_scaling: + midi_embed = self.note_midi_embed(note_midi[:, :, None] / 128) + dur_embed = self.note_dur_embed(torch.log(1 + note_dur.float())[:, :, None]) + else: + midi_embed = self.note_midi_embed(note_midi[:, :, None]) + dur_embed = self.note_dur_embed(note_dur.float()[:, :, None]) + midi_embed *= ~note_rest[:, :, None] ornament_embed = 0 if self.use_glide_embed: ornament_embed += self.note_glide_embed(glide) * self.glide_embed_scale From 2ea898f516ea3bbdc225197851a10207ddfe1740 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Tue, 20 May 2025 20:25:21 +0800 Subject: [PATCH 08/41] variance scaling for onnx --- deployment/modules/fastspeech2.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/deployment/modules/fastspeech2.py b/deployment/modules/fastspeech2.py index 20dfdb0d7..bb9d6a7bb 100644 --- a/deployment/modules/fastspeech2.py +++ b/deployment/modules/fastspeech2.py @@ -75,7 +75,10 @@ def forward( mel2ph = self.lr(durations) f0 = f0 * (mel2ph > 0) mel2ph = mel2ph[..., None].repeat((1, 1, hparams['hidden_size'])) - dur_embed = self.dur_embed(durations.float()[:, :, None]) + if self.use_variance_scaling: + dur_embed = self.dur_embed(torch.log(1 + durations.float())[:, :, None]) + else: + dur_embed = self.dur_embed(durations.float()[:, :, None]) if self.use_lang_id: lang_mask = torch.any( tokens[..., None] == self.cross_lingual_token_idx[None, None], @@ -99,7 +102,8 @@ def forward( if self.use_variance_embeds: variance_embeds = torch.stack([ - self.variance_embeds[v_name](variances[v_name][:, :, None]) + self.variance_embeds[v_name](variances[v_name][:, :, None]) + * self.variance_scaling_factor[v_name] for v_name in self.variance_embed_list ], dim=-1).sum(-1) condition += variance_embeds @@ -112,6 +116,7 @@ def forward( gender_mask = (gender < 0.).float() key_shift = gender * ((1. - gender_mask) * self.shift_max + gender_mask * abs(self.shift_min)) key_shift_embed = self.key_shift_embed(key_shift[:, :, None]) + key_shift_embed *= self.variance_scaling_factor['key_shift'] condition += key_shift_embed if hparams['use_speed_embed']: @@ -120,6 +125,7 @@ def forward( speed_embed = self.speed_embed(velocity[:, :, None]) else: speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None]) + speed_embed *= self.variance_scaling_factor['speed'] condition += speed_embed if hparams['use_spk_id']: @@ -162,7 +168,10 @@ def forward_encoder_word(self, tokens, word_div, word_dur, languages=None): def forward_encoder_phoneme(self, tokens, ph_dur, languages=None): txt_embed = self.txt_embed(tokens) - ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None]) + if self.use_variance_scaling: + ph_dur_embed = self.ph_dur_embed(torch.log(1 + ph_dur.float())[:, :, None]) + else: + ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None]) if self.use_lang_id: lang_mask = torch.any( tokens[..., None] == self.cross_lingual_token_idx[None, None], From b0ae9ca8ba08cb88719dce5fc4d9050c6b789900 Mon Sep 17 00:00:00 2001 From: Kakaru <97896816+KakaruHayate@users.noreply.github.com> Date: Tue, 20 May 2025 23:41:58 +0800 Subject: [PATCH 09/41] [DONE]Supplement the Variance Model Scaling / Retake Scaling / Conditioner cache on LYNXNet2 (#259) * Supplement the Variance Model Scaling / Retake Scaling / Conditioner cache on LYNXNet2 * Update toplevel.py * del use_retake_scaling --- configs/acoustic.yaml | 1 + configs/templates/config_acoustic.yaml | 1 + configs/templates/config_variance.yaml | 2 ++ configs/variance.yaml | 2 ++ deployment/modules/toplevel.py | 17 ++++++++--- modules/backbones/lynxnet2.py | 19 +++++++----- modules/toplevel.py | 40 +++++++++++++++++++++++--- 7 files changed, 67 insertions(+), 15 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 555e7310f..dec306f3a 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -77,6 +77,7 @@ backbone_args: num_layers: 6 kernel_size: 31 dropout_rate: 0.0 + use_conditioner_cache: true main_loss_type: l2 main_loss_log_norm: false schedule_type: 'linear' diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 9745a3fbe..76d148f1a 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -83,6 +83,7 @@ backbone_args: num_layers: 6 kernel_size: 31 dropout_rate: 0.0 + use_conditioner_cache: true #backbone_type: 'wavenet' #backbone_args: # num_channels: 512 diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index e29bf9f7c..eecb5cad7 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -105,6 +105,7 @@ pitch_prediction_args: num_layers: 6 num_channels: 512 dropout_rate: 0.0 + use_conditioner_cache: true variances_prediction_args: total_repeat_bins: 48 @@ -118,6 +119,7 @@ variances_prediction_args: num_layers: 6 num_channels: 384 dropout_rate: 0.0 + use_conditioner_cache: true lambda_dur_loss: 1.0 lambda_pitch_loss: 1.0 diff --git a/configs/variance.yaml b/configs/variance.yaml index fbec36634..2eb263d77 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -71,6 +71,7 @@ pitch_prediction_args: num_layers: 6 num_channels: 512 dropout_rate: 0.0 + use_conditioner_cache: true energy_db_min: -96.0 energy_db_max: -12.0 @@ -94,6 +95,7 @@ variances_prediction_args: num_layers: 6 num_channels: 384 dropout_rate: 0.0 + use_conditioner_cache: true lambda_dur_loss: 1.0 lambda_pitch_loss: 1.0 diff --git a/deployment/modules/toplevel.py b/deployment/modules/toplevel.py index 90ade235d..0d1752ef3 100644 --- a/deployment/modules/toplevel.py +++ b/deployment/modules/toplevel.py @@ -252,10 +252,16 @@ def forward_pitch_preprocess( base_pitch = self.smooth(frame_midi_pitch) if self.use_melody_encoder: delta_pitch = (pitch - base_pitch) * ~retake - pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None]) + if self.use_variance_scaling: + pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None] / 12) + else: + pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None]) else: base_pitch = base_pitch * retake + pitch * ~retake - pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) + if self.use_variance_scaling: + pitch_cond += self.base_pitch_embed(base_pitch[:, :, None] / 128) + else: + pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) if hparams['use_spk_id'] and spk_embed is not None: pitch_cond += spk_embed return pitch_cond, base_pitch @@ -275,13 +281,16 @@ def forward_variance_preprocess( variances: dict = None, retake=None, spk_embed=None ): condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) - variance_cond = condition + self.pitch_embed(pitch[:, :, None]) + if self.use_variance_scaling: + variance_cond = condition + self.pitch_embed(pitch[:, :, None] / 12) + else: + variance_cond = condition + self.pitch_embed(pitch[:, :, None]) non_retake_masks = [ v_retake.float() # [B, T, 1] for v_retake in (~retake).split(1, dim=2) ] variance_embeds = [ - self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks + self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks * self.variance_retake_scaling[v_name] for v_name, v_masks in zip(self.variance_prediction_list, non_retake_masks) ] variance_cond += torch.stack(variance_embeds, dim=-1).sum(-1) diff --git a/modules/backbones/lynxnet2.py b/modules/backbones/lynxnet2.py index 9ce522199..b819f5509 100644 --- a/modules/backbones/lynxnet2.py +++ b/modules/backbones/lynxnet2.py @@ -33,7 +33,7 @@ def forward(self, x): class LYNXNet2(nn.Module): def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=1, kernel_size=31, - dropout=0.0): + dropout=0.0, use_conditioner_cache=False): """ LYNXNet2(Linear Gated Depthwise Separable Convolution Network Version 2) """ @@ -41,9 +41,12 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio self.in_dims = in_dims self.n_feats = n_feats self.input_projection = nn.Linear(in_dims * n_feats, num_channels) - self.conditioner_projection = nn.Linear(hparams['hidden_size'], num_channels) - # It may need to be modified at some point to be compatible with the condition cache - # self.conditioner_projection = nn.Conv1d(hparams['hidden_size'], num_channels, 1) + self.use_conditioner_cache = use_conditioner_cache + if self.use_conditioner_cache: + # It may need to be modified at some point to be compatible with the condition cache + self.conditioner_projection = nn.Conv1d(hparams['hidden_size'], num_channels, 1) + else: + self.conditioner_projection = nn.Linear(hparams['hidden_size'], num_channels) self.diffusion_embedding = nn.Sequential( SinusoidalPosEmb(num_channels), nn.Linear(num_channels, num_channels * 4), @@ -81,9 +84,11 @@ def forward(self, spec, diffusion_step, cond): x = spec.flatten(start_dim=1, end_dim=2) # [B, F x M, T] x = self.input_projection(x.transpose(1, 2)) # [B, T, F x M] - x = x + self.conditioner_projection(cond.transpose(1, 2)) - # It may need to be modified at some point to be compatible with the condition cache - # x = x + self.conditioner_projection(cond.transpose(1, 2)) + if self.use_conditioner_cache: + # It may need to be modified at some point to be compatible with the condition cache + x = x + self.conditioner_projection(cond).transpose(1, 2) + else: + x = x + self.conditioner_projection(cond.transpose(1, 2)) x = x + self.diffusion_embedding(diffusion_step).unsqueeze(1) for layer in self.residual_layers: diff --git a/modules/toplevel.py b/modules/toplevel.py index aceff1f70..1ad5aa4b7 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -195,6 +195,28 @@ def __init__(self, vocab_size): else: raise NotImplementedError(self.diffusion_type) + self.use_variance_scaling = hparams.get('use_variance_scaling', False) + self.custom_variance_scaling_factor = { + 'energy': 1. / 96, + 'breathiness': 1. / 96, + 'voicing': 1. / 96, + 'tension': 0.1, + 'key_shift': 1. / 12, + 'speed': 1. + } + self.default_variance_scaling_factor = { + 'energy': 1., + 'breathiness': 1., + 'voicing': 1., + 'tension': 1., + 'key_shift': 1., + 'speed': 1. + } + if self.use_variance_scaling: + self.variance_retake_scaling = self.custom_variance_scaling_factor + else: + self.variance_retake_scaling = self.default_variance_scaling_factor + def forward( self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, mel2ph=None, note_midi=None, note_rest=None, note_dur=None, note_glide=None, mel2note=None, @@ -271,11 +293,17 @@ def forward( delta_pitch_in = torch.zeros_like(base_pitch) else: delta_pitch_in = (pitch - base_pitch) * ~pitch_retake - pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None]) + if self.use_variance_scaling: + pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None] / 12) + else: + pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None]) else: if not retake_unset: # retake base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake - pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) + if self.use_variance_scaling: + pitch_cond += self.base_pitch_embed(base_pitch[:, :, None] / 128) + else: + pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) if infer: pitch_pred_out = self.pitch_predictor(pitch_cond, infer=True) @@ -289,12 +317,16 @@ def forward( if pitch is None: pitch = base_pitch + pitch_pred_out - var_cond = condition + self.pitch_embed(pitch[:, :, None]) + if self.use_variance_scaling: + var_cond = condition + self.pitch_embed(pitch[:, :, None] / 12) + else: + var_cond = condition + self.pitch_embed(pitch[:, :, None]) variance_inputs = self.collect_variance_inputs(**kwargs) + if variance_retake is not None: variance_embeds = [ - self.variance_embeds[v_name](v_input[:, :, None]) * ~variance_retake[v_name][:, :, None] + self.variance_embeds[v_name](v_input[:, :, None]) * ~variance_retake[v_name][:, :, None] * self.variance_retake_scaling[v_name] for v_name, v_input in zip(self.variance_prediction_list, variance_inputs) ] var_cond += torch.stack(variance_embeds, dim=-1).sum(-1) From 3ae76d72fe2e3f5d22785d4bade5665037ee9ee6 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Wed, 21 May 2025 23:26:36 +0800 Subject: [PATCH 10/41] avoid precision conversions --- modules/commons/common_layers.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index e24d4488e..ef69f8487 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -120,8 +120,13 @@ def forward(self, x): gate_min, gate_max = torch.aminmax(gate.detach()) max_abs_out = torch.max(-out_min, out_max).float() max_abs_gate = torch.max(-gate_min, gate_max).float() - if max_abs_out * max_abs_gate > 1000: - return (out.float() * gate.float()).clamp(-1000, 1000).half() + max_abs_value = max_abs_out * max_abs_gate + if max_abs_value > 1000: + ratio = 1000 / max_abs_value + sqrt_ratio = torch.sqrt(ratio) + out = out * sqrt_ratio + gate = gate * sqrt_ratio + return (out * gate).clamp(-1000 * ratio, 1000 * ratio) / ratio return out * gate From d099e3c9522acefeeb30d4f696c1f4b7abb56225 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Thu, 22 May 2025 12:21:02 +0800 Subject: [PATCH 11/41] support bf16 calculation --- modules/optimizer/muon.py | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 9e59c9a5d..3ffe33a01 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -7,7 +7,25 @@ from .chained_optimizer import ChainedOptimizer, OptimizerSpec -def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: +def get_bf16_support_map(): + bf16_support_map = {} + + if not torch.cuda.is_available(): + return bf16_support_map + + device_count = torch.cuda.device_count() + if device_count == 0: + return bf16_support_map + + for i in range(device_count): + device = torch.device(f'cuda:{i}') + major, minor = torch.cuda.get_device_capability(device) + bf16_support_map[device] = (major >= 8) + + return bf16_support_map + + +def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor: """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose @@ -19,7 +37,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: """ assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng a, b, c = (3.4445, -4.7750, 2.0315) - X = G.float() + if use_bf16: + X = G.bfloat16() + else: + X = G.float() if G.size(-2) > G.size(-1): X = X.mT @@ -63,7 +84,8 @@ class Muon(torch.optim.Optimizer): def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5): defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) super().__init__(params, defaults) - + self.bf16_support_map = get_bf16_support_map() + @torch.no_grad() def step(self, closure=None): for group in self.param_groups: @@ -88,7 +110,8 @@ def step(self, closure=None): g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf if g.ndim >= 4: # for the case of conv filters g = g.view(g.size(0), g.size(1), -1) - g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + use_bf16 = self.bf16_support_map.get(g.device, False) + g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16) for i, p in enumerate(group_data["params"]): if group["weight_decay"] > 0: p.data.mul_(1 - group["lr"] * group["weight_decay"]) From 0db91f27d9a8f4a8bdd4df4efaa442e3b496ce32 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Thu, 22 May 2025 16:07:44 +0800 Subject: [PATCH 12/41] save memory --- modules/commons/common_layers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index ef69f8487..03527ef78 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -122,10 +122,8 @@ def forward(self, x): max_abs_gate = torch.max(-gate_min, gate_max).float() max_abs_value = max_abs_out * max_abs_gate if max_abs_value > 1000: - ratio = 1000 / max_abs_value - sqrt_ratio = torch.sqrt(ratio) - out = out * sqrt_ratio - gate = gate * sqrt_ratio + ratio = (1000 / max_abs_value).half() + gate *= ratio return (out * gate).clamp(-1000 * ratio, 1000 * ratio) / ratio return out * gate From f5915341d09dbd85d627e10b8b0569208befb520 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Thu, 5 Jun 2025 20:28:34 +0800 Subject: [PATCH 13/41] support atanglu --- configs/acoustic.yaml | 1 + configs/templates/config_acoustic.yaml | 1 + configs/templates/config_variance.yaml | 2 ++ configs/variance.yaml | 2 ++ modules/backbones/lynxnet2.py | 19 +++++++++++++------ modules/commons/common_layers.py | 16 ++++++++++++++++ 6 files changed, 35 insertions(+), 6 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index dec306f3a..435ddf9a4 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -78,6 +78,7 @@ backbone_args: kernel_size: 31 dropout_rate: 0.0 use_conditioner_cache: true + glu_type: 'atanglu' main_loss_type: l2 main_loss_log_norm: false schedule_type: 'linear' diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 76d148f1a..624e959b1 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -84,6 +84,7 @@ backbone_args: kernel_size: 31 dropout_rate: 0.0 use_conditioner_cache: true + glu_type: 'atanglu' #backbone_type: 'wavenet' #backbone_args: # num_channels: 512 diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index eecb5cad7..fb33b46ef 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -106,6 +106,7 @@ pitch_prediction_args: num_channels: 512 dropout_rate: 0.0 use_conditioner_cache: true + glu_type: 'atanglu' variances_prediction_args: total_repeat_bins: 48 @@ -120,6 +121,7 @@ variances_prediction_args: num_channels: 384 dropout_rate: 0.0 use_conditioner_cache: true + glu_type: 'atanglu' lambda_dur_loss: 1.0 lambda_pitch_loss: 1.0 diff --git a/configs/variance.yaml b/configs/variance.yaml index 2eb263d77..9795293d0 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -72,6 +72,7 @@ pitch_prediction_args: num_channels: 512 dropout_rate: 0.0 use_conditioner_cache: true + glu_type: 'atanglu' energy_db_min: -96.0 energy_db_max: -12.0 @@ -96,6 +97,7 @@ variances_prediction_args: num_channels: 384 dropout_rate: 0.0 use_conditioner_cache: true + glu_type: 'atanglu' lambda_dur_loss: 1.0 lambda_pitch_loss: 1.0 diff --git a/modules/backbones/lynxnet2.py b/modules/backbones/lynxnet2.py index b819f5509..9d6d15d37 100644 --- a/modules/backbones/lynxnet2.py +++ b/modules/backbones/lynxnet2.py @@ -2,14 +2,20 @@ import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Transpose +from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, ATanGLU, Transpose from utils.hparams import hparams class LYNXNet2Block(nn.Module): - def __init__(self, dim, expansion_factor, kernel_size=31, dropout=0.): + def __init__(self, dim, expansion_factor, kernel_size=31, dropout=0., glu_type='swiglu'): super().__init__() inner_dim = int(dim * expansion_factor) + if glu_type == 'swiglu': + _glu = SwiGLU() + elif glu_type == 'atanglu': + _glu = ATanGLU() + else: + raise ValueError(f'{glu_type} is not a valid activation') if float(dropout) > 0.: _dropout = nn.Dropout(dropout) else: @@ -20,9 +26,9 @@ def __init__(self, dim, expansion_factor, kernel_size=31, dropout=0.): nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim), Transpose((1, 2)), nn.Linear(dim, inner_dim * 2), - SwiGLU(), + _glu, nn.Linear(inner_dim, inner_dim * 2), - SwiGLU(), + _glu, nn.Linear(inner_dim, dim), _dropout ) @@ -33,7 +39,7 @@ def forward(self, x): class LYNXNet2(nn.Module): def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=1, kernel_size=31, - dropout=0.0, use_conditioner_cache=False): + dropout=0.0, use_conditioner_cache=False, glu_type='swiglu'): """ LYNXNet2(Linear Gated Depthwise Separable Convolution Network Version 2) """ @@ -59,7 +65,8 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio dim=num_channels, expansion_factor=expansion_factor, kernel_size=kernel_size, - dropout=dropout + dropout=dropout, + glu_type=glu_type ) for i in range(num_layers) ] diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index 03527ef78..f0e359e52 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -128,6 +128,19 @@ def forward(self, x): return out * gate +class ATanGLU(nn.Module): + # ArcTan-Applies the gated linear unit function. + def __init__(self, dim=-1): + super().__init__() + self.dim = dim + + def forward(self, x): + # out, gate = x.chunk(2, dim=self.dim) + # Using torch.split instead of chunk for ONNX export compatibility. + out, gate = torch.split(x, x.size(self.dim) // 2, dim=self.dim) + return out * torch.atan(gate) + + class KaimingNormalConv1d(torch.nn.Conv1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -160,6 +173,9 @@ def __init__(self, hidden_size, filter_size, kernel_size=1, dropout=0., act='gel elif self.act == 'swiglu': self.act_fn = SwiGLU() filter_size_1 = filter_size * 2 + elif self.act == 'atanglu': + self.act_fn = ATanGLU() + filter_size_1 = filter_size * 2 else: raise ValueError(f'{act} is not a valid activation') self.ffn_1 = nn.Conv1d(hidden_size, filter_size_1, kernel_size, padding=kernel_size // 2) From f6af252039344f1ff052d8f199c260ea5fc81dde Mon Sep 17 00:00:00 2001 From: yxlllc Date: Mon, 30 Jun 2025 23:53:25 +0800 Subject: [PATCH 14/41] support one smooth_kernel --- utils/binarizer_utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/utils/binarizer_utils.py b/utils/binarizer_utils.py index df5216429..940d14a5d 100644 --- a/utils/binarizer_utils.py +++ b/utils/binarizer_utils.py @@ -214,13 +214,16 @@ def __init__(self, kernel_size): super().__init__( in_channels=1, out_channels=1, - kernel_size=kernel_size, + kernel_size=max(kernel_size, 1), bias=False, padding='same', padding_mode='replicate' ) - smooth_kernel = torch.sin(torch.from_numpy( - np.linspace(0, 1, kernel_size).astype(np.float32) * np.pi - )) - smooth_kernel /= smooth_kernel.sum() + if kernel_size > 1: + smooth_kernel = torch.sin(torch.from_numpy( + np.linspace(0, 1, kernel_size).astype(np.float32) * np.pi + )) + smooth_kernel /= smooth_kernel.sum() + else: + smooth_kernel = torch.tensor([1.0], dtype=torch.float32) self.weight.data = smooth_kernel[None, None] From 2a10f2719481ea67e50c3addfada9fed595727e0 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Thu, 17 Jul 2025 01:26:20 +0800 Subject: [PATCH 15/41] optimize smooth width --- configs/acoustic.yaml | 8 ++++---- configs/variance.yaml | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 435ddf9a4..f85f28238 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -42,10 +42,10 @@ spec_max: [0] mel_vmin: -14. mel_vmax: 4. mel_base: 'e' -energy_smooth_width: 0.12 -breathiness_smooth_width: 0.12 -voicing_smooth_width: 0.12 -tension_smooth_width: 0.12 +energy_smooth_width: 0.06 +breathiness_smooth_width: 0.06 +voicing_smooth_width: 0.06 +tension_smooth_width: 0.06 use_lang_id: false num_lang: 1 diff --git a/configs/variance.yaml b/configs/variance.yaml index 9795293d0..c9f525461 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -76,18 +76,18 @@ pitch_prediction_args: energy_db_min: -96.0 energy_db_max: -12.0 -energy_smooth_width: 0.12 +energy_smooth_width: 0.06 breathiness_db_min: -96.0 breathiness_db_max: -20.0 -breathiness_smooth_width: 0.12 +breathiness_smooth_width: 0.06 voicing_db_min: -96.0 voicing_db_max: -12.0 -voicing_smooth_width: 0.12 +voicing_smooth_width: 0.06 tension_logit_min: -10.0 tension_logit_max: 10.0 -tension_smooth_width: 0.12 +tension_smooth_width: 0.06 variances_prediction_args: total_repeat_bins: 48 From 277c08210d7650dbb40df76fd301796e7d0db987 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Sat, 16 Aug 2025 20:39:56 +0800 Subject: [PATCH 16/41] fix --- modules/optimizer/chained_optimizer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/modules/optimizer/chained_optimizer.py b/modules/optimizer/chained_optimizer.py index b123f58e8..655680bf8 100644 --- a/modules/optimizer/chained_optimizer.py +++ b/modules/optimizer/chained_optimizer.py @@ -1,3 +1,4 @@ +import torch from torch import Tensor from torch.optim import Optimizer from torch.optim.optimizer import ParamsT @@ -87,9 +88,14 @@ def _copy_lr_to_optimizers(self) -> None: self.optimizers[optimizer_idx].param_groups[param_group_idx]["lr"] = param_group["lr"] def step(self, closure=None) -> None: + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() self._copy_lr_to_optimizers() for opt in self.optimizers: - opt.step(closure) + opt.step(closure=None) + return loss def add_param_group(self, param_group: Dict[str, Any]) -> None: super().add_param_group(param_group) From c4b01f665f30954a5889673277e3029305d31728 Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Tue, 15 Apr 2025 23:03:41 +0800 Subject: [PATCH 17/41] Fix RoPE cache issue about 'find_unused_parameters' when DDP training. (#244) * Fix issue about 'find_unused_parameters' when DDP training. * annotation * slim * Fix issue about 'find_unused_parameters' when DDP training. annotation slim * Update rotary_embedding_torch.py --- modules/commons/rotary_embedding_torch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/commons/rotary_embedding_torch.py b/modules/commons/rotary_embedding_torch.py index 4efcb514f..e0ab05f2a 100644 --- a/modules/commons/rotary_embedding_torch.py +++ b/modules/commons/rotary_embedding_torch.py @@ -306,7 +306,10 @@ def forward( exists(self.cached_freqs) and \ (offset + seq_len) <= self.cached_freqs_seq_len ): - return self.cached_freqs[offset:(offset + seq_len)].detach() + freqs = self.cached_freqs[offset:(offset + seq_len)].detach() + # Fix issue about 'find_unused_parameters' when DDP training.(#244) + freqs = freqs + 0. * self.freqs.sum() + return freqs freqs = self.freqs From 8e71afb044b74bb88a363bbcc213d3c3108ab79d Mon Sep 17 00:00:00 2001 From: yxlllc Date: Sat, 13 Sep 2025 17:43:24 +0800 Subject: [PATCH 18/41] fix and optimize --- configs/acoustic.yaml | 2 +- configs/templates/config_acoustic.yaml | 1 + modules/backbones/lynxnet.py | 4 +- modules/backbones/lynxnet2.py | 4 +- modules/optimizer/muon.py | 53 +++++++++++++------------- 5 files changed, 33 insertions(+), 31 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index f85f28238..0d7250d00 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -70,7 +70,7 @@ sampling_algorithm: euler sampling_steps: 20 diff_accelerator: ddim diff_speedup: 10 -hidden_size: 256 +hidden_size: 384 backbone_type: 'lynxnet2' backbone_args: num_channels: 1024 diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 624e959b1..669c3a6da 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -77,6 +77,7 @@ T_start: 0.4 T_start_infer: 0.4 K_step: 300 K_step_infer: 300 +hidden_size: 384 backbone_type: 'lynxnet2' backbone_args: num_channels: 1024 diff --git a/modules/backbones/lynxnet.py b/modules/backbones/lynxnet.py index 6d229aff8..9f5c6a383 100644 --- a/modules/backbones/lynxnet.py +++ b/modules/backbones/lynxnet.py @@ -74,7 +74,7 @@ def forward(self, x, conditioner, diffusion_step, front_cond_inject=False): class LYNXNet(nn.Module): def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=2, kernel_size=31, - activation='PReLU', dropout=0.0, strong_cond=False): + activation='PReLU', dropout_rate=0.0, strong_cond=False): """ LYNXNet(Linear Gated Depthwise Separable Convolution Network) TIPS:You can control the style of the generated results by modifying the 'activation', @@ -100,7 +100,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio expansion_factor=expansion_factor, kernel_size=kernel_size, activation=activation, - dropout=dropout + dropout=dropout_rate ) for i in range(num_layers) ] diff --git a/modules/backbones/lynxnet2.py b/modules/backbones/lynxnet2.py index 9d6d15d37..76e8580b5 100644 --- a/modules/backbones/lynxnet2.py +++ b/modules/backbones/lynxnet2.py @@ -39,7 +39,7 @@ def forward(self, x): class LYNXNet2(nn.Module): def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=1, kernel_size=31, - dropout=0.0, use_conditioner_cache=False, glu_type='swiglu'): + dropout_rate=0.0, use_conditioner_cache=False, glu_type='swiglu'): """ LYNXNet2(Linear Gated Depthwise Separable Convolution Network Version 2) """ @@ -65,7 +65,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio dim=num_channels, expansion_factor=expansion_factor, kernel_size=kernel_size, - dropout=dropout, + dropout=dropout_rate, glu_type=glu_type ) for i in range(num_layers) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 3ffe33a01..caf3a45e4 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -37,25 +37,25 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor """ assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng a, b, c = (3.4445, -4.7750, 2.0315) - if use_bf16: - X = G.bfloat16() - else: - X = G.float() - if G.size(-2) > G.size(-1): - X = X.mT + + X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32) # Ensure spectral norm is at most 1 X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) # Perform the NS iterations - for _ in range(steps): - A = X @ X.mT - B = torch.baddbmm(A, A, A, beta=b, alpha=c) - X = torch.baddbmm(X, B, X, beta=a, alpha=1) - - if G.size(-2) > G.size(-1): - X = X.mT - return X.to(G) + if X.size(-2) < X.size(-1): + for _ in range(steps): + A = torch.bmm(X, X.mT) + A = torch.baddbmm(A, A, A, beta=b, alpha=c) + X = torch.baddbmm(X, A, X, beta=a, alpha=1) + else: + for _ in range(steps): + A = torch.bmm(X.mT, X) + A = torch.baddbmm(A, A, A, beta=b, alpha=c) + X = torch.baddbmm(X, X, A, beta=a, alpha=1) + + return X class Muon(torch.optim.Optimizer): @@ -85,7 +85,7 @@ def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=Tr defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) super().__init__(params, defaults) self.bf16_support_map = get_bf16_support_map() - + @torch.no_grad() def step(self, closure=None): for group in self.param_groups: @@ -95,28 +95,29 @@ def step(self, closure=None): state = self.state[p] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) - buf: Tensor = state["momentum_buffer"] key = (p.shape, p.device, p.dtype) if key not in shape_groups: shape_groups[key] = {"params": [], "grads": [], "buffers": []} shape_groups[key]["params"].append(p) shape_groups[key]["grads"].append(g) - shape_groups[key]["buffers"].append(buf) + shape_groups[key]["buffers"].append(state["momentum_buffer"]) for key in shape_groups: group_data = shape_groups[key] - g = torch.stack(group_data["grads"]) - buf = torch.stack(group_data["buffers"]) - buf.lerp_(g, 1 - group["momentum"]) - g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf + p, g, buf, m = group_data["params"], group_data["grads"], group_data["buffers"], group["momentum"] + torch._foreach_lerp_(buf, g, 1-m) + if group["nesterov"]: + torch._foreach_lerp_(g, buf, m) + g = torch.stack(g) + else: + g = torch.stack(buf) + original_shape = g.shape if g.ndim >= 4: # for the case of conv filters g = g.view(g.size(0), g.size(1), -1) use_bf16 = self.bf16_support_map.get(g.device, False) g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16) - for i, p in enumerate(group_data["params"]): - if group["weight_decay"] > 0: - p.data.mul_(1 - group["lr"] * group["weight_decay"]) - p.data.add_(g[i].view_as(p), alpha=-group["lr"] * max(g[i].size()) ** 0.5) - self.state[p]["momentum_buffer"] = buf[i].clone() + if group["weight_decay"] > 0: + torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"]) + torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5) def get_params_for_muon(model) -> List[Parameter]: From 6df0ee977c3728f14cb79c2db8b19df30b23a0bf Mon Sep 17 00:00:00 2001 From: yxlllc Date: Fri, 3 Oct 2025 17:28:40 +0800 Subject: [PATCH 19/41] optimize --- configs/templates/config_variance.yaml | 4 ++-- configs/variance.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index fb33b46ef..3947928f2 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -66,7 +66,7 @@ tension_logit_max: 10.0 enc_ffn_kernel_size: 3 use_rope: true use_variance_scaling: true -hidden_size: 256 +hidden_size: 384 dur_prediction_args: arch: resnet hidden_size: 256 @@ -109,7 +109,7 @@ pitch_prediction_args: glu_type: 'atanglu' variances_prediction_args: - total_repeat_bins: 48 + total_repeat_bins: 72 # backbone_type: 'wavenet' # backbone_args: # num_layers: 10 diff --git a/configs/variance.yaml b/configs/variance.yaml index c9f525461..5a4799203 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -38,7 +38,7 @@ enc_ffn_kernel_size: 3 use_rope: true use_variance_scaling: true rel_pos: true -hidden_size: 256 +hidden_size: 384 dur_prediction_args: arch: resnet @@ -90,7 +90,7 @@ tension_logit_max: 10.0 tension_smooth_width: 0.06 variances_prediction_args: - total_repeat_bins: 48 + total_repeat_bins: 72 backbone_type: 'lynxnet2' backbone_args: num_layers: 6 From 47ad6f6be368f6dda66b3fc024c149de19003f95 Mon Sep 17 00:00:00 2001 From: yxlllc <33565655+yxlllc@users.noreply.github.com> Date: Tue, 11 Nov 2025 21:30:44 +0800 Subject: [PATCH 20/41] Stretch embed (#274) * Acoustic SR_embed (#270) * Acoustic SR_embed / Cosine annealing * del 'WarmupCosineSchedule' in config del 'WarmupCosineSchedule' in config * Fix the precision problem of 'StretchRegulator' in ONNX model * fix some odds and ends... * set 'use_stretch_embed' true on default * fix some odds and ends... * adjust * add stretch embed to variance models * fix * fix * fix * optimize * using lookup table for optimization * update --------- Co-authored-by: Kakaru <97896816+KakaruHayate@users.noreply.github.com> --- configs/acoustic.yaml | 1 + configs/templates/config_acoustic.yaml | 1 + configs/templates/config_variance.yaml | 1 + configs/variance.yaml | 1 + deployment/modules/fastspeech2.py | 22 ++++++++----- deployment/modules/toplevel.py | 16 +++++++--- modules/commons/common_layers.py | 2 +- modules/fastspeech/acoustic_encoder.py | 44 ++++++++++++++++++++------ modules/fastspeech/tts_modules.py | 9 +++--- modules/toplevel.py | 32 +++++++++++++++++-- utils/training_utils.py | 6 ++-- 11 files changed, 102 insertions(+), 33 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 0d7250d00..9bc978f26 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -64,6 +64,7 @@ timesteps: 1000 max_beta: 0.02 enc_ffn_kernel_size: 3 use_rope: true +use_stretch_embed: true use_variance_scaling: true rel_pos: true sampling_algorithm: euler diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 669c3a6da..acbe25dff 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -71,6 +71,7 @@ augmentation_args: diffusion_type: reflow enc_ffn_kernel_size: 3 use_rope: true +use_stretch_embed: true use_variance_scaling: true use_shallow_diffusion: true T_start: 0.4 diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 3947928f2..58a4d3a6d 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -65,6 +65,7 @@ tension_logit_max: 10.0 enc_ffn_kernel_size: 3 use_rope: true +use_stretch_embed: false use_variance_scaling: true hidden_size: 384 dur_prediction_args: diff --git a/configs/variance.yaml b/configs/variance.yaml index 5a4799203..6bd86cfad 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -36,6 +36,7 @@ predict_tension: false enc_ffn_kernel_size: 3 use_rope: true +use_stretch_embed: false use_variance_scaling: true rel_pos: true hidden_size: 384 diff --git a/deployment/modules/fastspeech2.py b/deployment/modules/fastspeech2.py index bb9d6a7bb..b22590c2a 100644 --- a/deployment/modules/fastspeech2.py +++ b/deployment/modules/fastspeech2.py @@ -73,6 +73,7 @@ def forward( txt_embed = self.txt_embed(tokens) durations = durations * (tokens > 0) mel2ph = self.lr(durations) + _mel2ph = mel2ph f0 = f0 * (mel2ph > 0) mel2ph = mel2ph[..., None].repeat((1, 1, hparams['hidden_size'])) if self.use_variance_scaling: @@ -92,6 +93,14 @@ def forward( encoded = F.pad(encoded, (0, 0, 1, 0)) condition = torch.gather(encoded, 1, mel2ph) + if self.use_stretch_embed: + stretch = torch.round(1000 * self.sr(_mel2ph, durations)) + table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device)) + stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(condition) + condition += stretch_embed + stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition) + condition += stretch_embed_rnn_out + if self.f0_embed_type == 'discrete': pitch = f0_to_coarse(f0) pitch_embed = self.pitch_embed(pitch) @@ -102,30 +111,27 @@ def forward( if self.use_variance_embeds: variance_embeds = torch.stack([ - self.variance_embeds[v_name](variances[v_name][:, :, None]) - * self.variance_scaling_factor[v_name] + self.variance_embeds[v_name](variances[v_name][:, :, None] * self.variance_scaling_factor[v_name]) for v_name in self.variance_embed_list ], dim=-1).sum(-1) condition += variance_embeds if hparams['use_key_shift_embed']: if hasattr(self, 'frozen_key_shift'): - key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None]) + key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None] * self.variance_scaling_factor['key_shift']) else: gender = torch.clip(gender, min=-1., max=1.) gender_mask = (gender < 0.).float() key_shift = gender * ((1. - gender_mask) * self.shift_max + gender_mask * abs(self.shift_min)) - key_shift_embed = self.key_shift_embed(key_shift[:, :, None]) - key_shift_embed *= self.variance_scaling_factor['key_shift'] + key_shift_embed = self.key_shift_embed(key_shift[:, :, None] * self.variance_scaling_factor['key_shift']) condition += key_shift_embed if hparams['use_speed_embed']: if velocity is not None: velocity = torch.clip(velocity, min=self.speed_min, max=self.speed_max) - speed_embed = self.speed_embed(velocity[:, :, None]) + speed_embed = self.speed_embed(velocity[:, :, None] * self.variance_scaling_factor['speed']) else: - speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None]) - speed_embed *= self.variance_scaling_factor['speed'] + speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None] * self.variance_scaling_factor['speed']) condition += speed_embed if hparams['use_spk_id']: diff --git a/deployment/modules/toplevel.py b/deployment/modules/toplevel.py index 0d1752ef3..3043168ca 100644 --- a/deployment/modules/toplevel.py +++ b/deployment/modules/toplevel.py @@ -211,14 +211,22 @@ def forward_linguistic_encoder_phoneme(self, tokens, ph_dur, languages=None): def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None): return self.fs2.forward_dur_predictor(encoder_out, x_masks, ph_midi, spk_embed=spk_embed) - def forward_mel2x_gather(self, x_src, x_dur, x_dim=None): + def forward_mel2x_gather(self, x_src, x_dur, x_dim=None, check_stretch_embed=False): mel2x = self.lr(x_dur) + _mel2x = mel2x if x_dim is not None: x_src = F.pad(x_src, [0, 0, 1, 0]) mel2x = mel2x[..., None].repeat([1, 1, x_dim]) else: x_src = F.pad(x_src, [1, 0]) x_cond = torch.gather(x_src, 1, mel2x) + if self.use_stretch_embed and check_stretch_embed: + stretch = torch.round(1000 * self.sr(_mel2x, x_dur)) + table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device)) + stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(x_cond) + x_cond += stretch_embed + stretch_embed_rnn_out, _ = self.stretch_embed_rnn(x_cond) + x_cond += stretch_embed_rnn_out return x_cond def forward_pitch_preprocess( @@ -226,7 +234,7 @@ def forward_pitch_preprocess( note_midi=None, note_rest=None, note_dur=None, note_glide=None, pitch=None, expr=None, retake=None, spk_embed=None ): - condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) + condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size, check_stretch_embed=True) if self.use_melody_encoder: if self.melody_encoder.use_glide_embed and note_glide is None: note_glide = torch.LongTensor([[0]]).to(encoder_out.device) @@ -280,7 +288,7 @@ def forward_variance_preprocess( self, encoder_out, ph_dur, pitch, variances: dict = None, retake=None, spk_embed=None ): - condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) + condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size, check_stretch_embed=True) if self.use_variance_scaling: variance_cond = condition + self.pitch_embed(pitch[:, :, None] / 12) else: @@ -290,7 +298,7 @@ def forward_variance_preprocess( for v_retake in (~retake).split(1, dim=2) ] variance_embeds = [ - self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks * self.variance_retake_scaling[v_name] + self.variance_embeds[v_name](variances[v_name][:, :, None] * self.variance_retake_scaling[v_name]) * v_masks for v_name, v_masks in zip(self.variance_prediction_list, non_retake_masks) ] variance_cond += torch.stack(variance_embeds, dim=-1).sum(-1) diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index f0e359e52..b6ebca235 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -325,6 +325,6 @@ def forward(self, x): half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] + emb = x.unsqueeze(-1) * emb.unsqueeze(0) emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb diff --git a/modules/fastspeech/acoustic_encoder.py b/modules/fastspeech/acoustic_encoder.py index 90da9cf80..86aa535a3 100644 --- a/modules/fastspeech/acoustic_encoder.py +++ b/modules/fastspeech/acoustic_encoder.py @@ -5,8 +5,9 @@ from modules.commons.common_layers import ( NormalInitEmbedding as Embedding, XavierUniformInitLinear as Linear, + SinusoidalPosEmb, ) -from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur +from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur, StretchRegulator from utils.hparams import hparams from utils.phoneme_utils import PAD_INDEX @@ -18,6 +19,19 @@ def __init__(self, vocab_size): self.use_lang_id = hparams.get('use_lang_id', False) if self.use_lang_id: self.lang_embed = Embedding(hparams['num_lang'] + 1, hparams['hidden_size'], padding_idx=0) + + self.use_stretch_embed = hparams.get('use_stretch_embed', None) + assert self.use_stretch_embed is not None, "You may be loading an old version of the model checkpoint, which is incompatible with the new version due to some bug fixes. It is recommended to roll back to the old version (commit id: 6df0ee977c3728f14cb79c2db8b19df30b23a0bf)" + if self.use_stretch_embed: + self.sr = StretchRegulator() + self.stretch_embed = nn.Sequential( + SinusoidalPosEmb(hparams['hidden_size']), + nn.Linear(hparams['hidden_size'], hparams['hidden_size'] * 4), + nn.GELU(), + nn.Linear(hparams['hidden_size'] * 4, hparams['hidden_size']), + ) + self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True) + self.dur_embed = Linear(1, hparams['hidden_size']) self.encoder = FastSpeech2Encoder( hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'], @@ -84,20 +98,17 @@ def __init__(self, vocab_size): def forward_variance_embedding(self, condition, key_shift=None, speed=None, **variances): if self.use_variance_embeds: variance_embeds = torch.stack([ - self.variance_embeds[v_name](variances[v_name][:, :, None]) - * self.variance_scaling_factor[v_name] + self.variance_embeds[v_name](variances[v_name][:, :, None] * self.variance_scaling_factor[v_name]) for v_name in self.variance_embed_list ], dim=-1).sum(-1) condition += variance_embeds if self.use_key_shift_embed: - key_shift_embed = self.key_shift_embed(key_shift[:, :, None]) - key_shift_embed *= self.variance_scaling_factor['key_shift'] + key_shift_embed = self.key_shift_embed(key_shift[:, :, None] * self.variance_scaling_factor['key_shift']) condition += key_shift_embed if self.use_speed_embed: - speed_embed = self.speed_embed(speed[:, :, None]) - speed_embed *= self.variance_scaling_factor['speed'] + speed_embed = self.speed_embed(speed[:, :, None] * self.variance_scaling_factor['speed']) condition += speed_embed return condition @@ -109,11 +120,11 @@ def forward( **kwargs ): txt_embed = self.txt_embed(txt_tokens) - dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1]).float() + dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1]) if self.use_variance_scaling: - dur_embed = self.dur_embed(torch.log(1 + dur[:, :, None])) + dur_embed = self.dur_embed(torch.log(1 + dur[:, :, None].float())) else: - dur_embed = self.dur_embed(dur[:, :, None]) + dur_embed = self.dur_embed(dur[:, :, None].float()) if self.use_lang_id: lang_embed = self.lang_embed(languages) extra_embed = dur_embed + lang_embed @@ -125,6 +136,19 @@ def forward( mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]]) condition = torch.gather(encoder_out, 1, mel2ph_) + if self.use_stretch_embed: + stretch = torch.round(1000 * self.sr(mel2ph, dur)) + if self.training and stretch.numel() > 1000: + # construct a phoneme stretching index lookup table with a total of 1001 indexes (0~1000) + table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device)) + stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(condition) + else: + stretch_embed = self.stretch_embed(stretch) + condition += stretch_embed + self.stretch_embed_rnn.flatten_parameters() + stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition) + condition = condition + stretch_embed_rnn_out + if self.use_spk_id: spk_mix_embed = kwargs.get('spk_mix_embed') if spk_mix_embed is not None: diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index 16b358e3c..882ebc115 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -347,14 +347,13 @@ def forward(self, mel2ph, dur=None): """ if dur is None: dur = mel2ph_to_dur(mel2ph, mel2ph.max()) - dur = F.pad(dur, [1, 0], value=1) # Avoid dividing by zero + dur = torch.cat([torch.ones_like(dur[:, :1]), dur], dim=1) # Avoid dividing by zero mel2dur = torch.gather(dur, 1, mel2ph) bound_mask = torch.gt(mel2ph[:, 1:], mel2ph[:, :-1]) - bound_mask = F.pad(bound_mask, [0, 1], mode='constant', value=True) - stretch_delta = 1 - bound_mask * mel2dur - stretch_delta = F.pad(stretch_delta, [1, -1], mode='constant', value=0) + stretch_delta = 1 - bound_mask * mel2dur[:, :-1] + stretch_delta = F.pad(stretch_delta, [1, 0]) stretch_denorm = torch.cumsum(stretch_delta, dim=1) - stretch = stretch_denorm / mel2dur + stretch = stretch_denorm.float() / mel2dur return stretch * (mel2ph > 0) diff --git a/modules/toplevel.py b/modules/toplevel.py index 1ad5aa4b7..3c3129665 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -10,7 +10,8 @@ from modules.aux_decoder import AuxDecoderAdaptor from modules.commons.common_layers import ( XavierUniformInitLinear as Linear, - NormalInitEmbedding as Embedding + NormalInitEmbedding as Embedding, + SinusoidalPosEmb ) from modules.core import ( GaussianDiffusion, PitchDiffusion, MultiVarianceDiffusion, @@ -18,7 +19,7 @@ ) from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic from modules.fastspeech.param_adaptor import ParameterAdaptorModule -from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator +from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator, StretchRegulator from modules.fastspeech.variance_encoder import FastSpeech2Variance, MelodyEncoder from utils.hparams import hparams @@ -133,6 +134,18 @@ def __init__(self, vocab_size): self.predict_dur = hparams['predict_dur'] self.predict_pitch = hparams['predict_pitch'] + self.use_stretch_embed = hparams.get('use_stretch_embed', None) + assert self.use_stretch_embed is not None, "You may be loading an old version of the model checkpoint, which is incompatible with the new version due to some bug fixes. It is recommended to roll back to the old version (commit id: 6df0ee977c3728f14cb79c2db8b19df30b23a0bf)" + if self.use_stretch_embed and (self.predict_pitch or self.predict_variances): + self.sr = StretchRegulator() + self.stretch_embed = nn.Sequential( + SinusoidalPosEmb(hparams['hidden_size']), + nn.Linear(hparams['hidden_size'], hparams['hidden_size'] * 4), + nn.GELU(), + nn.Linear(hparams['hidden_size'] * 4, hparams['hidden_size']), + ) + self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True) + self.use_spk_id = hparams['use_spk_id'] if self.use_spk_id: self.spk_embed = Embedding(hparams['num_spk'], hparams['hidden_size']) @@ -255,6 +268,19 @@ def forward( mel2ph_ = mel2ph[..., None].repeat([1, 1, hparams['hidden_size']]) condition = torch.gather(encoder_out, 1, mel2ph_) + if self.use_stretch_embed: + stretch = torch.round(1000 * self.sr(mel2ph, ph_dur)) + if self.training and stretch.numel() > 1000: + # construct a phoneme stretching index lookup table with a total of 1001 indexes (0~1000) + table = self.stretch_embed(torch.arange(0, 1001, device=stretch.device)) + stretch_embed = torch.index_select(table, 0, stretch.view(-1).long()).view_as(condition) + else: + stretch_embed = self.stretch_embed(stretch) + condition += stretch_embed + self.stretch_embed_rnn.flatten_parameters() + stretch_embed_rnn_out, _ = self.stretch_embed_rnn(condition) + condition = condition + stretch_embed_rnn_out + if self.use_spk_id: condition += spk_embed @@ -326,7 +352,7 @@ def forward( if variance_retake is not None: variance_embeds = [ - self.variance_embeds[v_name](v_input[:, :, None]) * ~variance_retake[v_name][:, :, None] * self.variance_retake_scaling[v_name] + self.variance_embeds[v_name](v_input[:, :, None] * self.variance_retake_scaling[v_name]) * ~variance_retake[v_name][:, :, None] for v_name, v_input in zip(self.variance_prediction_list, variance_inputs) ] var_cond += torch.stack(variance_embeds, dim=-1).sum(-1) diff --git a/utils/training_utils.py b/utils/training_utils.py index 26d24eec5..e906f7721 100644 --- a/utils/training_utils.py +++ b/utils/training_utils.py @@ -54,16 +54,18 @@ class WarmupCosineSchedule(LambdaLR): `eta_min` (default=0.0) corresponds to the minimum learning rate reached by the scheduler. """ - def __init__(self, optimizer, warmup_steps, t_total, eta_min=0.0, cycles=.5, last_epoch=-1): + def __init__(self, optimizer, warmup_steps, t_total, warmup_min=0.0, eta_min=0.0, cycles=.5, last_epoch=-1): self.warmup_steps = warmup_steps self.t_total = t_total self.eta_min = eta_min self.cycles = cycles + self.warmup_min = warmup_min super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) def lr_lambda(self, step): if step < self.warmup_steps: - return step / max(1.0, self.warmup_steps) + progress = step / max(1.0, self.warmup_steps) + return self.warmup_min + progress * (1.0 - self.warmup_min) # progress after warmup progress = (step - self.warmup_steps) / max(1, self.t_total - self.warmup_steps) return max(self.eta_min, 0.5 * (1. + math.cos(math.pi * self.cycles * 2.0 * progress))) From 1638ccdab34ad92d191c920393eb4364d0ef8731 Mon Sep 17 00:00:00 2001 From: Kakaru <97896816+KakaruHayate@users.noreply.github.com> Date: Mon, 24 Nov 2025 21:59:26 +0800 Subject: [PATCH 21/41] refactor RoPE (#276) refactor RoPE --- modules/commons/rotary_embedding_torch.py | 352 ++++------------------ 1 file changed, 53 insertions(+), 299 deletions(-) diff --git a/modules/commons/rotary_embedding_torch.py b/modules/commons/rotary_embedding_torch.py index e0ab05f2a..9af4a2771 100644 --- a/modules/commons/rotary_embedding_torch.py +++ b/modules/commons/rotary_embedding_torch.py @@ -1,323 +1,77 @@ from __future__ import annotations -from math import pi, log - import torch -from torch.amp import autocast -from torch.nn import Module, ModuleList -from torch import nn, einsum, broadcast_tensors, Tensor - +from torch import nn, einsum, Tensor +from torch.nn import Module from einops import rearrange, repeat -from typing import Literal - -# helper functions - -def exists(val): - return val is not None - -def default(val, d): - return val if exists(val) else d - -# broadcat, as tortoise-tts was using it - -def broadcat(tensors, dim = -1): - broadcasted_tensors = broadcast_tensors(*tensors) - return torch.cat(broadcasted_tensors, dim = dim) - -def slice_at_dim(t, dim_slice: slice, *, dim): - dim += (t.ndim if dim < 0 else 0) - colons = [slice(None)] * t.ndim - colons[dim] = dim_slice - return t[tuple(colons)] - -# rotary embedding helper functions - -def rotate_half(x): - x = rearrange(x, '... (d r) -> ... d r', r = 2) - x1, x2 = x.unbind(dim = -1) - x = torch.stack((-x2, x1), dim = -1) - return rearrange(x, '... d r -> ... (d r)') - -@autocast('cuda', enabled = False) -def apply_rotary_emb( - freqs, - t, - start_index = 0, - scale = 1., - seq_dim = -2, - freqs_seq_dim = None -): - dtype = t.dtype - if not exists(freqs_seq_dim): - if freqs.ndim == 2 or t.ndim == 3: - freqs_seq_dim = 0 +def rotate_half(x: Tensor, interleaved=True) -> Tensor: + if not interleaved: + # x_half1, x_half2 = x.chunk(2, dim=-1) + # Using torch.split instead of chunk for ONNX export compatibility. + x1, x2 = torch.split(x, x.size(-1) // 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x = rearrange(x, '... (d r) -> ... d r', r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, '... d r -> ... (d r)') - if t.ndim == 3 or exists(freqs_seq_dim): - seq_len = t.shape[seq_dim] - freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim) +def apply_rotary_emb(freqs: Tensor, t: Tensor, interleaved=True) -> Tensor: rot_dim = freqs.shape[-1] - end_index = start_index + rot_dim + t_to_rotate = t[..., :rot_dim] + t_pass_through = t[..., rot_dim:] + + t_rotated = (t_to_rotate * freqs.cos()) + (rotate_half(t_to_rotate, interleaved) * freqs.sin()) + + return torch.cat((t_rotated, t_pass_through), dim=-1) - assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' - - # Split t into three parts: left, middle (to be transformed), and right - t_left = t[..., :start_index] - t_middle = t[..., start_index:end_index] - t_right = t[..., end_index:] - - # Apply rotary embeddings without modifying t in place - t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) - - out = torch.cat((t_left, t_transformed, t_right), dim=-1) - - return out.type(dtype) - -# learned rotation helpers - -def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None): - if exists(freq_ranges): - rotations = einsum('..., f -> ... f', rotations, freq_ranges) - rotations = rearrange(rotations, '... r f -> ... (r f)') - - rotations = repeat(rotations, '... n -> ... (n r)', r = 2) - return apply_rotary_emb(rotations, t, start_index = start_index) - -# classes class RotaryEmbedding(Module): def __init__( self, dim, - custom_freqs: Tensor | None = None, - freqs_for: Literal['lang', 'pixel', 'constant'] = 'lang', - theta = 10000, - max_freq = 10, - num_freqs = 1, - learned_freq = False, - use_xpos = False, - xpos_scale_base = 512, - interpolate_factor = 1., - theta_rescale_factor = 1., - seq_before_head_dim = False, - cache_if_possible = True, - cache_max_seq_len = 8192 + theta=10000, + precompute_len=8192, + cache_max_seq_len=8192, + interleaved: bool = True ): super().__init__() - # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning - # has some connection to NTK literature - # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ - - theta *= theta_rescale_factor ** (dim / (dim - 2)) - - self.freqs_for = freqs_for + self.interleaved = interleaved - if exists(custom_freqs): - freqs = custom_freqs - elif freqs_for == 'lang': - freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) - elif freqs_for == 'pixel': - freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi - elif freqs_for == 'constant': - freqs = torch.ones(num_freqs).float() + inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) - self.cache_if_possible = cache_if_possible - self.cache_max_seq_len = cache_max_seq_len + self._cache_max_seq_len = max(precompute_len, cache_max_seq_len) + self._precomputed_len = precompute_len - self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) + self.register_buffer('cached_freqs', None, persistent=True) self.cached_freqs_seq_len = 0 + + if self._precomputed_len > 0: + self._precompute_cache(self._precomputed_len) - self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) - - self.learned_freq = learned_freq - - # dummy for device - - self.register_buffer('dummy', torch.tensor(0), persistent = False) - - # default sequence dimension - - self.seq_before_head_dim = seq_before_head_dim - self.default_seq_dim = -3 if seq_before_head_dim else -2 - - # interpolation factors - - assert interpolate_factor >= 1. - self.interpolate_factor = interpolate_factor - - # xpos - - self.use_xpos = use_xpos - - if not use_xpos: - return - - scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) - self.scale_base = xpos_scale_base - - self.register_buffer('scale', scale, persistent = False) - self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False) - self.cached_scales_seq_len = 0 - - # add apply_rotary_emb as static method - - self.apply_rotary_emb = staticmethod(apply_rotary_emb) - - @property - def device(self): - return self.dummy.device - - def get_seq_pos(self, seq_len, device, dtype, offset = 0): - return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor - - def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, scale = None): - seq_dim = default(seq_dim, self.default_seq_dim) - - assert not self.use_xpos or exists(scale), 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings' - - device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] - - seq = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset) - - freqs = self.forward(seq, seq_len = seq_len, offset = offset) - - if seq_dim == -3: - freqs = rearrange(freqs, 'n d -> n 1 d') - - return apply_rotary_emb(freqs, t, scale = default(scale, 1.), seq_dim = seq_dim) - - def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0): - dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim) - - q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] - assert q_len <= k_len - - q_scale = k_scale = 1. - - if self.use_xpos: - seq = self.get_seq_pos(k_len, dtype = dtype, device = device) - - q_scale = self.get_scale(seq[-q_len:]).type(dtype) - k_scale = self.get_scale(seq).type(dtype) - - rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, scale = q_scale, offset = k_len - q_len + offset) - rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim, scale = k_scale ** -1) - - rotated_q = rotated_q.type(q.dtype) - rotated_k = rotated_k.type(k.dtype) - - return rotated_q, rotated_k - - def rotate_queries_and_keys(self, q, k, seq_dim = None): - seq_dim = default(seq_dim, self.default_seq_dim) - - assert self.use_xpos - device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] - - seq = self.get_seq_pos(seq_len, dtype = dtype, device = device) - - freqs = self.forward(seq, seq_len = seq_len) - scale = self.get_scale(seq, seq_len = seq_len).to(dtype) - - if seq_dim == -3: - freqs = rearrange(freqs, 'n d -> n 1 d') - scale = rearrange(scale, 'n d -> n 1 d') - - rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim) - rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim) - - rotated_q = rotated_q.type(q.dtype) - rotated_k = rotated_k.type(k.dtype) - - return rotated_q, rotated_k - - def get_scale( - self, - t: Tensor, - seq_len: int | None = None, - offset = 0 - ): - assert self.use_xpos - - should_cache = ( - self.cache_if_possible and - exists(seq_len) and - (offset + seq_len) <= self.cache_max_seq_len - ) - - if ( - should_cache and \ - exists(self.cached_scales) and \ - (seq_len + offset) <= self.cached_scales_seq_len - ): - return self.cached_scales[offset:(offset + seq_len)] - - scale = 1. - if self.use_xpos: - power = (t - len(t) // 2) / self.scale_base - scale = self.scale ** rearrange(power, 'n -> n 1') - scale = repeat(scale, 'n d -> n (d r)', r = 2) - - if should_cache and offset == 0: - self.cached_scales[:seq_len] = scale.detach() - self.cached_scales_seq_len = seq_len - - return scale - - def get_axial_freqs(self, *dims): - Colon = slice(None) - all_freqs = [] - - for ind, dim in enumerate(dims): - if self.freqs_for == 'pixel': - pos = torch.linspace(-1, 1, steps = dim, device = self.device) - else: - pos = torch.arange(dim, device = self.device) - - freqs = self.forward(pos, seq_len = dim) - - all_axis = [None] * len(dims) - all_axis[ind] = Colon - - new_axis_slice = (Ellipsis, *all_axis, Colon) - all_freqs.append(freqs[new_axis_slice]) - - all_freqs = broadcast_tensors(*all_freqs) - return torch.cat(all_freqs, dim = -1) - - @autocast('cuda', enabled = False) - def forward( - self, - t: Tensor, - seq_len: int | None = None, - offset = 0 - ): - should_cache = ( - self.cache_if_possible and - not self.learned_freq and - exists(seq_len) and - self.freqs_for != 'pixel' and - (offset + seq_len) <= self.cache_max_seq_len - ) - - if ( - should_cache and \ - exists(self.cached_freqs) and \ - (offset + seq_len) <= self.cached_freqs_seq_len - ): - freqs = self.cached_freqs[offset:(offset + seq_len)].detach() - # Fix issue about 'find_unused_parameters' when DDP training.(#244) - freqs = freqs + 0. * self.freqs.sum() - return freqs - - freqs = self.freqs + def _precompute_cache(self, seq_len: int): + seq = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = einsum('i, j -> i j', seq, self.inv_freq) + + if self.interleaved: + freqs = repeat(freqs, '... n -> ... (n r)', r=2) + else: + freqs = torch.cat((freqs, freqs), dim=-1) - freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) - freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + self.cached_freqs = freqs + self.cached_freqs_seq_len = seq_len - if should_cache and offset == 0: - self.cached_freqs[:seq_len] = freqs.detach() - self.cached_freqs_seq_len = seq_len + def forward(self, t: Tensor, seq_len: int) -> Tensor: + if self.cached_freqs is None or seq_len > self.cached_freqs_seq_len: + self._precompute_cache(seq_len) + + return self.cached_freqs[0: seq_len].detach() - return freqs + def rotate_queries_or_keys(self, t: Tensor) -> Tensor: + device, dtype, seq_len = t.device, t.dtype, t.shape[-2] + freqs = self.forward(t, seq_len=seq_len) + + return apply_rotary_emb(freqs.to(device=device, dtype=dtype), t, self.interleaved) From c315d38bd236fdc6ee5bab397b5649b5a6b92b26 Mon Sep 17 00:00:00 2001 From: Kakaru <97896816+KakaruHayate@users.noreply.github.com> Date: Thu, 4 Dec 2025 00:42:45 +0800 Subject: [PATCH 22/41] NeoX style RoPE (#277) * refactor RoPE refactor RoPE * NeoX style RoPE * fix export ONNX model before RoPE refactor --- configs/acoustic.yaml | 1 + configs/templates/config_acoustic.yaml | 1 + configs/templates/config_variance.yaml | 1 + configs/variance.yaml | 1 + deployment/exporters/acoustic_exporter.py | 17 ++++++++++++++++- deployment/exporters/variance_exporter.py | 15 +++++++++++++++ modules/fastspeech/acoustic_encoder.py | 2 +- modules/fastspeech/tts_modules.py | 4 ++-- modules/fastspeech/variance_encoder.py | 4 ++-- 9 files changed, 40 insertions(+), 6 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 9bc978f26..935d6e160 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -64,6 +64,7 @@ timesteps: 1000 max_beta: 0.02 enc_ffn_kernel_size: 3 use_rope: true +rope_interleaved: false use_stretch_embed: true use_variance_scaling: true rel_pos: true diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index acbe25dff..9d63028f0 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -71,6 +71,7 @@ augmentation_args: diffusion_type: reflow enc_ffn_kernel_size: 3 use_rope: true +rope_interleaved: false use_stretch_embed: true use_variance_scaling: true use_shallow_diffusion: true diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 58a4d3a6d..40f4c532b 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -65,6 +65,7 @@ tension_logit_max: 10.0 enc_ffn_kernel_size: 3 use_rope: true +rope_interleaved: false use_stretch_embed: false use_variance_scaling: true hidden_size: 384 diff --git a/configs/variance.yaml b/configs/variance.yaml index 6bd86cfad..a819c1c43 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -36,6 +36,7 @@ predict_tension: false enc_ffn_kernel_size: 3 use_rope: true +rope_interleaved: false use_stretch_embed: false use_variance_scaling: true rel_pos: true diff --git a/deployment/exporters/acoustic_exporter.py b/deployment/exporters/acoustic_exporter.py index 849dae5db..e74373653 100644 --- a/deployment/exporters/acoustic_exporter.py +++ b/deployment/exporters/acoustic_exporter.py @@ -1,6 +1,7 @@ import json from pathlib import Path from typing import List, Union, Tuple, Dict +import warnings import onnx import onnxsim @@ -78,6 +79,7 @@ def __init__( self.export_spk = [(name, {name: 1.0}) for name in self.spk_map.keys()] if self.freeze_spk is not None: self.model.fs2.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1])) + self.rope_interleaved = hparams.get('rope_interleaved', None) def build_model(self) -> DiffSingerAcousticONNX: model = DiffSingerAcousticONNX( @@ -88,8 +90,21 @@ def build_model(self) -> DiffSingerAcousticONNX: for p in self.phoneme_dictionary.cross_lingual_phonemes }) ).eval().to(self.device) + if self.rope_interleaved is None: + warnings.warn( + "After RoPE is refactored, the checkpoint no longer contains relevant parameters. " + "(https://github.com/openvpi/DiffSinger/pull/276)" + "In order to export ONNX with behavior compatible with past checkpoints, " + "it will be set to 'strict=False', which will no longer check the validity of the checkpoint. " + "Please understand what you are doing.", + UserWarning, + stacklevel=2 + ) + strict=False + else: + strict=True load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps, - prefix_in_ckpt='model', strict=True, device=self.device) + prefix_in_ckpt='model', strict=strict, device=self.device) return model def export(self, path: Path): diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 82808ec08..69af991c4 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -1,6 +1,7 @@ import json from pathlib import Path from typing import Union, List, Tuple, Dict +import warnings import onnx import onnxsim @@ -81,6 +82,7 @@ def __init__( self.export_spk = [(name, {name: 1.0}) for name in self.spk_map.keys()] if self.freeze_spk is not None: self.model.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1])) + self.rope_interleaved = hparams.get('rope_interleaved', None) def build_model(self) -> DiffSingerVarianceONNX: model = DiffSingerVarianceONNX( @@ -90,6 +92,19 @@ def build_model(self) -> DiffSingerVarianceONNX: for p in self.phoneme_dictionary.cross_lingual_phonemes }) ).eval().to(self.device) + if self.rope_interleaved is None: + warnings.warn( + "After RoPE is refactored, the checkpoint no longer contains relevant parameters. " + "(https://github.com/openvpi/DiffSinger/pull/276)" + "In order to export ONNX with behavior compatible with past checkpoints, " + "it will be set to 'strict=False', which will no longer check the validity of the checkpoint. " + "Please understand what you are doing.", + UserWarning, + stacklevel=2 + ) + strict=False + else: + strict=True load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps, prefix_in_ckpt='model', strict=True, device=self.device) model.build_smooth_op(self.device) diff --git a/modules/fastspeech/acoustic_encoder.py b/modules/fastspeech/acoustic_encoder.py index 86aa535a3..868d383fd 100644 --- a/modules/fastspeech/acoustic_encoder.py +++ b/modules/fastspeech/acoustic_encoder.py @@ -38,7 +38,7 @@ def __init__(self, vocab_size): ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'], dropout=hparams['dropout'], num_heads=hparams['num_heads'], use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False), - use_rope=hparams.get('use_rope', False) + use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True) ) self.pitch_embed = Linear(1, hparams['hidden_size']) diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index 882ebc115..cc840aed3 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -369,14 +369,14 @@ def mel2ph_to_dur(mel2ph, T_txt, max_dur=None): class FastSpeech2Encoder(nn.Module): def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, ffn_act='gelu', - dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, use_rope=False): + dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, use_rope=False, rope_interleaved=True): super().__init__() self.num_layers = num_layers embed_dim = self.hidden_size = hidden_size self.dropout = dropout self.use_pos_embed = use_pos_embed if use_pos_embed and use_rope: - rotary_embed = RotaryEmbedding(dim = embed_dim // num_heads) + rotary_embed = RotaryEmbedding(dim = embed_dim // num_heads, interleaved = rope_interleaved) else: rotary_embed = None self.layers = nn.ModuleList([ diff --git a/modules/fastspeech/variance_encoder.py b/modules/fastspeech/variance_encoder.py index 70edcebcb..ba6994c1e 100644 --- a/modules/fastspeech/variance_encoder.py +++ b/modules/fastspeech/variance_encoder.py @@ -33,7 +33,7 @@ def __init__(self, vocab_size): ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'], dropout=hparams['dropout'], num_heads=hparams['num_heads'], use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False), - use_rope=hparams.get('use_rope', False) + use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True) ) dur_hparams = hparams['dur_prediction_args'] @@ -127,7 +127,7 @@ def get_hparam(key): ffn_kernel_size=get_hparam('enc_ffn_kernel_size'), ffn_act=get_hparam('ffn_act'), dropout=get_hparam('dropout'), num_heads=get_hparam('num_heads'), use_pos_embed=get_hparam('use_pos_embed'), rel_pos=get_hparam('rel_pos'), - use_rope=get_hparam('use_rope') + use_rope=get_hparam('use_rope'), rope_interleaved=hparams.get('rope_interleaved', True) ) self.out_proj = Linear(hidden_size, hparams['hidden_size']) From a39677b4202aa4c20bba2327b5c5455a7458a8f5 Mon Sep 17 00:00:00 2001 From: yxlllc <33565655+yxlllc@users.noreply.github.com> Date: Thu, 4 Dec 2025 00:56:00 +0800 Subject: [PATCH 23/41] Optimized glu (#278) * optimize * Keep AtanGLU behavior unchanged during eval (#275) --------- Co-authored-by: Kakaru <97896816+KakaruHayate@users.noreply.github.com> --- modules/commons/common_layers.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index b6ebca235..0012b99c3 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -128,6 +128,22 @@ def forward(self, x): return out * gate +class ATanGLUFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, out, gate): + atan_gate = torch.atan(gate) + decay_out = out / gate.square().add(1.0) + ctx.save_for_backward(decay_out, atan_gate) + return out * atan_gate + + @staticmethod + def backward(ctx, grad_output): + decay_out, atan_gate = ctx.saved_tensors + grad_out_part = grad_output * atan_gate + grad_gate_part = grad_output * decay_out + return grad_out_part, grad_gate_part + + class ATanGLU(nn.Module): # ArcTan-Applies the gated linear unit function. def __init__(self, dim=-1): @@ -136,9 +152,12 @@ def __init__(self, dim=-1): def forward(self, x): # out, gate = x.chunk(2, dim=self.dim) - # Using torch.split instead of chunk for ONNX export compatibility. + # Using torch.split instead of chunk for ONNX export compatibility. out, gate = torch.split(x, x.size(self.dim) // 2, dim=self.dim) - return out * torch.atan(gate) + if self.training: + return ATanGLUFunction.apply(out, gate) + else: + return out * torch.atan(gate) class KaimingNormalConv1d(torch.nn.Conv1d): From 31056972a08fdc9169152b6b60ad70e009ca9dbd Mon Sep 17 00:00:00 2001 From: Kakaru <97896816+KakaruHayate@users.noreply.github.com> Date: Wed, 10 Dec 2025 23:31:44 +0800 Subject: [PATCH 24/41] fix export (#281) --- deployment/exporters/acoustic_exporter.py | 2 +- deployment/exporters/variance_exporter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deployment/exporters/acoustic_exporter.py b/deployment/exporters/acoustic_exporter.py index e74373653..259343f38 100644 --- a/deployment/exporters/acoustic_exporter.py +++ b/deployment/exporters/acoustic_exporter.py @@ -35,6 +35,7 @@ def __init__( self.lang_map: dict = self.build_lang_map() self.phoneme_dictionary = load_phoneme_dictionary() self.use_lang_id = hparams.get('use_lang_id', False) and len(self.phoneme_dictionary.cross_lingual_phonemes) > 0 + self.rope_interleaved = hparams.get('rope_interleaved', None) self.model = self.build_model() self.fs2_aux_cache_path = self.cache_dir / ( 'fs2_aux.onnx' if self.model.use_shallow_diffusion else 'fs2.onnx' @@ -79,7 +80,6 @@ def __init__( self.export_spk = [(name, {name: 1.0}) for name in self.spk_map.keys()] if self.freeze_spk is not None: self.model.fs2.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1])) - self.rope_interleaved = hparams.get('rope_interleaved', None) def build_model(self) -> DiffSingerAcousticONNX: model = DiffSingerAcousticONNX( diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 69af991c4..0f84918c5 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -35,6 +35,7 @@ def __init__( self.lang_map: dict = self.build_lang_map() self.phoneme_dictionary = load_phoneme_dictionary() self.use_lang_id = hparams.get('use_lang_id', False) and len(self.phoneme_dictionary.cross_lingual_phonemes) > 0 + self.rope_interleaved = hparams.get('rope_interleaved', None) self.model = self.build_model() self.linguistic_encoder_cache_path = self.cache_dir / 'linguistic.onnx' self.dur_predictor_cache_path = self.cache_dir / 'dur.onnx' @@ -82,7 +83,6 @@ def __init__( self.export_spk = [(name, {name: 1.0}) for name in self.spk_map.keys()] if self.freeze_spk is not None: self.model.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1])) - self.rope_interleaved = hparams.get('rope_interleaved', None) def build_model(self) -> DiffSingerVarianceONNX: model = DiffSingerVarianceONNX( From 32d7fa1e61f651815e218c3a9fb8f434a460db5d Mon Sep 17 00:00:00 2001 From: yxlllc Date: Mon, 12 Jan 2026 21:14:12 +0800 Subject: [PATCH 25/41] fix padding of rmvpe and vr inferences --- modules/hnsep/vr/layers.py | 10 ++++++-- modules/hnsep/vr/nets.py | 43 +++++++++++++++++++---------------- modules/pe/rmvpe/inference.py | 14 +++++++++--- 3 files changed, 43 insertions(+), 24 deletions(-) diff --git a/modules/hnsep/vr/layers.py b/modules/hnsep/vr/layers.py index bff9460a4..4b4a2ecb3 100644 --- a/modules/hnsep/vr/layers.py +++ b/modules/hnsep/vr/layers.py @@ -64,8 +64,14 @@ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=F # self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) self.dropout = nn.Dropout2d(0.1) if dropout else None - def forward(self, x, skip=None): - x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + def forward(self, x, skip=None, fixed_length=True): + if fixed_length: + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + else: + _, _, h, w = x.size() + x = F.pad(x, (0, 1, 0, 1), mode='replicate') + x = F.interpolate(x, size=(2*h+1,2*w+1), mode='bilinear', align_corners=True) + x = x[:, :, :-1, :-1] if skip is not None: skip = crop_center(skip, x) diff --git a/modules/hnsep/vr/nets.py b/modules/hnsep/vr/nets.py index 58ca2fe07..f9da1d8c8 100644 --- a/modules/hnsep/vr/nets.py +++ b/modules/hnsep/vr/nets.py @@ -7,7 +7,7 @@ class BaseNet(nn.Module): - def __init__(self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))): + def __init__(self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6)), fixed_length=True): super(BaseNet, self).__init__() self.enc1 = layers.Conv2DBNActiv(nin, nout, 3, 1, 1) self.enc2 = layers.Encoder(nout, nout * 2, 3, 2, 1) @@ -22,8 +22,10 @@ def __init__(self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (1 self.dec2 = layers.Decoder(nout * (2 + 4), nout * 2, 3, 1, 1) self.lstm_dec2 = layers.LSTMModule(nout * 2, nin_lstm, nout_lstm) self.dec1 = layers.Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1) + + self.fixed_length = fixed_length - def forward(self, x): + def __call__(self, x): e1 = self.enc1(x) e2 = self.enc2(e1) e3 = self.enc3(e2) @@ -32,21 +34,22 @@ def forward(self, x): h = self.aspp(e5) - h = self.dec4(h, e4) - h = self.dec3(h, e3) - h = self.dec2(h, e2) + h = self.dec4(h, e4, fixed_length=self.fixed_length) + h = self.dec3(h, e3, fixed_length=self.fixed_length) + h = self.dec2(h, e2, fixed_length=self.fixed_length) h = torch.cat([h, self.lstm_dec2(h)], dim=1) - h = self.dec1(h, e1) + h = self.dec1(h, e1, fixed_length=self.fixed_length) return h class CascadedNet(nn.Module): - def __init__(self, n_fft, hop_length, nout=32, nout_lstm=128, is_complex=False, is_mono=False): + def __init__(self, n_fft, hop_length, nout=32, nout_lstm=128, is_complex=False, is_mono=False, fixed_length=True): super(CascadedNet, self).__init__() self.n_fft = n_fft self.hop_length = hop_length + self.seg_length = 32 * hop_length self.is_complex = is_complex self.is_mono = is_mono self.register_buffer("window", torch.hann_window(n_fft), persistent=False) @@ -60,23 +63,23 @@ def __init__(self, n_fft, hop_length, nout=32, nout_lstm=128, is_complex=False, nin = nin // 2 self.stg1_low_band_net = nn.Sequential( - BaseNet(nin, nout // 2, self.nin_lstm // 2, nout_lstm), + BaseNet(nin, nout // 2, self.nin_lstm // 2, nout_lstm, fixed_length=fixed_length), layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0) ) self.stg1_high_band_net = BaseNet( - nin, nout // 4, self.nin_lstm // 2, nout_lstm // 2 + nin, nout // 4, self.nin_lstm // 2, nout_lstm // 2, fixed_length=fixed_length ) self.stg2_low_band_net = nn.Sequential( - BaseNet(nout // 4 + nin, nout, self.nin_lstm // 2, nout_lstm), + BaseNet(nout // 4 + nin, nout, self.nin_lstm // 2, nout_lstm, fixed_length=fixed_length), layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0) ) self.stg2_high_band_net = BaseNet( - nout // 4 + nin, nout // 2, self.nin_lstm // 2, nout_lstm // 2 + nout // 4 + nin, nout // 2, self.nin_lstm // 2, nout_lstm // 2, fixed_length=fixed_length ) self.stg3_full_band_net = BaseNet( - 3 * nout // 4 + nin, nout, self.nin_lstm, nout_lstm + 3 * nout // 4 + nin, nout, self.nin_lstm, nout_lstm, fixed_length=fixed_length ) self.out = nn.Conv2d(nout, nin, 1, bias=False) @@ -150,8 +153,8 @@ def audio2spec(self, x, use_pad=False): B, C, T = x.shape x = x.reshape(B * C, T) if use_pad: - n_frames = T // self.hop_length + 1 - T_pad = (32 * ((n_frames - 1) // 32 + 1) - 1) * self.hop_length - T + T1 = T + self.hop_length + T_pad = self.seg_length * ((T1 - 1) // self.seg_length + 1) - T1 nl_pad = T_pad // 2 // self.hop_length Tl_pad = nl_pad * self.hop_length x = F.pad(x, (Tl_pad, T_pad - Tl_pad)) @@ -161,7 +164,8 @@ def audio2spec(self, x, use_pad=False): hop_length=self.hop_length, return_complex=True, window=self.window, - pad_mode='constant') + pad_mode='constant' + ) spec = spec.reshape(B, C, spec.shape[-2], spec.shape[-1]) return spec @@ -175,10 +179,10 @@ def spec2audio(self, x): def predict_from_audio(self, x): B, C, T = x.shape x = x.reshape(B * C, T) - n_frames = T // self.hop_length + 1 - T_pad = (32 * (n_frames // 32 + 1) - 1) * self.hop_length - T + T1 = T + self.hop_length + T_pad = self.seg_length * ((T1 - 1) // self.seg_length + 1) - T1 nl_pad = T_pad // 2 // self.hop_length - Tl_pad = nl_pad * self.hop_length + Tl_pad = nl_pad * self.hop_length x = F.pad(x, (Tl_pad, T_pad - Tl_pad)) spec = torch.stft( x, @@ -186,7 +190,8 @@ def predict_from_audio(self, x): hop_length=self.hop_length, return_complex=True, window=self.window, - pad_mode='constant') + pad_mode='constant' + ) spec = spec.reshape(B, C, spec.shape[-2], spec.shape[-1]) mask = self.forward(spec) spec_pred = spec * mask diff --git a/modules/pe/rmvpe/inference.py b/modules/pe/rmvpe/inference.py index 816a3582b..745b88b7b 100644 --- a/modules/pe/rmvpe/inference.py +++ b/modules/pe/rmvpe/inference.py @@ -19,6 +19,8 @@ def __init__(self, model_path, hop_length=160): self.model = E2E0(4, 1, (2, 2)).eval().to(self.device) ckpt = torch.load(model_path, map_location=self.device) self.model.load_state_dict(ckpt['model'], strict=False) + self.hop_length = hop_length + self.seg_length = 32 * hop_length self.mel_extractor = MelSpectrogram( N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX ).to(self.device) @@ -26,7 +28,7 @@ def __init__(self, model_path, hop_length=160): @torch.no_grad() def mel2hidden(self, mel): n_frames = mel.shape[-1] - mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant') + mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='reflecr') hidden = self.model(mel) return hidden[:, :n_frames] @@ -47,9 +49,15 @@ def infer_from_audio(self, audio, sample_rate=16000, thred=0.03, use_viterbi=Fal self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128) self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.device) audio_res = self.resample_kernel[key_str](audio) + B, T = audio_res.shape + n_frames = T // self.hop_length + 1 + T1 = T + self.hop_length + T_pad = self.seg_length * ((T1 - 1) // self.seg_length + 1) - T1 + audio_res = F.pad(audio_res, (0, T_pad)) mel = self.mel_extractor(audio_res, center=True) - hidden = self.mel2hidden(mel) - f0 = self.decode(hidden, thred=thred, use_viterbi=use_viterbi) + with torch.no_grad(): + hidden = self.model(mel) + f0 = self.decode(hidden[:, :n_frames], thred=thred, use_viterbi=use_viterbi) return f0 def get_pitch( From 494d5d8ab27160bbbb62dd62a11fbcdfc51111e1 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Mon, 12 Jan 2026 21:27:15 +0800 Subject: [PATCH 26/41] fix --- modules/pe/rmvpe/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/pe/rmvpe/inference.py b/modules/pe/rmvpe/inference.py index 745b88b7b..f3b8ad50e 100644 --- a/modules/pe/rmvpe/inference.py +++ b/modules/pe/rmvpe/inference.py @@ -28,7 +28,7 @@ def __init__(self, model_path, hop_length=160): @torch.no_grad() def mel2hidden(self, mel): n_frames = mel.shape[-1] - mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='reflecr') + mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='reflect') hidden = self.model(mel) return hidden[:, :n_frames] From 1c184abf2b2ad14e3ce55c7d3d81bcc8367ad0e0 Mon Sep 17 00:00:00 2001 From: Kakaru <97896816+KakaruHayate@users.noreply.github.com> Date: Mon, 12 Jan 2026 21:59:00 +0800 Subject: [PATCH 27/41] Merge pull request #288 from KakaruHayate/fix_rope fix export old variance model weight / load ckpt for finetune --- deployment/exporters/variance_exporter.py | 2 +- modules/commons/rotary_embedding_torch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 0f84918c5..49f63d8e3 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -106,7 +106,7 @@ def build_model(self) -> DiffSingerVarianceONNX: else: strict=True load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps, - prefix_in_ckpt='model', strict=True, device=self.device) + prefix_in_ckpt='model', strict=strict, device=self.device) model.build_smooth_op(self.device) return model diff --git a/modules/commons/rotary_embedding_torch.py b/modules/commons/rotary_embedding_torch.py index 9af4a2771..ea74b208f 100644 --- a/modules/commons/rotary_embedding_torch.py +++ b/modules/commons/rotary_embedding_torch.py @@ -46,7 +46,7 @@ def __init__( self._cache_max_seq_len = max(precompute_len, cache_max_seq_len) self._precomputed_len = precompute_len - self.register_buffer('cached_freqs', None, persistent=True) + self.register_buffer('cached_freqs', None, persistent=False) self.cached_freqs_seq_len = 0 if self._precomputed_len > 0: From 9ce87119f7bfaf72de452218e0b8fb6148587c65 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Mon, 12 Jan 2026 22:01:50 +0800 Subject: [PATCH 28/41] Keep T dim to avoid confusing (cherry picked from commit d8b38e66d8a59b9eafa650f3363ed8864a58766d) --- deployment/exporters/acoustic_exporter.py | 4 ++-- deployment/exporters/variance_exporter.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deployment/exporters/acoustic_exporter.py b/deployment/exporters/acoustic_exporter.py index ae7f917d3..51b3a1d12 100644 --- a/deployment/exporters/acoustic_exporter.py +++ b/deployment/exporters/acoustic_exporter.py @@ -354,8 +354,8 @@ def _perform_spk_mix(self, spk_mix: Dict[str, float]): spk_mix_value_N /= spk_mix_value_sum # normalize spk_mix_embed = torch.sum( self.model.fs2.spk_embed(spk_mix_id_N) * spk_mix_value_N.unsqueeze(2), # => [1, N, H] - dim=1, keepdim=False - ) # => [1, H] + dim=1, keepdim=True + ) # => [1, 1, H] return spk_mix_embed def _optimize_fs2_aux_graph(self, fs2: onnx.ModelProto) -> onnx.ModelProto: diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 49f63d8e3..959ab0436 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -653,8 +653,8 @@ def _perform_spk_mix(self, spk_mix: Dict[str, float]): spk_mix_value_N /= spk_mix_value_sum # normalize spk_mix_embed = torch.sum( self.model.spk_embed(spk_mix_id_N) * spk_mix_value_N.unsqueeze(2), # => [1, N, H] - dim=1, keepdim=False - ) # => [1, H] + dim=1, keepdim=True + ) # => [1, 1, H] return spk_mix_embed def _optimize_linguistic_graph(self, linguistic: onnx.ModelProto) -> onnx.ModelProto: From 9bb3a9fe8a267aba5e474e5e8cb71a415fc41ef8 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 13 Jan 2026 15:51:25 +0800 Subject: [PATCH 29/41] Compute RoPE cache in init and make it read-only (cherry picked from commit 4e298fe7e7178810660a6d53eb625acc0a549a40) --- modules/commons/rotary_embedding_torch.py | 50 ++++++++--------------- modules/fastspeech/tts_modules.py | 2 +- scripts/export.py | 2 +- 3 files changed, 20 insertions(+), 34 deletions(-) diff --git a/modules/commons/rotary_embedding_torch.py b/modules/commons/rotary_embedding_torch.py index ea74b208f..1a1fa193e 100644 --- a/modules/commons/rotary_embedding_torch.py +++ b/modules/commons/rotary_embedding_torch.py @@ -1,8 +1,7 @@ -from __future__ import annotations import torch -from torch import nn, einsum, Tensor -from torch.nn import Module from einops import rearrange, repeat +from torch import einsum, Tensor +from torch.nn import Module def rotate_half(x: Tensor, interleaved=True) -> Tensor: @@ -22,56 +21,43 @@ def apply_rotary_emb(freqs: Tensor, t: Tensor, interleaved=True) -> Tensor: rot_dim = freqs.shape[-1] t_to_rotate = t[..., :rot_dim] t_pass_through = t[..., rot_dim:] - + t_rotated = (t_to_rotate * freqs.cos()) + (rotate_half(t_to_rotate, interleaved) * freqs.sin()) - + return torch.cat((t_rotated, t_pass_through), dim=-1) class RotaryEmbedding(Module): def __init__( - self, - dim, - theta=10000, - precompute_len=8192, - cache_max_seq_len=8192, - interleaved: bool = True + self, + dim, + theta=10000, + max_seq_len=8192, + interleaved: bool = True ): super().__init__() self.interleaved = interleaved - + self.cached_freqs_seq_len = max_seq_len inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer('inv_freq', inv_freq) - - self._cache_max_seq_len = max(precompute_len, cache_max_seq_len) - self._precomputed_len = precompute_len - - self.register_buffer('cached_freqs', None, persistent=False) - self.cached_freqs_seq_len = 0 - - if self._precomputed_len > 0: - self._precompute_cache(self._precomputed_len) + self.register_buffer('inv_freq', inv_freq, persistent=False) + self.register_buffer('cached_freqs', self._precompute_cache(max_seq_len), persistent=False) def _precompute_cache(self, seq_len: int): seq = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = einsum('i, j -> i j', seq, self.inv_freq) - if self.interleaved: freqs = repeat(freqs, '... n -> ... (n r)', r=2) else: freqs = torch.cat((freqs, freqs), dim=-1) + return freqs - self.cached_freqs = freqs - self.cached_freqs_seq_len = seq_len - - def forward(self, t: Tensor, seq_len: int) -> Tensor: - if self.cached_freqs is None or seq_len > self.cached_freqs_seq_len: - self._precompute_cache(seq_len) - + def forward(self, seq_len: int) -> Tensor: + if seq_len > self.cached_freqs_seq_len: + raise RuntimeError("sequence exceeds RoPE max_seq_len!") return self.cached_freqs[0: seq_len].detach() def rotate_queries_or_keys(self, t: Tensor) -> Tensor: device, dtype, seq_len = t.device, t.dtype, t.shape[-2] - freqs = self.forward(t, seq_len=seq_len) - + freqs = self.forward(seq_len=seq_len) + return apply_rotary_emb(freqs.to(device=device, dtype=dtype), t, self.interleaved) diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index cc840aed3..5d8f01261 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -376,7 +376,7 @@ def __init__(self, hidden_size, num_layers, self.dropout = dropout self.use_pos_embed = use_pos_embed if use_pos_embed and use_rope: - rotary_embed = RotaryEmbedding(dim = embed_dim // num_heads, interleaved = rope_interleaved) + rotary_embed = RotaryEmbedding(dim=embed_dim // num_heads, interleaved=rope_interleaved) else: rotary_embed = None self.layers = nn.ModuleList([ diff --git a/scripts/export.py b/scripts/export.py index d666175d6..c63597e66 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -298,5 +298,5 @@ def nsf_hifigan( if __name__ == '__main__': - check_pytorch_version() + # check_pytorch_version() main() From 446494cb5dff96893a09f279d46d0232bd203c75 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 13 Jan 2026 15:51:51 +0800 Subject: [PATCH 30/41] Fix checkpoint compatibility (cherry picked from commit 74e8d32b4fb2254180c63e57a9ca202a43e797f4) --- deployment/exporters/acoustic_exporter.py | 19 ++------------- deployment/exporters/variance_exporter.py | 15 ------------ utils/__init__.py | 29 ++++++++++++++++------- 3 files changed, 22 insertions(+), 41 deletions(-) diff --git a/deployment/exporters/acoustic_exporter.py b/deployment/exporters/acoustic_exporter.py index 51b3a1d12..1f39ffdc6 100644 --- a/deployment/exporters/acoustic_exporter.py +++ b/deployment/exporters/acoustic_exporter.py @@ -1,7 +1,6 @@ import json from pathlib import Path -from typing import List, Union, Tuple, Dict -import warnings +from typing import Union, List, Tuple, Dict import onnx import onnxsim @@ -35,7 +34,6 @@ def __init__( self.lang_map: dict = self.build_lang_map() self.phoneme_dictionary = load_phoneme_dictionary() self.use_lang_id = hparams.get('use_lang_id', False) and len(self.phoneme_dictionary.cross_lingual_phonemes) > 0 - self.rope_interleaved = hparams.get('rope_interleaved', None) self.model = self.build_model() self.fs2_aux_cache_path = self.cache_dir / ( 'fs2_aux.onnx' if self.model.use_shallow_diffusion else 'fs2.onnx' @@ -90,21 +88,8 @@ def build_model(self) -> DiffSingerAcousticONNX: for p in self.phoneme_dictionary.cross_lingual_phonemes }) ).eval().to(self.device) - if self.rope_interleaved is None: - warnings.warn( - "After RoPE is refactored, the checkpoint no longer contains relevant parameters. " - "(https://github.com/openvpi/DiffSinger/pull/276)" - "In order to export ONNX with behavior compatible with past checkpoints, " - "it will be set to 'strict=False', which will no longer check the validity of the checkpoint. " - "Please understand what you are doing.", - UserWarning, - stacklevel=2 - ) - strict=False - else: - strict=True load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps, - prefix_in_ckpt='model', strict=strict, device=self.device) + prefix_in_ckpt='model', strict=True, device=self.device) return model def export(self, path: Path): diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 959ab0436..9ae0a2e2a 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -1,7 +1,6 @@ import json from pathlib import Path from typing import Union, List, Tuple, Dict -import warnings import onnx import onnxsim @@ -35,7 +34,6 @@ def __init__( self.lang_map: dict = self.build_lang_map() self.phoneme_dictionary = load_phoneme_dictionary() self.use_lang_id = hparams.get('use_lang_id', False) and len(self.phoneme_dictionary.cross_lingual_phonemes) > 0 - self.rope_interleaved = hparams.get('rope_interleaved', None) self.model = self.build_model() self.linguistic_encoder_cache_path = self.cache_dir / 'linguistic.onnx' self.dur_predictor_cache_path = self.cache_dir / 'dur.onnx' @@ -92,19 +90,6 @@ def build_model(self) -> DiffSingerVarianceONNX: for p in self.phoneme_dictionary.cross_lingual_phonemes }) ).eval().to(self.device) - if self.rope_interleaved is None: - warnings.warn( - "After RoPE is refactored, the checkpoint no longer contains relevant parameters. " - "(https://github.com/openvpi/DiffSinger/pull/276)" - "In order to export ONNX with behavior compatible with past checkpoints, " - "it will be set to 'strict=False', which will no longer check the validity of the checkpoint. " - "Please understand what you are doing.", - UserWarning, - stacklevel=2 - ) - strict=False - else: - strict=True load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps, prefix_in_ckpt='model', strict=strict, device=self.device) model.build_smooth_op(self.device) diff --git a/utils/__init__.py b/utils/__init__.py index 1f4c17c04..78da6d6a7 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -5,6 +5,7 @@ import time import types from collections import OrderedDict +from fnmatch import fnmatch import numpy as np import torch @@ -165,12 +166,14 @@ def filter_kwargs(dict_to_filter, kwarg_obj): def load_ckpt( cur_model, ckpt_base_dir, ckpt_steps=None, - prefix_in_ckpt='model', ignored_prefixes=None, key_in_ckpt='state_dict', + prefix_in_ckpt='model', exclude_key_patterns=None, key_in_ckpt='state_dict', strict=True, device='cpu' ): - if ignored_prefixes is None: - # NOTICE: this is for compatibility with old checkpoints which have duplicate txt_embed layer in them. - ignored_prefixes = ['model.fs2.encoder.embed_tokens'] + if exclude_key_patterns is None: + # Pop all RoPE buffers from some old checkpoints, + # Because these buffers are all computed during initialization now. + # TODO: this is a legacy handling and should be removed in the future. + exclude_key_patterns = ['*.rotary_embed.*'] if not isinstance(ckpt_base_dir, pathlib.Path): ckpt_base_dir = pathlib.Path(ckpt_base_dir) if ckpt_base_dir.is_file(): @@ -197,11 +200,19 @@ def load_ckpt( else: state_dict = ckpt_loaded[key_in_ckpt] if prefix_in_ckpt is not None: - state_dict = OrderedDict({ - k[len(prefix_in_ckpt) + 1:]: v - for k, v in state_dict.items() if k.startswith(f'{prefix_in_ckpt}.') - if all(not k.startswith(p) for p in ignored_prefixes) - }) + state_dict = OrderedDict() + for k, v in ckpt_loaded[key_in_ckpt].items(): + if not k.startswith(f'{prefix_in_ckpt}.'): + continue + k = k[len(prefix_in_ckpt) + 1:] + excluded = False + for pat in exclude_key_patterns: + if fnmatch(k, pat): + excluded = True + break + if excluded: + continue + state_dict[k] = v if not strict: cur_model_state_dict = cur_model.state_dict() unmatched_keys = [] From 8aa988691f810c3db0904a1d8cd5a931c26a893d Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 13 Jan 2026 16:01:05 +0800 Subject: [PATCH 31/41] Fix check version --- scripts/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/export.py b/scripts/export.py index c63597e66..d666175d6 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -298,5 +298,5 @@ def nsf_hifigan( if __name__ == '__main__': - # check_pytorch_version() + check_pytorch_version() main() From e7ab9cd00bd6cd37e70b397c054b55189dcde3e2 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 13 Jan 2026 16:03:41 +0800 Subject: [PATCH 32/41] Fix strict=True --- deployment/exporters/variance_exporter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 9ae0a2e2a..74b455ee5 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -91,7 +91,7 @@ def build_model(self) -> DiffSingerVarianceONNX: }) ).eval().to(self.device) load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps, - prefix_in_ckpt='model', strict=strict, device=self.device) + prefix_in_ckpt='model', strict=True, device=self.device) model.build_smooth_op(self.device) return model From 1a96f26dfb8a5aefa9ba5612c193575e70f6a971 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Wed, 14 Jan 2026 21:54:01 +0800 Subject: [PATCH 33/41] Check embed dim for RoPE RoPE requires the hidden size to be multiple of num_heads * 2 (cherry picked from commit ebc3805f941a14490a8e8817d0a4553fb94c7945) --- modules/fastspeech/tts_modules.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index 5d8f01261..e1dde0499 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -376,6 +376,11 @@ def __init__(self, hidden_size, num_layers, self.dropout = dropout self.use_pos_embed = use_pos_embed if use_pos_embed and use_rope: + if embed_dim % (num_heads * 2) != 0: + raise ValueError( + "RoPE requires the hidden size to be multiple of " + f"num_heads * 2 = {num_heads * 2}, but got {embed_dim}." + ) rotary_embed = RotaryEmbedding(dim=embed_dim // num_heads, interleaved=rope_interleaved) else: rotary_embed = None From b03d3e606a3ddfad7bef7904013fddf7eb73aa2f Mon Sep 17 00:00:00 2001 From: Kakaru <97896816+KakaruHayate@users.noreply.github.com> Date: Sun, 25 Jan 2026 18:06:59 +0800 Subject: [PATCH 34/41] fix "input Linear" and "output Linear" optim by Muon (#285) * fix "input Linear" and "output Linear" optim by Muon fix "input Linear" and "output Linear" optim by Muon * Update excluded module classes in muon.py --- modules/aux_decoder/convnext.py | 4 +++- modules/backbones/lynxnet.py | 4 ++-- modules/backbones/lynxnet2.py | 4 ++-- modules/backbones/wavenet.py | 4 ++-- modules/commons/common_layers.py | 21 +++++++++++++++++++++ modules/fastspeech/acoustic_encoder.py | 11 ++++++----- modules/fastspeech/tts_modules.py | 4 ++-- modules/fastspeech/variance_encoder.py | 9 +++++---- modules/optimizer/muon.py | 16 +++++++++++++--- modules/toplevel.py | 10 +++++----- 10 files changed, 61 insertions(+), 26 deletions(-) diff --git a/modules/aux_decoder/convnext.py b/modules/aux_decoder/convnext.py index a03959ddf..2b6ef1a80 100644 --- a/modules/aux_decoder/convnext.py +++ b/modules/aux_decoder/convnext.py @@ -3,6 +3,8 @@ import torch import torch.nn as nn +from modules.commons.common_layers import AdamWCov1d + class ConvNeXtBlock(nn.Module): """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. @@ -71,7 +73,7 @@ def __init__( layer_scale_init_value=1e-6, drop_out=dropout_rate ) for _ in range(num_layers) ) - self.outconv = nn.Conv1d( + self.outconv = AdamWCov1d( num_channels, out_dims, kernel_size, stride=1, padding=(kernel_size - 1) // 2 ) diff --git a/modules/backbones/lynxnet.py b/modules/backbones/lynxnet.py index 9f5c6a383..766dc960f 100644 --- a/modules/backbones/lynxnet.py +++ b/modules/backbones/lynxnet.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Transpose +from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Transpose, AdamWCov1d from modules.commons.common_layers import KaimingNormalConv1d as Conv1d from utils.hparams import hparams @@ -106,7 +106,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio ] ) self.norm = nn.LayerNorm(num_channels) - self.output_projection = Conv1d(num_channels, in_dims * n_feats, kernel_size=1) + self.output_projection = AdamWCov1d(num_channels, in_dims * n_feats, kernel_size=1) self.strong_cond = strong_cond nn.init.zeros_(self.output_projection.weight) diff --git a/modules/backbones/lynxnet2.py b/modules/backbones/lynxnet2.py index 76e8580b5..864ba3779 100644 --- a/modules/backbones/lynxnet2.py +++ b/modules/backbones/lynxnet2.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, ATanGLU, Transpose +from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, ATanGLU, Transpose, AdamWLinear from utils.hparams import hparams @@ -72,7 +72,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio ] ) self.norm = nn.LayerNorm(num_channels) - self.output_projection = nn.Linear(num_channels, in_dims * n_feats) + self.output_projection = AdamWLinear(num_channels, in_dims * n_feats) nn.init.kaiming_normal_(self.input_projection.weight) nn.init.kaiming_normal_(self.conditioner_projection.weight) nn.init.zeros_(self.output_projection.weight) diff --git a/modules/backbones/wavenet.py b/modules/backbones/wavenet.py index 2cbff961d..1baedbfa3 100644 --- a/modules/backbones/wavenet.py +++ b/modules/backbones/wavenet.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb +from modules.commons.common_layers import SinusoidalPosEmb, AdamWCov1d from modules.commons.common_layers import KaimingNormalConv1d as Conv1d from utils.hparams import hparams @@ -64,7 +64,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=20, num_channels=256, dilatio for i in range(num_layers) ]) self.skip_projection = Conv1d(num_channels, num_channels, 1) - self.output_projection = Conv1d(num_channels, in_dims * n_feats, 1) + self.output_projection = AdamWCov1d(num_channels, in_dims * n_feats, 1) nn.init.zeros_(self.output_projection.weight) def forward(self, spec, diffusion_step, cond): diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index 0012b99c3..cb5aa72ac 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -26,6 +26,21 @@ def __init__( nn.init.constant_(self.weight[padding_idx], 0) +class AdamWLinear(torch.nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + *args, + bias: bool = True, + **kwargs + ): + super().__init__(in_features, out_features, *args, bias=bias, **kwargs) + nn.init.xavier_uniform_(self.weight) + if bias: + nn.init.constant_(self.bias, 0.) + + class XavierUniformInitLinear(torch.nn.Linear): def __init__( self, @@ -160,6 +175,12 @@ def forward(self, x): return out * torch.atan(gate) +class AdamWCov1d(torch.nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + nn.init.kaiming_normal_(self.weight) + + class KaimingNormalConv1d(torch.nn.Conv1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/modules/fastspeech/acoustic_encoder.py b/modules/fastspeech/acoustic_encoder.py index 868d383fd..f75ab2d5d 100644 --- a/modules/fastspeech/acoustic_encoder.py +++ b/modules/fastspeech/acoustic_encoder.py @@ -6,6 +6,7 @@ NormalInitEmbedding as Embedding, XavierUniformInitLinear as Linear, SinusoidalPosEmb, + AdamWLinear, ) from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur, StretchRegulator from utils.hparams import hparams @@ -32,7 +33,7 @@ def __init__(self, vocab_size): ) self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True) - self.dur_embed = Linear(1, hparams['hidden_size']) + self.dur_embed = AdamWLinear(1, hparams['hidden_size']) self.encoder = FastSpeech2Encoder( hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'], ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'], @@ -41,7 +42,7 @@ def __init__(self, vocab_size): use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True) ) - self.pitch_embed = Linear(1, hparams['hidden_size']) + self.pitch_embed = AdamWLinear(1, hparams['hidden_size']) self.variance_embed_list = [] self.use_energy_embed = hparams.get('use_energy_embed', False) self.use_breathiness_embed = hparams.get('use_breathiness_embed', False) @@ -59,7 +60,7 @@ def __init__(self, vocab_size): self.use_variance_embeds = len(self.variance_embed_list) > 0 if self.use_variance_embeds: self.variance_embeds = nn.ModuleDict({ - v_name: Linear(1, hparams['hidden_size']) + v_name: AdamWLinear(1, hparams['hidden_size']) for v_name in self.variance_embed_list }) @@ -85,11 +86,11 @@ def __init__(self, vocab_size): self.use_key_shift_embed = hparams.get('use_key_shift_embed', False) if self.use_key_shift_embed: - self.key_shift_embed = Linear(1, hparams['hidden_size']) + self.key_shift_embed = AdamWLinear(1, hparams['hidden_size']) self.use_speed_embed = hparams.get('use_speed_embed', False) if self.use_speed_embed: - self.speed_embed = Linear(1, hparams['hidden_size']) + self.speed_embed = AdamWLinear(1, hparams['hidden_size']) self.use_spk_id = hparams['use_spk_id'] if self.use_spk_id: diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index e1dde0499..daf85127f 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch.nn import functional as F from modules.commons.rotary_embedding_torch import RotaryEmbedding -from modules.commons.common_layers import SinusoidalPositionalEmbedding, EncSALayer +from modules.commons.common_layers import SinusoidalPositionalEmbedding, EncSALayer, AdamWLinear from modules.commons.espnet_positional_embedding import RelPositionalEncoding DEFAULT_MAX_SOURCE_POSITIONS = 2000 @@ -110,7 +110,7 @@ def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3, # self.crf = CRF(out_dims, batch_first=True) else: raise NotImplementedError() - self.linear = torch.nn.Linear(n_chans, self.out_dims) + self.linear = AdamWLinear(n_chans, self.out_dims) def out2dur(self, xs): if self.loss_type in ['mse', 'huber']: diff --git a/modules/fastspeech/variance_encoder.py b/modules/fastspeech/variance_encoder.py index ba6994c1e..712964846 100644 --- a/modules/fastspeech/variance_encoder.py +++ b/modules/fastspeech/variance_encoder.py @@ -5,6 +5,7 @@ from modules.commons.common_layers import ( NormalInitEmbedding as Embedding, XavierUniformInitLinear as Linear, + AdamWLinear, ) from modules.fastspeech.tts_modules import FastSpeech2Encoder, DurationPredictor from utils.hparams import hparams @@ -24,9 +25,9 @@ def __init__(self, vocab_size): if self.predict_dur: self.onset_embed = Embedding(2, hparams['hidden_size']) - self.word_dur_embed = Linear(1, hparams['hidden_size']) + self.word_dur_embed = AdamWLinear(1, hparams['hidden_size']) else: - self.ph_dur_embed = Linear(1, hparams['hidden_size']) + self.ph_dur_embed = AdamWLinear(1, hparams['hidden_size']) self.encoder = FastSpeech2Encoder( hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'], @@ -112,8 +113,8 @@ def get_hparam(key): # MIDI inputs hidden_size = get_hparam('hidden_size') self.use_variance_scaling = hparams.get('use_variance_scaling', False) - self.note_midi_embed = Linear(1, hidden_size) - self.note_dur_embed = Linear(1, hidden_size) + self.note_midi_embed = AdamWLinear(1, hidden_size) + self.note_dur_embed = AdamWLinear(1, hidden_size) # ornament inputs self.use_glide_embed = hparams['use_glide_embed'] diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index caf3a45e4..678f22e65 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -1,3 +1,4 @@ +import collections import torch import torch.nn as nn import torch.nn.functional as F @@ -6,6 +7,8 @@ from typing import List from .chained_optimizer import ChainedOptimizer, OptimizerSpec +from modules.commons.common_layers import AdamWLinear, AdamWCov1d + def get_bf16_support_map(): bf16_support_map = {} @@ -129,13 +132,20 @@ def get_params_for_muon(model) -> List[Parameter]: Returns: A list of parameters that should be optimized with muon. """ + excluded_module_classes = (nn.Embedding, AdamWLinear, AdamWCov1d) muon_params = [] - for module in model.modules(): + # BFS through all submodules and exclude parameters from certain module types + queue = collections.deque([model]) + while queue: + module = queue.popleft() + if isinstance(module, excluded_module_classes): + continue for param in module.parameters(recurse=False): if not param.requires_grad: continue - if not isinstance(module, nn.Embedding) and param.ndim >= 2: + if param.ndim >= 2: muon_params.append(param) + queue.extend(list(module.children())) return muon_params @@ -150,4 +160,4 @@ def __init__(self, model, lr=0.0005, weight_decay=0.0, muon_args={}, adamw_args= callback = lambda p, spec_idx: print( f"Adding param {p.shape} to optimizer{spec_idx} {str(specs[spec_idx].class_type)}" ) - super().__init__(model.parameters(), specs, lr=lr, weight_decay=weight_decay, optimizer_selection_callback=callback) \ No newline at end of file + super().__init__(model.parameters(), specs, lr=lr, weight_decay=weight_decay, optimizer_selection_callback=callback) diff --git a/modules/toplevel.py b/modules/toplevel.py index 3c3129665..bc8029af3 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -11,7 +11,7 @@ from modules.commons.common_layers import ( XavierUniformInitLinear as Linear, NormalInitEmbedding as Embedding, - SinusoidalPosEmb + SinusoidalPosEmb, AdamWLinear, ) from modules.core import ( GaussianDiffusion, PitchDiffusion, MultiVarianceDiffusion, @@ -160,9 +160,9 @@ def __init__(self, vocab_size): self.use_melody_encoder = hparams.get('use_melody_encoder', False) if self.use_melody_encoder: self.melody_encoder = MelodyEncoder(enc_hparams=hparams['melody_encoder_args']) - self.delta_pitch_embed = Linear(1, hparams['hidden_size']) + self.delta_pitch_embed = AdamWLinear(1, hparams['hidden_size']) else: - self.base_pitch_embed = Linear(1, hparams['hidden_size']) + self.base_pitch_embed = AdamWLinear(1, hparams['hidden_size']) self.pitch_retake_embed = Embedding(2, hparams['hidden_size']) pitch_hparams = hparams['pitch_prediction_args'] @@ -195,9 +195,9 @@ def __init__(self, vocab_size): raise ValueError(f"Invalid diffusion type: {self.diffusion_type}") if self.predict_variances: - self.pitch_embed = Linear(1, hparams['hidden_size']) + self.pitch_embed = AdamWLinear(1, hparams['hidden_size']) self.variance_embeds = nn.ModuleDict({ - v_name: Linear(1, hparams['hidden_size']) + v_name: AdamWLinear(1, hparams['hidden_size']) for v_name in self.variance_prediction_list }) From c2ae8afcfec4f484eb215e6a1ce7a8d272e1783d Mon Sep 17 00:00:00 2001 From: Kakaru <97896816+KakaruHayate@users.noreply.github.com> Date: Sat, 9 May 2026 20:16:35 +0800 Subject: [PATCH 35/41] fix typo (#300) --- modules/aux_decoder/convnext.py | 4 ++-- modules/backbones/lynxnet.py | 4 ++-- modules/backbones/wavenet.py | 4 ++-- modules/commons/common_layers.py | 2 +- modules/optimizer/muon.py | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/modules/aux_decoder/convnext.py b/modules/aux_decoder/convnext.py index 2b6ef1a80..9cfa1ae16 100644 --- a/modules/aux_decoder/convnext.py +++ b/modules/aux_decoder/convnext.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -from modules.commons.common_layers import AdamWCov1d +from modules.commons.common_layers import AdamWCovn1d class ConvNeXtBlock(nn.Module): @@ -73,7 +73,7 @@ def __init__( layer_scale_init_value=1e-6, drop_out=dropout_rate ) for _ in range(num_layers) ) - self.outconv = AdamWCov1d( + self.outconv = AdamWCovn1d( num_channels, out_dims, kernel_size, stride=1, padding=(kernel_size - 1) // 2 ) diff --git a/modules/backbones/lynxnet.py b/modules/backbones/lynxnet.py index 766dc960f..41e72f97f 100644 --- a/modules/backbones/lynxnet.py +++ b/modules/backbones/lynxnet.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Transpose, AdamWCov1d +from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Transpose, AdamWCovn1d from modules.commons.common_layers import KaimingNormalConv1d as Conv1d from utils.hparams import hparams @@ -106,7 +106,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio ] ) self.norm = nn.LayerNorm(num_channels) - self.output_projection = AdamWCov1d(num_channels, in_dims * n_feats, kernel_size=1) + self.output_projection = AdamWCovn1d(num_channels, in_dims * n_feats, kernel_size=1) self.strong_cond = strong_cond nn.init.zeros_(self.output_projection.weight) diff --git a/modules/backbones/wavenet.py b/modules/backbones/wavenet.py index 1baedbfa3..0792a99d7 100644 --- a/modules/backbones/wavenet.py +++ b/modules/backbones/wavenet.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb, AdamWCov1d +from modules.commons.common_layers import SinusoidalPosEmb, AdamWCovn1d from modules.commons.common_layers import KaimingNormalConv1d as Conv1d from utils.hparams import hparams @@ -64,7 +64,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=20, num_channels=256, dilatio for i in range(num_layers) ]) self.skip_projection = Conv1d(num_channels, num_channels, 1) - self.output_projection = AdamWCov1d(num_channels, in_dims * n_feats, 1) + self.output_projection = AdamWCovn1d(num_channels, in_dims * n_feats, 1) nn.init.zeros_(self.output_projection.weight) def forward(self, spec, diffusion_step, cond): diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index cb5aa72ac..8f15f2ce7 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -175,7 +175,7 @@ def forward(self, x): return out * torch.atan(gate) -class AdamWCov1d(torch.nn.Conv1d): +class AdamWCovn1d(torch.nn.Conv1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) nn.init.kaiming_normal_(self.weight) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 678f22e65..f5c5e7654 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -7,7 +7,7 @@ from typing import List from .chained_optimizer import ChainedOptimizer, OptimizerSpec -from modules.commons.common_layers import AdamWLinear, AdamWCov1d +from modules.commons.common_layers import AdamWLinear, AdamWCovn1d def get_bf16_support_map(): @@ -132,7 +132,7 @@ def get_params_for_muon(model) -> List[Parameter]: Returns: A list of parameters that should be optimized with muon. """ - excluded_module_classes = (nn.Embedding, AdamWLinear, AdamWCov1d) + excluded_module_classes = (nn.Embedding, AdamWLinear, AdamWCovn1d) muon_params = [] # BFS through all submodules and exclude parameters from certain module types queue = collections.deque([model]) From 36cd2df17cd6aa0d035219af11e21028306c02df Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Sat, 9 May 2026 20:55:19 +0800 Subject: [PATCH 36/41] Fix typo --- modules/aux_decoder/convnext.py | 4 ++-- modules/backbones/lynxnet.py | 7 +++---- modules/backbones/wavenet.py | 4 ++-- modules/commons/common_layers.py | 2 +- modules/optimizer/muon.py | 6 +++--- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/modules/aux_decoder/convnext.py b/modules/aux_decoder/convnext.py index 9cfa1ae16..ad3fa1e2f 100644 --- a/modules/aux_decoder/convnext.py +++ b/modules/aux_decoder/convnext.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -from modules.commons.common_layers import AdamWCovn1d +from modules.commons.common_layers import AdamWConv1d class ConvNeXtBlock(nn.Module): @@ -73,7 +73,7 @@ def __init__( layer_scale_init_value=1e-6, drop_out=dropout_rate ) for _ in range(num_layers) ) - self.outconv = AdamWCovn1d( + self.outconv = AdamWConv1d( num_channels, out_dims, kernel_size, stride=1, padding=(kernel_size - 1) // 2 ) diff --git a/modules/backbones/lynxnet.py b/modules/backbones/lynxnet.py index 41e72f97f..9529d1efe 100644 --- a/modules/backbones/lynxnet.py +++ b/modules/backbones/lynxnet.py @@ -2,11 +2,10 @@ # https://github.com/CNChTu/Diffusion-SVC/blob/v2.0_dev/diffusion/naive_v2/model_conformer_naive.py # https://github.com/CNChTu/Diffusion-SVC/blob/v2.0_dev/diffusion/naive_v2/naive_v2_diff.py -import torch import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Transpose, AdamWCovn1d +from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Transpose, AdamWConv1d from modules.commons.common_layers import KaimingNormalConv1d as Conv1d from utils.hparams import hparams @@ -102,11 +101,11 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio activation=activation, dropout=dropout_rate ) - for i in range(num_layers) + for _ in range(num_layers) ] ) self.norm = nn.LayerNorm(num_channels) - self.output_projection = AdamWCovn1d(num_channels, in_dims * n_feats, kernel_size=1) + self.output_projection = AdamWConv1d(num_channels, in_dims * n_feats, kernel_size=1) self.strong_cond = strong_cond nn.init.zeros_(self.output_projection.weight) diff --git a/modules/backbones/wavenet.py b/modules/backbones/wavenet.py index 0792a99d7..f70d82b60 100644 --- a/modules/backbones/wavenet.py +++ b/modules/backbones/wavenet.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb, AdamWCovn1d +from modules.commons.common_layers import SinusoidalPosEmb, AdamWConv1d from modules.commons.common_layers import KaimingNormalConv1d as Conv1d from utils.hparams import hparams @@ -64,7 +64,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=20, num_channels=256, dilatio for i in range(num_layers) ]) self.skip_projection = Conv1d(num_channels, num_channels, 1) - self.output_projection = AdamWCovn1d(num_channels, in_dims * n_feats, 1) + self.output_projection = AdamWConv1d(num_channels, in_dims * n_feats, 1) nn.init.zeros_(self.output_projection.weight) def forward(self, spec, diffusion_step, cond): diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index 8f15f2ce7..1e20299f1 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -175,7 +175,7 @@ def forward(self, x): return out * torch.atan(gate) -class AdamWCovn1d(torch.nn.Conv1d): +class AdamWConv1d(torch.nn.Conv1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) nn.init.kaiming_normal_(self.weight) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index f5c5e7654..d5991dac0 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -3,11 +3,11 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torch.nn import Module, Parameter, Embedding +from torch.nn import Parameter from typing import List from .chained_optimizer import ChainedOptimizer, OptimizerSpec -from modules.commons.common_layers import AdamWLinear, AdamWCovn1d +from modules.commons.common_layers import AdamWLinear, AdamWConv1d def get_bf16_support_map(): @@ -132,7 +132,7 @@ def get_params_for_muon(model) -> List[Parameter]: Returns: A list of parameters that should be optimized with muon. """ - excluded_module_classes = (nn.Embedding, AdamWLinear, AdamWCovn1d) + excluded_module_classes = (nn.Embedding, AdamWLinear, AdamWConv1d) muon_params = [] # BFS through all submodules and exclude parameters from certain module types queue = collections.deque([model]) From 5771a3bb5f74a0e4eaa1e5b37b47baf6575dd1ed Mon Sep 17 00:00:00 2001 From: Kakaru <97896816+KakaruHayate@users.noreply.github.com> Date: Sat, 20 Jun 2026 23:02:53 +0800 Subject: [PATCH 37/41] perf: optimize Muon with GramNS and fix fp16 spectral norm stability (#301) - Fix fp16 stability: Using float32 exclusively for the initial spectral normalization step prevents instability, allowing the rest of the algorithm to safely execute in fp16. - Integrate Gram Newton-Schulz: Computes iterations on the smaller Gram matrix. - Benchmarks show up to a 42% time reduction for heavily rectangular matrices (e.g., 8192x1024 drops from 58ms to 33ms) with no performance penalty on square shapes. --- modules/optimizer/muon.py | 80 ++++++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 23 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index d5991dac0..3cc8e453d 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -10,25 +10,7 @@ from modules.commons.common_layers import AdamWLinear, AdamWConv1d -def get_bf16_support_map(): - bf16_support_map = {} - - if not torch.cuda.is_available(): - return bf16_support_map - - device_count = torch.cuda.device_count() - if device_count == 0: - return bf16_support_map - - for i in range(device_count): - device = torch.device(f'cuda:{i}') - major, minor = torch.cuda.get_device_capability(device) - bf16_support_map[device] = (major >= 8) - - return bf16_support_map - - -def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor: +def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose @@ -41,11 +23,13 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng a, b, c = (3.4445, -4.7750, 2.0315) - X = G.to(dtype = torch.bfloat16 if use_bf16 else torch.float32) + X = G.to(torch.float32) # Ensure spectral norm is at most 1 X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) + X = X.to(torch.float16) + # Perform the NS iterations if X.size(-2) < X.size(-1): for _ in range(steps): @@ -61,6 +45,57 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor return X +def gram_newton_schulz(G: Tensor, steps: int, reset_iterations: List[int]=[2]) -> Tensor: + """ + Refer to: + Gram Newton-Schulz: A Fast, Hardware-Aware Newton-Schulz Algorithm for Muon + Authors: Jack Zhang, Noah Amsel, Berlin Chen, Tri Dao + Blogpost: https://dao-ailab.github.io/blog/2026/gram-newton-schulz/ + + Gram Newton-Schulz iteration to compute the orthogonalization of G. + Mathematically identical to standard Newton-Schulz but computes iterating + on the smaller NxN Gram matrix to save up to 50% FLOPs. + """ + assert G.ndim == 3 + original_shape = G.shape + dtype = G.dtype + + X = G.to(torch.float32) + X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7) + should_transpose = X.size(-2) > X.size(-1) + if should_transpose: + X = X.mT + X = X.to(torch.float16) + + a, b, c = (3.4445, -4.7750, 2.0315) + + if X.size(-2) != X.size(-1): + R = torch.bmm(X, X.mT) + Q = None + for i in range(steps): + if i in reset_iterations and i != 0: + X = torch.bmm(Q, X) + R = torch.bmm(X, X.mT) + Q = None + Z = torch.baddbmm(R, R, R, beta=b, alpha=c) + if i != 0 and i not in reset_iterations: + Q = torch.baddbmm(Q, Q, Z, beta=a, alpha=1.0) + else: + Q = Z.clone() + Q.diagonal(dim1=-2, dim2=-1).add_(a) + if i < steps - 1 and (i + 1) not in reset_iterations: + RZ = torch.baddbmm(R, R, Z, beta=a, alpha=1.0) + R = torch.baddbmm(RZ, Z, RZ, beta=a, alpha=1.0) + X = torch.bmm(Q, X) if not should_transpose else torch.bmm(X.mT, Q) + else: + for _ in range(steps): + A = torch.bmm(X, X.mT) + B = torch.baddbmm(A, A, A, beta=b, alpha=c) + X = torch.baddbmm(X, B, X, beta=a, alpha=1.0) + + return X.to(dtype).view(original_shape) + + class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz @@ -87,7 +122,6 @@ class Muon(torch.optim.Optimizer): def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5): defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) super().__init__(params, defaults) - self.bf16_support_map = get_bf16_support_map() @torch.no_grad() def step(self, closure=None): @@ -116,8 +150,8 @@ def step(self, closure=None): original_shape = g.shape if g.ndim >= 4: # for the case of conv filters g = g.view(g.size(0), g.size(1), -1) - use_bf16 = self.bf16_support_map.get(g.device, False) - g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16) + g = gram_newton_schulz(g, steps=group["ns_steps"]) + if group["weight_decay"] > 0: torch._foreach_mul_(p, 1 - group["lr"] * group["weight_decay"]) torch._foreach_add_(p, g.view(original_shape).unbind(0), alpha=-group["lr"] * max(g[0].size()) ** 0.5) From fe2df4506d346892fa0f985c492de39ab292fd2a Mon Sep 17 00:00:00 2001 From: wolfgitpr <133209402+wolfgitpr@users.noreply.github.com> Date: Mon, 22 Jun 2026 13:47:15 +0800 Subject: [PATCH 38/41] Adjust configs --- configs/acoustic.yaml | 10 +-- configs/duration.yaml | 149 ++++++++++++++++++++++++++++++++++++++++++ configs/variance.yaml | 26 ++++---- 3 files changed, 167 insertions(+), 18 deletions(-) create mode 100644 configs/duration.yaml diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 935d6e160..176ef563b 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -78,7 +78,7 @@ backbone_args: num_channels: 1024 num_layers: 6 kernel_size: 31 - dropout_rate: 0.0 + dropout_rate: 0.1 use_conditioner_cache: true glu_type: 'atanglu' main_loss_type: l2 @@ -116,18 +116,18 @@ optimizer_args: adamw_args: weight_decay: 0.0 lr_scheduler_args: - step_size: 5000 + step_size: 4000 gamma: 0.8 max_batch_frames: 50000 max_batch_size: 64 dataset_size_key: 'lengths' val_with_vocoder: true -val_check_interval: 2000 +val_check_interval: 4000 num_valid_plots: 10 max_updates: 100000 -num_ckpt_keep: 5 +num_ckpt_keep: 8 permanent_ckpt_start: 60000 -permanent_ckpt_interval: 10000 +permanent_ckpt_interval: 8000 finetune_enabled: false finetune_ckpt_path: null diff --git a/configs/duration.yaml b/configs/duration.yaml new file mode 100644 index 000000000..0f7005185 --- /dev/null +++ b/configs/duration.yaml @@ -0,0 +1,149 @@ +base_config: + - configs/base.yaml + +task_cls: training.variance_task.VarianceTask + +dictionaries: {} +extra_phonemes: [] +merged_phoneme_groups: [] +datasets: [] + +audio_sample_rate: 44100 +hop_size: 512 # Hop size. +fft_size: 2048 # FFT size. +win_size: 2048 # FFT size. +midi_smooth_width: 0.06 # in seconds + +binarization_args: + shuffle: true + num_workers: 0 + prefer_ds: false + +binary_data_dir: 'data/opencpop_duration/binary' +binarizer_cls: preprocessing.variance_binarizer.VarianceBinarizer + +use_lang_id: false +num_lang: 1 +use_spk_id: false +num_spk: 1 + +predict_dur: true +predict_pitch: false +predict_energy: false +predict_breathiness: false +predict_voicing: false +predict_tension: false + +enc_ffn_kernel_size: 3 +use_rope: true +rope_interleaved: false +use_stretch_embed: false +use_variance_scaling: true +rel_pos: true +hidden_size: 384 + +dur_prediction_args: + arch: resnet + hidden_size: 256 + dropout: 0.1 + num_layers: 5 + kernel_size: 3 + log_offset: 1.0 + loss_type: mse + lambda_pdur_loss: 0.3 + lambda_wdur_loss: 1.0 + lambda_sdur_loss: 3.0 + +use_melody_encoder: true +melody_encoder_args: + hidden_size: 128 + enc_layers: 4 +use_glide_embed: false +glide_types: [up, down] +glide_embed_scale: 11.313708498984760 # sqrt(128) + +pitch_prediction_args: + pitd_norm_min: -8.0 + pitd_norm_max: 8.0 + pitd_clip_min: -12.0 + pitd_clip_max: 12.0 + repeat_bins: 96 + backbone_type: 'lynxnet2' + backbone_args: + num_layers: 6 + num_channels: 512 + dropout_rate: 0.1 + use_conditioner_cache: true + glu_type: 'atanglu' + +energy_db_min: -96.0 +energy_db_max: -12.0 +energy_smooth_width: 0.06 + +breathiness_db_min: -96.0 +breathiness_db_max: -20.0 +breathiness_smooth_width: 0.06 +voicing_db_min: -96.0 +voicing_db_max: -12.0 +voicing_smooth_width: 0.06 + +tension_logit_min: -10.0 +tension_logit_max: 10.0 +tension_smooth_width: 0.06 + +variances_prediction_args: + total_repeat_bins: 72 + backbone_type: 'lynxnet2' + backbone_args: + num_layers: 6 + num_channels: 384 + dropout_rate: 0.1 + use_conditioner_cache: true + glu_type: 'atanglu' + +lambda_dur_loss: 1.0 +lambda_pitch_loss: 1.0 +lambda_var_loss: 1.0 + +diffusion_type: reflow # ddpm +time_scale_factor: 1000 +schedule_type: 'linear' +K_step: 1000 +timesteps: 1000 +max_beta: 0.02 +main_loss_type: l2 +main_loss_log_norm: true +sampling_algorithm: euler +sampling_steps: 20 +diff_accelerator: ddim +diff_speedup: 10 + +# train and eval +num_sanity_val_steps: 1 +optimizer_args: + optimizer_cls: torch.optim.AdamW + lr: 0.0006 +lr_scheduler_args: + scheduler_cls: torch.optim.lr_scheduler.StepLR + step_size: 5000 + gamma: 0.75 +max_batch_frames: 80000 +max_batch_size: 48 +dataset_size_key: 'lengths' +val_check_interval: 4000 +num_valid_plots: 10 +max_updates: 80000 +num_ckpt_keep: 8 +permanent_ckpt_start: 30000 +permanent_ckpt_interval: 5000 + +finetune_enabled: false +finetune_ckpt_path: null +finetune_ignored_params: + - model.spk_embed + - model.fs2.txt_embed + - model.fs2.encoder.embed_tokens +finetune_strict_shapes: true + +freezing_enabled: false +frozen_params: [] diff --git a/configs/variance.yaml b/configs/variance.yaml index a819c1c43..34c189134 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -3,10 +3,10 @@ base_config: task_cls: training.variance_task.VarianceTask -dictionaries: {} -extra_phonemes: [] -merged_phoneme_groups: [] -datasets: [] +dictionaries: { } +extra_phonemes: [ ] +merged_phoneme_groups: [ ] +datasets: [ ] audio_sample_rate: 44100 hop_size: 512 # Hop size. @@ -67,12 +67,12 @@ pitch_prediction_args: pitd_norm_max: 8.0 pitd_clip_min: -12.0 pitd_clip_max: 12.0 - repeat_bins: 64 + repeat_bins: 96 backbone_type: 'lynxnet2' backbone_args: num_layers: 6 num_channels: 512 - dropout_rate: 0.0 + dropout_rate: 0.1 use_conditioner_cache: true glu_type: 'atanglu' @@ -97,7 +97,7 @@ variances_prediction_args: backbone_args: num_layers: 6 num_channels: 384 - dropout_rate: 0.0 + dropout_rate: 0.1 use_conditioner_cache: true glu_type: 'atanglu' @@ -128,17 +128,17 @@ optimizer_args: adamw_args: weight_decay: 0.0 lr_scheduler_args: - step_size: 5000 + step_size: 4000 gamma: 0.8 max_batch_frames: 80000 max_batch_size: 48 dataset_size_key: 'lengths' -val_check_interval: 2000 +val_check_interval: 4000 num_valid_plots: 10 -max_updates: 100000 -num_ckpt_keep: 5 -permanent_ckpt_start: 60000 -permanent_ckpt_interval: 10000 +max_updates: 80000 +num_ckpt_keep: 8 +permanent_ckpt_start: 30000 +permanent_ckpt_interval: 8000 finetune_enabled: false finetune_ckpt_path: null From 41d1cfd05d5b4b6cea09ed81678120d83dc32fac Mon Sep 17 00:00:00 2001 From: wolfgitpr <133209402+wolfgitpr@users.noreply.github.com> Date: Mon, 22 Jun 2026 14:05:38 +0800 Subject: [PATCH 39/41] Revert "Adjust configs" This reverts commit fe2df4506d346892fa0f985c492de39ab292fd2a. --- configs/acoustic.yaml | 10 +-- configs/duration.yaml | 149 ------------------------------------------ configs/variance.yaml | 26 ++++---- 3 files changed, 18 insertions(+), 167 deletions(-) delete mode 100644 configs/duration.yaml diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 176ef563b..935d6e160 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -78,7 +78,7 @@ backbone_args: num_channels: 1024 num_layers: 6 kernel_size: 31 - dropout_rate: 0.1 + dropout_rate: 0.0 use_conditioner_cache: true glu_type: 'atanglu' main_loss_type: l2 @@ -116,18 +116,18 @@ optimizer_args: adamw_args: weight_decay: 0.0 lr_scheduler_args: - step_size: 4000 + step_size: 5000 gamma: 0.8 max_batch_frames: 50000 max_batch_size: 64 dataset_size_key: 'lengths' val_with_vocoder: true -val_check_interval: 4000 +val_check_interval: 2000 num_valid_plots: 10 max_updates: 100000 -num_ckpt_keep: 8 +num_ckpt_keep: 5 permanent_ckpt_start: 60000 -permanent_ckpt_interval: 8000 +permanent_ckpt_interval: 10000 finetune_enabled: false finetune_ckpt_path: null diff --git a/configs/duration.yaml b/configs/duration.yaml deleted file mode 100644 index 0f7005185..000000000 --- a/configs/duration.yaml +++ /dev/null @@ -1,149 +0,0 @@ -base_config: - - configs/base.yaml - -task_cls: training.variance_task.VarianceTask - -dictionaries: {} -extra_phonemes: [] -merged_phoneme_groups: [] -datasets: [] - -audio_sample_rate: 44100 -hop_size: 512 # Hop size. -fft_size: 2048 # FFT size. -win_size: 2048 # FFT size. -midi_smooth_width: 0.06 # in seconds - -binarization_args: - shuffle: true - num_workers: 0 - prefer_ds: false - -binary_data_dir: 'data/opencpop_duration/binary' -binarizer_cls: preprocessing.variance_binarizer.VarianceBinarizer - -use_lang_id: false -num_lang: 1 -use_spk_id: false -num_spk: 1 - -predict_dur: true -predict_pitch: false -predict_energy: false -predict_breathiness: false -predict_voicing: false -predict_tension: false - -enc_ffn_kernel_size: 3 -use_rope: true -rope_interleaved: false -use_stretch_embed: false -use_variance_scaling: true -rel_pos: true -hidden_size: 384 - -dur_prediction_args: - arch: resnet - hidden_size: 256 - dropout: 0.1 - num_layers: 5 - kernel_size: 3 - log_offset: 1.0 - loss_type: mse - lambda_pdur_loss: 0.3 - lambda_wdur_loss: 1.0 - lambda_sdur_loss: 3.0 - -use_melody_encoder: true -melody_encoder_args: - hidden_size: 128 - enc_layers: 4 -use_glide_embed: false -glide_types: [up, down] -glide_embed_scale: 11.313708498984760 # sqrt(128) - -pitch_prediction_args: - pitd_norm_min: -8.0 - pitd_norm_max: 8.0 - pitd_clip_min: -12.0 - pitd_clip_max: 12.0 - repeat_bins: 96 - backbone_type: 'lynxnet2' - backbone_args: - num_layers: 6 - num_channels: 512 - dropout_rate: 0.1 - use_conditioner_cache: true - glu_type: 'atanglu' - -energy_db_min: -96.0 -energy_db_max: -12.0 -energy_smooth_width: 0.06 - -breathiness_db_min: -96.0 -breathiness_db_max: -20.0 -breathiness_smooth_width: 0.06 -voicing_db_min: -96.0 -voicing_db_max: -12.0 -voicing_smooth_width: 0.06 - -tension_logit_min: -10.0 -tension_logit_max: 10.0 -tension_smooth_width: 0.06 - -variances_prediction_args: - total_repeat_bins: 72 - backbone_type: 'lynxnet2' - backbone_args: - num_layers: 6 - num_channels: 384 - dropout_rate: 0.1 - use_conditioner_cache: true - glu_type: 'atanglu' - -lambda_dur_loss: 1.0 -lambda_pitch_loss: 1.0 -lambda_var_loss: 1.0 - -diffusion_type: reflow # ddpm -time_scale_factor: 1000 -schedule_type: 'linear' -K_step: 1000 -timesteps: 1000 -max_beta: 0.02 -main_loss_type: l2 -main_loss_log_norm: true -sampling_algorithm: euler -sampling_steps: 20 -diff_accelerator: ddim -diff_speedup: 10 - -# train and eval -num_sanity_val_steps: 1 -optimizer_args: - optimizer_cls: torch.optim.AdamW - lr: 0.0006 -lr_scheduler_args: - scheduler_cls: torch.optim.lr_scheduler.StepLR - step_size: 5000 - gamma: 0.75 -max_batch_frames: 80000 -max_batch_size: 48 -dataset_size_key: 'lengths' -val_check_interval: 4000 -num_valid_plots: 10 -max_updates: 80000 -num_ckpt_keep: 8 -permanent_ckpt_start: 30000 -permanent_ckpt_interval: 5000 - -finetune_enabled: false -finetune_ckpt_path: null -finetune_ignored_params: - - model.spk_embed - - model.fs2.txt_embed - - model.fs2.encoder.embed_tokens -finetune_strict_shapes: true - -freezing_enabled: false -frozen_params: [] diff --git a/configs/variance.yaml b/configs/variance.yaml index 34c189134..a819c1c43 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -3,10 +3,10 @@ base_config: task_cls: training.variance_task.VarianceTask -dictionaries: { } -extra_phonemes: [ ] -merged_phoneme_groups: [ ] -datasets: [ ] +dictionaries: {} +extra_phonemes: [] +merged_phoneme_groups: [] +datasets: [] audio_sample_rate: 44100 hop_size: 512 # Hop size. @@ -67,12 +67,12 @@ pitch_prediction_args: pitd_norm_max: 8.0 pitd_clip_min: -12.0 pitd_clip_max: 12.0 - repeat_bins: 96 + repeat_bins: 64 backbone_type: 'lynxnet2' backbone_args: num_layers: 6 num_channels: 512 - dropout_rate: 0.1 + dropout_rate: 0.0 use_conditioner_cache: true glu_type: 'atanglu' @@ -97,7 +97,7 @@ variances_prediction_args: backbone_args: num_layers: 6 num_channels: 384 - dropout_rate: 0.1 + dropout_rate: 0.0 use_conditioner_cache: true glu_type: 'atanglu' @@ -128,17 +128,17 @@ optimizer_args: adamw_args: weight_decay: 0.0 lr_scheduler_args: - step_size: 4000 + step_size: 5000 gamma: 0.8 max_batch_frames: 80000 max_batch_size: 48 dataset_size_key: 'lengths' -val_check_interval: 4000 +val_check_interval: 2000 num_valid_plots: 10 -max_updates: 80000 -num_ckpt_keep: 8 -permanent_ckpt_start: 30000 -permanent_ckpt_interval: 8000 +max_updates: 100000 +num_ckpt_keep: 5 +permanent_ckpt_start: 60000 +permanent_ckpt_interval: 10000 finetune_enabled: false finetune_ckpt_path: null From 76c048849cb7f90c5c16ca18372cbb0ef6cf6a5a Mon Sep 17 00:00:00 2001 From: wolfgitpr <133209402+wolfgitpr@users.noreply.github.com> Date: Mon, 22 Jun 2026 16:14:19 +0800 Subject: [PATCH 40/41] Adjust configs && Add duration config --- configs/acoustic.yaml | 4 ++-- configs/duration.yaml | 34 ++++++++++++++++++++++++++++++++++ configs/variance.yaml | 6 +++--- 3 files changed, 39 insertions(+), 5 deletions(-) create mode 100644 configs/duration.yaml diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 935d6e160..e03618704 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -122,10 +122,10 @@ max_batch_frames: 50000 max_batch_size: 64 dataset_size_key: 'lengths' val_with_vocoder: true -val_check_interval: 2000 +val_check_interval: 4000 num_valid_plots: 10 max_updates: 100000 -num_ckpt_keep: 5 +num_ckpt_keep: 8 permanent_ckpt_start: 60000 permanent_ckpt_interval: 10000 diff --git a/configs/duration.yaml b/configs/duration.yaml new file mode 100644 index 000000000..89810a3f7 --- /dev/null +++ b/configs/duration.yaml @@ -0,0 +1,34 @@ +base_config: + - configs/variance.yaml + +binary_data_dir: 'data/opencpop_duration/binary' + +predict_dur: true +predict_pitch: false +predict_energy: false +predict_breathiness: true +predict_voicing: true +predict_tension: false + +dur_prediction_args: + arch: resnet + hidden_size: 256 + dropout: 0.1 + num_layers: 5 + kernel_size: 3 + log_offset: 1.0 + loss_type: mse + lambda_pdur_loss: 0.3 + lambda_wdur_loss: 1.0 + lambda_sdur_loss: 3.0 + +# train and eval +optimizer_args: + optimizer_cls: torch.optim.AdamW + lr: 0.0006 +lr_scheduler_args: + scheduler_cls: torch.optim.lr_scheduler.StepLR + step_size: 5000 + gamma: 0.75 + +max_updates: 60000 \ No newline at end of file diff --git a/configs/variance.yaml b/configs/variance.yaml index a819c1c43..d4e203670 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -133,11 +133,11 @@ lr_scheduler_args: max_batch_frames: 80000 max_batch_size: 48 dataset_size_key: 'lengths' -val_check_interval: 2000 +val_check_interval: 4000 num_valid_plots: 10 -max_updates: 100000 +max_updates: 80000 num_ckpt_keep: 5 -permanent_ckpt_start: 60000 +permanent_ckpt_start: 30000 permanent_ckpt_interval: 10000 finetune_enabled: false From 20d129159a0c013a8df901b9ffae3cdd5a555249 Mon Sep 17 00:00:00 2001 From: wolfgitpr <133209402+wolfgitpr@users.noreply.github.com> Date: Mon, 22 Jun 2026 16:40:47 +0800 Subject: [PATCH 41/41] Add duration template --- configs/duration.yaml | 34 ------ configs/templates/config_acoustic.yaml | 4 +- configs/templates/config_duration.yaml | 139 +++++++++++++++++++++++++ configs/templates/config_variance.yaml | 6 +- 4 files changed, 144 insertions(+), 39 deletions(-) delete mode 100644 configs/duration.yaml create mode 100644 configs/templates/config_duration.yaml diff --git a/configs/duration.yaml b/configs/duration.yaml deleted file mode 100644 index 89810a3f7..000000000 --- a/configs/duration.yaml +++ /dev/null @@ -1,34 +0,0 @@ -base_config: - - configs/variance.yaml - -binary_data_dir: 'data/opencpop_duration/binary' - -predict_dur: true -predict_pitch: false -predict_energy: false -predict_breathiness: true -predict_voicing: true -predict_tension: false - -dur_prediction_args: - arch: resnet - hidden_size: 256 - dropout: 0.1 - num_layers: 5 - kernel_size: 3 - log_offset: 1.0 - loss_type: mse - lambda_pdur_loss: 0.3 - lambda_wdur_loss: 1.0 - lambda_sdur_loss: 3.0 - -# train and eval -optimizer_args: - optimizer_cls: torch.optim.AdamW - lr: 0.0006 -lr_scheduler_args: - scheduler_cls: torch.optim.lr_scheduler.StepLR - step_size: 5000 - gamma: 0.75 - -max_updates: 60000 \ No newline at end of file diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 6b7e3795d..e755e8846 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -117,8 +117,8 @@ max_updates: 100000 num_valid_plots: 10 val_with_vocoder: true -val_check_interval: 2000 -num_ckpt_keep: 5 +val_check_interval: 4000 +num_ckpt_keep: 8 permanent_ckpt_start: 60000 permanent_ckpt_interval: 10000 pl_trainer_devices: 'auto' diff --git a/configs/templates/config_duration.yaml b/configs/templates/config_duration.yaml new file mode 100644 index 000000000..1753364fb --- /dev/null +++ b/configs/templates/config_duration.yaml @@ -0,0 +1,139 @@ +base_config: + - configs/variance.yaml + +dictionaries: + zh: dictionaries/opencpop-extension.txt +extra_phonemes: [] +merged_phoneme_groups: [] + +datasets: + - raw_data_dir: data/xxx1/raw + speaker: speaker1 + spk_id: 0 + language: zh + test_prefixes: + - wav1 + - wav2 + - wav3 + - wav4 + - wav5 + - raw_data_dir: data/xxx2/raw + speaker: speaker2 + spk_id: 1 + language: zh + test_prefixes: + - wav1 + - wav2 + - wav3 + - wav4 + - wav5 + +binary_data_dir: data/xxx/binary +binarization_args: + num_workers: 0 +pe: parselmouth +pe_ckpt: 'checkpoints/rmvpe/model.pt' +hnsep: vr +hnsep_ckpt: 'checkpoints/vr/model.pt' + +use_lang_id: false +num_lang: 1 +use_spk_id: false +num_spk: 1 +# NOTICE: before enabling variance modules, please read the docs at +# https://github.com/openvpi/DiffSinger/tree/main/docs/BestPractices.md#mutual-influence-between-variance-modules +predict_dur: true +predict_pitch: false +# NOTICE: before enabling variance predictions, please read the docs at +# https://github.com/openvpi/DiffSinger/tree/main/docs/BestPractices.md#choosing-variance-parameters +predict_energy: false +predict_breathiness: true +predict_voicing: true +predict_tension: false + +energy_db_min: -96.0 +energy_db_max: -12.0 + +breathiness_db_min: -96.0 +breathiness_db_max: -20.0 + +voicing_db_min: -96.0 +voicing_db_max: -12.0 + +tension_logit_min: -10.0 +tension_logit_max: 10.0 + +enc_ffn_kernel_size: 3 +use_rope: true +rope_interleaved: false +use_stretch_embed: false +use_variance_scaling: true +hidden_size: 384 +dur_prediction_args: + arch: resnet + hidden_size: 256 + dropout: 0.1 + num_layers: 5 + kernel_size: 3 + log_offset: 1.0 + loss_type: mse + lambda_pdur_loss: 0.3 + lambda_wdur_loss: 1.0 + lambda_sdur_loss: 3.0 + +use_melody_encoder: true +melody_encoder_args: + hidden_size: 128 + enc_layers: 4 +use_glide_embed: false +glide_types: [up, down] +glide_embed_scale: 11.313708498984760 # sqrt(128) + +diffusion_type: reflow + +pitch_prediction_args: + pitd_norm_min: -8.0 + pitd_norm_max: 8.0 + pitd_clip_min: -12.0 + pitd_clip_max: 12.0 + repeat_bins: 64 + backbone_type: 'lynxnet2' + backbone_args: + num_layers: 6 + num_channels: 512 + dropout_rate: 0.0 + use_conditioner_cache: true + glu_type: 'atanglu' + +variances_prediction_args: + total_repeat_bins: 72 + backbone_type: 'lynxnet2' + backbone_args: + num_layers: 6 + num_channels: 384 + dropout_rate: 0.0 + use_conditioner_cache: true + glu_type: 'atanglu' + +lambda_dur_loss: 1.0 +lambda_pitch_loss: 1.0 +lambda_var_loss: 1.0 + +optimizer_args: + optimizer_cls: torch.optim.AdamW + lr: 0.0006 +lr_scheduler_args: + scheduler_cls: torch.optim.lr_scheduler.StepLR + step_size: 5000 + gamma: 0.75 +max_batch_frames: 80000 +max_batch_size: 48 +max_updates: 60000 + +num_valid_plots: 10 +val_check_interval: 4000 +num_ckpt_keep: 8 +permanent_ckpt_start: 30000 +permanent_ckpt_interval: 10000 +pl_trainer_devices: 'auto' +pl_trainer_precision: '16-mixed' diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index f87adb3cd..116154ac7 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -131,12 +131,12 @@ lr_scheduler_args: gamma: 0.8 max_batch_frames: 80000 max_batch_size: 48 -max_updates: 100000 +max_updates: 80000 num_valid_plots: 10 val_check_interval: 2000 -num_ckpt_keep: 5 -permanent_ckpt_start: 60000 +num_ckpt_keep: 8 +permanent_ckpt_start: 30000 permanent_ckpt_interval: 10000 pl_trainer_devices: 'auto' pl_trainer_precision: '16-mixed'