oweller2
commited on
Commit
·
e0229bb
1
Parent(s):
1f59624
added updated code:
Browse files- __init__.py +3 -3
- attention.py +48 -38
- config.json +1 -1
- configuration_bert.py +3 -3
- modeling_flexbert.py +25 -41
__init__.py
CHANGED
|
@@ -33,13 +33,14 @@ from .modeling_flexbert import (
|
|
| 33 |
FlexBertForMaskedLM,
|
| 34 |
FlexBertForSequenceClassification,
|
| 35 |
FlexBertForMultipleChoice,
|
| 36 |
-
|
| 37 |
)
|
| 38 |
from .bert_padding import(
|
| 39 |
IndexFirstAxis,
|
| 40 |
IndexPutFirstAxis
|
| 41 |
)
|
| 42 |
|
|
|
|
| 43 |
__all__ = [
|
| 44 |
"BertAlibiEmbeddings",
|
| 45 |
"BertAlibiEncoder",
|
|
@@ -69,6 +70,5 @@ __all__ = [
|
|
| 69 |
"FlexBertForMaskedLM",
|
| 70 |
"FlexBertForSequenceClassification",
|
| 71 |
"FlexBertForMultipleChoice",
|
| 72 |
-
"
|
| 73 |
-
"IndexPutFirstAxis"
|
| 74 |
]
|
|
|
|
| 33 |
FlexBertForMaskedLM,
|
| 34 |
FlexBertForSequenceClassification,
|
| 35 |
FlexBertForMultipleChoice,
|
| 36 |
+
FlexBertForCausalLM,
|
| 37 |
)
|
| 38 |
from .bert_padding import(
|
| 39 |
IndexFirstAxis,
|
| 40 |
IndexPutFirstAxis
|
| 41 |
)
|
| 42 |
|
| 43 |
+
|
| 44 |
__all__ = [
|
| 45 |
"BertAlibiEmbeddings",
|
| 46 |
"BertAlibiEncoder",
|
|
|
|
| 70 |
"FlexBertForMaskedLM",
|
| 71 |
"FlexBertForSequenceClassification",
|
| 72 |
"FlexBertForMultipleChoice",
|
| 73 |
+
"FlexBertForCausalLM"
|
|
|
|
| 74 |
]
|
attention.py
CHANGED
|
@@ -74,7 +74,7 @@ class BertAlibiUnpadSelfAttention(nn.Module):
|
|
| 74 |
f"heads ({config.num_attention_heads})"
|
| 75 |
)
|
| 76 |
|
| 77 |
-
self.
|
| 78 |
self.num_attention_heads = config.num_attention_heads
|
| 79 |
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 80 |
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
@@ -145,7 +145,7 @@ class BertAlibiUnpadSelfAttention(nn.Module):
|
|
| 145 |
dropout_p=self.p_dropout,
|
| 146 |
deterministic=self.deterministic_fa2,
|
| 147 |
alibi_slopes=slopes,
|
| 148 |
-
|
| 149 |
)
|
| 150 |
attention = attention.to(orig_dtype) # type: ignore
|
| 151 |
else:
|
|
@@ -156,10 +156,11 @@ class BertAlibiUnpadSelfAttention(nn.Module):
|
|
| 156 |
dropout_p=self.p_dropout,
|
| 157 |
deterministic=self.deterministic_fa2,
|
| 158 |
alibi_slopes=slopes,
|
| 159 |
-
|
| 160 |
)
|
| 161 |
else:
|
| 162 |
-
assert not self.
|
|
|
|
| 163 |
qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
| 164 |
unpad_bs, *_ = qkv.shape
|
| 165 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size)
|
|
@@ -236,6 +237,7 @@ class BertAlibiUnpadAttention(nn.Module):
|
|
| 236 |
slopes: None or (batch, heads) or (heads,)
|
| 237 |
"""
|
| 238 |
assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}"
|
|
|
|
| 239 |
self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes)
|
| 240 |
if subset_idx is not None:
|
| 241 |
return self.output(
|
|
@@ -293,7 +295,7 @@ class FlexBertUnpadAttention(FlexBertAttentionBase):
|
|
| 293 |
f"heads ({config.num_attention_heads})"
|
| 294 |
)
|
| 295 |
|
| 296 |
-
self.
|
| 297 |
self.num_attention_heads = config.num_attention_heads
|
| 298 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 299 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
|
@@ -402,7 +404,7 @@ class FlexBertUnpadAttention(FlexBertAttentionBase):
|
|
| 402 |
dropout_p=self.p_dropout,
|
| 403 |
deterministic=self.deterministic_fa2,
|
| 404 |
window_size=self.sliding_window,
|
| 405 |
-
|
| 406 |
)
|
| 407 |
attn = attn.to(orig_dtype) # type: ignore
|
| 408 |
else:
|
|
@@ -413,11 +415,12 @@ class FlexBertUnpadAttention(FlexBertAttentionBase):
|
|
| 413 |
dropout_p=self.p_dropout,
|
| 414 |
deterministic=self.deterministic_fa2,
|
| 415 |
window_size=self.sliding_window,
|
| 416 |
-
|
| 417 |
)
|
| 418 |
attn = attn.view(bs, dim)
|
| 419 |
else:
|
| 420 |
-
assert not self.
|
|
|
|
| 421 |
qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
| 422 |
unpad_bs, seqlen, _ = qkv.shape
|
| 423 |
|
|
@@ -456,7 +459,7 @@ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
|
|
| 456 |
f"heads ({config.num_attention_heads})"
|
| 457 |
)
|
| 458 |
|
| 459 |
-
self.
|
| 460 |
self.num_attention_heads = config.num_attention_heads
|
| 461 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 462 |
self.hidden_size = config.hidden_size
|
|
@@ -556,7 +559,7 @@ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
|
|
| 556 |
dropout_p=self.p_dropout,
|
| 557 |
deterministic=self.deterministic_fa2,
|
| 558 |
window_size=self.sliding_window,
|
| 559 |
-
|
| 560 |
)
|
| 561 |
attn = attn.to(orig_dtype) # type: ignore
|
| 562 |
else:
|
|
@@ -567,11 +570,12 @@ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
|
|
| 567 |
dropout_p=self.p_dropout,
|
| 568 |
deterministic=self.deterministic_fa2,
|
| 569 |
window_size=self.sliding_window,
|
| 570 |
-
|
| 571 |
)
|
| 572 |
attn = attn.view(bs, dim)
|
| 573 |
else:
|
| 574 |
-
assert not self.
|
|
|
|
| 575 |
qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
| 576 |
unpad_bs, seqlen, _ = qkv.shape
|
| 577 |
|
|
@@ -610,7 +614,7 @@ class FlexBertPaddedAttention(FlexBertAttentionBase):
|
|
| 610 |
f"heads ({config.num_attention_heads})"
|
| 611 |
)
|
| 612 |
|
| 613 |
-
self.
|
| 614 |
self.num_attention_heads = config.num_attention_heads
|
| 615 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 616 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
|
@@ -695,7 +699,7 @@ class FlexBertPaddedAttention(FlexBertAttentionBase):
|
|
| 695 |
dropout_p=self.p_dropout,
|
| 696 |
deterministic=self.deterministic_fa2,
|
| 697 |
window_size=self.sliding_window,
|
| 698 |
-
|
| 699 |
)
|
| 700 |
attn = attn.to(orig_dtype) # type: ignore
|
| 701 |
else:
|
|
@@ -704,10 +708,11 @@ class FlexBertPaddedAttention(FlexBertAttentionBase):
|
|
| 704 |
dropout_p=self.p_dropout,
|
| 705 |
deterministic=self.deterministic_fa2,
|
| 706 |
window_size=self.sliding_window,
|
| 707 |
-
|
| 708 |
)
|
| 709 |
else:
|
| 710 |
-
assert not self.
|
|
|
|
| 711 |
qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
|
| 712 |
|
| 713 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2)
|
|
@@ -743,7 +748,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
| 743 |
f"heads ({config.num_attention_heads})"
|
| 744 |
)
|
| 745 |
|
| 746 |
-
self.
|
| 747 |
self.num_attention_heads = config.num_attention_heads
|
| 748 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 749 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
|
@@ -882,7 +887,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
| 882 |
max_seqlen_q=max_seqlen,
|
| 883 |
max_seqlen_k=max_seqlen,
|
| 884 |
deterministic=self.deterministic_fa2,
|
| 885 |
-
causal=self.
|
| 886 |
)
|
| 887 |
attn = attn.to(orig_dtype) # type: ignore
|
| 888 |
else:
|
|
@@ -896,7 +901,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
| 896 |
max_seqlen_q=max_seqlen,
|
| 897 |
max_seqlen_k=max_seqlen,
|
| 898 |
deterministic=self.deterministic_fa2,
|
| 899 |
-
causal=self.
|
| 900 |
)
|
| 901 |
attn = attn.view(bs, dim)
|
| 902 |
elif self.use_fa2:
|
|
@@ -914,7 +919,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
| 914 |
dropout_p=self.p_dropout,
|
| 915 |
deterministic=self.deterministic_fa2,
|
| 916 |
window_size=self.sliding_window,
|
| 917 |
-
causal=self.
|
| 918 |
)
|
| 919 |
attn = attn.to(orig_dtype) # type: ignore
|
| 920 |
else:
|
|
@@ -925,11 +930,12 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
| 925 |
dropout_p=self.p_dropout,
|
| 926 |
deterministic=self.deterministic_fa2,
|
| 927 |
window_size=self.sliding_window,
|
| 928 |
-
causal=self.
|
| 929 |
)
|
| 930 |
attn = attn.view(bs, dim)
|
| 931 |
else:
|
| 932 |
-
assert not self.
|
|
|
|
| 933 |
qkv = bert_padding.pad_input(
|
| 934 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
| 935 |
) # batch, max_seqlen, thd
|
|
@@ -969,7 +975,7 @@ class FlexBertPaddedRopeAttention(FlexBertAttentionBase):
|
|
| 969 |
f"heads ({config.num_attention_heads})"
|
| 970 |
)
|
| 971 |
|
| 972 |
-
self.
|
| 973 |
self.num_attention_heads = config.num_attention_heads
|
| 974 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 975 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
|
@@ -1080,7 +1086,7 @@ class FlexBertPaddedRopeAttention(FlexBertAttentionBase):
|
|
| 1080 |
dropout_p=self.p_dropout,
|
| 1081 |
deterministic=self.deterministic_fa2,
|
| 1082 |
window_size=self.sliding_window,
|
| 1083 |
-
|
| 1084 |
)
|
| 1085 |
attn = attn.to(orig_dtype) # type: ignore
|
| 1086 |
else:
|
|
@@ -1089,10 +1095,11 @@ class FlexBertPaddedRopeAttention(FlexBertAttentionBase):
|
|
| 1089 |
dropout_p=self.p_dropout,
|
| 1090 |
deterministic=self.deterministic_fa2,
|
| 1091 |
window_size=self.sliding_window,
|
| 1092 |
-
|
| 1093 |
)
|
| 1094 |
else:
|
| 1095 |
-
assert not self.
|
|
|
|
| 1096 |
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
|
| 1097 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2)
|
| 1098 |
attn = F.scaled_dot_product_attention(
|
|
@@ -1127,7 +1134,7 @@ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
|
|
| 1127 |
f"heads ({config.num_attention_heads})"
|
| 1128 |
)
|
| 1129 |
|
| 1130 |
-
self.
|
| 1131 |
self.num_attention_heads = config.num_attention_heads
|
| 1132 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 1133 |
self.hidden_size = config.hidden_size
|
|
@@ -1253,7 +1260,7 @@ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
|
|
| 1253 |
dropout_p=self.p_dropout,
|
| 1254 |
deterministic=self.deterministic_fa2,
|
| 1255 |
window_size=self.sliding_window,
|
| 1256 |
-
|
| 1257 |
)
|
| 1258 |
attn = attn.to(orig_dtype) # type: ignore
|
| 1259 |
else:
|
|
@@ -1264,11 +1271,12 @@ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
|
|
| 1264 |
dropout_p=self.p_dropout,
|
| 1265 |
deterministic=self.deterministic_fa2,
|
| 1266 |
window_size=self.sliding_window,
|
| 1267 |
-
|
| 1268 |
)
|
| 1269 |
attn = attn.view(bs, dim)
|
| 1270 |
else:
|
| 1271 |
-
assert not self.
|
|
|
|
| 1272 |
qkv = bert_padding.pad_input(
|
| 1273 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
| 1274 |
) # batch, max_seqlen, thd
|
|
@@ -1308,7 +1316,7 @@ class FlexBertPaddedRopeParallelAttention(FlexBertAttentionBase):
|
|
| 1308 |
f"heads ({config.num_attention_heads})"
|
| 1309 |
)
|
| 1310 |
|
| 1311 |
-
self.
|
| 1312 |
self.num_attention_heads = config.num_attention_heads
|
| 1313 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 1314 |
self.hidden_size = config.hidden_size
|
|
@@ -1413,7 +1421,7 @@ class FlexBertPaddedRopeParallelAttention(FlexBertAttentionBase):
|
|
| 1413 |
dropout_p=self.p_dropout,
|
| 1414 |
deterministic=self.deterministic_fa2,
|
| 1415 |
window_size=self.sliding_window,
|
| 1416 |
-
|
| 1417 |
)
|
| 1418 |
attn = attn.to(orig_dtype) # type: ignore
|
| 1419 |
else:
|
|
@@ -1422,10 +1430,11 @@ class FlexBertPaddedRopeParallelAttention(FlexBertAttentionBase):
|
|
| 1422 |
dropout_p=self.p_dropout,
|
| 1423 |
deterministic=self.deterministic_fa2,
|
| 1424 |
window_size=self.sliding_window,
|
| 1425 |
-
|
| 1426 |
)
|
| 1427 |
else:
|
| 1428 |
-
assert not self.
|
|
|
|
| 1429 |
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
|
| 1430 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2)
|
| 1431 |
attn = F.scaled_dot_product_attention(
|
|
@@ -1460,7 +1469,7 @@ class FlexBertPaddedParallelAttention(FlexBertAttentionBase):
|
|
| 1460 |
f"heads ({config.num_attention_heads})"
|
| 1461 |
)
|
| 1462 |
|
| 1463 |
-
self.
|
| 1464 |
self.num_attention_heads = config.num_attention_heads
|
| 1465 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 1466 |
self.hidden_size = config.hidden_size
|
|
@@ -1537,7 +1546,7 @@ class FlexBertPaddedParallelAttention(FlexBertAttentionBase):
|
|
| 1537 |
dropout_p=self.p_dropout,
|
| 1538 |
deterministic=self.deterministic_fa2,
|
| 1539 |
window_size=self.sliding_window,
|
| 1540 |
-
|
| 1541 |
)
|
| 1542 |
attn = attn.to(orig_dtype) # type: ignore
|
| 1543 |
else:
|
|
@@ -1546,10 +1555,11 @@ class FlexBertPaddedParallelAttention(FlexBertAttentionBase):
|
|
| 1546 |
dropout_p=self.p_dropout,
|
| 1547 |
deterministic=self.deterministic_fa2,
|
| 1548 |
window_size=self.sliding_window,
|
| 1549 |
-
|
| 1550 |
)
|
| 1551 |
else:
|
| 1552 |
-
assert not self.
|
|
|
|
| 1553 |
qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
|
| 1554 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
|
| 1555 |
attn = F.scaled_dot_product_attention(
|
|
|
|
| 74 |
f"heads ({config.num_attention_heads})"
|
| 75 |
)
|
| 76 |
|
| 77 |
+
self.is_causal = config.causal_mask
|
| 78 |
self.num_attention_heads = config.num_attention_heads
|
| 79 |
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 80 |
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
|
|
| 145 |
dropout_p=self.p_dropout,
|
| 146 |
deterministic=self.deterministic_fa2,
|
| 147 |
alibi_slopes=slopes,
|
| 148 |
+
causal=self.is_causal
|
| 149 |
)
|
| 150 |
attention = attention.to(orig_dtype) # type: ignore
|
| 151 |
else:
|
|
|
|
| 156 |
dropout_p=self.p_dropout,
|
| 157 |
deterministic=self.deterministic_fa2,
|
| 158 |
alibi_slopes=slopes,
|
| 159 |
+
causal = self.is_causal
|
| 160 |
)
|
| 161 |
else:
|
| 162 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
| 163 |
+
assert False
|
| 164 |
qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
| 165 |
unpad_bs, *_ = qkv.shape
|
| 166 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size)
|
|
|
|
| 237 |
slopes: None or (batch, heads) or (heads,)
|
| 238 |
"""
|
| 239 |
assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}"
|
| 240 |
+
assert False
|
| 241 |
self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes)
|
| 242 |
if subset_idx is not None:
|
| 243 |
return self.output(
|
|
|
|
| 295 |
f"heads ({config.num_attention_heads})"
|
| 296 |
)
|
| 297 |
|
| 298 |
+
self.is_causal = config.causal_mask
|
| 299 |
self.num_attention_heads = config.num_attention_heads
|
| 300 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 301 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
|
|
|
| 404 |
dropout_p=self.p_dropout,
|
| 405 |
deterministic=self.deterministic_fa2,
|
| 406 |
window_size=self.sliding_window,
|
| 407 |
+
causal=self.is_causal
|
| 408 |
)
|
| 409 |
attn = attn.to(orig_dtype) # type: ignore
|
| 410 |
else:
|
|
|
|
| 415 |
dropout_p=self.p_dropout,
|
| 416 |
deterministic=self.deterministic_fa2,
|
| 417 |
window_size=self.sliding_window,
|
| 418 |
+
causal=self.is_causal
|
| 419 |
)
|
| 420 |
attn = attn.view(bs, dim)
|
| 421 |
else:
|
| 422 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
| 423 |
+
assert False
|
| 424 |
qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
| 425 |
unpad_bs, seqlen, _ = qkv.shape
|
| 426 |
|
|
|
|
| 459 |
f"heads ({config.num_attention_heads})"
|
| 460 |
)
|
| 461 |
|
| 462 |
+
self.is_causal = config.causal_mask
|
| 463 |
self.num_attention_heads = config.num_attention_heads
|
| 464 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 465 |
self.hidden_size = config.hidden_size
|
|
|
|
| 559 |
dropout_p=self.p_dropout,
|
| 560 |
deterministic=self.deterministic_fa2,
|
| 561 |
window_size=self.sliding_window,
|
| 562 |
+
causal=self.is_causal
|
| 563 |
)
|
| 564 |
attn = attn.to(orig_dtype) # type: ignore
|
| 565 |
else:
|
|
|
|
| 570 |
dropout_p=self.p_dropout,
|
| 571 |
deterministic=self.deterministic_fa2,
|
| 572 |
window_size=self.sliding_window,
|
| 573 |
+
causal=self.is_causal
|
| 574 |
)
|
| 575 |
attn = attn.view(bs, dim)
|
| 576 |
else:
|
| 577 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
| 578 |
+
assert False
|
| 579 |
qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
| 580 |
unpad_bs, seqlen, _ = qkv.shape
|
| 581 |
|
|
|
|
| 614 |
f"heads ({config.num_attention_heads})"
|
| 615 |
)
|
| 616 |
|
| 617 |
+
self.is_causal = config.causal_mask
|
| 618 |
self.num_attention_heads = config.num_attention_heads
|
| 619 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 620 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
|
|
|
| 699 |
dropout_p=self.p_dropout,
|
| 700 |
deterministic=self.deterministic_fa2,
|
| 701 |
window_size=self.sliding_window,
|
| 702 |
+
causal=self.is_causal
|
| 703 |
)
|
| 704 |
attn = attn.to(orig_dtype) # type: ignore
|
| 705 |
else:
|
|
|
|
| 708 |
dropout_p=self.p_dropout,
|
| 709 |
deterministic=self.deterministic_fa2,
|
| 710 |
window_size=self.sliding_window,
|
| 711 |
+
causal=self.is_causal
|
| 712 |
)
|
| 713 |
else:
|
| 714 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
| 715 |
+
assert False
|
| 716 |
qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
|
| 717 |
|
| 718 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2)
|
|
|
|
| 748 |
f"heads ({config.num_attention_heads})"
|
| 749 |
)
|
| 750 |
|
| 751 |
+
self.is_causal = config.causal_mask
|
| 752 |
self.num_attention_heads = config.num_attention_heads
|
| 753 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 754 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
|
|
|
| 887 |
max_seqlen_q=max_seqlen,
|
| 888 |
max_seqlen_k=max_seqlen,
|
| 889 |
deterministic=self.deterministic_fa2,
|
| 890 |
+
causal=self.is_causal,
|
| 891 |
)
|
| 892 |
attn = attn.to(orig_dtype) # type: ignore
|
| 893 |
else:
|
|
|
|
| 901 |
max_seqlen_q=max_seqlen,
|
| 902 |
max_seqlen_k=max_seqlen,
|
| 903 |
deterministic=self.deterministic_fa2,
|
| 904 |
+
causal=self.is_causal,
|
| 905 |
)
|
| 906 |
attn = attn.view(bs, dim)
|
| 907 |
elif self.use_fa2:
|
|
|
|
| 919 |
dropout_p=self.p_dropout,
|
| 920 |
deterministic=self.deterministic_fa2,
|
| 921 |
window_size=self.sliding_window,
|
| 922 |
+
causal=self.is_causal,
|
| 923 |
)
|
| 924 |
attn = attn.to(orig_dtype) # type: ignore
|
| 925 |
else:
|
|
|
|
| 930 |
dropout_p=self.p_dropout,
|
| 931 |
deterministic=self.deterministic_fa2,
|
| 932 |
window_size=self.sliding_window,
|
| 933 |
+
causal=self.is_causal,
|
| 934 |
)
|
| 935 |
attn = attn.view(bs, dim)
|
| 936 |
else:
|
| 937 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
| 938 |
+
assert False
|
| 939 |
qkv = bert_padding.pad_input(
|
| 940 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
| 941 |
) # batch, max_seqlen, thd
|
|
|
|
| 975 |
f"heads ({config.num_attention_heads})"
|
| 976 |
)
|
| 977 |
|
| 978 |
+
self.is_causal = config.causal_mask
|
| 979 |
self.num_attention_heads = config.num_attention_heads
|
| 980 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 981 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
|
|
|
| 1086 |
dropout_p=self.p_dropout,
|
| 1087 |
deterministic=self.deterministic_fa2,
|
| 1088 |
window_size=self.sliding_window,
|
| 1089 |
+
causal=self.is_causal,
|
| 1090 |
)
|
| 1091 |
attn = attn.to(orig_dtype) # type: ignore
|
| 1092 |
else:
|
|
|
|
| 1095 |
dropout_p=self.p_dropout,
|
| 1096 |
deterministic=self.deterministic_fa2,
|
| 1097 |
window_size=self.sliding_window,
|
| 1098 |
+
causal=self.is_causal
|
| 1099 |
)
|
| 1100 |
else:
|
| 1101 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
| 1102 |
+
assert False
|
| 1103 |
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
|
| 1104 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2)
|
| 1105 |
attn = F.scaled_dot_product_attention(
|
|
|
|
| 1134 |
f"heads ({config.num_attention_heads})"
|
| 1135 |
)
|
| 1136 |
|
| 1137 |
+
self.is_causal = config.causal_mask
|
| 1138 |
self.num_attention_heads = config.num_attention_heads
|
| 1139 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 1140 |
self.hidden_size = config.hidden_size
|
|
|
|
| 1260 |
dropout_p=self.p_dropout,
|
| 1261 |
deterministic=self.deterministic_fa2,
|
| 1262 |
window_size=self.sliding_window,
|
| 1263 |
+
causal=self.is_causal,
|
| 1264 |
)
|
| 1265 |
attn = attn.to(orig_dtype) # type: ignore
|
| 1266 |
else:
|
|
|
|
| 1271 |
dropout_p=self.p_dropout,
|
| 1272 |
deterministic=self.deterministic_fa2,
|
| 1273 |
window_size=self.sliding_window,
|
| 1274 |
+
causal=self.is_causal,
|
| 1275 |
)
|
| 1276 |
attn = attn.view(bs, dim)
|
| 1277 |
else:
|
| 1278 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
| 1279 |
+
assert False
|
| 1280 |
qkv = bert_padding.pad_input(
|
| 1281 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
| 1282 |
) # batch, max_seqlen, thd
|
|
|
|
| 1316 |
f"heads ({config.num_attention_heads})"
|
| 1317 |
)
|
| 1318 |
|
| 1319 |
+
self.is_causal = config.causal_mask
|
| 1320 |
self.num_attention_heads = config.num_attention_heads
|
| 1321 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 1322 |
self.hidden_size = config.hidden_size
|
|
|
|
| 1421 |
dropout_p=self.p_dropout,
|
| 1422 |
deterministic=self.deterministic_fa2,
|
| 1423 |
window_size=self.sliding_window,
|
| 1424 |
+
causal=self.is_causal
|
| 1425 |
)
|
| 1426 |
attn = attn.to(orig_dtype) # type: ignore
|
| 1427 |
else:
|
|
|
|
| 1430 |
dropout_p=self.p_dropout,
|
| 1431 |
deterministic=self.deterministic_fa2,
|
| 1432 |
window_size=self.sliding_window,
|
| 1433 |
+
causal=self.is_causal
|
| 1434 |
)
|
| 1435 |
else:
|
| 1436 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
| 1437 |
+
assert False
|
| 1438 |
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
|
| 1439 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2)
|
| 1440 |
attn = F.scaled_dot_product_attention(
|
|
|
|
| 1469 |
f"heads ({config.num_attention_heads})"
|
| 1470 |
)
|
| 1471 |
|
| 1472 |
+
self.is_causal = config.causal_mask
|
| 1473 |
self.num_attention_heads = config.num_attention_heads
|
| 1474 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 1475 |
self.hidden_size = config.hidden_size
|
|
|
|
| 1546 |
dropout_p=self.p_dropout,
|
| 1547 |
deterministic=self.deterministic_fa2,
|
| 1548 |
window_size=self.sliding_window,
|
| 1549 |
+
causal=self.is_causal
|
| 1550 |
)
|
| 1551 |
attn = attn.to(orig_dtype) # type: ignore
|
| 1552 |
else:
|
|
|
|
| 1555 |
dropout_p=self.p_dropout,
|
| 1556 |
deterministic=self.deterministic_fa2,
|
| 1557 |
window_size=self.sliding_window,
|
| 1558 |
+
causal=self.is_causal
|
| 1559 |
)
|
| 1560 |
else:
|
| 1561 |
+
assert not self.is_causal, f"causal attention mask not yet implemented here"
|
| 1562 |
+
assert False
|
| 1563 |
qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
|
| 1564 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
|
| 1565 |
attn = F.scaled_dot_product_attention(
|
config.json
CHANGED
|
@@ -88,4 +88,4 @@
|
|
| 88 |
"use_sdpa_attn_mask": false,
|
| 89 |
"vocab_size": 50368,
|
| 90 |
"is_casual": true
|
| 91 |
-
}
|
|
|
|
| 88 |
"use_sdpa_attn_mask": false,
|
| 89 |
"vocab_size": 50368,
|
| 90 |
"is_casual": true
|
| 91 |
+
}
|
configuration_bert.py
CHANGED
|
@@ -97,7 +97,7 @@ class FlexBertConfig(TransformersBertConfig):
|
|
| 97 |
pad_logits: bool = False,
|
| 98 |
compile_model: bool = False,
|
| 99 |
masked_prediction: bool = False,
|
| 100 |
-
|
| 101 |
**kwargs,
|
| 102 |
):
|
| 103 |
"""
|
|
@@ -157,7 +157,7 @@ class FlexBertConfig(TransformersBertConfig):
|
|
| 157 |
pad_logits (bool): Pad logits after the calculating the loss.
|
| 158 |
compile_model (bool): Compile the subset of the model which can be compiled.
|
| 159 |
masked_prediction (bool): Use only pass the masked tokens throught the final MLM layers
|
| 160 |
-
|
| 161 |
**kwargs: Additional keyword arguments.
|
| 162 |
"""
|
| 163 |
super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
|
|
@@ -215,7 +215,7 @@ class FlexBertConfig(TransformersBertConfig):
|
|
| 215 |
self.pad_logits = pad_logits
|
| 216 |
self.compile_model = compile_model
|
| 217 |
self.masked_prediction = masked_prediction
|
| 218 |
-
self.
|
| 219 |
|
| 220 |
if loss_kwargs.get("return_z_loss", False):
|
| 221 |
if loss_function != "fa_cross_entropy":
|
|
|
|
| 97 |
pad_logits: bool = False,
|
| 98 |
compile_model: bool = False,
|
| 99 |
masked_prediction: bool = False,
|
| 100 |
+
causal_mask: bool = False,
|
| 101 |
**kwargs,
|
| 102 |
):
|
| 103 |
"""
|
|
|
|
| 157 |
pad_logits (bool): Pad logits after the calculating the loss.
|
| 158 |
compile_model (bool): Compile the subset of the model which can be compiled.
|
| 159 |
masked_prediction (bool): Use only pass the masked tokens throught the final MLM layers
|
| 160 |
+
causal (bool): Use a causal mask, defaulting to false.
|
| 161 |
**kwargs: Additional keyword arguments.
|
| 162 |
"""
|
| 163 |
super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
|
|
|
|
| 215 |
self.pad_logits = pad_logits
|
| 216 |
self.compile_model = compile_model
|
| 217 |
self.masked_prediction = masked_prediction
|
| 218 |
+
self.causal_mask = causal_mask
|
| 219 |
|
| 220 |
if loss_kwargs.get("return_z_loss", False):
|
| 221 |
if loss_function != "fa_cross_entropy":
|
modeling_flexbert.py
CHANGED
|
@@ -125,7 +125,6 @@ from .rotary import UnpaddedRotaryEmbedding
|
|
| 125 |
|
| 126 |
logger = logging.getLogger(__name__)
|
| 127 |
|
| 128 |
-
|
| 129 |
def _count_parameters(model: nn.Module, trainable: bool = True) -> int:
|
| 130 |
if trainable:
|
| 131 |
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
@@ -873,7 +872,7 @@ class FlexBertPreTrainedModel(BertPreTrainedModel):
|
|
| 873 |
|
| 874 |
def _init_module_weights(self, module: nn.Module):
|
| 875 |
"""
|
| 876 |
-
Custom weight init of modules using .bert_layers.initialization.init_weights
|
| 877 |
Currently only supports init of embedding modules
|
| 878 |
"""
|
| 879 |
assert isinstance(module, nn.Module)
|
|
@@ -1126,7 +1125,6 @@ class FlexBertForMaskedLM(FlexBertPreTrainedModel):
|
|
| 1126 |
# seqlen) dimensions are flattened
|
| 1127 |
|
| 1128 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1129 |
-
|
| 1130 |
if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
|
| 1131 |
batch_size, seq_len = input_ids.shape[:2]
|
| 1132 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
|
|
@@ -1506,9 +1504,7 @@ class FlexBertForMultipleChoice(FlexBertPreTrainedModel):
|
|
| 1506 |
return params
|
| 1507 |
|
| 1508 |
|
| 1509 |
-
class
|
| 1510 |
-
config_class = FlexBertConfig
|
| 1511 |
-
|
| 1512 |
"""Bert Model transformer with a LM head.
|
| 1513 |
|
| 1514 |
This head is just a standard LM head module. Used for causal language modeling tasks.
|
|
@@ -1538,23 +1534,14 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
| 1538 |
self._init_weights(reset_params=False)
|
| 1539 |
|
| 1540 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
| 1541 |
-
# Handle the XOR condition
|
| 1542 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
| 1543 |
-
|
| 1544 |
-
|
| 1545 |
-
# Add basic initialization for common module types
|
| 1546 |
-
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 1547 |
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 1548 |
-
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 1549 |
-
module.bias.data.zero_()
|
| 1550 |
-
elif isinstance(module, nn.LayerNorm):
|
| 1551 |
-
module.bias.data.zero_()
|
| 1552 |
-
module.weight.data.fill_(1.0)
|
| 1553 |
else:
|
| 1554 |
assert isinstance(reset_params, bool)
|
| 1555 |
self.bert._init_weights(reset_params=reset_params)
|
| 1556 |
self.lm_head._init_weights(reset_params=reset_params)
|
| 1557 |
-
|
| 1558 |
if not self.config.tie_word_embeddings:
|
| 1559 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
| 1560 |
|
|
@@ -1644,7 +1631,6 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
| 1644 |
# seqlen) dimensions are flattened
|
| 1645 |
|
| 1646 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1647 |
-
|
| 1648 |
if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
|
| 1649 |
batch_size, seq_len = input_ids.shape[:2]
|
| 1650 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
|
|
@@ -1664,29 +1650,28 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
| 1664 |
logits = self.compiled_lm_head(hidden_states)
|
| 1665 |
else:
|
| 1666 |
logits = self.lm_head(hidden_states)
|
| 1667 |
-
|
| 1668 |
loss = None
|
| 1669 |
if labels is not None:
|
| 1670 |
-
if
|
| 1671 |
-
# Unpadded case: shift within each sequence using input_ids
|
| 1672 |
-
# Initialize shifted labels from input_ids
|
| 1673 |
shift_labels = torch.full_like(input_ids, -100)
|
| 1674 |
-
|
| 1675 |
-
|
|
|
|
| 1676 |
for i in range(len(cu_seqlens) - 1):
|
| 1677 |
-
|
| 1678 |
-
|
| 1679 |
-
|
| 1680 |
-
|
| 1681 |
-
|
| 1682 |
-
|
| 1683 |
-
|
| 1684 |
-
|
| 1685 |
-
|
| 1686 |
-
|
| 1687 |
-
|
| 1688 |
-
|
| 1689 |
-
|
| 1690 |
|
| 1691 |
else:
|
| 1692 |
# Padded case: simple shift
|
|
@@ -1703,7 +1688,7 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
| 1703 |
)
|
| 1704 |
|
| 1705 |
if self.pad_logits:
|
| 1706 |
-
print(f"Padding logits: {logits.shape}")
|
| 1707 |
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
|
| 1708 |
if len(new_logits.shape) == 2:
|
| 1709 |
new_logits = new_logits.unsqueeze(0)
|
|
@@ -1714,7 +1699,7 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
| 1714 |
attentions=None,
|
| 1715 |
)
|
| 1716 |
else:
|
| 1717 |
-
print(f"Non-padding logits: {logits.shape}")
|
| 1718 |
if len(logits.shape) == 2:
|
| 1719 |
logits = logits.unsqueeze(0)
|
| 1720 |
return CausalLMOutput(
|
|
@@ -1757,7 +1742,6 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
| 1757 |
params += _count_parameters(self.lm_head, trainable)
|
| 1758 |
return params
|
| 1759 |
|
| 1760 |
-
FlexBertForCasualLM.register_for_auto_class("AutoModelForCausalLM")
|
| 1761 |
|
| 1762 |
def init_model_from_pretrained(
|
| 1763 |
pretrained_model: FlexBertModel,
|
|
|
|
| 125 |
|
| 126 |
logger = logging.getLogger(__name__)
|
| 127 |
|
|
|
|
| 128 |
def _count_parameters(model: nn.Module, trainable: bool = True) -> int:
|
| 129 |
if trainable:
|
| 130 |
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
| 872 |
|
| 873 |
def _init_module_weights(self, module: nn.Module):
|
| 874 |
"""
|
| 875 |
+
Custom weight init of modules using src.bert_layers.initialization.init_weights
|
| 876 |
Currently only supports init of embedding modules
|
| 877 |
"""
|
| 878 |
assert isinstance(module, nn.Module)
|
|
|
|
| 1125 |
# seqlen) dimensions are flattened
|
| 1126 |
|
| 1127 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
| 1128 |
if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
|
| 1129 |
batch_size, seq_len = input_ids.shape[:2]
|
| 1130 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
|
|
|
|
| 1504 |
return params
|
| 1505 |
|
| 1506 |
|
| 1507 |
+
class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
|
|
|
|
|
| 1508 |
"""Bert Model transformer with a LM head.
|
| 1509 |
|
| 1510 |
This head is just a standard LM head module. Used for causal language modeling tasks.
|
|
|
|
| 1534 |
self._init_weights(reset_params=False)
|
| 1535 |
|
| 1536 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
|
|
|
| 1537 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
| 1538 |
+
if module:
|
| 1539 |
+
self._init_module_weights(module)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1540 |
else:
|
| 1541 |
assert isinstance(reset_params, bool)
|
| 1542 |
self.bert._init_weights(reset_params=reset_params)
|
| 1543 |
self.lm_head._init_weights(reset_params=reset_params)
|
| 1544 |
+
|
| 1545 |
if not self.config.tie_word_embeddings:
|
| 1546 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
| 1547 |
|
|
|
|
| 1631 |
# seqlen) dimensions are flattened
|
| 1632 |
|
| 1633 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
| 1634 |
if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
|
| 1635 |
batch_size, seq_len = input_ids.shape[:2]
|
| 1636 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
|
|
|
|
| 1650 |
logits = self.compiled_lm_head(hidden_states)
|
| 1651 |
else:
|
| 1652 |
logits = self.lm_head(hidden_states)
|
| 1653 |
+
|
| 1654 |
loss = None
|
| 1655 |
if labels is not None:
|
| 1656 |
+
if cu_seqlens is not None:
|
|
|
|
|
|
|
| 1657 |
shift_labels = torch.full_like(input_ids, -100)
|
| 1658 |
+
shift_labels[:-1] = input_ids[1:]
|
| 1659 |
+
|
| 1660 |
+
# Mask boundaries
|
| 1661 |
for i in range(len(cu_seqlens) - 1):
|
| 1662 |
+
boundary_pos = cu_seqlens[i+1] - 1
|
| 1663 |
+
shift_labels[boundary_pos] = -100
|
| 1664 |
+
|
| 1665 |
+
# Mask out PAD tokens
|
| 1666 |
+
mask = (shift_labels == 50283)
|
| 1667 |
+
shift_labels = torch.where(mask, torch.tensor(-100, device=shift_labels.device), shift_labels)
|
| 1668 |
+
|
| 1669 |
+
|
| 1670 |
+
# print input_ids[(cu_seqlens[2]+1)-5:(cu_seqlens[2]+1)+5]
|
| 1671 |
+
# print shift_labels[(cu_seqlens[2]+1)-5:(cu_seqlens[2]+1)+5]
|
| 1672 |
+
# print input_ids[(cu_seqlens[-2]+1)-5:(cu_seqlens[-2]+1)+5]
|
| 1673 |
+
# print shift_labels[(cu_seqlens[-2]+1)-5:(cu_seqlens[-2]+1)+5]
|
| 1674 |
+
# breakpoint() # pkill -u oweller2 -f wandb
|
| 1675 |
|
| 1676 |
else:
|
| 1677 |
# Padded case: simple shift
|
|
|
|
| 1688 |
)
|
| 1689 |
|
| 1690 |
if self.pad_logits:
|
| 1691 |
+
# print(f"Padding logits: {logits.shape}")
|
| 1692 |
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
|
| 1693 |
if len(new_logits.shape) == 2:
|
| 1694 |
new_logits = new_logits.unsqueeze(0)
|
|
|
|
| 1699 |
attentions=None,
|
| 1700 |
)
|
| 1701 |
else:
|
| 1702 |
+
# print(f"Non-padding logits: {logits.shape}")
|
| 1703 |
if len(logits.shape) == 2:
|
| 1704 |
logits = logits.unsqueeze(0)
|
| 1705 |
return CausalLMOutput(
|
|
|
|
| 1742 |
params += _count_parameters(self.lm_head, trainable)
|
| 1743 |
return params
|
| 1744 |
|
|
|
|
| 1745 |
|
| 1746 |
def init_model_from_pretrained(
|
| 1747 |
pretrained_model: FlexBertModel,
|