def embedding(input: Tensor,
    weight: Tensor,
    padding_idx: Optional[int]=None,
    max_norm: Optional[float]=None,
    norm_type: float=2.,
    scale_grad_by_freq: bool=False,
    sparse: bool=False) -> Tensor:
  _0 = "AssertionError: Padding_idx must be within num_embeddings"
  if torch.__isnot__(padding_idx, None):
    padding_idx1 = unchecked_cast(int, padding_idx)
    if torch.gt(padding_idx1, 0):
      _1 = torch.lt(padding_idx1, torch.size(weight, 0))
      if _1:
        pass
      else:
        ops.prim.RaiseException(_0)
      padding_idx2 = padding_idx1
    else:
      if torch.lt(padding_idx1, 0):
        _2 = torch.neg(torch.size(weight, 0))
        if torch.ge(padding_idx1, _2):
          pass
        else:
          ops.prim.RaiseException(_0)
        padding_idx4 = torch.add(torch.size(weight, 0), padding_idx1)
        padding_idx3 = padding_idx4
      else:
        padding_idx3 = padding_idx1
      padding_idx2 = padding_idx3
    padding_idx0 = padding_idx2
  else:
    padding_idx0 = -1
  if torch.__isnot__(max_norm, None):
    input0 = torch.contiguous(input)
  else:
    input0 = input
  _3 = torch.embedding(weight, input0, padding_idx0, scale_grad_by_freq, sparse)
  return _3
def dropout(input: Tensor,
    p: float=0.5,
    training: bool=True,
    inplace: bool=False) -> Tensor:
  _4 = "dropout probability has to be between 0 and 1, but got {}"
  if torch.lt(p, 0.):
    _5 = True
  else:
    _5 = torch.gt(p, 1.)
  if _5:
    ops.prim.RaiseException(torch.format(_4, p), "builtins.ValueError")
  else:
    pass
  if inplace:
    _6 = torch.dropout_(input, p, training)
  else:
    _6 = torch.dropout(input, p, training)
  return _6
def _none_or_dtype(input: Optional[Tensor]) -> Optional[int]:
  if torch.__is__(input, None):
    _7 : Optional[int] = None
  else:
    input1 = unchecked_cast(Tensor, input)
    input2 = unchecked_cast(Tensor, input1)
    _7 = ops.prim.dtype(input2)
  return _7
def _canonical_mask(mask: Optional[Tensor],
    mask_name: str,
    other_type: Optional[int],
    other_name: str,
    target_type: int,
    check_other: bool=True) -> Optional[Tensor]:
  _8 = "only bool and floating types of {} are supported"
  _9 = "Support for mismatched {} and {} is deprecated. Use same type for both instead."
  if torch.__isnot__(mask, None):
    mask1 = unchecked_cast(Tensor, mask)
    _mask_dtype = ops.prim.dtype(mask1)
    _mask_is_float = torch.is_floating_point(mask1)
    if torch.ne(_mask_dtype, 11):
      _10 = torch.__not__(_mask_is_float)
    else:
      _10 = False
    if _10:
      _11 = torch.add("AssertionError: ", torch.format(_8, mask_name))
      ops.prim.RaiseException(_11)
    else:
      pass
    if check_other:
      _12 = torch.__isnot__(other_type, None)
    else:
      _12 = False
    if _12:
      other_type0 = unchecked_cast(int, other_type)
      _13 = torch.ne(_mask_dtype, other_type0)
      if _13:
        _14 = torch.format(_9, mask_name, other_name)
        torch.warn(_14)
      else:
        pass
    else:
      pass
    if torch.__not__(_mask_is_float):
      _15 = torch.zeros_like(mask1, dtype=target_type)
      mask3 = torch.masked_fill_(_15, mask1, -inf)
      mask2 = mask3
    else:
      mask2 = mask1
    mask0 : Optional[Tensor] = mask2
  else:
    mask0 = mask
  return mask0
def multi_head_attention_forward(query: Tensor,
    key: Tensor,
    value: Tensor,
    embed_dim_to_check: int,
    num_heads: int,
    in_proj_weight: Optional[Tensor],
    in_proj_bias: Optional[Tensor],
    bias_k: Optional[Tensor],
    bias_v: Optional[Tensor],
    add_zero_attn: bool,
    dropout_p: float,
    out_proj_weight: Tensor,
    out_proj_bias: Optional[Tensor],
    training: bool=True,
    key_padding_mask: Optional[Tensor]=None,
    need_weights: bool=True,
    attn_mask: Optional[Tensor]=None,
    use_separate_proj_weight: bool=False,
    q_proj_weight: Optional[Tensor]=None,
    k_proj_weight: Optional[Tensor]=None,
    v_proj_weight: Optional[Tensor]=None,
    static_k: Optional[Tensor]=None,
    static_v: Optional[Tensor]=None,
    average_attn_weights: bool=True,
    is_causal: bool=False) -> Tuple[Tensor, Optional[Tensor]]:
  _16 = __torch__.torch.nn.functional._mha_shape_check
  _17 = __torch__.torch.nn.functional._none_or_dtype
  _18 = __torch__.torch.nn.functional._canonical_mask
  _19 = "was expecting embedding dimension of {}, but got {}"
  _20 = "embed_dim {} not divisible by num_heads {}"
  _21 = "key\'s sequence and batch dims {} do not match value\'s {}"
  _22 = "key shape {} does not match value shape {}"
  _23 = "AssertionError: use_separate_proj_weight is False but in_proj_weight is None"
  _24 = __torch__.torch.nn.functional._in_projection_packed
  _25 = "AssertionError: use_separate_proj_weight is True but q_proj_weight is None"
  _26 = "AssertionError: use_separate_proj_weight is True but k_proj_weight is None"
  _27 = "AssertionError: use_separate_proj_weight is True but v_proj_weight is None"
  _28 = __torch__.torch.nn.functional._in_projection
  _29 = "The shape of the 2D attn_mask is {}, but should be {}."
  _30 = "The shape of the 3D attn_mask is {}, but should be {}."
  _31 = "attn_mask\'s dimension {} is not supported"
  _32 = "AssertionError: bias cannot be added to static key."
  _33 = "AssertionError: bias cannot be added to static value."
  _34 = "expecting static_k.size(0) of {}, but got {}"
  _35 = "expecting static_k.size(2) of {}, but got {}"
  _36 = "expecting static_v.size(0) of {}, but got {}"
  _37 = "expecting static_v.size(2) of {}, but got {}"
  _38 = "expecting key_padding_mask shape of {}, but got {}"
  _39 = uninitialized(Optional[Tensor])
  _40 = uninitialized(Optional[Tensor])
  _41 = uninitialized(Tensor)
  is_batched = _16(query, key, value, key_padding_mask, attn_mask, num_heads, )
  if torch.__not__(is_batched):
    query1 = torch.unsqueeze(query, 1)
    key1 = torch.unsqueeze(key, 1)
    value1 = torch.unsqueeze(value, 1)
    _42 = torch.__isnot__(key_padding_mask, None)
    if _42:
      key_padding_mask2 = unchecked_cast(Tensor, key_padding_mask)
      key_padding_mask3 = torch.unsqueeze(key_padding_mask2, 0)
      key_padding_mask1 : Optional[Tensor] = key_padding_mask3
    else:
      key_padding_mask1 = key_padding_mask
    query0, key0, key_padding_mask0, value0 = query1, key1, key_padding_mask1, value1
  else:
    query0, key0, key_padding_mask0, value0 = query, key, key_padding_mask, value
  tgt_len, bsz, embed_dim, = torch.size(query0)
  src_len, _43, _44, = torch.size(key0)
  key_padding_mask4 = _18(key_padding_mask0, "key_padding_mask", _17(attn_mask, ), "attn_mask", ops.prim.dtype(query0), True, )
  if is_causal:
    attn_mask0 : Optional[Tensor] = None
  else:
    attn_mask0 = attn_mask
  _45 = torch.eq(embed_dim, embed_dim_to_check)
  if _45:
    pass
  else:
    _46 = torch.format(_19, embed_dim_to_check, embed_dim)
    ops.prim.RaiseException(torch.add("AssertionError: ", _46))
  embed_dim0 = unchecked_cast(int, embed_dim)
  head_dim = torch.floordiv(embed_dim0, num_heads)
  _47 = torch.eq(torch.mul(head_dim, num_heads), embed_dim0)
  if _47:
    pass
  else:
    _48 = torch.format(_20, embed_dim0, num_heads)
    ops.prim.RaiseException(torch.add("AssertionError: ", _48))
  if use_separate_proj_weight:
    _49 = torch.slice(torch.size(key0), None, 2)
    _50 = torch.slice(torch.size(value0), None, 2)
    if torch.eq(_49, _50):
      pass
    else:
      _51 = torch.slice(torch.size(key0), None, 2)
      _52 = torch.slice(torch.size(value0), None, 2)
      _53 = torch.add("AssertionError: ", torch.format(_21, _51, _52))
      ops.prim.RaiseException(_53)
  else:
    _54 = torch.eq(torch.size(key0), torch.size(value0))
    if _54:
      pass
    else:
      _55 = torch.format(_22, torch.size(key0), torch.size(value0))
      _56 = torch.add("AssertionError: ", _55)
      ops.prim.RaiseException(_56)
  _57 = torch.__not__(use_separate_proj_weight)
  if _57:
    _58 = torch.__isnot__(in_proj_weight, None)
    if _58:
      in_proj_weight1 = unchecked_cast(Tensor, in_proj_weight)
      in_proj_weight0 = in_proj_weight1
    else:
      ops.prim.RaiseException(_23)
      in_proj_weight0 = _41
    _59 = _24(query0, key0, value0, in_proj_weight0, in_proj_bias, )
    q0, k0, v0, = _59
    q, k, v = q0, k0, v0
  else:
    _60 = torch.__isnot__(q_proj_weight, None)
    if _60:
      q_proj_weight1 = unchecked_cast(Tensor, q_proj_weight)
      q_proj_weight0 = q_proj_weight1
    else:
      ops.prim.RaiseException(_25)
      q_proj_weight0 = _41
    _61 = torch.__isnot__(k_proj_weight, None)
    if _61:
      k_proj_weight1 = unchecked_cast(Tensor, k_proj_weight)
      k_proj_weight0 = k_proj_weight1
    else:
      ops.prim.RaiseException(_26)
      k_proj_weight0 = _41
    _62 = torch.__isnot__(v_proj_weight, None)
    if _62:
      v_proj_weight1 = unchecked_cast(Tensor, v_proj_weight)
      v_proj_weight0 = v_proj_weight1
    else:
      ops.prim.RaiseException(_27)
      v_proj_weight0 = _41
    if torch.__is__(in_proj_bias, None):
      b_q, b_k, b_v = None, None, None
    else:
      in_proj_bias0 = unchecked_cast(Tensor, in_proj_bias)
      b_q0, b_k0, b_v0, = torch.chunk(in_proj_bias0, 3)
      b_q, b_k, b_v = b_q0, b_k0, b_v0
    _63 = _28(query0, key0, value0, q_proj_weight0, k_proj_weight0, v_proj_weight0, b_q, b_k, b_v, )
    q1, k1, v1, = _63
    q, k, v = q1, k1, v1
  attn_mask1 = _18(attn_mask0, "attn_mask", _17(key_padding_mask4, ), "key_padding_mask", ops.prim.dtype(q), False, )
  if torch.__isnot__(attn_mask1, None):
    attn_mask3 = unchecked_cast(Tensor, attn_mask1)
    if torch.eq(torch.dim(attn_mask3), 2):
      correct_2d_size = (tgt_len, src_len)
      _64 = torch.ne(torch.size(attn_mask3), [tgt_len, src_len])
      if _64:
        _65 = torch.format(_29, torch.size(attn_mask3), correct_2d_size)
        ops.prim.RaiseException(_65, "builtins.RuntimeError")
      else:
        pass
      attn_mask4 = torch.unsqueeze(attn_mask3, 0)
    else:
      _66 = torch.eq(torch.dim(attn_mask3), 3)
      if _66:
        _67 = torch.mul(bsz, num_heads)
        correct_3d_size = (_67, tgt_len, src_len)
        _68 = torch.ne(torch.size(attn_mask3), [_67, tgt_len, src_len])
        if _68:
          _69 = torch.format(_30, torch.size(attn_mask3), correct_3d_size)
          ops.prim.RaiseException(_69, "builtins.RuntimeError")
        else:
          pass
      else:
        _70 = torch.format(_31, torch.dim(attn_mask3))
        ops.prim.RaiseException(_70, "builtins.RuntimeError")
      attn_mask4 = attn_mask3
    attn_mask2 : Optional[Tensor] = attn_mask4
  else:
    attn_mask2 = attn_mask1
  if torch.__isnot__(bias_k, None):
    bias_k1 = unchecked_cast(Tensor, bias_k)
    _71, bias_k0 = torch.__isnot__(bias_v, None), bias_k1
  else:
    _71, bias_k0 = False, bias_k
  if _71:
    bias_k2 = unchecked_cast(Tensor, bias_k0)
    bias_v0 = unchecked_cast(Tensor, bias_v)
    if torch.__is__(static_k, None):
      static_k1 : Optional[Tensor] = static_k
    else:
      ops.prim.RaiseException(_32)
      static_k1 = _40
    if torch.__is__(static_v, None):
      static_v1 : Optional[Tensor] = static_v
    else:
      ops.prim.RaiseException(_33)
      static_v1 = _39
    _72 = [k, torch.repeat(bias_k2, [1, bsz, 1])]
    k3 = torch.cat(_72)
    _73 = [v, torch.repeat(bias_v0, [1, bsz, 1])]
    v3 = torch.cat(_73)
    if torch.__isnot__(attn_mask2, None):
      attn_mask7 = unchecked_cast(Tensor, attn_mask2)
      attn_mask6 : Optional[Tensor] = torch.pad(attn_mask7, [0, 1])
    else:
      attn_mask6 = attn_mask2
    _74 = torch.__isnot__(key_padding_mask4, None)
    if _74:
      key_padding_mask7 = unchecked_cast(Tensor, key_padding_mask4)
      key_padding_mask8 = torch.pad(key_padding_mask7, [0, 1])
      key_padding_mask6 : Optional[Tensor] = key_padding_mask8
    else:
      key_padding_mask6 = key_padding_mask4
    static_k0, k2, static_v0, v2, attn_mask5, key_padding_mask5 = static_k1, k3, static_v1, v3, attn_mask6, key_padding_mask6
  else:
    if torch.__is__(bias_k0, None):
      pass
    else:
      ops.prim.RaiseException("AssertionError: ")
    if torch.__is__(bias_v, None):
      pass
    else:
      ops.prim.RaiseException("AssertionError: ")
    static_k0, k2, static_v0, v2, attn_mask5, key_padding_mask5 = static_k, k, static_v, v, attn_mask2, key_padding_mask4
  _75 = [tgt_len, torch.mul(bsz, num_heads), head_dim]
  q2 = torch.transpose(torch.view(q, _75), 0, 1)
  if torch.__is__(static_k0, None):
    _76 = [(torch.size(k2))[0], torch.mul(bsz, num_heads), head_dim]
    k5 = torch.transpose(torch.view(k2, _76), 0, 1)
    k4 = k5
  else:
    static_k2 = unchecked_cast(Tensor, static_k0)
    _77 = torch.eq(torch.size(static_k2, 0), torch.mul(bsz, num_heads))
    if _77:
      pass
    else:
      _78 = torch.format(_34, torch.mul(bsz, num_heads), torch.size(static_k2, 0))
      _79 = torch.add("AssertionError: ", _78)
      ops.prim.RaiseException(_79)
    _80 = torch.eq(torch.size(static_k2, 2), head_dim)
    if _80:
      pass
    else:
      _81 = torch.format(_35, head_dim, torch.size(static_k2, 2))
      _82 = torch.add("AssertionError: ", _81)
      ops.prim.RaiseException(_82)
    k4 = static_k2
  if torch.__is__(static_v0, None):
    _83 = [(torch.size(v2))[0], torch.mul(bsz, num_heads), head_dim]
    v5 = torch.transpose(torch.view(v2, _83), 0, 1)
    v4 = v5
  else:
    static_v2 = unchecked_cast(Tensor, static_v0)
    _84 = torch.eq(torch.size(static_v2, 0), torch.mul(bsz, num_heads))
    if _84:
      pass
    else:
      _85 = torch.format(_36, torch.mul(bsz, num_heads), torch.size(static_v2, 0))
      _86 = torch.add("AssertionError: ", _85)
      ops.prim.RaiseException(_86)
    _87 = torch.eq(torch.size(static_v2, 2), head_dim)
    if _87:
      pass
    else:
      _88 = torch.format(_37, head_dim, torch.size(static_v2, 2))
      _89 = torch.add("AssertionError: ", _88)
      ops.prim.RaiseException(_89)
    v4 = static_v2
  if add_zero_attn:
    _90 = torch.mul(bsz, num_heads)
    _91 = ops.prim.dtype(k4)
    _92 = ops.prim.device(k4)
    _93 = torch.zeros([_90, 1, head_dim], dtype=_91, layout=None, device=_92)
    k7 = torch.cat([k4, _93], 1)
    _94 = ops.prim.dtype(v4)
    _95 = ops.prim.device(v4)
    _96 = torch.zeros([_90, 1, head_dim], dtype=_94, layout=None, device=_95)
    v7 = torch.cat([v4, _96], 1)
    if torch.__isnot__(attn_mask5, None):
      attn_mask10 = unchecked_cast(Tensor, attn_mask5)
      attn_mask9 : Optional[Tensor] = torch.pad(attn_mask10, [0, 1])
    else:
      attn_mask9 = attn_mask5
    _97 = torch.__isnot__(key_padding_mask5, None)
    if _97:
      key_padding_mask11 = unchecked_cast(Tensor, key_padding_mask5)
      key_padding_mask12 = torch.pad(key_padding_mask11, [0, 1])
      key_padding_mask10 : Optional[Tensor] = key_padding_mask12
    else:
      key_padding_mask10 = key_padding_mask5
    k6, key_padding_mask9, attn_mask8, v6 = k7, key_padding_mask10, attn_mask9, v7
  else:
    k6, key_padding_mask9, attn_mask8, v6 = k4, key_padding_mask5, attn_mask5, v4
  src_len0 = torch.size(k6, 1)
  _98 = torch.__isnot__(key_padding_mask9, None)
  if _98:
    key_padding_mask13 = unchecked_cast(Tensor, key_padding_mask9)
    _99 = torch.eq(torch.size(key_padding_mask13), [bsz, src_len0])
    if _99:
      pass
    else:
      _100 = torch.format(_38, (bsz, src_len0), torch.size(key_padding_mask13))
      _101 = torch.add("AssertionError: ", _100)
      ops.prim.RaiseException(_101)
    _102 = torch.view(key_padding_mask13, [bsz, 1, 1, src_len0])
    _103 = torch.expand(_102, [-1, num_heads, -1, -1])
    _104 = [torch.mul(bsz, num_heads), 1, src_len0]
    key_padding_mask14 = torch.reshape(_103, _104)
    if torch.__is__(attn_mask8, None):
      attn_mask12 = key_padding_mask14
    else:
      attn_mask13 = unchecked_cast(Tensor, attn_mask8)
      attn_mask14 = torch.add(attn_mask13, key_padding_mask14)
      attn_mask12 = attn_mask14
    attn_mask11 : Optional[Tensor] = attn_mask12
  else:
    attn_mask11 = attn_mask8
  if torch.__not__(training):
    dropout_p0 = 0.
  else:
    dropout_p0 = dropout_p
  if need_weights:
    B, Nt, E, = torch.size(q2)
    q_scaled = torch.div(q2, torch.sqrt(E))
    if torch.__isnot__(attn_mask11, None):
      attn_mask15 = unchecked_cast(Tensor, attn_mask11)
      attn_output_weights0 = torch.baddbmm(attn_mask15, q_scaled, torch.transpose(k6, -2, -1))
      attn_output_weights = attn_output_weights0
    else:
      attn_output_weights1 = torch.bmm(q_scaled, torch.transpose(k6, -2, -1))
      attn_output_weights = attn_output_weights1
    attn_output_weights2 = __torch__.torch.nn.functional.softmax(attn_output_weights, -1, 3, None, )
    if torch.gt(dropout_p0, 0.):
      attn_output_weights4 = __torch__.torch.nn.functional.dropout(attn_output_weights2, dropout_p0, True, False, )
      attn_output_weights3 = attn_output_weights4
    else:
      attn_output_weights3 = attn_output_weights2
    attn_output = torch.bmm(attn_output_weights3, v6)
    _106 = torch.contiguous(torch.transpose(attn_output, 0, 1))
    _107 = [torch.mul(tgt_len, bsz), embed_dim0]
    attn_output0 = torch.view(_106, _107)
    attn_output1 = torch.linear(attn_output0, out_proj_weight, out_proj_bias)
    _108 = [tgt_len, bsz, torch.size(attn_output1, 1)]
    attn_output2 = torch.view(attn_output1, _108)
    attn_output_weights5 = torch.view(attn_output_weights3, [bsz, num_heads, tgt_len, src_len0])
    if average_attn_weights:
      attn_output_weights7 = torch.mean(attn_output_weights5, [1])
      attn_output_weights6 = attn_output_weights7
    else:
      attn_output_weights6 = attn_output_weights5
    if torch.__not__(is_batched):
      attn_output4 = torch.squeeze(attn_output2, 1)
      attn_output_weights9 = torch.squeeze(attn_output_weights6, 0)
      attn_output3, attn_output_weights8 = attn_output4, attn_output_weights9
    else:
      attn_output3, attn_output_weights8 = attn_output2, attn_output_weights6
    _109 = (attn_output3, attn_output_weights8)
    _105 = _109
  else:
    if torch.__isnot__(attn_mask11, None):
      attn_mask17 = unchecked_cast(Tensor, attn_mask11)
      _110 = torch.eq(torch.size(attn_mask17, 0), 1)
      if _110:
        _112 = torch.eq(torch.dim(attn_mask17), 3)
        _111 = _112
      else:
        _111 = False
      if _111:
        attn_mask18 = torch.unsqueeze(attn_mask17, 0)
      else:
        attn_mask19 = torch.view(attn_mask17, [bsz, num_heads, -1, src_len0])
        attn_mask18 = attn_mask19
      attn_mask16 : Optional[Tensor] = attn_mask18
    else:
      attn_mask16 = attn_mask11
    q3 = torch.view(q2, [bsz, num_heads, tgt_len, head_dim])
    _113 = [bsz, num_heads, src_len0, head_dim]
    k8 = torch.view(k6, _113)
    _114 = [bsz, num_heads, src_len0, head_dim]
    v8 = torch.view(v6, _114)
    attn_output5 = torch.scaled_dot_product_attention(q3, k8, v8, attn_mask16, dropout_p0, is_causal)
    _115 = torch.permute(attn_output5, [2, 0, 1, 3])
    _116 = torch.contiguous(_115)
    _117 = [torch.mul(bsz, tgt_len), embed_dim0]
    attn_output6 = torch.view(_116, _117)
    attn_output7 = torch.linear(attn_output6, out_proj_weight, out_proj_bias)
    _118 = [tgt_len, bsz, torch.size(attn_output7, 1)]
    attn_output8 = torch.view(attn_output7, _118)
    if torch.__not__(is_batched):
      attn_output9 = torch.squeeze(attn_output8, 1)
    else:
      attn_output9 = attn_output8
    _105 = (attn_output9, None)
  return _105
def layer_norm(input: Tensor,
    normalized_shape: List[int],
    weight: Optional[Tensor]=None,
    bias: Optional[Tensor]=None,
    eps: float=1.0000000000000001e-05) -> Tensor:
  _119 = torch.layer_norm(input, normalized_shape, weight, bias, eps)
  return _119
def relu(input: Tensor,
    inplace: bool=False) -> Tensor:
  if inplace:
    result = torch.relu_(input)
  else:
    result = torch.relu(input)
  return result
def leaky_relu(input: Tensor,
    negative_slope: float=0.01,
    inplace: bool=False) -> Tensor:
  if inplace:
    result0 = torch.leaky_relu_(input, negative_slope)
    result = result0
  else:
    result1 = torch.leaky_relu(input, negative_slope)
    result = result1
  return result
def embedding_bag(input: Tensor,
    weight: Tensor,
    offsets: Optional[Tensor]=None,
    max_norm: Optional[float]=None,
    norm_type: float=2.,
    scale_grad_by_freq: bool=False,
    mode: str="mean",
    sparse: bool=False,
    per_sample_weights: Optional[Tensor]=None,
    include_last_offset: bool=False,
    padding_idx: Optional[int]=None) -> Tensor:
  _120 = "Argument order of nn.functional.embedding_bag was changed. Usage `embedding_bag(weight, input, ...)` is deprecated, and should now be `embedding_bag(input, weight, ...)`."
  _121 = "embedding_bag: If per_sample_weights ({}) is not None, then it must have the same shape as the input ({})"
  _122 = "if input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences. However, found offsets of type {}"
  _123 = "offsets has to be a 1D Tensor but got None"
  _124 = "input has to be 1D or 2D Tensor, but got Tensor of dimension {}"
  _125 = "max mode does not support scaling the gradient by the frequency"
  _126 = "max mode does not support sparse weights"
  _127 = "mode has to be one of sum, mean or max"
  _128 = "embedding_bag: per_sample_weights was not None. per_sample_weights is only supported for mode=\'sum\' (got mode=\'{}\'). Please open a feature request on GitHub."
  _129 = uninitialized(int)
  _130 = uninitialized(Tensor)
  _131 = uninitialized(Optional[Tensor])
  if torch.eq(ops.prim.dtype(weight), 4):
    _132 = torch.is_floating_point(input)
  else:
    _132 = False
  if _132:
    torch.warn(_120)
    input3, weight0 = weight, input
  else:
    input3, weight0 = input, weight
  _133 = torch.__isnot__(per_sample_weights, None)
  if _133:
    per_sample_weights1 = unchecked_cast(Tensor, per_sample_weights)
    _135 = torch.ne(torch.size(input3), torch.size(per_sample_weights1))
    _134, per_sample_weights0 = _135, per_sample_weights1
  else:
    _134, per_sample_weights0 = False, per_sample_weights
  if _134:
    per_sample_weights3 = unchecked_cast(Tensor, per_sample_weights0)
    _136 = torch.format(_121, torch.size(per_sample_weights3), torch.size(input3))
    ops.prim.RaiseException(_136, "builtins.ValueError")
    per_sample_weights2 : Optional[Tensor] = _131
  else:
    per_sample_weights2 = per_sample_weights0
  if torch.eq(torch.dim(input3), 2):
    if torch.__isnot__(offsets, None):
      ops.prim.RaiseException(torch.format(_122, "<unknown>"), "builtins.ValueError")
    else:
      pass
    offsets1 = torch.arange(0, torch.numel(input3), torch.size(input3, 1), dtype=ops.prim.dtype(input3), layout=None, device=ops.prim.device(input3))
    input5 = torch.reshape(input3, [-1])
    _137 = torch.__isnot__(per_sample_weights2, None)
    if _137:
      per_sample_weights6 = unchecked_cast(Tensor, per_sample_weights2)
      per_sample_weights7 = torch.reshape(per_sample_weights6, [-1])
      per_sample_weights5 : Optional[Tensor] = per_sample_weights7
    else:
      per_sample_weights5 = per_sample_weights2
    per_sample_weights4, input4, offsets0 = per_sample_weights5, input5, offsets1
  else:
    if torch.eq(torch.dim(input3), 1):
      if torch.__is__(offsets, None):
        ops.prim.RaiseException(_123, "builtins.ValueError")
        offsets3 = _130
      else:
        offsets3 = unchecked_cast(Tensor, offsets)
      if torch.ne(torch.dim(offsets3), 1):
        ops.prim.RaiseException("offsets has to be a 1D Tensor", "builtins.ValueError")
      else:
        pass
      offsets2 = offsets3
    else:
      _138 = torch.format(_124, torch.dim(input3))
      ops.prim.RaiseException(_138, "builtins.ValueError")
      offsets2 = _130
    per_sample_weights4, input4, offsets0 = per_sample_weights2, input3, offsets2
  if torch.eq(mode, "sum"):
    mode_enum = 0
  else:
    if torch.eq(mode, "mean"):
      mode_enum0 = 1
    else:
      if torch.eq(mode, "max"):
        if scale_grad_by_freq:
          ops.prim.RaiseException(_125, "builtins.ValueError")
        else:
          pass
        if sparse:
          ops.prim.RaiseException(_126, "builtins.ValueError")
        else:
          pass
        mode_enum1 = 2
      else:
        ops.prim.RaiseException(_127, "builtins.ValueError")
        mode_enum1 = _129
      mode_enum0 = mode_enum1
    mode_enum = mode_enum0
  _139 = torch.__isnot__(per_sample_weights4, None)
  if _139:
    per_sample_weights9 = unchecked_cast(Tensor, per_sample_weights4)
    _140, per_sample_weights8 = torch.ne(mode, "sum"), per_sample_weights9
  else:
    _140, per_sample_weights8 = False, per_sample_weights4
  if _140:
    ops.prim.RaiseException(torch.format(_128, mode), "builtins.NotImplementedError")
    per_sample_weights10 : Optional[Tensor] = _131
  else:
    per_sample_weights10 = per_sample_weights8
  ret, _141, _142, _143 = torch.embedding_bag(weight0, input4, offsets0, scale_grad_by_freq, mode_enum, sparse, per_sample_weights10, include_last_offset, padding_idx)
  return ret
def _mha_shape_check(query: Tensor,
    key: Tensor,
    value: Tensor,
    key_padding_mask: Optional[Tensor],
    attn_mask: Optional[Tensor],
    num_heads: int) -> bool:
  _144 = "For batched (3-D) `query`, expected `key` and `value` to be 3-D but found {}-D and {}-D tensors respectively"
  _145 = "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D but found {}-D tensor instead"
  _146 = "For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D but found {}-D tensor instead"
  _147 = "For unbatched (2-D) `query`, expected `key` and `value` to be 2-D but found {}-D and {}-D tensors respectively"
  _148 = "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D but found {}-D tensor instead"
  _149 = "For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D but found {}-D tensor instead"
  _150 = "Expected `attn_mask` shape to be {} but got {}"
  _151 = "query should be unbatched 2D or batched 3D tensor but received {}-D query tensor"
  _152 = uninitialized(bool)
  if torch.eq(torch.dim(query), 3):
    if torch.eq(torch.dim(key), 3):
      _153 = torch.eq(torch.dim(value), 3)
    else:
      _153 = False
    if _153:
      pass
    else:
      _154 = torch.format(_144, torch.dim(key), torch.dim(value))
      _155 = torch.add("AssertionError: ", _154)
      ops.prim.RaiseException(_155)
    _156 = torch.__isnot__(key_padding_mask, None)
    if _156:
      key_padding_mask15 = unchecked_cast(Tensor, key_padding_mask)
      _157 = torch.eq(torch.dim(key_padding_mask15), 2)
      if _157:
        pass
      else:
        _158 = torch.format(_145, torch.dim(key_padding_mask15))
        _159 = torch.add("AssertionError: ", _158)
        ops.prim.RaiseException(_159)
    else:
      pass
    if torch.__isnot__(attn_mask, None):
      attn_mask20 = unchecked_cast(Tensor, attn_mask)
      _160 = torch.dim(attn_mask20)
      if torch.__contains__([2, 3], _160):
        pass
      else:
        _161 = torch.format(_146, torch.dim(attn_mask20))
        _162 = torch.add("AssertionError: ", _161)
        ops.prim.RaiseException(_162)
    else:
      pass
    is_batched = True
  else:
    if torch.eq(torch.dim(query), 2):
      if torch.eq(torch.dim(key), 2):
        _163 = torch.eq(torch.dim(value), 2)
      else:
        _163 = False
      if _163:
        pass
      else:
        _164 = torch.format(_147, torch.dim(key), torch.dim(value))
        _165 = torch.add("AssertionError: ", _164)
        ops.prim.RaiseException(_165)
      _166 = torch.__isnot__(key_padding_mask, None)
      if _166:
        key_padding_mask16 = unchecked_cast(Tensor, key_padding_mask)
        _167 = torch.eq(torch.dim(key_padding_mask16), 1)
        if _167:
          pass
        else:
          _168 = torch.format(_148, torch.dim(key_padding_mask16))
          _169 = torch.add("AssertionError: ", _168)
          ops.prim.RaiseException(_169)
      else:
        pass
      if torch.__isnot__(attn_mask, None):
        attn_mask21 = unchecked_cast(Tensor, attn_mask)
        _170 = torch.dim(attn_mask21)
        _171 = torch.__contains__([2, 3], _170)
        if _171:
          pass
        else:
          _172 = torch.format(_149, torch.dim(attn_mask21))
          _173 = torch.add("AssertionError: ", _172)
          ops.prim.RaiseException(_173)
        _174 = torch.eq(torch.dim(attn_mask21), 3)
        if _174:
          _175 = (torch.size(query))[0]
          _176 = (torch.size(key))[0]
          expected_shape = (num_heads, _175, _176)
          _177 = torch.eq(torch.size(attn_mask21), [num_heads, _175, _176])
          if _177:
            pass
          else:
            _178 = torch.format(_150, expected_shape, torch.size(attn_mask21))
            _179 = torch.add("AssertionError: ", _178)
            ops.prim.RaiseException(_179)
        else:
          pass
      else:
        pass
      is_batched0 = False
    else:
      _180 = torch.format(_151, torch.dim(query))
      _181 = torch.add("AssertionError: ", _180)
      ops.prim.RaiseException(_181)
      is_batched0 = _152
    is_batched = is_batched0
  return is_batched
def _in_projection_packed(q: Tensor,
    k: Tensor,
    v: Tensor,
    w: Tensor,
    b: Optional[Tensor]=None) -> List[Tensor]:
  E = torch.size(q, -1)
  if torch.__is__(k, v):
    if torch.__is__(q, k):
      proj = torch.linear(q, w, b)
      _184 = torch.unsqueeze(torch.unflatten(proj, -1, [3, E]), 0)
      _185 = torch.squeeze(torch.transpose(_184, 0, -2), -2)
      proj0 = torch.contiguous(_185)
      _186 = [torch.select(proj0, 0, 0), torch.select(proj0, 0, 1), torch.select(proj0, 0, 2)]
      _183 = _186
    else:
      _187 = torch.split(w, [E, torch.mul(E, 2)])
      w_q, w_kv, = _187
      if torch.__is__(b, None):
        b_q, b_kv = None, None
      else:
        b0 = unchecked_cast(Tensor, b)
        _188 = torch.split(b0, [E, torch.mul(E, 2)])
        b_q1, b_kv0, = _188
        b_q, b_kv = b_q1, b_kv0
      q_proj = torch.linear(q, w_q, b_q)
      kv_proj = torch.linear(k, w_kv, b_kv)
      _189 = torch.unflatten(kv_proj, -1, [2, E])
      _190 = torch.transpose(torch.unsqueeze(_189, 0), 0, -2)
      kv_proj0 = torch.contiguous(torch.squeeze(_190, -2))
      _191 = [q_proj, torch.select(kv_proj0, 0, 0), torch.select(kv_proj0, 0, 1)]
      _183 = _191
    _182 = _183
  else:
    w_q0, w_k, w_v, = torch.chunk(w, 3)
    if torch.__is__(b, None):
      b_q2, b_k, b_v = None, None, None
    else:
      b1 = unchecked_cast(Tensor, b)
      b_q3, b_k1, b_v1, = torch.chunk(b1, 3)
      b_q2, b_k, b_v = b_q3, b_k1, b_v1
    _192 = [torch.linear(q, w_q0, b_q2), torch.linear(k, w_k, b_k), torch.linear(v, w_v, b_v)]
    _182 = _192
  return _182
def _in_projection(q: Tensor,
    k: Tensor,
    v: Tensor,
    w_q: Tensor,
    w_k: Tensor,
    w_v: Tensor,
    b_q: Optional[Tensor]=None,
    b_k: Optional[Tensor]=None,
    b_v: Optional[Tensor]=None) -> Tuple[Tensor, Tensor, Tensor]:
  _193 = "expecting query weights shape of {}, but got {}"
  _194 = "expecting key weights shape of {}, but got {}"
  _195 = "expecting value weights shape of {}, but got {}"
  _196 = "expecting query bias shape of {}, but got {}"
  _197 = "expecting key bias shape of {}, but got {}"
  _198 = "expecting value bias shape of {}, but got {}"
  _199 = uninitialized(Optional[Tensor])
  _200 = uninitialized(Optional[Tensor])
  _201 = uninitialized(Optional[Tensor])
  Eq = torch.size(q, -1)
  Ek = torch.size(k, -1)
  Ev = torch.size(v, -1)
  if torch.eq(torch.size(w_q), [Eq, Eq]):
    pass
  else:
    _202 = torch.format(_193, (Eq, Eq), torch.size(w_q))
    ops.prim.RaiseException(torch.add("AssertionError: ", _202))
  if torch.eq(torch.size(w_k), [Eq, Ek]):
    pass
  else:
    _203 = torch.format(_194, (Eq, Ek), torch.size(w_k))
    ops.prim.RaiseException(torch.add("AssertionError: ", _203))
  if torch.eq(torch.size(w_v), [Eq, Ev]):
    pass
  else:
    _204 = torch.format(_195, (Eq, Ev), torch.size(w_v))
    ops.prim.RaiseException(torch.add("AssertionError: ", _204))
  if torch.__is__(b_q, None):
    _205, b_q4 = True, b_q
  else:
    b_q5 = unchecked_cast(Tensor, b_q)
    _205, b_q4 = torch.eq(torch.size(b_q5), [Eq]), b_q5
  if _205:
    b_q6 : Optional[Tensor] = b_q4
  else:
    b_q7 = unchecked_cast(Tensor, b_q4)
    _206 = torch.format(_196, (Eq,), torch.size(b_q7))
    ops.prim.RaiseException(torch.add("AssertionError: ", _206))
    b_q6 = _201
  if torch.__is__(b_k, None):
    _207, b_k2 = True, b_k
  else:
    b_k3 = unchecked_cast(Tensor, b_k)
    _207, b_k2 = torch.eq(torch.size(b_k3), [Eq]), b_k3
  if _207:
    b_k4 : Optional[Tensor] = b_k2
  else:
    b_k5 = unchecked_cast(Tensor, b_k2)
    _208 = torch.format(_197, (Eq,), torch.size(b_k5))
    ops.prim.RaiseException(torch.add("AssertionError: ", _208))
    b_k4 = _200
  if torch.__is__(b_v, None):
    _209, b_v2 = True, b_v
  else:
    b_v3 = unchecked_cast(Tensor, b_v)
    _209, b_v2 = torch.eq(torch.size(b_v3), [Eq]), b_v3
  if _209:
    b_v4 : Optional[Tensor] = b_v2
  else:
    b_v5 = unchecked_cast(Tensor, b_v2)
    _210 = torch.format(_198, (Eq,), torch.size(b_v5))
    ops.prim.RaiseException(torch.add("AssertionError: ", _210))
    b_v4 = _199
  _211 = (torch.linear(q, w_q, b_q6), torch.linear(k, w_k, b_k4), torch.linear(v, w_v, b_v4))
  return _211
def softmax(input: Tensor,
    dim: Optional[int]=None,
    _stacklevel: int=3,
    dtype: Optional[int]=None) -> Tensor:
  _212 = __torch__.torch.nn.functional._get_softmax_dim
  if torch.__is__(dim, None):
    dim1 = _212("softmax", torch.dim(input), _stacklevel, )
    dim0 = dim1
  else:
    dim0 = unchecked_cast(int, dim)
  if torch.__is__(dtype, None):
    ret = torch.softmax(input, dim0)
  else:
    dtype0 = unchecked_cast(int, dtype)
    ret = torch.softmax(input, dim0, dtype0)
  return ret
def _get_softmax_dim(name: str,
    ndim: int,
    stacklevel: int) -> int:
  _213 = "Implicit dimension choice for {} has been deprecated. Change the call to include dim=X as an argument."
  torch.warn(torch.format(_213, name), stacklevel)
  if torch.eq(ndim, 0):
    _214 = True
  else:
    _214 = torch.eq(ndim, 1)
  if _214:
    _215 = True
  else:
    _215 = torch.eq(ndim, 3)
  if _215:
    ret = 0
  else:
    ret = 1
  return ret
