-
Notifications
You must be signed in to change notification settings - Fork 438
/
main.py
147 lines (124 loc) · 4.95 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import math
from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union
import _torch as torch
import _torch.nn as nn
import _torch.nn.functional as F
from _torch import Tensor
from pyre_extensions import Multiply
N_HEADS = TypeVar("N_HEADS", bound=int)
DIM = TypeVar("DIM", bound=int)
BS = TypeVar("BS", bound=int)
QLEN = TypeVar("QLEN", bound=int)
KLEN = TypeVar("KLEN", bound=int)
DIM_PER_HEAD = TypeVar("DIM_PER_HEAD", bound=int)
A = TypeVar("A", bound=int)
B = TypeVar("B", bound=int)
def mult(a: A, b: B) -> Multiply[A, B]:
...
class MultiHeadAttention(Generic[N_HEADS, DIM]):
NEW_ID = itertools.count()
layer_id: int
dim: DIM
n_heads: N_HEADS
q_lin: nn.Linear[DIM, DIM]
k_lin: nn.Linear[DIM, DIM]
v_lin: nn.Linear[DIM, DIM]
out_lin: nn.Linear[DIM, DIM]
def __init__(self, n_heads: N_HEADS, dim: DIM, dropout: float):
super().__init__()
self.layer_id = next(MultiHeadAttention.NEW_ID)
self.dim = dim
self.n_heads = n_heads
self.dropout = dropout
# assert self.dim % self.n_heads == 0
self.q_lin = nn.Linear(dim, dim)
self.k_lin = nn.Linear(dim, dim)
self.v_lin = nn.Linear(dim, dim)
self.out_lin = nn.Linear(dim, dim)
def forward(
self,
input: Tensor[BS, QLEN, DIM],
mask: Union[Tensor[BS, KLEN], Tensor[BS, KLEN, KLEN]],
kv: Optional[Tensor[BS, KLEN, DIM]],
cache: Optional[
Dict[
int,
Tuple[
Tensor[BS, N_HEADS, Any, DIM_PER_HEAD],
Tensor[BS, N_HEADS, Any, DIM_PER_HEAD],
],
]
],
cache_slen: int,
dim_per_head: DIM_PER_HEAD,
) -> Tensor[BS, QLEN, DIM]:
"""
Self-attention (if kv is None) or attention
over source sentence (provided by kv).
"""
# Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
bs, qlen, dim = input.size()
if kv is None:
klen = qlen if cache is None else cache_slen + qlen
else:
klen = kv.size(1)
# dim_per_head = dim // self.n_heads #> dim_per_head cannot be a literal
mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)
def shape(x: Tensor[BS, Any, DIM]) -> Tensor[BS, N_HEADS, Any, DIM_PER_HEAD]:
"""projection"""
# variables defined outside of the body of the function are not typed
bs: BS
dim_per_head: DIM_PER_HEAD
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
def unshape(
x: Tensor[BS, N_HEADS, QLEN, DIM_PER_HEAD]
) -> Tensor[BS, QLEN, Any]:
"""compute context"""
return (
x.transpose(1, 2)
.contiguous()
.view(bs, -1, mult(self.n_heads, dim_per_head))
)
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
if kv is None:
k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
elif cache is None or self.layer_id not in cache:
k = v = kv
k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head)
if cache is not None:
if self.layer_id in cache:
if kv is None:
k_, v_ = cache[self.layer_id]
k = torch.cat(k_, k, dim=2) # (bs, n_heads, klen, dim_per_head)
v = torch.cat(v_, v, dim=2) # (bs, n_heads, klen, dim_per_head)
else:
k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores: Tensor[BS, N_HEADS, QLEN, KLEN] = torch.matmul(
q, k.transpose(2, 3)
) # (bs, n_heads, qlen, klen)
mask2 = (
(mask == 0).view(mask_reshape).expand_as(scores)
) # (bs, n_heads, qlen, klen)
scores2 = scores.masked_fill(mask2, -float("inf")) # (bs, n_heads, qlen, klen)
weights = F.softmax(scores2.float(), dim=-1).type_as(
scores
) # (bs, n_heads, qlen, klen)
weights = F.dropout(
weights, p=self.dropout, training=0.5
) # self.training) # (bs, n_heads, qlen, klen)
context: Tensor[BS, N_HEADS, QLEN, DIM_PER_HEAD] = torch.matmul(
weights, v
) # (bs, n_heads, qlen, dim_per_head)
context2: Tensor[BS, QLEN, DIM] = unshape(context) # (bs, qlen, dim)
return self.out_lin(context2)
a = MultiHeadAttention(20, 4, 0.5)