class MultiheadAttention(Module):
  __parameters__ = ["in_proj_weight", "q_proj_weight", "k_proj_weight", "v_proj_weight", "in_proj_bias", ]
  __buffers__ = []
  in_proj_weight : Tensor
  q_proj_weight : NoneType
  k_proj_weight : NoneType
  v_proj_weight : NoneType
  in_proj_bias : Tensor
  training : bool
  _is_full_backward_hook : NoneType
  embed_dim : int
  kdim : int
  vdim : int
  _qkv_same_embed_dim : bool
  num_heads : int
  dropout : float
  head_dim : int
  bias_k : Optional[Tensor]
  bias_v : Optional[Tensor]
  add_zero_attn : bool
  out_proj : __torch__.torch.nn.modules.linear.NonDynamicallyQuantizableLinear
  batch_first : Final[bool] = True
  def forward(self: __torch__.torch.nn.modules.activation.MultiheadAttention,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    key_padding_mask: Optional[Tensor]=None,
    need_weights: bool=True,
    attn_mask: Optional[Tensor]=None,
    average_attn_weights: bool=True,
    is_causal: bool=False) -> Tuple[Tensor, Optional[Tensor]]:
    _0 = "AssertionError: Only allow causal mask or attn_mask"
    _1 = __torch__.torch.nn.functional._none_or_dtype
    _2 = __torch__.torch.nn.functional._canonical_mask
    _3 = "input not batched; expected query.dim() of 3 but got {}"
    why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
    _4 = "dtypes of query ({}) and self.in_proj_bias ({}) don\'t match"
    _5 = "dtypes of query ({}) and self.in_proj_weight ({}) don\'t match"
    why_not_fast_path0 = "supplying both src_key_padding_mask and src_mask at the same time                              is not supported with NestedTensor input"
    why_not_fast_path1 = "some Tensor argument is neither CUDA nor CPU"
    why_not_fast_path2 = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
    _6 = "The fast path was not hit because {}"
    _7 = "MultiheadAttention does not support NestedTensor outside of its fast path. "
    _8 = __torch__.torch.nn.functional.multi_head_attention_forward
    _9 = uninitialized(Tuple[Tensor, Tensor])
    _10 = uninitialized(Optional[Tensor])
    if torch.__isnot__(attn_mask, None):
      _11, attn_mask0 = is_causal, unchecked_cast(Tensor, attn_mask)
    else:
      _11, attn_mask0 = False, attn_mask
    if _11:
      ops.prim.RaiseException(_0)
      attn_mask1 : Optional[Tensor] = _10
    else:
      attn_mask1 = attn_mask0
    is_batched = torch.eq(torch.dim(query), 3)
    key_padding_mask0 = _2(key_padding_mask, "key_padding_mask", _1(attn_mask1, ), "attn_mask", ops.prim.dtype(query), True, )
    if torch.__not__(is_batched):
      why_not_fast_path4 = torch.format(_3, torch.dim(query))
      why_not_fast_path3, key_padding_mask1 = why_not_fast_path4, key_padding_mask0
    else:
      if torch.__isnot__(query, key):
        _12 = True
      else:
        _12 = torch.__isnot__(key, value)
      if _12:
        why_not_fast_path5, key_padding_mask2 = why_not_fast_path, key_padding_mask0
      else:
        _13 = ops.prim.dtype(query)
        in_proj_bias = self.in_proj_bias
        _14 = torch.ne(_13, ops.prim.dtype(in_proj_bias))
        if _14:
          _15 = ops.prim.dtype(query)
          in_proj_bias0 = self.in_proj_bias
          why_not_fast_path7 = torch.format(_4, _15, ops.prim.dtype(in_proj_bias0))
          why_not_fast_path6, key_padding_mask3 = why_not_fast_path7, key_padding_mask0
        else:
          _16 = ops.prim.dtype(query)
          in_proj_weight = self.in_proj_weight
          _17 = ops.prim.dtype(in_proj_weight)
          if torch.ne(_16, _17):
            _18 = ops.prim.dtype(query)
            in_proj_weight0 = self.in_proj_weight
            _19 = ops.prim.dtype(in_proj_weight0)
            why_not_fast_path8, key_padding_mask4 = torch.format(_5, _18, _19), key_padding_mask0
          else:
            training = self.training
            if training:
              why_not_fast_path9, key_padding_mask5 = "training is enabled", key_padding_mask0
            else:
              bias_k = self.bias_k
              _20 = torch.__isnot__(bias_k, None)
              if _20:
                why_not_fast_path10, key_padding_mask6 = "self.bias_k was not None", key_padding_mask0
              else:
                bias_v = self.bias_v
                _21 = torch.__isnot__(bias_v, None)
                if _21:
                  why_not_fast_path11, key_padding_mask7 = "self.bias_v was not None", key_padding_mask0
                else:
                  add_zero_attn = self.add_zero_attn
                  if add_zero_attn:
                    why_not_fast_path12, key_padding_mask8 = "add_zero_attn was enabled", key_padding_mask0
                  else:
                    _qkv_same_embed_dim = self._qkv_same_embed_dim
                    _22 = torch.__not__(_qkv_same_embed_dim)
                    if _22:
                      why_not_fast_path13, key_padding_mask9 = "_qkv_same_embed_dim was not True", key_padding_mask0
                    else:
                      _23 = ops.prim.is_nested(query)
                      if _23:
                        _25 = torch.__isnot__(key_padding_mask0, None)
                        if _25:
                          key_padding_mask12 = unchecked_cast(Tensor, key_padding_mask0)
                          _26, key_padding_mask11 = True, key_padding_mask12
                        else:
                          _27 = torch.__isnot__(attn_mask1, None)
                          _26, key_padding_mask11 = _27, key_padding_mask0
                        _24, key_padding_mask10 = _26, key_padding_mask11
                      else:
                        _24, key_padding_mask10 = False, key_padding_mask0
                      if _24:
                        why_not_fast_path14 = why_not_fast_path0
                      else:
                        _28 = torch.is_autocast_enabled()
                        if _28:
                          why_not_fast_path15 = "autocast is enabled"
                        else:
                          why_not_fast_path15 = ""
                        why_not_fast_path14 = why_not_fast_path15
                      why_not_fast_path13, key_padding_mask9 = why_not_fast_path14, key_padding_mask10
                    why_not_fast_path12, key_padding_mask8 = why_not_fast_path13, key_padding_mask9
                  why_not_fast_path11, key_padding_mask7 = why_not_fast_path12, key_padding_mask8
                why_not_fast_path10, key_padding_mask6 = why_not_fast_path11, key_padding_mask7
              why_not_fast_path9, key_padding_mask5 = why_not_fast_path10, key_padding_mask6
            why_not_fast_path8, key_padding_mask4 = why_not_fast_path9, key_padding_mask5
          why_not_fast_path6, key_padding_mask3 = why_not_fast_path8, key_padding_mask4
        why_not_fast_path5, key_padding_mask2 = why_not_fast_path6, key_padding_mask3
      why_not_fast_path3, key_padding_mask1 = why_not_fast_path5, key_padding_mask2
    _29 = torch.gt(torch.len(why_not_fast_path3), 0)
    if torch.__not__(_29):
      in_proj_weight1 = self.in_proj_weight
      in_proj_bias1 = self.in_proj_bias
      out_proj = self.out_proj
      weight = out_proj.weight
      out_proj0 = self.out_proj
      bias = out_proj0.bias
      _32 = annotate(List[bool], [])
      if ops.prim.is_cuda(query):
        _33 = True
      else:
        _34 = str(ops.prim.device(query))
        _35 = torch.find(_34, "cpu", 0, torch.len(_34))
        _33 = torch.ne(_35, -1)
      _36 = torch.append(_32, _33)
      if ops.prim.is_cuda(key):
        _37 = True
      else:
        _38 = str(ops.prim.device(key))
        _39 = torch.find(_38, "cpu", 0, torch.len(_38))
        _37 = torch.ne(_39, -1)
      _40 = torch.append(_32, _37)
      if ops.prim.is_cuda(value):
        _41 = True
      else:
        _42 = str(ops.prim.device(value))
        _43 = torch.find(_42, "cpu", 0, torch.len(_42))
        _41 = torch.ne(_43, -1)
      _44 = torch.append(_32, _41)
      if ops.prim.is_cuda(in_proj_weight1):
        _45 = True
      else:
        _46 = ops.prim.device(in_proj_weight1)
        _47 = str(_46)
        _48 = torch.find(_47, "cpu", 0, torch.len(_47))
        _45 = torch.ne(_48, -1)
      _49 = torch.append(_32, _45)
      if ops.prim.is_cuda(in_proj_bias1):
        _50 = True
      else:
        _51 = str(ops.prim.device(in_proj_bias1))
        _52 = torch.find(_51, "cpu", 0, torch.len(_51))
        _50 = torch.ne(_52, -1)
      _53 = torch.append(_32, _50)
      if ops.prim.is_cuda(weight):
        _54 = True
      else:
        _55 = str(ops.prim.device(weight))
        _56 = torch.find(_55, "cpu", 0, torch.len(_55))
        _54 = torch.ne(_56, -1)
      _57 = torch.append(_32, _54)
      if ops.prim.is_cuda(bias):
        _58 = True
      else:
        _59 = str(ops.prim.device(bias))
        _60 = torch.find(_59, "cpu", 0, torch.len(_59))
        _58 = torch.ne(_60, -1)
      _61 = torch.append(_32, _58)
      if torch.__not__(torch.all(_32)):
        why_not_fast_path17 = why_not_fast_path1
      else:
        if torch.is_grad_enabled():
          _63 = annotate(List[bool], [])
          _64 = torch.append(_63, ops.prim.requires_grad(query))
          _65 = torch.append(_63, ops.prim.requires_grad(key))
          _66 = torch.append(_63, ops.prim.requires_grad(value))
          _67 = ops.prim.requires_grad(in_proj_weight1)
          _68 = torch.append(_63, _67)
          _69 = ops.prim.requires_grad(in_proj_bias1)
          _70 = torch.append(_63, _69)
          _71 = ops.prim.requires_grad(weight)
          _72 = torch.append(_63, _71)
          _73 = torch.append(_63, ops.prim.requires_grad(bias))
          _62 = torch.any(_63)
        else:
          _62 = False
        if _62:
          why_not_fast_path18 = why_not_fast_path2
        else:
          why_not_fast_path18 = why_not_fast_path3
        why_not_fast_path17 = why_not_fast_path18
      _74 = torch.gt(torch.len(why_not_fast_path17), 0)
      if torch.__not__(_74):
        _77 = (self).merge_masks(attn_mask1, key_padding_mask1, query, )
        merged_mask, mask_type, = _77
        embed_dim = self.embed_dim
        num_heads = self.num_heads
        in_proj_weight2 = self.in_proj_weight
        in_proj_bias2 = self.in_proj_bias
        out_proj1 = self.out_proj
        weight0 = out_proj1.weight
        out_proj2 = self.out_proj
        bias0 = out_proj2.bias
        _78, _79 = torch._native_multi_head_attention(query, key, value, embed_dim, num_heads, in_proj_weight2, in_proj_bias2, weight0, bias0, merged_mask, need_weights, average_attn_weights, mask_type)
        _75, _76 = True, (_78, _79)
      else:
        _75, _76 = False, _9
      _30, _31, why_not_fast_path16 = _75, _76, why_not_fast_path17
    else:
      _30, _31, why_not_fast_path16 = False, _9, why_not_fast_path3
    if _30:
      _80 = _31
    else:
      if ops.prim.is_nested(query):
        _81 = True
      else:
        _81 = ops.prim.is_nested(key)
      if _81:
        any_nested = True
      else:
        any_nested = ops.prim.is_nested(value)
      if torch.__not__(any_nested):
        pass
      else:
        _82 = torch.format(_6, why_not_fast_path16)
        _83 = torch.add("AssertionError: ", torch.add(_7, _82))
        ops.prim.RaiseException(_83)
      if is_batched:
        if torch.__is__(key, value):
          if torch.__is__(query, key):
            _84 = torch.transpose(query, 1, 0)
            query2, key2, value2 = _84, _84, _84
          else:
            _85 = annotate(List[Tensor], [])
            _86 = torch.transpose(query, 1, 0)
            _87 = torch.append(_85, _86)
            _88 = torch.append(_85, torch.transpose(key, 1, 0))
            query3, key3, = _85
            query2, key2, value2 = query3, key3, key3
          query1, key1, value1 = query2, key2, value2
        else:
          _89 = annotate(List[Tensor], [])
          _90 = torch.append(_89, torch.transpose(query, 1, 0))
          _91 = torch.append(_89, torch.transpose(key, 1, 0))
          _92 = torch.append(_89, torch.transpose(value, 1, 0))
          query4, key4, value3, = _89
          query1, key1, value1 = query4, key4, value3
        query0, key0, value0 = query1, key1, value1
      else:
        query0, key0, value0 = query, key, value
      _qkv_same_embed_dim0 = self._qkv_same_embed_dim
      _93 = torch.__not__(_qkv_same_embed_dim0)
      if _93:
        embed_dim0 = self.embed_dim
        num_heads0 = self.num_heads
        in_proj_weight3 = self.in_proj_weight
        in_proj_bias3 = self.in_proj_bias
        bias_k0 = self.bias_k
        bias_v0 = self.bias_v
        add_zero_attn0 = self.add_zero_attn
        dropout = self.dropout
        out_proj3 = self.out_proj
        weight1 = out_proj3.weight
        out_proj4 = self.out_proj
        bias1 = out_proj4.bias
        training0 = self.training
        q_proj_weight = self.q_proj_weight
        k_proj_weight = self.k_proj_weight
        v_proj_weight = self.v_proj_weight
        _94 = _8(query0, key0, value0, embed_dim0, num_heads0, in_proj_weight3, in_proj_bias3, bias_k0, bias_v0, add_zero_attn0, dropout, weight1, bias1, training0, key_padding_mask1, need_weights, attn_mask1, True, q_proj_weight, k_proj_weight, v_proj_weight, None, None, average_attn_weights, is_causal, )
        attn_output0, attn_output_weights0, = _94
        attn_output, attn_output_weights = attn_output0, attn_output_weights0
      else:
        embed_dim1 = self.embed_dim
        num_heads1 = self.num_heads
        in_proj_weight4 = self.in_proj_weight
        in_proj_bias4 = self.in_proj_bias
        bias_k1 = self.bias_k
        bias_v1 = self.bias_v
        add_zero_attn1 = self.add_zero_attn
        dropout0 = self.dropout
        out_proj5 = self.out_proj
        weight2 = out_proj5.weight
        out_proj6 = self.out_proj
        bias2 = out_proj6.bias
        training1 = self.training
        _95 = _8(query0, key0, value0, embed_dim1, num_heads1, in_proj_weight4, in_proj_bias4, bias_k1, bias_v1, add_zero_attn1, dropout0, weight2, bias2, training1, key_padding_mask1, need_weights, attn_mask1, False, None, None, None, None, None, average_attn_weights, is_causal, )
        attn_output1, attn_output_weights1, = _95
        attn_output, attn_output_weights = attn_output1, attn_output_weights1
      if is_batched:
        _97 = torch.transpose(attn_output, 1, 0)
        _96 = (_97, attn_output_weights)
      else:
        _98 = (attn_output, attn_output_weights)
        _96 = _98
      _80 = _96
    return _80
  def merge_masks(self: __torch__.torch.nn.modules.activation.MultiheadAttention,
    attn_mask: Optional[Tensor],
    key_padding_mask: Optional[Tensor],
    query: Tensor) -> Tuple[Optional[Tensor], Optional[int]]:
    _99 = __torch__.torch.nn.functional._none_or_dtype
    _100 = __torch__.torch.nn.functional._canonical_mask
    attn_mask2 = _100(attn_mask, "attn_mask", _99(key_padding_mask, ), "key_padding_mask", ops.prim.dtype(query), False, )
    if torch.__isnot__(attn_mask2, None):
      attn_mask4 = unchecked_cast(Tensor, attn_mask2)
      merged_mask, mask_type, attn_mask3 = attn_mask4, 0, attn_mask4
    else:
      merged_mask, mask_type, attn_mask3 = None, None, attn_mask2
    _101 = torch.__isnot__(key_padding_mask, None)
    if _101:
      key_padding_mask14 = unchecked_cast(Tensor, key_padding_mask)
      key_padding_mask13, merged_mask0, mask_type0 = key_padding_mask14, key_padding_mask14, 1
    else:
      key_padding_mask13, merged_mask0, mask_type0 = key_padding_mask, merged_mask, mask_type
    if torch.__isnot__(attn_mask3, None):
      attn_mask6 = unchecked_cast(Tensor, attn_mask3)
      _103 = torch.__isnot__(key_padding_mask13, None)
      _102, attn_mask5 = _103, attn_mask6
    else:
      _102, attn_mask5 = False, attn_mask3
    if _102:
      attn_mask7 = unchecked_cast(Tensor, attn_mask5)
      key_padding_mask15 = unchecked_cast(Tensor, key_padding_mask13)
      batch_size, seq_len, _104, = torch.size(query)
      _105 = torch.view(key_padding_mask15, [batch_size, 1, 1, seq_len])
      num_heads = self.num_heads
      key_padding_mask_expanded = torch.expand(_105, [-1, num_heads, -1, -1])
      _106 = torch.view(attn_mask7, [1, 1, seq_len, seq_len])
      num_heads2 = self.num_heads
      attn_mask_expanded = torch.expand(_106, [batch_size, num_heads2, -1, -1])
      merged_mask2 = torch.add(attn_mask_expanded, key_padding_mask_expanded)
      merged_mask1, mask_type1 = merged_mask2, 2
    else:
      merged_mask1, mask_type1 = merged_mask0, mask_type0
    return (merged_mask1, mask_type1)
class ReLU(Module):
  __parameters__ = []
  __buffers__ = []
  training : bool
  _is_full_backward_hook : NoneType
  inplace : Final[bool] = False
  def forward(self: __torch__.torch.nn.modules.activation.ReLU,
    input: Tensor) -> Tensor:
    _107 = __torch__.torch.nn.functional.relu(input, False, )
    return _107
class LeakyReLU(Module):
  __parameters__ = []
  __buffers__ = []
  training : bool
  _is_full_backward_hook : NoneType
  negative_slope : Final[float] = 0.10000000000000001
  inplace : Final[bool] = False
  def forward(self: __torch__.torch.nn.modules.activation.LeakyReLU,
    input: Tensor) -> Tensor:
    _108 = __torch__.torch.nn.functional.leaky_relu
    _109 = _108(input, 0.10000000000000001, False, )
    return _109
