def embedding(input: Tensor,
    weight: Tensor,
    padding_idx: Optional[int]=None,
    max_norm: Optional[float]=None,
    norm_type: float=2.,
    scale_grad_by_freq: bool=False,
    sparse: bool=False) -> Tensor:
  _0 = "AssertionError: Padding_idx must be within num_embeddings"
  if torch.__isnot__(padding_idx, None):
    padding_idx1 = unchecked_cast(int, padding_idx)
    if torch.gt(padding_idx1, 0):
      _1 = torch.lt(padding_idx1, torch.size(weight, 0))
      if _1:
        pass
      else:
        ops.prim.RaiseException(_0)
      padding_idx2 = padding_idx1
    else:
      if torch.lt(padding_idx1, 0):
        _2 = torch.neg(torch.size(weight, 0))
        if torch.ge(padding_idx1, _2):
          pass
        else:
          ops.prim.RaiseException(_0)
        padding_idx4 = torch.add(torch.size(weight, 0), padding_idx1)
        padding_idx3 = padding_idx4
      else:
        padding_idx3 = padding_idx1
      padding_idx2 = padding_idx3
    padding_idx0 = padding_idx2
  else:
    padding_idx0 = -1
  if torch.__isnot__(max_norm, None):
    input0 = torch.contiguous(input)
  else:
    input0 = input
  _3 = torch.embedding(weight, input0, padding_idx0, scale_grad_by_freq, sparse)
  return _3
def dropout(input: Tensor,
    p: float=0.5,
    training: bool=True,
    inplace: bool=False) -> Tensor:
  _4 = "dropout probability has to be between 0 and 1, but got {}"
  if torch.lt(p, 0.):
    _5 = True
  else:
    _5 = torch.gt(p, 1.)
  if _5:
    ops.prim.RaiseException(torch.format(_4, p))
  else:
    pass
  if inplace:
    _6 = torch.dropout_(input, p, training)
  else:
    _6 = torch.dropout(input, p, training)
  return _6
def multi_head_attention_forward(query: Tensor,
    key: Tensor,
    value: Tensor,
    embed_dim_to_check: int,
    num_heads: int,
    in_proj_weight: Tensor,
    in_proj_bias: Optional[Tensor],
    bias_k: Optional[Tensor],
    bias_v: Optional[Tensor],
    add_zero_attn: bool,
    dropout_p: float,
    out_proj_weight: Tensor,
    out_proj_bias: Optional[Tensor],
    training: bool=True,
    key_padding_mask: Optional[Tensor]=None,
    need_weights: bool=True,
    attn_mask: Optional[Tensor]=None,
    use_separate_proj_weight: bool=False,
    q_proj_weight: Optional[Tensor]=None,
    k_proj_weight: Optional[Tensor]=None,
    v_proj_weight: Optional[Tensor]=None,
    static_k: Optional[Tensor]=None,
    static_v: Optional[Tensor]=None) -> Tuple[Tensor, Optional[Tensor]]:
  _7 = "was expecting embedding dimension of {}, but got {}"
  _8 = "embed_dim {} not divisible by num_heads {}"
  _9 = "key\'s sequence and batch dims {} do not match value\'s {}"
  _10 = "key shape {} does not match value shape {}"
  _11 = __torch__.torch.nn.functional._in_projection_packed
  _12 = "AssertionError: use_separate_proj_weight is True but q_proj_weight is None"
  _13 = "AssertionError: use_separate_proj_weight is True but k_proj_weight is None"
  _14 = "AssertionError: use_separate_proj_weight is True but v_proj_weight is None"
  _15 = __torch__.torch.nn.functional._in_projection
  _16 = "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
  _17 = "Only float, byte, and bool types are supported for attn_mask, not {}"
  _18 = "The shape of the 2D attn_mask is {}, but should be {}."
  _19 = "The shape of the 3D attn_mask is {}, but should be {}."
  _20 = "attn_mask\'s dimension {} is not supported"
  _21 = "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
  _22 = "AssertionError: bias cannot be added to static key."
  _23 = "AssertionError: bias cannot be added to static value."
  _24 = "expecting static_k.size(0) of {}, but got {}"
  _25 = "expecting static_k.size(2) of {}, but got {}"
  _26 = "expecting static_v.size(0) of {}, but got {}"
  _27 = "expecting static_v.size(2) of {}, but got {}"
  _28 = "expecting key_padding_mask shape of {}, but got {}"
  _29 = __torch__.torch.nn.functional._scaled_dot_product_attention
  _30 = uninitialized(Optional[Tensor])
  _31 = uninitialized(Optional[Tensor])
  _32 = uninitialized(Tensor)
  tgt_len, bsz, embed_dim, = torch.size(query)
  src_len, _33, _34, = torch.size(key)
  _35 = torch.eq(embed_dim, embed_dim_to_check)
  if _35:
    pass
  else:
    _36 = torch.format(_7, embed_dim_to_check, embed_dim)
    ops.prim.RaiseException(torch.add("AssertionError: ", _36))
  embed_dim0 = unchecked_cast(int, embed_dim)
  head_dim = torch.floordiv(embed_dim0, num_heads)
  _37 = torch.eq(torch.mul(head_dim, num_heads), embed_dim0)
  if _37:
    pass
  else:
    _38 = torch.format(_8, embed_dim0, num_heads)
    ops.prim.RaiseException(torch.add("AssertionError: ", _38))
  if use_separate_proj_weight:
    _39 = torch.slice(torch.size(key), None, 2)
    _40 = torch.slice(torch.size(value), None, 2)
    if torch.eq(_39, _40):
      pass
    else:
      _41 = torch.slice(torch.size(key), None, 2)
      _42 = torch.slice(torch.size(value), None, 2)
      _43 = torch.add("AssertionError: ", torch.format(_9, _41, _42))
      ops.prim.RaiseException(_43)
  else:
    _44 = torch.eq(torch.size(key), torch.size(value))
    if _44:
      pass
    else:
      _45 = torch.format(_10, torch.size(key), torch.size(value))
      _46 = torch.add("AssertionError: ", _45)
      ops.prim.RaiseException(_46)
  _47 = torch.__not__(use_separate_proj_weight)
  if _47:
    _48 = _11(query, key, value, in_proj_weight, in_proj_bias, )
    q0, k0, v0, = _48
    k, v, q = k0, v0, q0
  else:
    _49 = torch.__isnot__(q_proj_weight, None)
    if _49:
      q_proj_weight1 = unchecked_cast(Tensor, q_proj_weight)
      q_proj_weight0 = q_proj_weight1
    else:
      ops.prim.RaiseException(_12)
      q_proj_weight0 = _32
    _50 = torch.__isnot__(k_proj_weight, None)
    if _50:
      k_proj_weight1 = unchecked_cast(Tensor, k_proj_weight)
      k_proj_weight0 = k_proj_weight1
    else:
      ops.prim.RaiseException(_13)
      k_proj_weight0 = _32
    _51 = torch.__isnot__(v_proj_weight, None)
    if _51:
      v_proj_weight1 = unchecked_cast(Tensor, v_proj_weight)
      v_proj_weight0 = v_proj_weight1
    else:
      ops.prim.RaiseException(_14)
      v_proj_weight0 = _32
    if torch.__is__(in_proj_bias, None):
      b_q, b_k, b_v = None, None, None
    else:
      in_proj_bias0 = unchecked_cast(Tensor, in_proj_bias)
      b_q0, b_k0, b_v0, = torch.chunk(in_proj_bias0, 3)
      b_q, b_k, b_v = b_q0, b_k0, b_v0
    _52 = _15(query, key, value, q_proj_weight0, k_proj_weight0, v_proj_weight0, b_q, b_k, b_v, )
    q1, k1, v1, = _52
    k, v, q = k1, v1, q1
  if torch.__isnot__(attn_mask, None):
    attn_mask1 = unchecked_cast(Tensor, attn_mask)
    _53 = torch.eq(ops.prim.dtype(attn_mask1), 0)
    if _53:
      torch.warn(_16)
      attn_mask2 = torch.to(attn_mask1, 11)
    else:
      _54 = torch.is_floating_point(attn_mask1)
      if _54:
        _55 = True
      else:
        _56 = torch.eq(ops.prim.dtype(attn_mask1), 11)
        _55 = _56
      if _55:
        pass
      else:
        _57 = torch.format(_17, ops.prim.dtype(attn_mask1))
        _58 = torch.add("AssertionError: ", _57)
        ops.prim.RaiseException(_58)
      attn_mask2 = attn_mask1
    if torch.eq(torch.dim(attn_mask2), 2):
      correct_2d_size = (tgt_len, src_len)
      _59 = torch.ne(torch.size(attn_mask2), [tgt_len, src_len])
      if _59:
        _60 = torch.format(_18, torch.size(attn_mask2), correct_2d_size)
        ops.prim.RaiseException(_60)
      else:
        pass
      attn_mask3 = torch.unsqueeze(attn_mask2, 0)
    else:
      _61 = torch.eq(torch.dim(attn_mask2), 3)
      if _61:
        _62 = torch.mul(bsz, num_heads)
        correct_3d_size = (_62, tgt_len, src_len)
        _63 = torch.ne(torch.size(attn_mask2), [_62, tgt_len, src_len])
        if _63:
          _64 = torch.format(_19, torch.size(attn_mask2), correct_3d_size)
          ops.prim.RaiseException(_64)
        else:
          pass
      else:
        _65 = torch.format(_20, torch.dim(attn_mask2))
        ops.prim.RaiseException(_65)
      attn_mask3 = attn_mask2
    attn_mask0 : Optional[Tensor] = attn_mask3
  else:
    attn_mask0 = attn_mask
  _66 = torch.__isnot__(key_padding_mask, None)
  if _66:
    key_padding_mask1 = unchecked_cast(Tensor, key_padding_mask)
    _68 = torch.eq(ops.prim.dtype(key_padding_mask1), 0)
    _67, key_padding_mask0 = _68, key_padding_mask1
  else:
    _67, key_padding_mask0 = False, key_padding_mask
  if _67:
    key_padding_mask3 = unchecked_cast(Tensor, key_padding_mask0)
    torch.warn(_21)
    key_padding_mask2 : Optional[Tensor] = torch.to(key_padding_mask3, 11)
  else:
    key_padding_mask2 = key_padding_mask0
  if torch.__isnot__(bias_k, None):
    bias_k1 = unchecked_cast(Tensor, bias_k)
    _69, bias_k0 = torch.__isnot__(bias_v, None), bias_k1
  else:
    _69, bias_k0 = False, bias_k
  if _69:
    bias_k2 = unchecked_cast(Tensor, bias_k0)
    bias_v0 = unchecked_cast(Tensor, bias_v)
    if torch.__is__(static_k, None):
      static_k1 : Optional[Tensor] = static_k
    else:
      ops.prim.RaiseException(_22)
      static_k1 = _31
    if torch.__is__(static_v, None):
      static_v1 : Optional[Tensor] = static_v
    else:
      ops.prim.RaiseException(_23)
      static_v1 = _30
    _70 = [k, torch.repeat(bias_k2, [1, bsz, 1])]
    k3 = torch.cat(_70)
    _71 = [v, torch.repeat(bias_v0, [1, bsz, 1])]
    v3 = torch.cat(_71)
    if torch.__isnot__(attn_mask0, None):
      attn_mask6 = unchecked_cast(Tensor, attn_mask0)
      attn_mask7 = __torch__.torch.nn.functional._pad(attn_mask6, [0, 1], "constant", 0., )
      attn_mask5 : Optional[Tensor] = attn_mask7
    else:
      attn_mask5 = attn_mask0
    _72 = torch.__isnot__(key_padding_mask2, None)
    if _72:
      key_padding_mask6 = unchecked_cast(Tensor, key_padding_mask2)
      key_padding_mask7 = __torch__.torch.nn.functional._pad(key_padding_mask6, [0, 1], "constant", 0., )
      key_padding_mask5 : Optional[Tensor] = key_padding_mask7
    else:
      key_padding_mask5 = key_padding_mask2
    static_k0, k2, static_v0, v2, attn_mask4, key_padding_mask4 = static_k1, k3, static_v1, v3, attn_mask5, key_padding_mask5
  else:
    if torch.__is__(bias_k0, None):
      pass
    else:
      ops.prim.RaiseException("AssertionError: ")
    if torch.__is__(bias_v, None):
      pass
    else:
      ops.prim.RaiseException("AssertionError: ")
    static_k0, k2, static_v0, v2, attn_mask4, key_padding_mask4 = static_k, k, static_v, v, attn_mask0, key_padding_mask2
  _73 = torch.contiguous(q)
  _74 = [tgt_len, torch.mul(bsz, num_heads), head_dim]
  q2 = torch.transpose(torch.view(_73, _74), 0, 1)
  if torch.__is__(static_k0, None):
    _75 = torch.contiguous(k2)
    _76 = [(torch.size(k2))[0], torch.mul(bsz, num_heads), head_dim]
    k5 = torch.transpose(torch.view(_75, _76), 0, 1)
    k4 = k5
  else:
    static_k2 = unchecked_cast(Tensor, static_k0)
    _77 = torch.eq(torch.size(static_k2, 0), torch.mul(bsz, num_heads))
    if _77:
      pass
    else:
      _78 = torch.format(_24, torch.mul(bsz, num_heads), torch.size(static_k2, 0))
      _79 = torch.add("AssertionError: ", _78)
      ops.prim.RaiseException(_79)
    _80 = torch.eq(torch.size(static_k2, 2), head_dim)
    if _80:
      pass
    else:
      _81 = torch.format(_25, head_dim, torch.size(static_k2, 2))
      _82 = torch.add("AssertionError: ", _81)
      ops.prim.RaiseException(_82)
    k4 = static_k2
  if torch.__is__(static_v0, None):
    _83 = torch.contiguous(v2)
    _84 = [(torch.size(v2))[0], torch.mul(bsz, num_heads), head_dim]
    v5 = torch.transpose(torch.view(_83, _84), 0, 1)
    v4 = v5
  else:
    static_v2 = unchecked_cast(Tensor, static_v0)
    _85 = torch.eq(torch.size(static_v2, 0), torch.mul(bsz, num_heads))
    if _85:
      pass
    else:
      _86 = torch.format(_26, torch.mul(bsz, num_heads), torch.size(static_v2, 0))
      _87 = torch.add("AssertionError: ", _86)
      ops.prim.RaiseException(_87)
    _88 = torch.eq(torch.size(static_v2, 2), head_dim)
    if _88:
      pass
    else:
      _89 = torch.format(_27, head_dim, torch.size(static_v2, 2))
      _90 = torch.add("AssertionError: ", _89)
      ops.prim.RaiseException(_90)
    v4 = static_v2
  if add_zero_attn:
    _91 = torch.mul(bsz, num_heads)
    _92 = ops.prim.dtype(k4)
    _93 = ops.prim.device(k4)
    _94 = torch.zeros([_91, 1, head_dim], dtype=_92, layout=None, device=_93)
    k7 = torch.cat([k4, _94], 1)
    _95 = ops.prim.dtype(v4)
    _96 = ops.prim.device(v4)
    _97 = torch.zeros([_91, 1, head_dim], dtype=_95, layout=None, device=_96)
    v7 = torch.cat([v4, _97], 1)
    if torch.__isnot__(attn_mask4, None):
      attn_mask10 = unchecked_cast(Tensor, attn_mask4)
      attn_mask11 = __torch__.torch.nn.functional._pad(attn_mask10, [0, 1], "constant", 0., )
      attn_mask9 : Optional[Tensor] = attn_mask11
    else:
      attn_mask9 = attn_mask4
    _98 = torch.__isnot__(key_padding_mask4, None)
    if _98:
      key_padding_mask10 = unchecked_cast(Tensor, key_padding_mask4)
      key_padding_mask11 = __torch__.torch.nn.functional._pad(key_padding_mask10, [0, 1], "constant", 0., )
      key_padding_mask9 : Optional[Tensor] = key_padding_mask11
    else:
      key_padding_mask9 = key_padding_mask4
    k6, key_padding_mask8, attn_mask8, v6 = k7, key_padding_mask9, attn_mask9, v7
  else:
    k6, key_padding_mask8, attn_mask8, v6 = k4, key_padding_mask4, attn_mask4, v4
  src_len0 = torch.size(k6, 1)
  _99 = torch.__isnot__(key_padding_mask8, None)
  if _99:
    key_padding_mask12 = unchecked_cast(Tensor, key_padding_mask8)
    _100 = torch.eq(torch.size(key_padding_mask12), [bsz, src_len0])
    if _100:
      pass
    else:
      _101 = torch.format(_28, (bsz, src_len0), torch.size(key_padding_mask12))
      _102 = torch.add("AssertionError: ", _101)
      ops.prim.RaiseException(_102)
    _103 = torch.view(key_padding_mask12, [bsz, 1, 1, src_len0])
    _104 = torch.expand(_103, [-1, num_heads, -1, -1])
    _105 = [torch.mul(bsz, num_heads), 1, src_len0]
    key_padding_mask13 = torch.reshape(_104, _105)
    if torch.__is__(attn_mask8, None):
      attn_mask13 = key_padding_mask13
    else:
      attn_mask14 = unchecked_cast(Tensor, attn_mask8)
      _106 = torch.eq(ops.prim.dtype(attn_mask14), 11)
      if _106:
        attn_mask16 = torch.logical_or(attn_mask14, key_padding_mask13)
        attn_mask15 = attn_mask16
      else:
        attn_mask17 = torch.masked_fill(attn_mask14, key_padding_mask13, -inf)
        attn_mask15 = attn_mask17
      attn_mask13 = attn_mask15
    attn_mask12 : Optional[Tensor] = attn_mask13
  else:
    attn_mask12 = attn_mask8
  if torch.__isnot__(attn_mask12, None):
    attn_mask19 = unchecked_cast(Tensor, attn_mask12)
    _108 = torch.eq(ops.prim.dtype(attn_mask19), 11)
    _107, attn_mask18 = _108, attn_mask19
  else:
    _107, attn_mask18 = False, attn_mask12
  if _107:
    attn_mask21 = unchecked_cast(Tensor, attn_mask18)
    new_attn_mask = torch.zeros_like(attn_mask21, dtype=6)
    _109 = torch.masked_fill_(new_attn_mask, attn_mask21, -inf)
    attn_mask20 : Optional[Tensor] = new_attn_mask
  else:
    attn_mask20 = attn_mask18
  if torch.__not__(training):
    dropout_p0 = 0.
  else:
    dropout_p0 = dropout_p
  _110 = _29(q2, k6, v6, attn_mask20, dropout_p0, )
  attn_output, attn_output_weights, = _110
  _111 = torch.contiguous(torch.transpose(attn_output, 0, 1))
  attn_output0 = torch.view(_111, [tgt_len, bsz, embed_dim0])
  attn_output1 = __torch__.torch.nn.functional.linear(attn_output0, out_proj_weight, out_proj_bias, )
  if need_weights:
    attn_output_weights0 = torch.view(attn_output_weights, [bsz, num_heads, tgt_len, src_len0])
    _113 = torch.sum(attn_output_weights0, [1])
    _114 = (attn_output1, torch.div(_113, num_heads))
    _112 = _114
  else:
    _112 = (attn_output1, None)
  return _112
def linear(input: Tensor,
    weight: Tensor,
    bias: Optional[Tensor]=None) -> Tensor:
  return torch.linear(input, weight, bias)
def layer_norm(input: Tensor,
    normalized_shape: List[int],
    weight: Optional[Tensor]=None,
    bias: Optional[Tensor]=None,
    eps: float=1.0000000000000001e-05) -> Tensor:
  _115 = torch.layer_norm(input, normalized_shape, weight, bias, eps)
  return _115
def relu(input: Tensor,
    inplace: bool=False) -> Tensor:
  if inplace:
    result = torch.relu_(input)
  else:
    result = torch.relu(input)
  return result
def leaky_relu(input: Tensor,
    negative_slope: float=0.01,
    inplace: bool=False) -> Tensor:
  if inplace:
    result0 = torch.leaky_relu_(input, negative_slope)
    result = result0
  else:
    result1 = torch.leaky_relu(input, negative_slope)
    result = result1
  return result
def embedding_bag(input: Tensor,
    weight: Tensor,
    offsets: Optional[Tensor]=None,
    max_norm: Optional[float]=None,
    norm_type: float=2.,
    scale_grad_by_freq: bool=False,
    mode: str="mean",
    sparse: bool=False,
    per_sample_weights: Optional[Tensor]=None,
    include_last_offset: bool=False,
    padding_idx: Optional[int]=None) -> Tensor:
  _116 = "Argument order of nn.functional.embedding_bag was changed. Usage `embedding_bag(weight, input, ...)` is deprecated, and should now be `embedding_bag(input, weight, ...)`."
  _117 = "embedding_bag: If per_sample_weights ({}) is not None, then it must have the same shape as the input ({})"
  _118 = "if input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences. However, found offsets of type {}"
  _119 = "offsets has to be a 1D Tensor but got None"
  _120 = "input has to be 1D or 2D Tensor, but got Tensor of dimension {}"
  _121 = "max mode does not support scaling the gradient by the frequency"
  _122 = "max mode does not support sparse weights"
  _123 = "mode has to be one of sum, mean or max"
  _124 = "embedding_bag: per_sample_weights was not None. per_sample_weights is only supported for mode=\'sum\' (got mode=\'{}\'). Please open a feature request on GitHub."
  _125 = uninitialized(int)
  _126 = uninitialized(Tensor)
  _127 = uninitialized(Optional[Tensor])
  if torch.eq(ops.prim.dtype(weight), 4):
    _128 = torch.is_floating_point(input)
  else:
    _128 = False
  if _128:
    torch.warn(_116)
    input1, weight0 = weight, input
  else:
    input1, weight0 = input, weight
  _129 = torch.__isnot__(per_sample_weights, None)
  if _129:
    per_sample_weights1 = unchecked_cast(Tensor, per_sample_weights)
    _131 = torch.ne(torch.size(input1), torch.size(per_sample_weights1))
    _130, per_sample_weights0 = _131, per_sample_weights1
  else:
    _130, per_sample_weights0 = False, per_sample_weights
  if _130:
    per_sample_weights3 = unchecked_cast(Tensor, per_sample_weights0)
    _132 = torch.format(_117, torch.size(per_sample_weights3), torch.size(input1))
    ops.prim.RaiseException(_132)
    per_sample_weights2 : Optional[Tensor] = _127
  else:
    per_sample_weights2 = per_sample_weights0
  if torch.eq(torch.dim(input1), 2):
    if torch.__isnot__(offsets, None):
      ops.prim.RaiseException(torch.format(_118, "<unknown>"))
    else:
      pass
    offsets1 = torch.arange(0, torch.numel(input1), torch.size(input1, 1), dtype=ops.prim.dtype(input1), layout=None, device=ops.prim.device(input1))
    input3 = torch.reshape(input1, [-1])
    _133 = torch.__isnot__(per_sample_weights2, None)
    if _133:
      per_sample_weights6 = unchecked_cast(Tensor, per_sample_weights2)
      per_sample_weights7 = torch.reshape(per_sample_weights6, [-1])
      per_sample_weights5 : Optional[Tensor] = per_sample_weights7
    else:
      per_sample_weights5 = per_sample_weights2
    per_sample_weights4, input2, offsets0 = per_sample_weights5, input3, offsets1
  else:
    if torch.eq(torch.dim(input1), 1):
      if torch.__is__(offsets, None):
        ops.prim.RaiseException(_119)
        offsets3 = _126
      else:
        offsets3 = unchecked_cast(Tensor, offsets)
      if torch.ne(torch.dim(offsets3), 1):
        ops.prim.RaiseException("offsets has to be a 1D Tensor")
      else:
        pass
      offsets2 = offsets3
    else:
      _134 = torch.format(_120, torch.dim(input1))
      ops.prim.RaiseException(_134)
      offsets2 = _126
    per_sample_weights4, input2, offsets0 = per_sample_weights2, input1, offsets2
  if torch.eq(mode, "sum"):
    mode_enum = 0
  else:
    if torch.eq(mode, "mean"):
      mode_enum0 = 1
    else:
      if torch.eq(mode, "max"):
        if scale_grad_by_freq:
          ops.prim.RaiseException(_121)
        else:
          pass
        if sparse:
          ops.prim.RaiseException(_122)
        else:
          pass
        mode_enum1 = 2
      else:
        ops.prim.RaiseException(_123)
        mode_enum1 = _125
      mode_enum0 = mode_enum1
    mode_enum = mode_enum0
  _135 = torch.__isnot__(per_sample_weights4, None)
  if _135:
    per_sample_weights9 = unchecked_cast(Tensor, per_sample_weights4)
    _136, per_sample_weights8 = torch.ne(mode, "sum"), per_sample_weights9
  else:
    _136, per_sample_weights8 = False, per_sample_weights4
  if _136:
    ops.prim.RaiseException(torch.format(_124, mode))
    per_sample_weights10 : Optional[Tensor] = _127
  else:
    per_sample_weights10 = per_sample_weights8
  ret, _137, _138, _139 = torch.embedding_bag(weight0, input2, offsets0, scale_grad_by_freq, mode_enum, sparse, per_sample_weights10, include_last_offset, padding_idx)
  return ret
def _in_projection_packed(q: Tensor,
    k: Tensor,
    v: Tensor,
    w: Tensor,
    b: Optional[Tensor]=None) -> List[Tensor]:
  E = torch.size(q, -1)
  if torch.__is__(k, v):
    if torch.__is__(q, k):
      _142 = __torch__.torch.nn.functional.linear(q, w, b, )
      _141 = torch.chunk(_142, 3, -1)
    else:
      _143 = torch.split(w, [E, torch.mul(E, 2)])
      w_q, w_kv, = _143
      if torch.__is__(b, None):
        b_q, b_kv = None, None
      else:
        b0 = unchecked_cast(Tensor, b)
        _144 = torch.split(b0, [E, torch.mul(E, 2)])
        b_q1, b_kv0, = _144
        b_q, b_kv = b_q1, b_kv0
      _145 = __torch__.torch.nn.functional.linear(q, w_q, b_q, )
      _146 = __torch__.torch.nn.functional.linear(k, w_kv, b_kv, )
      _147 = torch.chunk(_146, 2, -1)
      _141 = torch.add([_145], _147)
    _140 = _141
  else:
    w_q0, w_k, w_v, = torch.chunk(w, 3)
    if torch.__is__(b, None):
      b_q2, b_k, b_v = None, None, None
    else:
      b1 = unchecked_cast(Tensor, b)
      b_q3, b_k1, b_v1, = torch.chunk(b1, 3)
      b_q2, b_k, b_v = b_q3, b_k1, b_v1
    _148 = __torch__.torch.nn.functional.linear(q, w_q0, b_q2, )
    _149 = __torch__.torch.nn.functional.linear(k, w_k, b_k, )
    _150 = __torch__.torch.nn.functional.linear(v, w_v, b_v, )
    _140 = [_148, _149, _150]
  return _140
def _in_projection(q: Tensor,
    k: Tensor,
    v: Tensor,
    w_q: Tensor,
    w_k: Tensor,
    w_v: Tensor,
    b_q: Optional[Tensor]=None,
    b_k: Optional[Tensor]=None,
    b_v: Optional[Tensor]=None) -> Tuple[Tensor, Tensor, Tensor]:
  _151 = "expecting query weights shape of {}, but got {}"
  _152 = "expecting key weights shape of {}, but got {}"
  _153 = "expecting value weights shape of {}, but got {}"
  _154 = "expecting query bias shape of {}, but got {}"
  _155 = "expecting key bias shape of {}, but got {}"
  _156 = "expecting value bias shape of {}, but got {}"
  _157 = uninitialized(Optional[Tensor])
  _158 = uninitialized(Optional[Tensor])
  _159 = uninitialized(Optional[Tensor])
  Eq = torch.size(q, -1)
  Ek = torch.size(k, -1)
  Ev = torch.size(v, -1)
  if torch.eq(torch.size(w_q), [Eq, Eq]):
    pass
  else:
    _160 = torch.format(_151, (Eq, Eq), torch.size(w_q))
    ops.prim.RaiseException(torch.add("AssertionError: ", _160))
  if torch.eq(torch.size(w_k), [Eq, Ek]):
    pass
  else:
    _161 = torch.format(_152, (Eq, Ek), torch.size(w_k))
    ops.prim.RaiseException(torch.add("AssertionError: ", _161))
  if torch.eq(torch.size(w_v), [Eq, Ev]):
    pass
  else:
    _162 = torch.format(_153, (Eq, Ev), torch.size(w_v))
    ops.prim.RaiseException(torch.add("AssertionError: ", _162))
  if torch.__is__(b_q, None):
    _163, b_q4 = True, b_q
  else:
    b_q5 = unchecked_cast(Tensor, b_q)
    _163, b_q4 = torch.eq(torch.size(b_q5), [Eq]), b_q5
  if _163:
    b_q6 : Optional[Tensor] = b_q4
  else:
    b_q7 = unchecked_cast(Tensor, b_q4)
    _164 = torch.format(_154, (Eq,), torch.size(b_q7))
    ops.prim.RaiseException(torch.add("AssertionError: ", _164))
    b_q6 = _159
  if torch.__is__(b_k, None):
    _165, b_k2 = True, b_k
  else:
    b_k3 = unchecked_cast(Tensor, b_k)
    _165, b_k2 = torch.eq(torch.size(b_k3), [Eq]), b_k3
  if _165:
    b_k4 : Optional[Tensor] = b_k2
  else:
    b_k5 = unchecked_cast(Tensor, b_k2)
    _166 = torch.format(_155, (Eq,), torch.size(b_k5))
    ops.prim.RaiseException(torch.add("AssertionError: ", _166))
    b_k4 = _158
  if torch.__is__(b_v, None):
    _167, b_v2 = True, b_v
  else:
    b_v3 = unchecked_cast(Tensor, b_v)
    _167, b_v2 = torch.eq(torch.size(b_v3), [Eq]), b_v3
  if _167:
    b_v4 : Optional[Tensor] = b_v2
  else:
    b_v5 = unchecked_cast(Tensor, b_v2)
    _168 = torch.format(_156, (Eq,), torch.size(b_v5))
    ops.prim.RaiseException(torch.add("AssertionError: ", _168))
    b_v4 = _157
  _169 = __torch__.torch.nn.functional.linear(q, w_q, b_q6, )
  _170 = __torch__.torch.nn.functional.linear(k, w_k, b_k4, )
  _171 = __torch__.torch.nn.functional.linear(v, w_v, b_v4, )
  return (_169, _170, _171)
def _pad(input: Tensor,
    pad: List[int],
    mode: str="constant",
    value: float=0.) -> Tensor:
  _172 = "AssertionError: Padding length must be divisible by 2"
  _173 = "AssertionError: Padding length too large"
  _174 = "Padding mode \"{}\"\" doesn\'t take in value argument"
  _175 = __torch__.torch.nn.functional._pad_circular
  _176 = "Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now"
  _177 = uninitialized(Tensor)
  _178 = torch.eq(torch.remainder(torch.len(pad), 2), 0)
  if _178:
    pass
  else:
    ops.prim.RaiseException(_172)
  _179 = torch.le(torch.floordiv(torch.len(pad), 2), torch.dim(input))
  if _179:
    pass
  else:
    ops.prim.RaiseException(_173)
  if torch.eq(mode, "constant"):
    _181 = torch.constant_pad_nd(input, pad, value)
    _180 = _181
  else:
    if torch.eq(value, 0.):
      pass
    else:
      _182 = torch.add("AssertionError: ", torch.format(_174, mode))
      ops.prim.RaiseException(_182)
    if torch.eq(torch.len(pad), 2):
      if torch.eq(torch.dim(input), 2):
        _184 = True
      else:
        _184 = torch.eq(torch.dim(input), 3)
      _183 = _184
    else:
      _183 = False
    if _183:
      if torch.eq(mode, "reflect"):
        _187 = torch.reflection_pad1d(input, pad)
        _186 = _187
      else:
        if torch.eq(mode, "replicate"):
          _189 = torch.replication_pad1d(input, pad)
          _188 = _189
        else:
          if torch.eq(mode, "circular"):
            _190 = _175(input, pad, )
          else:
            ops.prim.RaiseException("")
            _190 = _177
          _188 = _190
        _186 = _188
      _185 = _186
    else:
      if torch.eq(torch.len(pad), 4):
        if torch.eq(torch.dim(input), 3):
          _192 = True
        else:
          _192 = torch.eq(torch.dim(input), 4)
        _191 = _192
      else:
        _191 = False
      if _191:
        if torch.eq(mode, "reflect"):
          _195 = torch.reflection_pad2d(input, pad)
          _194 = _195
        else:
          if torch.eq(mode, "replicate"):
            _197 = torch.replication_pad2d(input, pad)
            _196 = _197
          else:
            if torch.eq(mode, "circular"):
              _198 = _175(input, pad, )
            else:
              ops.prim.RaiseException("")
              _198 = _177
            _196 = _198
          _194 = _196
        _193 = _194
      else:
        if torch.eq(torch.len(pad), 6):
          if torch.eq(torch.dim(input), 4):
            _200 = True
          else:
            _201 = torch.eq(torch.dim(input), 5)
            _200 = _201
          _199 = _200
        else:
          _199 = False
        if _199:
          if torch.eq(mode, "reflect"):
            _204 = torch.reflection_pad3d(input, pad)
            _203 = _204
          else:
            if torch.eq(mode, "replicate"):
              _206 = torch.replication_pad3d(input, pad)
              _205 = _206
            else:
              _207 = torch.eq(mode, "circular")
              if _207:
                _208 = _175(input, pad, )
              else:
                ops.prim.RaiseException("")
                _208 = _177
              _205 = _208
            _203 = _205
          _202 = _203
        else:
          ops.prim.RaiseException(_176)
          _202 = _177
        _193 = _202
      _185 = _193
    _180 = _185
  return _180
def _scaled_dot_product_attention(q: Tensor,
    k: Tensor,
    v: Tensor,
    attn_mask: Optional[Tensor]=None,
    dropout_p: float=0.) -> Tuple[Tensor, Tensor]:
  B, Nt, E, = torch.size(q)
  q3 = torch.div(q, torch.sqrt(E))
  attn = torch.bmm(q3, torch.transpose(k, -2, -1))
  if torch.__isnot__(attn_mask, None):
    attn_mask22 = unchecked_cast(Tensor, attn_mask)
    attn0 = torch.add_(attn, attn_mask22)
  else:
    attn0 = attn
  attn1 = __torch__.torch.nn.functional.softmax(attn0, -1, 3, None, )
  if torch.gt(dropout_p, 0.):
    attn3 = __torch__.torch.nn.functional.dropout(attn1, dropout_p, True, False, )
    attn2 = attn3
  else:
    attn2 = attn1
  output = torch.bmm(attn2, v)
  return (output, attn2)
def _pad_circular(input: Tensor,
    padding: List[int]) -> Tensor:
  _209 = "AssertionError: Padding value causes wrapping around more than once."
  _210 = "AssertionError: Negative padding value is resulting in an empty dimension."
  in_shape = torch.size(input)
  paddable_shape = torch.slice(in_shape, 2)
  ndim = torch.len(paddable_shape)
  _211 = [9223372036854775807, torch.len(paddable_shape)]
  for idx in range(ops.prim.min(_211)):
    size = paddable_shape[idx]
    _212 = torch.neg(torch.add(torch.mul(idx, 2), 1))
    if torch.le(padding[_212], size):
      pass
    else:
      ops.prim.RaiseException(_209)
    _213 = torch.neg(torch.add(torch.mul(idx, 2), 2))
    if torch.le(padding[_213], size):
      pass
    else:
      ops.prim.RaiseException(_209)
    _214 = torch.neg(torch.add(torch.mul(idx, 2), 1))
    _215 = padding[_214]
    _216 = torch.neg(torch.add(torch.mul(idx, 2), 2))
    _217 = torch.add(torch.add(_215, padding[_216]), size)
    if torch.ge(_217, 0):
      pass
    else:
      ops.prim.RaiseException(_210)
  out_shape = torch.slice(in_shape, None, 2)
  _218 = [9223372036854775807, torch.len(paddable_shape)]
  out_shape0 = out_shape
  for idx0 in range(ops.prim.min(_218)):
    size0 = paddable_shape[idx0]
    _219 = torch.neg(torch.add(torch.mul(idx0, 2), 1))
    _220 = torch.add(size0, padding[_219])
    _221 = torch.neg(torch.add(torch.mul(idx0, 2), 2))
    out_shape1 = torch.add_(out_shape0, [torch.add(_220, padding[_221])])
    out_shape0 = out_shape1
  out = torch.empty(out_shape0, dtype=ops.prim.dtype(input), layout=ops.prim.layout(input), device=ops.prim.device(input))
  if torch.eq(ndim, 1):
    out_d0 = ops.prim.max(padding[-2], 0)
    out_d1 = torch.sub(out_shape0[2], ops.prim.max(padding[-1], 0))
    in_d0 = ops.prim.max(torch.neg(padding[-2]), 0)
    _222 = in_shape[2]
    _223 = ops.prim.max(torch.neg(padding[-1]), 0)
    in_d1 = torch.sub(_222, _223)
    _224 = torch.slice(input, -1, in_d0, in_d1)
    _225 = torch.slice(out, -1, out_d0, out_d1)
    _226 = torch.copy_(_225, _224)
  else:
    if torch.eq(ndim, 2):
      out_d00 = ops.prim.max(padding[-2], 0)
      out_d10 = torch.sub(out_shape0[2], ops.prim.max(padding[-1], 0))
      out_h0 = ops.prim.max(padding[-4], 0)
      out_h1 = torch.sub(out_shape0[3], ops.prim.max(padding[-3], 0))
      in_d00 = ops.prim.max(torch.neg(padding[-2]), 0)
      _227 = in_shape[2]
      _228 = ops.prim.max(torch.neg(padding[-1]), 0)
      in_d10 = torch.sub(_227, _228)
      in_h0 = ops.prim.max(torch.neg(padding[-4]), 0)
      _229 = in_shape[3]
      _230 = ops.prim.max(torch.neg(padding[-3]), 0)
      in_h1 = torch.sub(_229, _230)
      _231 = torch.slice(input, -2, in_d00, in_d10)
      _232 = torch.slice(_231, -1, in_h0, in_h1)
      _233 = torch.slice(out, -2, out_d00, out_d10)
      _234 = torch.slice(_233, -1, out_h0, out_h1)
      _235 = torch.copy_(_234, _232)
    else:
      if torch.eq(ndim, 3):
        out_d01 = ops.prim.max(padding[-2], 0)
        out_d11 = torch.sub(out_shape0[2], ops.prim.max(padding[-1], 0))
        out_h00 = ops.prim.max(padding[-4], 0)
        out_h10 = torch.sub(out_shape0[3], ops.prim.max(padding[-3], 0))
        out_w0 = ops.prim.max(padding[-6], 0)
        out_w1 = torch.sub(out_shape0[4], ops.prim.max(padding[-5], 0))
        in_d01 = ops.prim.max(torch.neg(padding[-2]), 0)
        _236 = in_shape[2]
        _237 = ops.prim.max(torch.neg(padding[-1]), 0)
        in_d11 = torch.sub(_236, _237)
        in_h00 = ops.prim.max(torch.neg(padding[-4]), 0)
        _238 = in_shape[3]
        _239 = ops.prim.max(torch.neg(padding[-3]), 0)
        in_h10 = torch.sub(_238, _239)
        in_w0 = ops.prim.max(torch.neg(padding[-6]), 0)
        _240 = in_shape[4]
        _241 = ops.prim.max(torch.neg(padding[-5]), 0)
        in_w1 = torch.sub(_240, _241)
        _242 = torch.slice(input, -3, in_d01, in_d11)
        _243 = torch.slice(_242, -2, in_h00, in_h10)
        _244 = torch.slice(_243, -1, in_w0, in_w1)
        _245 = torch.slice(out, -3, out_d01, out_d11)
        _246 = torch.slice(_245, -2, out_h00, out_h10)
        _247 = torch.slice(_246, -1, out_w0, out_w1)
        _248 = torch.copy_(_247, _244)
      else:
        pass
  if torch.gt(padding[-2], 0):
    _249 = torch.sub(out_shape0[2], padding[-2])
    i0 = torch.sub(_249, ops.prim.max(padding[-1], 0))
    i1 = torch.sub(out_shape0[2], ops.prim.max(padding[-1], 0))
    o1 = padding[-2]
    _250 = torch.slice(torch.slice(torch.slice(out), 1), 2, i0, i1)
    _251 = torch.slice(torch.slice(torch.slice(out), 1), 2, 0, o1)
    _252 = torch.copy_(_251, _250)
  else:
    pass
  if torch.gt(padding[-1], 0):
    i00 = ops.prim.max(padding[-2], 0)
    i10 = torch.add(ops.prim.max(padding[-2], 0), padding[-1])
    o0 = torch.sub(out_shape0[2], padding[-1])
    o10 = out_shape0[2]
    _253 = torch.slice(torch.slice(torch.slice(out), 1), 2, i00, i10)
    _254 = torch.slice(torch.slice(torch.slice(out), 1), 2, o0, o10)
    _255 = torch.copy_(_254, _253)
  else:
    pass
  if torch.gt(torch.len(padding), 2):
    if torch.gt(padding[-4], 0):
      _256 = torch.sub(out_shape0[3], padding[-4])
      i01 = torch.sub(_256, ops.prim.max(padding[-3], 0))
      i11 = torch.sub(out_shape0[3], ops.prim.max(padding[-3], 0))
      o11 = padding[-4]
      _257 = torch.slice(torch.slice(torch.slice(out), 1), 2)
      _258 = torch.slice(_257, 3, i01, i11)
      _259 = torch.slice(torch.slice(torch.slice(out), 1), 2)
      _260 = torch.copy_(torch.slice(_259, 3, 0, o11), _258)
    else:
      pass
    if torch.gt(padding[-3], 0):
      i02 = ops.prim.max(padding[-4], 0)
      i12 = torch.add(ops.prim.max(padding[-4], 0), padding[-3])
      o00 = torch.sub(out_shape0[3], padding[-3])
      o12 = out_shape0[3]
      _261 = torch.slice(torch.slice(torch.slice(out), 1), 2)
      _262 = torch.slice(_261, 3, i02, i12)
      _263 = torch.slice(torch.slice(torch.slice(out), 1), 2)
      _264 = torch.copy_(torch.slice(_263, 3, o00, o12), _262)
    else:
      pass
  else:
    pass
  if torch.gt(torch.len(padding), 4):
    if torch.gt(padding[-6], 0):
      _265 = torch.sub(out_shape0[4], padding[-6])
      i03 = torch.sub(_265, ops.prim.max(padding[-5], 0))
      i13 = torch.sub(out_shape0[4], ops.prim.max(padding[-5], 0))
      o13 = padding[-6]
      _266 = torch.slice(torch.slice(torch.slice(out), 1), 2)
      _267 = torch.slice(torch.slice(_266, 3), 4, i03, i13)
      _268 = torch.slice(torch.slice(torch.slice(out), 1), 2)
      _269 = torch.slice(torch.slice(_268, 3), 4, 0, o13)
      _270 = torch.copy_(_269, _267)
    else:
      pass
    if torch.gt(padding[-5], 0):
      i04 = ops.prim.max(padding[-6], 0)
      i14 = torch.add(ops.prim.max(padding[-6], 0), padding[-5])
      o01 = torch.sub(out_shape0[4], padding[-5])
      o14 = out_shape0[4]
      _271 = torch.slice(torch.slice(torch.slice(out), 1), 2)
      _272 = torch.slice(torch.slice(_271, 3), 4, i04, i14)
      _273 = torch.slice(torch.slice(torch.slice(out), 1), 2)
      _274 = torch.slice(torch.slice(_273, 3), 4, o01, o14)
      _275 = torch.copy_(_274, _272)
    else:
      pass
  else:
    pass
  return out
def softmax(input: Tensor,
    dim: Optional[int]=None,
    _stacklevel: int=3,
    dtype: Optional[int]=None) -> Tensor:
  _276 = __torch__.torch.nn.functional._get_softmax_dim
  if torch.__is__(dim, None):
    dim1 = _276("softmax", torch.dim(input), _stacklevel, )
    dim0 = dim1
  else:
    dim0 = unchecked_cast(int, dim)
  if torch.__is__(dtype, None):
    ret = torch.softmax(input, dim0)
  else:
    dtype0 = unchecked_cast(int, dtype)
    ret = torch.softmax(input, dim0, dtype0)
  return ret
def _get_softmax_dim(name: str,
    ndim: int,
    stacklevel: int) -> int:
  _277 = "Implicit dimension choice for {} has been deprecated. Change the call to include dim=X as an argument."
  torch.warn(torch.format(_277, name), stacklevel)
  if torch.eq(ndim, 0):
    _278 = True
  else:
    _278 = torch.eq(ndim, 1)
  if _278:
    _279 = True
  else:
    _279 = torch.eq(ndim, 3)
  if _279:
    ret = 0
  else:
    ret = 1
  return ret
