class MultiTTSModel(Module):
  __parameters__ = []
  __buffers__ = []
  training : bool
  _is_full_backward_hook : Optional[bool]
  point_id : int
  comma_ids : List[int]
  mean_std_coef : List[float]
  sil_value : float
  tacotron : __torch__.jit_forward_model.___torch_mangle_178.JitMultiForward
  vocoder : __torch__.vocoder.hifigan.jit_vocoder.JitGenerator
  dur_predictor : __torch__.jit_forward_model.___torch_mangle_168.JitDurPredictor
  pitch_predictor : __torch__.jit_forward_model.___torch_mangle_169.JitPitchPredictor
  def forward(self: __torch__.silero_vocoder.jit_model.___torch_mangle_179.MultiTTSModel,
    sequence: Tensor,
    speaker_ids: Tensor,
    sr: int=48000,
    symb_durs: Optional[Dict[int, int]]=None,
    durs_rate: Optional[Tensor]=None,
    pitch_coefs: Optional[Tensor]=None,
    gt_durs: Optional[Tensor]=None,
    gt_pitch: Optional[Tensor]=None,
    device: str="cpu") -> Tuple[Tensor, Tensor]:
    _0 = __torch__.tacotron2.fastpitch_layers.make_token_len_mask
    _1 = uninitialized(Tensor)
    _2 = _0(torch.transpose(sequence, 0, 1), )
    orig_mask = torch.to(_2, torch.device(device))
    if torch.__isnot__(symb_durs, None):
      symb_durs1 = unchecked_cast(Dict[int, int], symb_durs)
      _4 = torch.gt(torch.len(symb_durs1), 0)
      _3, symb_durs0 = _4, symb_durs1
    else:
      _3, symb_durs0 = False, symb_durs
    if _3:
      symb_durs3 = unchecked_cast(Dict[int, int], symb_durs0)
      support_ids0 = (self).get_support_ids(sequence, symb_durs3, )
      support_ids, symb_durs2 = support_ids0, symb_durs3
    else:
      support_ids, symb_durs2 = None, symb_durs0
    if torch.__isnot__(support_ids, None):
      support_ids2 = unchecked_cast(List[int], support_ids)
      support_sequence0 = (self).replace_comma(sequence, support_ids2, )
      support_ids1, support_sequence = support_ids2, support_sequence0
    else:
      support_ids1, support_sequence = support_ids, sequence
    if torch.__is__(gt_durs, None):
      if torch.__isnot__(symb_durs2, None):
        symb_durs6 = unchecked_cast(Dict[int, int], symb_durs2)
        _6 = torch.gt(torch.len(symb_durs6), 0)
        _5, symb_durs5 = _6, symb_durs6
      else:
        _5, symb_durs5 = False, symb_durs2
      if _5:
        symb_durs8 = unchecked_cast(Dict[int, int], symb_durs5)
        _8 = torch.__isnot__(support_ids1, None)
        _7, symb_durs7 = _8, symb_durs8
      else:
        _7, symb_durs7 = False, symb_durs5
      if _7:
        symb_durs10 = unchecked_cast(Dict[int, int], symb_durs7)
        support_ids4 = unchecked_cast(List[int], support_ids1)
        _9 = torch.eq((torch.size(sequence))[0], 1)
        if _9:
          pass
        else:
          ops.prim.RaiseException("AssertionError: ")
        support_sequence1 = torch.to(support_sequence, torch.device(device))
        dur_predictor = self.dur_predictor
        _10 = torch.cat([sequence, support_sequence1])
        _11 = torch.cat([orig_mask, orig_mask])
        pred_log_dur0 = (dur_predictor).forward(_10, speaker_ids, _11, 1., )
        pred_log_dur, symb_durs9, support_ids3 = pred_log_dur0, symb_durs10, support_ids4
      else:
        dur_predictor0 = self.dur_predictor
        pred_log_dur1 = (dur_predictor0).forward(sequence, speaker_ids, orig_mask, 1., )
        pred_log_dur, symb_durs9, support_ids3 = pred_log_dur1, symb_durs7, support_ids1
      pred_log_dur2 = torch.sub(torch.exp(pred_log_dur), 1)
      _12 = torch.lt(pred_log_dur2, 0.)
      _13 = torch.tensor(0., dtype=ops.prim.dtype(pred_log_dur2), device=ops.prim.device(pred_log_dur2))
      _14 = annotate(List[Optional[Tensor]], [_12])
      _15 = torch.index_put_(pred_log_dur2, _14, _13)
      dur_hat0 = torch.round(pred_log_dur2)
      if torch.__isnot__(symb_durs9, None):
        symb_durs12 = unchecked_cast(Dict[int, int], symb_durs9)
        _17 = torch.gt(torch.len(symb_durs12), 0)
        _16, symb_durs11 = _17, symb_durs12
      else:
        _16, symb_durs11 = False, symb_durs9
      if _16:
        symb_durs14 = unchecked_cast(Dict[int, int], symb_durs11)
        _18 = torch.__isnot__(support_ids3, None)
        if _18:
          support_ids5 = unchecked_cast(List[int], support_ids3)
          sdurs = annotate(List[int], [])
          for _19 in range(torch.len(support_ids5)):
            sid = support_ids5[_19]
            _20 = torch.append(sdurs, symb_durs14[sid])
          pause_len = ops.prim.max(sdurs)
          dur_hat3 = (self).process_pauses(pred_log_dur2, support_ids5, pause_len, )
          dur_hat2 = dur_hat3
        else:
          dur_hat2 = dur_hat0
        dur_hat1, symb_durs13 = dur_hat2, symb_durs14
      else:
        dur_hat1, symb_durs13 = dur_hat0, symb_durs11
      if torch.__isnot__(durs_rate, None):
        durs_rate0 = unchecked_cast(Tensor, durs_rate)
        _21 = torch.eq(torch.size(durs_rate0), torch.size(dur_hat1))
        if _21:
          pass
        else:
          ops.prim.RaiseException("AssertionError: ")
        dur_hat4 = torch.div(dur_hat1, durs_rate0)
      else:
        dur_hat4 = dur_hat1
      _22 = torch.__isnot__(symb_durs13, None)
      if _22:
        symb_durs16 = unchecked_cast(Dict[int, int], symb_durs13)
        _24 = torch.gt(torch.len(symb_durs16), 0)
        _23, symb_durs15 = _24, symb_durs16
      else:
        _23, symb_durs15 = False, symb_durs13
      if _23:
        symb_durs18 = unchecked_cast(Dict[int, int], symb_durs15)
        _25 = torch.items(symb_durs18)
        for _26 in range(torch.len(_25)):
          sid0, sdur, = _25[_26]
          _27 = torch.select(torch.select(dur_hat4, 0, 0), 0, sid0)
          _28 = torch.tensor(sdur, dtype=ops.prim.dtype(_27), device=ops.prim.device(_27))
          _29 = torch.copy_(_27, _28)
        symb_durs17 : Optional[Dict[int, int]] = symb_durs18
      else:
        symb_durs17 = symb_durs15
      dur_hat, symb_durs4 = dur_hat4, symb_durs17
    else:
      dur_hat, symb_durs4 = unchecked_cast(Tensor, gt_durs), symb_durs2
    _30 = torch.mul(dur_hat, torch.bitwise_not(orig_mask))
    mel_lengths = torch.sum(_30, [1])
    if torch.__is__(gt_pitch, None):
      pitch_predictor = self.pitch_predictor
      pitch_hat0 = (pitch_predictor).forward(sequence, speaker_ids, orig_mask, )
      _31 = torch.__isnot__(pitch_coefs, None)
      if _31:
        pitch_coefs0 = unchecked_cast(Tensor, pitch_coefs)
        pitch_hat2 = (self).update_pitch_coef(pitch_hat0, pitch_coefs0, speaker_ids, )
        pitch_hat1 = pitch_hat2
      else:
        pitch_hat1 = pitch_hat0
      pitch_hat = pitch_hat1
    else:
      pitch_hat = unchecked_cast(Tensor, gt_pitch)
    tacotron = self.tacotron
    mel_outputs = (tacotron).forward(sequence, speaker_ids, orig_mask, dur_hat, pitch_hat, )
    if torch.__isnot__(symb_durs4, None):
      symb_durs20 = unchecked_cast(Dict[int, int], symb_durs4)
      _33 = torch.gt(torch.len(symb_durs20), 0)
      _32, symb_durs19 = _33, symb_durs20
    else:
      _32, symb_durs19 = False, symb_durs4
    if _32:
      symb_durs21 = unchecked_cast(Dict[int, int], symb_durs19)
      mel_outputs1 = (self).fx_pauses(mel_outputs, dur_hat, symb_durs21, )
      mel_outputs0 = mel_outputs1
    else:
      mel_outputs0 = mel_outputs
    vocoder = self.vocoder
    audio = (vocoder).forward(mel_outputs0, sr, )
    if torch.__isnot__(audio, None):
      audio0 = unchecked_cast(Tensor, audio)
    else:
      ops.prim.RaiseException("AssertionError: ")
      audio0 = _1
    audio1 = torch.detach(audio0)
    audio2 = torch.squeeze(audio1, 1)
    _34 = torch.mul(torch.detach(mel_lengths), sr)
    audio_lengths = torch.div(_34, 80)
    return (audio2, audio_lengths)
  def get_support_ids(self: __torch__.silero_vocoder.jit_model.___torch_mangle_179.MultiTTSModel,
    sequence: Tensor,
    symb_durs: Dict[int, int]) -> Optional[List[int]]:
    temp_ids = annotate(List[int], [])
    _35 = torch.keys(symb_durs)
    for _36 in range(torch.len(_35)):
      sid = _35[_36]
      _37 = torch.select(torch.select(sequence, 0, 0), 0, sid)
      comma_ids = self.comma_ids
      _38 = torch.__contains__(comma_ids, annotate(int, _37))
      if _38:
        _39 = torch.append(temp_ids, sid)
      else:
        pass
    if torch.gt(torch.len(temp_ids), 0):
      support_ids : Optional[List[int]] = temp_ids
    else:
      support_ids = None
    return support_ids
  def replace_comma(self: __torch__.silero_vocoder.jit_model.___torch_mangle_179.MultiTTSModel,
    sequence: Tensor,
    support_ids: List[int]) -> Tensor:
    support_sequence = torch.clone(sequence)
    point_id = self.point_id
    _40 = torch.tensor(support_ids, dtype=4)
    _41 = torch.slice(support_sequence)
    _42 = torch.tensor(point_id, dtype=ops.prim.dtype(_41), device=ops.prim.device(_41))
    _43 = annotate(List[Optional[Tensor]], [None, _40])
    _44 = torch.index_put_(_41, _43, _42)
    return support_sequence
  def process_pauses(self: __torch__.silero_vocoder.jit_model.___torch_mangle_179.MultiTTSModel,
    comma_point_durs: Tensor,
    comma_ids: List[int],
    pause_len: int) -> Tensor:
    _45 = torch.tensor(comma_ids, dtype=4)
    _46 = torch.slice(comma_point_durs)
    _47 = annotate(List[Optional[Tensor]], [None, _45])
    _48, _49 = torch.min(torch.index(_46, _47), 1)
    comma_len = torch.item(torch.select(_48, 0, 0))
    point_len = torch.item(torch.select(_48, 0, 1))
    if torch.gt(comma_len, point_len):
      _50 = True
    else:
      _50 = torch.le(pause_len, comma_len)
    if _50:
      final_dur0 = torch.select(comma_point_durs, 0, 0)
      final_dur = final_dur0
    else:
      if torch.ge(pause_len, point_len):
        final_dur2 = torch.select(comma_point_durs, 0, 1)
        final_dur1 = final_dur2
      else:
        point_cf = torch.div(torch.sub(pause_len, comma_len), torch.sub(point_len, comma_len))
        comma_cf = torch.div(torch.sub(point_len, pause_len), torch.sub(point_len, comma_len))
        _51 = torch.select(comma_point_durs, 0, 0)
        _52 = torch.mul(_51, comma_cf)
        _53 = torch.select(comma_point_durs, 0, 1)
        final_dur3 = torch.add(_52, torch.mul(_53, point_cf))
        final_dur1 = final_dur3
      final_dur = final_dur1
    final_dur4 = torch.unsqueeze(final_dur, 0)
    return torch.round(final_dur4)
  def update_pitch_coef(self: __torch__.silero_vocoder.jit_model.___torch_mangle_179.MultiTTSModel,
    norm_pitch: Tensor,
    pitch_coef: Tensor,
    speaker_ids: Tensor) -> Tensor:
    norm_pitch_coefd = torch.mul(norm_pitch, pitch_coef)
    _54 = torch.eq(torch.select(pitch_coef, 0, 0), 0)
    _55 = (torch.where(_54))[0]
    _56 = torch.slice(pitch_coef)
    _57 = torch.tensor(1., dtype=ops.prim.dtype(_56), device=ops.prim.device(_56))
    _58 = annotate(List[Optional[Tensor]], [None, _55])
    _59 = torch.index_put_(_56, _58, _57)
    mean_std_coef = self.mean_std_coef
    _60 = mean_std_coef[int(torch.item(speaker_ids))]
    mean_std_shift = torch.mul(torch.sub(pitch_coef, 1.), _60)
    norm_pitch_shift = torch.add(norm_pitch_coefd, mean_std_shift)
    return norm_pitch_shift
  def fx_pauses(self: __torch__.silero_vocoder.jit_model.___torch_mangle_179.MultiTTSModel,
    mel_outputs: Tensor,
    dur_hat: Tensor,
    symb_durs: Dict[int, int]) -> Tensor:
    _61 = torch.keys(symb_durs)
    for _62 in range(torch.len(_61)):
      sid = _61[_62]
      _63 = torch.slice(torch.select(dur_hat, 0, 0), 0, None, sid)
      dur_sum = int(torch.item(torch.sum(_63)))
      _64 = torch.select(torch.select(dur_hat, 0, 0), 0, sid)
      sdur = int(torch.item(_64))
      if torch.gt(sdur, 10):
        sil_value = self.sil_value
        _65 = torch.slice(torch.slice(mel_outputs), 1)
        _66 = torch.slice(_65, 2, torch.add(dur_sum, 10), torch.add(dur_sum, sdur))
        _67 = torch.tensor(sil_value, dtype=ops.prim.dtype(_66), device=ops.prim.device(_66))
        _68 = torch.copy_(_66, _67)
      else:
        sil_value0 = self.sil_value
        _69 = torch.slice(torch.slice(mel_outputs), 1)
        _70 = torch.slice(_69, 2, dur_sum, torch.add(dur_sum, sdur))
        _71 = torch.tensor(sil_value0, dtype=ops.prim.dtype(_70), device=ops.prim.device(_70))
        _72 = torch.copy_(_70, _71)
    return mel_outputs
