kilianhaefeli
commited on
Commit
·
3a8246f
1
Parent(s):
55543b4
add new modleing file which contains fixed to batching
Browse files- modeling.py +174 -47
- modeling_old.py +785 -0
modeling.py
CHANGED
|
@@ -483,16 +483,69 @@ class Fast_dLLM_QwenModel(Fast_dLLM_QwenPreTrainedModel):
|
|
| 483 |
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] if not self.training else inputs_embeds.shape[1]//2, device=inputs_embeds.device
|
| 484 |
)
|
| 485 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
if position_ids is None:
|
| 487 |
-
|
| 488 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
if self.training:
|
| 490 |
attention_mask = self.gen_mask(labels.shape[1], self.bd_size, labels.shape[0], self.config.num_attention_heads).to(device=inputs_embeds.device)
|
| 491 |
else:
|
| 492 |
if use_block_cache and block_past_key_values.get_seq_length() != 0:
|
| 493 |
attention_mask = None
|
| 494 |
else:
|
| 495 |
-
attention_mask = self.eval_mask(input_ids.shape[1], block_size, past_key_values.get_seq_length() if past_key_values is not None else 0).to(device=inputs_embeds.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
hidden_states = inputs_embeds
|
| 498 |
|
|
@@ -652,7 +705,8 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 652 |
def generate(
|
| 653 |
self,
|
| 654 |
input_ids,
|
| 655 |
-
|
|
|
|
| 656 |
mask_id=151665,
|
| 657 |
threshold=1,
|
| 658 |
small_block_size=8,
|
|
@@ -664,59 +718,86 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 664 |
use_block_cache=False,
|
| 665 |
**kwargs
|
| 666 |
):
|
|
|
|
|
|
|
|
|
|
| 667 |
num_blocks = max_new_tokens // block_size
|
|
|
|
|
|
|
| 668 |
original_input_length = input_ids.shape[1]
|
| 669 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 670 |
if input_ids.shape[1] > block_size:
|
| 671 |
-
output = self.forward(input_ids=input_ids[:, :(input_ids.shape[1] // block_size * block_size)], use_cache=True, update_past_key_values=True, block_size=block_size)
|
| 672 |
logits, past_key_values = output.logits, output.past_key_values
|
| 673 |
if input_ids.shape[1] % block_size == 0:
|
| 674 |
next_token = logits[:, -1:, :].argmax(dim=-1)
|
| 675 |
input_ids = torch.cat([input_ids, next_token], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
else:
|
| 677 |
past_key_values = None
|
| 678 |
|
| 679 |
num_small_blocks = block_size // small_block_size
|
| 680 |
|
|
|
|
|
|
|
|
|
|
| 681 |
for block_idx in range(num_blocks):
|
| 682 |
-
|
| 683 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 684 |
prompt_length = input_ids.shape[1]
|
| 685 |
# Initialize x_init with mask_id
|
| 686 |
x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long)
|
| 687 |
-
x_init = torch.cat([input_ids, x_init], dim=1)
|
|
|
|
|
|
|
|
|
|
| 688 |
|
| 689 |
x_t = x_init.clone()
|
| 690 |
block_past_key_values = None
|
|
|
|
| 691 |
while True:
|
| 692 |
-
|
| 693 |
-
stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
|
| 694 |
-
if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
|
| 695 |
-
break
|
| 696 |
mask_idx = (x_t[:, -block_size:] == mask_id)
|
| 697 |
-
|
|
|
|
| 698 |
if mask_idx.sum() == 0:
|
| 699 |
-
|
|
|
|
|
|
|
| 700 |
logits, past_key_values = output.logits, output.past_key_values
|
| 701 |
next_token = logits[:, -1:, :].argmax(dim=-1)
|
| 702 |
x_t = torch.cat([x_t, next_token], dim=1)
|
|
|
|
|
|
|
|
|
|
| 703 |
break
|
| 704 |
for small_block_idx in range(num_small_blocks):
|
| 705 |
-
small_block_start_idx = small_block_idx * small_block_size
|
| 706 |
-
small_block_end_idx = small_block_start_idx + small_block_size
|
| 707 |
|
| 708 |
-
start = -block_size + small_block_start_idx
|
| 709 |
end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx
|
| 710 |
while True:
|
| 711 |
mask_idx = (x_t[:, -block_size:] == mask_id)
|
| 712 |
-
if mask_idx[:, start:end].sum() == 0:
|
| 713 |
break
|
| 714 |
-
if stop_token in x_t[:, prompt_length:]:
|
| 715 |
-
stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
|
| 716 |
-
if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
|
| 717 |
-
break
|
| 718 |
-
|
| 719 |
if use_block_cache:
|
|
|
|
| 720 |
if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
|
| 721 |
output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
|
| 722 |
logits, block_past_key_values = output.logits, output.block_past_key_values
|
|
@@ -726,28 +807,71 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 726 |
logits = self.forward(input_ids=x_t[:,start:end], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True, block_past_key_values=block_past_key_values, replace_position=small_block_start_idx).logits
|
| 727 |
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 728 |
else:
|
| 729 |
-
|
| 730 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
logits = logits[:, start:end]
|
| 732 |
|
| 733 |
-
|
| 734 |
x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
|
| 735 |
# Select tokens with probability greater than threshold from p_1t
|
| 736 |
x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
|
| 737 |
x1_p = torch.where(mask_idx[:, start:end], x1_p, -torch.inf)
|
| 738 |
|
| 739 |
unmask_idx = (x1_p > threshold)
|
|
|
|
| 740 |
max_prob_idx = x1_p.argmax(dim=-1)
|
| 741 |
unmask_idx[torch.arange(x_1.shape[0]), max_prob_idx] = True
|
| 742 |
unmask_idx = unmask_idx & mask_idx[:, start:end]
|
| 743 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 744 |
x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
|
| 745 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 746 |
input_ids = x_t
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 751 |
return input_ids
|
| 752 |
|
| 753 |
def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
|
|
@@ -758,28 +882,31 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 758 |
p_1t = torch.softmax(logits, dim=-1)
|
| 759 |
x_1 = p_1t.argmax(dim=-1)
|
| 760 |
return x_1, p_1t
|
| 761 |
-
|
| 762 |
-
probs = F.softmax(scaled_logits, dim=-1)
|
| 763 |
|
| 764 |
-
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
| 765 |
-
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 766 |
|
| 767 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
| 768 |
-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 769 |
-
sorted_indices_to_remove[..., 0] = 0
|
| 770 |
|
| 771 |
indices_to_remove = torch.zeros_like(probs, dtype=torch.bool).scatter_(
|
| 772 |
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
| 773 |
-
)
|
| 774 |
-
|
| 775 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
probs_sum = torch.sum(probs, dim=-1, keepdim=True)
|
| 780 |
-
normalized_probs = probs / probs_sum
|
| 781 |
|
| 782 |
-
|
| 783 |
-
|
|
|
|
|
|
|
| 784 |
|
| 785 |
-
return x_1, p_1t
|
|
|
|
| 483 |
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] if not self.training else inputs_embeds.shape[1]//2, device=inputs_embeds.device
|
| 484 |
)
|
| 485 |
|
| 486 |
+
# --- keep the user/tokenizer padding mask BEFORE you overwrite attention_mask ---
|
| 487 |
+
padding_mask_2d = attention_mask # shape [B, KV_LEN], 1=token, 0=pad
|
| 488 |
+
|
| 489 |
+
# -------------------------
|
| 490 |
+
# Position ids (left padding)
|
| 491 |
+
# -------------------------
|
| 492 |
if position_ids is None:
|
| 493 |
+
if (padding_mask_2d is not None) and (not self.training):
|
| 494 |
+
# full, per-sample positions over KV_LEN
|
| 495 |
+
pos_full = padding_mask_2d.long().cumsum(-1) - 1 # pads => -1
|
| 496 |
+
pos_full = pos_full.clamp_min(0) # pads => 0
|
| 497 |
+
|
| 498 |
+
q_len = inputs_embeds.shape[1]
|
| 499 |
+
kv_len = pos_full.shape[1]
|
| 500 |
+
if kv_len < q_len:
|
| 501 |
+
raise ValueError(f"attention_mask KV_LEN={kv_len} < input_len={q_len}. "
|
| 502 |
+
"When using cache, pass the FULL mask (past+current).")
|
| 503 |
+
|
| 504 |
+
q_start = kv_len - q_len # assumes current tokens are the last q_len positions
|
| 505 |
+
position_ids = pos_full[:, q_start:]
|
| 506 |
+
else:
|
| 507 |
+
# no padding mask: same positions for all batch elements
|
| 508 |
+
position_ids = cache_position.unsqueeze(0)
|
| 509 |
+
|
| 510 |
+
# -------------------------
|
| 511 |
+
# Attention mask (block-causal + padding), per sample
|
| 512 |
+
# -------------------------
|
| 513 |
if self.training:
|
| 514 |
attention_mask = self.gen_mask(labels.shape[1], self.bd_size, labels.shape[0], self.config.num_attention_heads).to(device=inputs_embeds.device)
|
| 515 |
else:
|
| 516 |
if use_block_cache and block_past_key_values.get_seq_length() != 0:
|
| 517 |
attention_mask = None
|
| 518 |
else:
|
| 519 |
+
# attention_mask = self.eval_mask(input_ids.shape[1], block_size, past_key_values.get_seq_length() if past_key_values is not None else 0).to(device=inputs_embeds.device)
|
| 520 |
+
if padding_mask_2d is None:
|
| 521 |
+
# fallback: original behavior (no padding)
|
| 522 |
+
structural = self.eval_mask(
|
| 523 |
+
seqlen=input_ids.shape[1],
|
| 524 |
+
block_size=block_size,
|
| 525 |
+
cache_seq_len=past_key_values.get_seq_length() if past_key_values is not None else 0,
|
| 526 |
+
).to(device=inputs_embeds.device)
|
| 527 |
+
attention_mask = structural[None, None, :, :] # [1,1,Q,KV]
|
| 528 |
+
else:
|
| 529 |
+
pad = padding_mask_2d.to(torch.bool) # [B, KV]
|
| 530 |
+
B, kv_len = pad.shape
|
| 531 |
+
q_len = inputs_embeds.shape[1]
|
| 532 |
+
q_start = kv_len - q_len
|
| 533 |
+
|
| 534 |
+
# Per-sample block ids computed from *non-pad* positions
|
| 535 |
+
pos_full = pad.long().cumsum(-1) - 1
|
| 536 |
+
pos_full = pos_full.clamp_min(0)
|
| 537 |
+
block_full = pos_full // block_size # [B, KV]
|
| 538 |
+
|
| 539 |
+
block_q = block_full[:, q_start:] # [B, Q]
|
| 540 |
+
block_k = block_full # [B, KV]
|
| 541 |
+
|
| 542 |
+
structural = block_q.unsqueeze(-1) >= block_k.unsqueeze(-2) # [B, Q, KV]
|
| 543 |
+
|
| 544 |
+
# Mask keys AND queries (only valid tokens participate)
|
| 545 |
+
key_ok = pad[:, None, None, :] # [B,1,1,KV]
|
| 546 |
+
query_ok = pad[:, None, q_start:, None] # [B,1,Q,1]
|
| 547 |
+
|
| 548 |
+
attention_mask = structural[:, None, :, :] & key_ok & query_ok # [B,1,Q,KV]
|
| 549 |
|
| 550 |
hidden_states = inputs_embeds
|
| 551 |
|
|
|
|
| 705 |
def generate(
|
| 706 |
self,
|
| 707 |
input_ids,
|
| 708 |
+
attention_mask=None, # --- ADDED ARGUMENT ---
|
| 709 |
+
max_new_tokens=20, # Added default value for safety
|
| 710 |
mask_id=151665,
|
| 711 |
threshold=1,
|
| 712 |
small_block_size=8,
|
|
|
|
| 718 |
use_block_cache=False,
|
| 719 |
**kwargs
|
| 720 |
):
|
| 721 |
+
if use_block_cache:
|
| 722 |
+
raise ValueError("use_block_cache=True is not supported in this generate() implementation.")
|
| 723 |
+
assert attention_mask is not None, "attention_mask must be provided for this generate() implementation."
|
| 724 |
num_blocks = max_new_tokens // block_size
|
| 725 |
+
device = input_ids.device
|
| 726 |
+
batch_size = input_ids.size(0)
|
| 727 |
original_input_length = input_ids.shape[1]
|
| 728 |
|
| 729 |
+
# Track which sequences in the batch are still active
|
| 730 |
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
|
| 731 |
+
|
| 732 |
+
# Handle prefix processing (Context Encoding)
|
| 733 |
if input_ids.shape[1] > block_size:
|
| 734 |
+
output = self.forward(input_ids=input_ids[:, :(input_ids.shape[1] // block_size * block_size)], attention_mask=attention_mask[:, :(input_ids.shape[1] // block_size * block_size)], use_cache=True, update_past_key_values=True, block_size=block_size)
|
| 735 |
logits, past_key_values = output.logits, output.past_key_values
|
| 736 |
if input_ids.shape[1] % block_size == 0:
|
| 737 |
next_token = logits[:, -1:, :].argmax(dim=-1)
|
| 738 |
input_ids = torch.cat([input_ids, next_token], dim=1)
|
| 739 |
+
|
| 740 |
+
# Update finished status
|
| 741 |
+
# unfinished_sequences = unfinished_sequences & (next_token.squeeze(-1) != stop_token).long()
|
| 742 |
+
# Append to mask: If unfinished, append 1. If finished, append 0.
|
| 743 |
+
new_mask_col = unfinished_sequences.unsqueeze(1).to(dtype=attention_mask.dtype)
|
| 744 |
+
attention_mask = torch.cat([attention_mask, new_mask_col], dim=1)
|
| 745 |
else:
|
| 746 |
past_key_values = None
|
| 747 |
|
| 748 |
num_small_blocks = block_size // small_block_size
|
| 749 |
|
| 750 |
+
iterations = torch.zeros((batch_size,), device=device)
|
| 751 |
+
n_generated_tokens = torch.zeros((batch_size,), device=device)
|
| 752 |
+
finished = torch.zeros((batch_size,), dtype=torch.bool, device=device)
|
| 753 |
for block_idx in range(num_blocks):
|
| 754 |
+
new_tokens = input_ids[:, original_input_length:]
|
| 755 |
+
has_stop_now = (new_tokens == stop_token).any(dim=1)
|
| 756 |
+
finished |= has_stop_now
|
| 757 |
+
|
| 758 |
+
stop_tokens = torch.where(input_ids[:, original_input_length:] == stop_token, 1.0, 0.0).sum(dim=1)
|
| 759 |
+
if stop_tokens.min() > 0:
|
| 760 |
+
break # break if all sequences have generated the stop token
|
| 761 |
+
|
| 762 |
prompt_length = input_ids.shape[1]
|
| 763 |
# Initialize x_init with mask_id
|
| 764 |
x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long)
|
| 765 |
+
x_init = torch.cat([input_ids, x_init], dim=1) # 1: [1, 32]
|
| 766 |
+
|
| 767 |
+
mask_extension = unfinished_sequences.unsqueeze(1).repeat(1, block_size - prompt_length % block_size).to(dtype=attention_mask.dtype)
|
| 768 |
+
curr_attention_mask = torch.cat([attention_mask, mask_extension], dim=1)
|
| 769 |
|
| 770 |
x_t = x_init.clone()
|
| 771 |
block_past_key_values = None
|
| 772 |
+
|
| 773 |
while True:
|
| 774 |
+
# Mask idx is just the current block hwere there are masks. (Note, that first token is never a mask token, because at least one token must be present to condition on!)
|
|
|
|
|
|
|
|
|
|
| 775 |
mask_idx = (x_t[:, -block_size:] == mask_id)
|
| 776 |
+
|
| 777 |
+
# If in the current 32 block there is no mask id, then we have generated 31 tokens already at max! (one token must always be et otherwise there is not conditional generation possible!)
|
| 778 |
if mask_idx.sum() == 0:
|
| 779 |
+
# In here we predict the last token whcih corresponds to the first token in the next batch.
|
| 780 |
+
# Why not predict it before? If sampling it before the rest was sampled then this would violate semi-autoregressiveness.
|
| 781 |
+
output = self.forward(input_ids=x_t[:, -block_size:], attention_mask=curr_attention_mask, use_cache=True, past_key_values=past_key_values, update_past_key_values=True, block_size=block_size)
|
| 782 |
logits, past_key_values = output.logits, output.past_key_values
|
| 783 |
next_token = logits[:, -1:, :].argmax(dim=-1)
|
| 784 |
x_t = torch.cat([x_t, next_token], dim=1)
|
| 785 |
+
curr_attention_mask = torch.cat([curr_attention_mask, unfinished_sequences.unsqueeze(1).to(curr_attention_mask.dtype)], dim=1)
|
| 786 |
+
iterations += (~finished).long()
|
| 787 |
+
n_generated_tokens += (~finished).long()
|
| 788 |
break
|
| 789 |
for small_block_idx in range(num_small_blocks):
|
| 790 |
+
small_block_start_idx = small_block_idx * small_block_size # 0
|
| 791 |
+
small_block_end_idx = small_block_start_idx + small_block_size # 32
|
| 792 |
|
| 793 |
+
start = -block_size + small_block_start_idx # -32
|
| 794 |
end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx
|
| 795 |
while True:
|
| 796 |
mask_idx = (x_t[:, -block_size:] == mask_id)
|
| 797 |
+
if mask_idx[:, start:end].sum() == 0: # [-32:None] (just get 32 tokens newest of the mask ids, meaning we dont have any current mask preictions)
|
| 798 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 799 |
if use_block_cache:
|
| 800 |
+
assert False, "use_block_cache=True is not supported in this generate() implementation."
|
| 801 |
if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
|
| 802 |
output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
|
| 803 |
logits, block_past_key_values = output.logits, output.block_past_key_values
|
|
|
|
| 807 |
logits = self.forward(input_ids=x_t[:,start:end], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True, block_past_key_values=block_past_key_values, replace_position=small_block_start_idx).logits
|
| 808 |
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 809 |
else:
|
| 810 |
+
# enter most recent 32 tokens as the input
|
| 811 |
+
# kv cache does not contain the kv for our current block yet so it gets recomputed always!
|
| 812 |
+
logits = self.forward(input_ids=x_t[:, -block_size:], attention_mask=curr_attention_mask, use_cache=True, past_key_values=past_key_values, update_past_key_values=False,block_size=block_size,).logits
|
| 813 |
+
# the logits to be sampled from are the most recent 32 tokens
|
| 814 |
+
# shift because of autoregressive conversion and valid by appending anything to the start since first token mask is off anyways always.
|
| 815 |
+
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1) # TODO maybe prepend nan or sth
|
| 816 |
logits = logits[:, start:end]
|
| 817 |
|
|
|
|
| 818 |
x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
|
| 819 |
# Select tokens with probability greater than threshold from p_1t
|
| 820 |
x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
|
| 821 |
x1_p = torch.where(mask_idx[:, start:end], x1_p, -torch.inf)
|
| 822 |
|
| 823 |
unmask_idx = (x1_p > threshold)
|
| 824 |
+
# Ensure at least one token is unmasked in the current small block
|
| 825 |
max_prob_idx = x1_p.argmax(dim=-1)
|
| 826 |
unmask_idx[torch.arange(x_1.shape[0]), max_prob_idx] = True
|
| 827 |
unmask_idx = unmask_idx & mask_idx[:, start:end]
|
| 828 |
|
| 829 |
+
# Add 1 to iterations if the sequence is not stopped AND at least one token is generated in this iteration
|
| 830 |
+
iterations += (~finished & unmask_idx.any(dim=1)).long()
|
| 831 |
+
|
| 832 |
+
# count number of generated tokens in this iteration if not stopped
|
| 833 |
+
n_generated_iter = torch.where(finished, 0, unmask_idx.sum(dim=1))
|
| 834 |
+
n_generated_tokens += n_generated_iter
|
| 835 |
+
|
| 836 |
+
# Only update the positions where unmask_idx is True AND the sequence if not finished TODO check this, otherwise
|
| 837 |
x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
|
| 838 |
|
| 839 |
+
new_tokens = input_ids[:, original_input_length:]
|
| 840 |
+
has_stop_now = (new_tokens == stop_token).any(dim=1)
|
| 841 |
+
finished |= has_stop_now
|
| 842 |
+
|
| 843 |
input_ids = x_t
|
| 844 |
+
attention_mask = curr_attention_mask
|
| 845 |
+
|
| 846 |
+
print("Generated iterations per sequence:", iterations)
|
| 847 |
+
print("Generated tokens per sequence:", n_generated_tokens)
|
| 848 |
+
print("Average tokens generated per iteration:", (n_generated_tokens / iterations))
|
| 849 |
+
# Final truncation: keep everything up to the *latest* first stop_token
|
| 850 |
+
new_tokens = input_ids[:, original_input_length:]
|
| 851 |
+
has_stop = (new_tokens == stop_token)
|
| 852 |
+
|
| 853 |
+
gen = input_ids[:, original_input_length:] # (B, T)
|
| 854 |
+
|
| 855 |
+
T = gen.size(1)
|
| 856 |
+
|
| 857 |
+
if T > 0:
|
| 858 |
+
device = input_ids.device
|
| 859 |
+
B = input_ids.size(0)
|
| 860 |
+
|
| 861 |
+
idx = torch.arange(T, device=device).unsqueeze(0).expand(B, T)
|
| 862 |
+
stop_mask = gen.eq(stop_token)
|
| 863 |
+
|
| 864 |
+
first_stop = torch.where(stop_mask, idx, torch.full_like(idx, T)).min(dim=1).values
|
| 865 |
+
has_stop = first_stop < T
|
| 866 |
+
keep = torch.where(has_stop, first_stop + 1, torch.full_like(first_stop, T))
|
| 867 |
+
|
| 868 |
+
pad_id = self.config.pad_token_id if getattr(self.config, "pad_token_id", None) is not None else stop_token
|
| 869 |
+
after = idx >= keep.unsqueeze(1)
|
| 870 |
+
gen = gen.clone()
|
| 871 |
+
gen[after] = pad_id
|
| 872 |
+
|
| 873 |
+
input_ids = torch.cat([input_ids[:, :original_input_length], gen], dim=1)
|
| 874 |
+
|
| 875 |
return input_ids
|
| 876 |
|
| 877 |
def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
|
|
|
|
| 882 |
p_1t = torch.softmax(logits, dim=-1)
|
| 883 |
x_1 = p_1t.argmax(dim=-1)
|
| 884 |
return x_1, p_1t
|
| 885 |
+
probs = torch.softmax(scaled_logits, dim=-1) # [B, seq_len, vocab_size]
|
|
|
|
| 886 |
|
| 887 |
+
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) # [B, seq_len, sorted(vocab_size)]
|
| 888 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # [B, seq_len, cumsum(sorted(vocab_size))]
|
| 889 |
|
| 890 |
+
sorted_indices_to_remove = cumulative_probs > top_p # [B, seq_len, bool(sorted(vocab_size))]
|
| 891 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() # clone the tensor to avoid in-place operation error
|
| 892 |
+
sorted_indices_to_remove[..., 0] = 0 # always keep at least one token
|
| 893 |
|
| 894 |
indices_to_remove = torch.zeros_like(probs, dtype=torch.bool).scatter_(
|
| 895 |
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
| 896 |
+
) # [B, seq_len, vocab_size], take 0 array and
|
| 897 |
+
# set True at the indices where sorted_indices_to_remove is True
|
| 898 |
+
# we index using the sorted indices in order to put the values back to their original position
|
| 899 |
+
|
| 900 |
+
# prev: probs[indices_to_remove] = 0, indices_to_remove is of the same shape as probs
|
| 901 |
+
# and therefore this operation just selects
|
| 902 |
+
probs = probs.masked_fill(indices_to_remove, 0.0)
|
| 903 |
|
| 904 |
+
probs_sum = probs.sum(dim=-1, keepdim=True).clamp_min(1e-12)
|
| 905 |
+
p_1t = probs / probs_sum
|
|
|
|
|
|
|
| 906 |
|
| 907 |
+
vocab = p_1t.shape[-1]
|
| 908 |
+
flat = p_1t.reshape(-1, vocab)
|
| 909 |
+
samples = torch.multinomial(flat, num_samples=1).squeeze(-1)
|
| 910 |
+
x_1 = samples.view(*p_1t.shape[:-1])
|
| 911 |
|
| 912 |
+
return x_1, p_1t
|
modeling_old.py
ADDED
|
@@ -0,0 +1,785 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Optional, Union
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
from transformers.activations import ACT2FN
|
| 10 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 11 |
+
from transformers.generation import GenerationMixin
|
| 12 |
+
from transformers.integrations import use_kernel_forward_from_hub
|
| 13 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 14 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 15 |
+
from transformers.modeling_outputs import (
|
| 16 |
+
BaseModelOutputWithPast,
|
| 17 |
+
CausalLMOutputWithPast,
|
| 18 |
+
)
|
| 19 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 20 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 21 |
+
from transformers.processing_utils import Unpack
|
| 22 |
+
from transformers.utils import auto_docstring, can_return_tuple, logging
|
| 23 |
+
from .configuration import Fast_dLLM_QwenConfig
|
| 24 |
+
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
| 25 |
+
from einops import rearrange, repeat
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class CausalLMOutputWithPastAndBlockCache(CausalLMOutputWithPast):
|
| 32 |
+
block_past_key_values: Optional[Cache] = None
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class BaseModelOutputWithPastAndBlockCache(BaseModelOutputWithPast):
|
| 36 |
+
block_past_key_values: Optional[Cache] = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
|
| 40 |
+
def fused_flex_attention(q, k, v, mask=None):
|
| 41 |
+
return flex_attention(q, k, v, block_mask=mask, enable_gqa=True)
|
| 42 |
+
|
| 43 |
+
def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
|
| 44 |
+
"""
|
| 45 |
+
Constructs the specialized block diffusion attention mask for training
|
| 46 |
+
composed of three masks:
|
| 47 |
+
- **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
|
| 48 |
+
- **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
|
| 49 |
+
- **Block Causal Mask (M_BC)**: Attention to update x0
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
b, h: Batch and head indices (ignored for mask logic).
|
| 53 |
+
q_idx, kv_idx: Query and Key indices.
|
| 54 |
+
seq_len: Total sequence length.
|
| 55 |
+
block_size: Defines the block structure.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
A boolean attention mask.
|
| 59 |
+
"""
|
| 60 |
+
# Indicate whether token belongs to xt or x0
|
| 61 |
+
x0_flag_q = (q_idx >= n)
|
| 62 |
+
x0_flag_kv = (kv_idx >= n)
|
| 63 |
+
|
| 64 |
+
# Compute block indices
|
| 65 |
+
block_q = torch.where(x0_flag_q == 1,
|
| 66 |
+
(q_idx - n) // block_size,
|
| 67 |
+
q_idx // block_size)
|
| 68 |
+
block_kv = torch.where(x0_flag_kv == 1,
|
| 69 |
+
(kv_idx - n) // block_size,
|
| 70 |
+
kv_idx // block_size)
|
| 71 |
+
|
| 72 |
+
# **1. Block Diagonal Mask (M_BD) **
|
| 73 |
+
block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
|
| 74 |
+
|
| 75 |
+
# **2. Offset Block-Causal Mask (M_OBC) **
|
| 76 |
+
offset_block_causal = (
|
| 77 |
+
(block_q > block_kv)
|
| 78 |
+
& (x0_flag_kv == 1)
|
| 79 |
+
& (x0_flag_q == 0)
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# **3. Block-Causal Mask (M_BC) **
|
| 83 |
+
block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
|
| 84 |
+
|
| 85 |
+
# **4. Combine Masks **
|
| 86 |
+
return block_diagonal | offset_block_causal | block_causal
|
| 87 |
+
|
| 88 |
+
def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
|
| 89 |
+
# Compute block indices
|
| 90 |
+
block_q = q_idx // block_size
|
| 91 |
+
block_kv = kv_idx // block_size
|
| 92 |
+
|
| 93 |
+
return block_q >= block_kv
|
| 94 |
+
|
| 95 |
+
class Fast_dLLM_QwenMLP(nn.Module):
|
| 96 |
+
def __init__(self, config):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.config = config
|
| 99 |
+
self.hidden_size = config.hidden_size
|
| 100 |
+
self.intermediate_size = config.intermediate_size
|
| 101 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 102 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 103 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 104 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 108 |
+
return down_proj
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def rotate_half(x):
|
| 112 |
+
"""Rotates half the hidden dims of the input."""
|
| 113 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 114 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 115 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 119 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
q (`torch.Tensor`): The query tensor.
|
| 123 |
+
k (`torch.Tensor`): The key tensor.
|
| 124 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 125 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 126 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 127 |
+
Deprecated and unused.
|
| 128 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 129 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 130 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 131 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 132 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 133 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 134 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 135 |
+
Returns:
|
| 136 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 137 |
+
"""
|
| 138 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 139 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 140 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 141 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 142 |
+
return q_embed, k_embed
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 146 |
+
"""
|
| 147 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 148 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 149 |
+
"""
|
| 150 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 151 |
+
if n_rep == 1:
|
| 152 |
+
return hidden_states
|
| 153 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 154 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class Fast_dLLM_QwenAttention(nn.Module):
|
| 158 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 159 |
+
|
| 160 |
+
def __init__(self, config: Fast_dLLM_QwenConfig, layer_idx: int):
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.config = config
|
| 163 |
+
self.layer_idx = layer_idx
|
| 164 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 165 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 166 |
+
self.scaling = self.head_dim**-0.5
|
| 167 |
+
self.attention_dropout = config.attention_dropout
|
| 168 |
+
self.is_causal = True
|
| 169 |
+
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
|
| 170 |
+
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
|
| 171 |
+
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
|
| 172 |
+
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
| 173 |
+
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
|
| 174 |
+
|
| 175 |
+
def forward(
|
| 176 |
+
self,
|
| 177 |
+
hidden_states: torch.Tensor,
|
| 178 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 179 |
+
attention_mask: Optional[torch.Tensor],
|
| 180 |
+
past_key_value: Optional[Cache] = None,
|
| 181 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 182 |
+
update_past_key_values: Optional[bool] = False,
|
| 183 |
+
block_past_key_values: Optional[Cache] = None,
|
| 184 |
+
replace_position: Optional[int] = None,
|
| 185 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 186 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 187 |
+
input_shape = hidden_states.shape[:-1]
|
| 188 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 189 |
+
|
| 190 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 191 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 192 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 193 |
+
|
| 194 |
+
cos, sin = position_embeddings
|
| 195 |
+
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 196 |
+
if self.training:
|
| 197 |
+
#split q into two parts
|
| 198 |
+
q_1 = query_states[:,:,:query_states.shape[2]//2]
|
| 199 |
+
q_2 = query_states[:,:,query_states.shape[2]//2:]
|
| 200 |
+
#split k into two parts
|
| 201 |
+
k_1 = key_states[:,:,:key_states.shape[2]//2]
|
| 202 |
+
k_2 = key_states[:,:,key_states.shape[2]//2:]
|
| 203 |
+
q_1, k_1 = apply_rotary_pos_emb(q_1, k_1, cos, sin)
|
| 204 |
+
q_2, k_2 = apply_rotary_pos_emb(q_2, k_2, cos, sin)
|
| 205 |
+
query_states = torch.cat((q_1, q_2), dim=-2)
|
| 206 |
+
key_states = torch.cat((k_1, k_2), dim=-2)
|
| 207 |
+
else:
|
| 208 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 209 |
+
|
| 210 |
+
if block_past_key_values is not None:
|
| 211 |
+
if len(block_past_key_values) <= self.layer_idx:
|
| 212 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 213 |
+
key_states, value_states = block_past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 214 |
+
else:
|
| 215 |
+
block_cache_key_states = block_past_key_values[self.layer_idx][0]
|
| 216 |
+
block_cache_value_states = block_past_key_values[self.layer_idx][1]
|
| 217 |
+
|
| 218 |
+
block_cache_key_states[:, :, replace_position:replace_position+key_states.shape[2]] = key_states
|
| 219 |
+
block_cache_value_states[:, :, replace_position:replace_position+value_states.shape[2]] = value_states
|
| 220 |
+
key_states = block_cache_key_states
|
| 221 |
+
value_states = block_cache_value_states
|
| 222 |
+
|
| 223 |
+
if past_key_value is not None:
|
| 224 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 225 |
+
if update_past_key_values:
|
| 226 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 227 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 228 |
+
elif len(past_key_value) > self.layer_idx:
|
| 229 |
+
key_states = torch.cat((past_key_value[self.layer_idx][0], key_states), dim=-2)
|
| 230 |
+
value_states = torch.cat((past_key_value[self.layer_idx][1], value_states), dim=-2)
|
| 231 |
+
|
| 232 |
+
if self.training:
|
| 233 |
+
attn_output = fused_flex_attention(query_states, key_states, value_states, mask=attention_mask)
|
| 234 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 235 |
+
else:
|
| 236 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
|
| 237 |
+
|
| 238 |
+
attn_output, attn_weights = attention_interface(
|
| 239 |
+
self,
|
| 240 |
+
query_states,
|
| 241 |
+
key_states,
|
| 242 |
+
value_states,
|
| 243 |
+
attention_mask,
|
| 244 |
+
is_causal=False,
|
| 245 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 246 |
+
scaling=self.scaling,
|
| 247 |
+
sliding_window=self.sliding_window, # main diff with Llama
|
| 248 |
+
**kwargs,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 252 |
+
attn_output = self.o_proj(attn_output)
|
| 253 |
+
return attn_output
|
| 254 |
+
|
| 255 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 256 |
+
class Fast_dLLM_QwenRMSNorm(nn.Module):
|
| 257 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 258 |
+
"""
|
| 259 |
+
Fast_dLLM_QwenRMSNorm is equivalent to T5LayerNorm
|
| 260 |
+
"""
|
| 261 |
+
super().__init__()
|
| 262 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 263 |
+
self.variance_epsilon = eps
|
| 264 |
+
|
| 265 |
+
def forward(self, hidden_states):
|
| 266 |
+
input_dtype = hidden_states.dtype
|
| 267 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 268 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 269 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 270 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 271 |
+
|
| 272 |
+
def extra_repr(self):
|
| 273 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class Fast_dLLM_QwenDecoderLayer(GradientCheckpointingLayer):
|
| 277 |
+
def __init__(self, config: Fast_dLLM_QwenConfig, layer_idx: int):
|
| 278 |
+
super().__init__()
|
| 279 |
+
self.hidden_size = config.hidden_size
|
| 280 |
+
|
| 281 |
+
self.self_attn = Fast_dLLM_QwenAttention(config=config, layer_idx=layer_idx)
|
| 282 |
+
|
| 283 |
+
self.mlp = Fast_dLLM_QwenMLP(config)
|
| 284 |
+
self.input_layernorm = Fast_dLLM_QwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 285 |
+
self.post_attention_layernorm = Fast_dLLM_QwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 286 |
+
self.attention_type = config.layer_types[layer_idx]
|
| 287 |
+
|
| 288 |
+
def forward(
|
| 289 |
+
self,
|
| 290 |
+
hidden_states: torch.Tensor,
|
| 291 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 292 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 293 |
+
past_key_value: Optional[Cache] = None,
|
| 294 |
+
use_cache: Optional[bool] = False,
|
| 295 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 296 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 297 |
+
update_past_key_values: Optional[bool] = False,
|
| 298 |
+
use_block_cache: Optional[bool] = False,
|
| 299 |
+
block_past_key_values: Optional[Cache] = None,
|
| 300 |
+
replace_position: Optional[int] = None,
|
| 301 |
+
**kwargs
|
| 302 |
+
) -> tuple[torch.Tensor]:
|
| 303 |
+
residual = hidden_states
|
| 304 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 305 |
+
# Self Attention
|
| 306 |
+
hidden_states = self.self_attn(
|
| 307 |
+
hidden_states=hidden_states,
|
| 308 |
+
attention_mask=attention_mask,
|
| 309 |
+
position_ids=position_ids,
|
| 310 |
+
past_key_value=past_key_value,
|
| 311 |
+
use_cache=use_cache,
|
| 312 |
+
cache_position=cache_position,
|
| 313 |
+
position_embeddings=position_embeddings,
|
| 314 |
+
update_past_key_values=update_past_key_values,
|
| 315 |
+
use_block_cache=use_block_cache,
|
| 316 |
+
block_past_key_values=block_past_key_values,
|
| 317 |
+
replace_position=replace_position,
|
| 318 |
+
**kwargs,
|
| 319 |
+
)
|
| 320 |
+
hidden_states = residual + hidden_states
|
| 321 |
+
|
| 322 |
+
# Fully Connected
|
| 323 |
+
residual = hidden_states
|
| 324 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 325 |
+
hidden_states = self.mlp(hidden_states)
|
| 326 |
+
hidden_states = residual + hidden_states
|
| 327 |
+
return hidden_states
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class Fast_dLLM_QwenPreTrainedModel(PreTrainedModel):
|
| 332 |
+
config_class = Fast_dLLM_QwenConfig
|
| 333 |
+
base_model_prefix = "model"
|
| 334 |
+
supports_gradient_checkpointing = True
|
| 335 |
+
_no_split_modules = ["Fast_dLLM_QwenDecoderLayer"]
|
| 336 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 337 |
+
_supports_flash_attn_2 = True
|
| 338 |
+
_supports_sdpa = True
|
| 339 |
+
_supports_flex_attn = True
|
| 340 |
+
_supports_cache_class = True
|
| 341 |
+
_supports_quantized_cache = True
|
| 342 |
+
_supports_static_cache = True
|
| 343 |
+
_supports_attention_backend = True
|
| 344 |
+
_can_record_outputs = {
|
| 345 |
+
"hidden_states": Fast_dLLM_QwenDecoderLayer,
|
| 346 |
+
"attentions": Fast_dLLM_QwenAttention,
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
def _init_weights(self, module):
|
| 350 |
+
std = self.config.initializer_range
|
| 351 |
+
if isinstance(module, nn.Linear):
|
| 352 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 353 |
+
if module.bias is not None:
|
| 354 |
+
module.bias.data.zero_()
|
| 355 |
+
elif isinstance(module, nn.Embedding):
|
| 356 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 357 |
+
if module.padding_idx is not None:
|
| 358 |
+
module.weight.data[module.padding_idx].zero_()
|
| 359 |
+
elif isinstance(module, Fast_dLLM_QwenRMSNorm):
|
| 360 |
+
module.weight.data.fill_(1.0)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class Fast_dLLM_QwenRotaryEmbedding(nn.Module):
|
| 364 |
+
def __init__(self, config: Fast_dLLM_QwenConfig, device=None):
|
| 365 |
+
super().__init__()
|
| 366 |
+
# BC: "rope_type" was originally "type"
|
| 367 |
+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 368 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 369 |
+
else:
|
| 370 |
+
self.rope_type = "default"
|
| 371 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 372 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 373 |
+
|
| 374 |
+
self.config = config
|
| 375 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 376 |
+
|
| 377 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 378 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 379 |
+
self.original_inv_freq = self.inv_freq
|
| 380 |
+
|
| 381 |
+
@torch.no_grad()
|
| 382 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 383 |
+
def forward(self, x, position_ids):
|
| 384 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 385 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 386 |
+
|
| 387 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 388 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 389 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 390 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 391 |
+
cos = emb.cos() * self.attention_scaling
|
| 392 |
+
sin = emb.sin() * self.attention_scaling
|
| 393 |
+
|
| 394 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
class Fast_dLLM_QwenModel(Fast_dLLM_QwenPreTrainedModel):
|
| 399 |
+
def __init__(self, config: Fast_dLLM_QwenConfig):
|
| 400 |
+
super().__init__(config)
|
| 401 |
+
self.padding_idx = config.pad_token_id
|
| 402 |
+
self.vocab_size = config.vocab_size
|
| 403 |
+
self.bd_size = config.bd_size
|
| 404 |
+
|
| 405 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 406 |
+
self.layers = nn.ModuleList(
|
| 407 |
+
[Fast_dLLM_QwenDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 408 |
+
)
|
| 409 |
+
self.norm = Fast_dLLM_QwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 410 |
+
self.rotary_emb = Fast_dLLM_QwenRotaryEmbedding(config=config)
|
| 411 |
+
self.gradient_checkpointing = True
|
| 412 |
+
|
| 413 |
+
# Initialize weights and apply final processing
|
| 414 |
+
self.post_init()
|
| 415 |
+
|
| 416 |
+
def get_input_embeddings(self):
|
| 417 |
+
return self.embed_tokens
|
| 418 |
+
|
| 419 |
+
def set_input_embeddings(self, value):
|
| 420 |
+
self.embed_tokens = value
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def eval_mask(self, seqlen, block_size, cache_seq_len):
|
| 424 |
+
q_indices = torch.arange(seqlen) + cache_seq_len
|
| 425 |
+
k_indices = torch.arange(seqlen + cache_seq_len)
|
| 426 |
+
mask = eval_block_diff_mask(
|
| 427 |
+
q_idx=q_indices[:, None],
|
| 428 |
+
kv_idx=k_indices[None, :],
|
| 429 |
+
block_size=block_size
|
| 430 |
+
)
|
| 431 |
+
return mask
|
| 432 |
+
|
| 433 |
+
def gen_mask(self, seqlen, block_size, B, H):
|
| 434 |
+
mask = create_block_mask(
|
| 435 |
+
partial(block_diff_mask, block_size=block_size, n=seqlen),
|
| 436 |
+
B=B, H=H, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
|
| 437 |
+
|
| 438 |
+
return mask
|
| 439 |
+
|
| 440 |
+
def forward(
|
| 441 |
+
self,
|
| 442 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 443 |
+
labels: Optional[torch.LongTensor] = None,
|
| 444 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 445 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 446 |
+
past_key_values: Optional[Cache] = None,
|
| 447 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 448 |
+
use_cache: Optional[bool] = None,
|
| 449 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 450 |
+
update_past_key_values: Optional[bool] = False,
|
| 451 |
+
block_size: Optional[int] = 32,
|
| 452 |
+
use_block_cache: Optional[bool] = False,
|
| 453 |
+
block_past_key_values: Optional[Cache] = None,
|
| 454 |
+
replace_position: Optional[int] = None,
|
| 455 |
+
**kwargs
|
| 456 |
+
) -> BaseModelOutputWithPast:
|
| 457 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 458 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 459 |
+
|
| 460 |
+
if inputs_embeds is None:
|
| 461 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 462 |
+
|
| 463 |
+
if use_cache and past_key_values is None:
|
| 464 |
+
past_key_values = DynamicCache()
|
| 465 |
+
|
| 466 |
+
if use_block_cache and block_past_key_values is None:
|
| 467 |
+
block_past_key_values = DynamicCache()
|
| 468 |
+
|
| 469 |
+
if cache_position is None:
|
| 470 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 471 |
+
if self.training:
|
| 472 |
+
cache_position = torch.arange(
|
| 473 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]//2, device=inputs_embeds.device
|
| 474 |
+
)
|
| 475 |
+
else:
|
| 476 |
+
if use_block_cache:
|
| 477 |
+
block_start_position = past_seen_tokens+replace_position if replace_position is not None else past_seen_tokens
|
| 478 |
+
cache_position = torch.arange(
|
| 479 |
+
block_start_position, block_start_position + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 480 |
+
)
|
| 481 |
+
else:
|
| 482 |
+
cache_position = torch.arange(
|
| 483 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] if not self.training else inputs_embeds.shape[1]//2, device=inputs_embeds.device
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
if position_ids is None:
|
| 487 |
+
position_ids = cache_position.unsqueeze(0)
|
| 488 |
+
|
| 489 |
+
if self.training:
|
| 490 |
+
attention_mask = self.gen_mask(labels.shape[1], self.bd_size, labels.shape[0], self.config.num_attention_heads).to(device=inputs_embeds.device)
|
| 491 |
+
else:
|
| 492 |
+
if use_block_cache and block_past_key_values.get_seq_length() != 0:
|
| 493 |
+
attention_mask = None
|
| 494 |
+
else:
|
| 495 |
+
attention_mask = self.eval_mask(input_ids.shape[1], block_size, past_key_values.get_seq_length() if past_key_values is not None else 0).to(device=inputs_embeds.device)
|
| 496 |
+
|
| 497 |
+
hidden_states = inputs_embeds
|
| 498 |
+
|
| 499 |
+
# create position embeddings to be shared across the decoder layers
|
| 500 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 501 |
+
|
| 502 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 503 |
+
hidden_states = decoder_layer(
|
| 504 |
+
hidden_states,
|
| 505 |
+
attention_mask=attention_mask,
|
| 506 |
+
position_ids=position_ids,
|
| 507 |
+
past_key_value=past_key_values,
|
| 508 |
+
use_cache=use_cache,
|
| 509 |
+
cache_position=cache_position,
|
| 510 |
+
position_embeddings=position_embeddings,
|
| 511 |
+
update_past_key_values=update_past_key_values,
|
| 512 |
+
use_block_cache=use_block_cache,
|
| 513 |
+
block_past_key_values=block_past_key_values,
|
| 514 |
+
replace_position=replace_position,
|
| 515 |
+
**kwargs,
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
hidden_states = self.norm(hidden_states)
|
| 519 |
+
return BaseModelOutputWithPastAndBlockCache(
|
| 520 |
+
last_hidden_state=hidden_states,
|
| 521 |
+
past_key_values=past_key_values if use_cache else None,
|
| 522 |
+
block_past_key_values=block_past_key_values if use_block_cache else None,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
| 527 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 528 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 529 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 530 |
+
|
| 531 |
+
def __init__(self, config):
|
| 532 |
+
super().__init__(config)
|
| 533 |
+
self.model = Fast_dLLM_QwenModel(config)
|
| 534 |
+
self.vocab_size = config.vocab_size
|
| 535 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 536 |
+
|
| 537 |
+
# Initialize weights and apply final processing
|
| 538 |
+
self.post_init()
|
| 539 |
+
|
| 540 |
+
def get_input_embeddings(self):
|
| 541 |
+
return self.model.embed_tokens
|
| 542 |
+
|
| 543 |
+
def set_input_embeddings(self, value):
|
| 544 |
+
self.model.embed_tokens = value
|
| 545 |
+
|
| 546 |
+
def get_output_embeddings(self):
|
| 547 |
+
return self.lm_head
|
| 548 |
+
|
| 549 |
+
def set_output_embeddings(self, new_embeddings):
|
| 550 |
+
self.lm_head = new_embeddings
|
| 551 |
+
|
| 552 |
+
def set_decoder(self, decoder):
|
| 553 |
+
self.model = decoder
|
| 554 |
+
|
| 555 |
+
def get_decoder(self):
|
| 556 |
+
return self.model
|
| 557 |
+
|
| 558 |
+
@can_return_tuple
|
| 559 |
+
def forward(
|
| 560 |
+
self,
|
| 561 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 562 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 563 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 564 |
+
past_key_values: Optional[Cache] = None,
|
| 565 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 566 |
+
labels: Optional[torch.LongTensor] = None,
|
| 567 |
+
use_cache: Optional[bool] = None,
|
| 568 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 569 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 570 |
+
update_past_key_values: Optional[bool] = False,
|
| 571 |
+
block_size: Optional[int] = 32,
|
| 572 |
+
use_block_cache: Optional[bool] = False,
|
| 573 |
+
block_past_key_values: Optional[Cache] = None,
|
| 574 |
+
replace_position: Optional[int] = None,
|
| 575 |
+
mask_id: Optional[int] = 151665,
|
| 576 |
+
**kwargs
|
| 577 |
+
) -> CausalLMOutputWithPastAndBlockCache:
|
| 578 |
+
|
| 579 |
+
if self.training:
|
| 580 |
+
original_labels = labels.clone()
|
| 581 |
+
original_input_ids = input_ids.clone()
|
| 582 |
+
|
| 583 |
+
noisy_input_ids = input_ids.clone()
|
| 584 |
+
|
| 585 |
+
input_ids = input_ids.reshape(input_ids.shape[0] * input_ids.shape[1] // self.model.bd_size, self.model.bd_size)
|
| 586 |
+
b, l = input_ids.shape
|
| 587 |
+
t = torch.rand((b,), device=input_ids.device)
|
| 588 |
+
eps=1e-3
|
| 589 |
+
p_mask = (1 - eps) * t + eps
|
| 590 |
+
p_mask = p_mask[:, None].repeat(1, l)
|
| 591 |
+
|
| 592 |
+
mask_indices = torch.rand((b, l), device=input_ids.device) < p_mask
|
| 593 |
+
x_t = torch.where(mask_indices, mask_id, input_ids).reshape(labels.shape)
|
| 594 |
+
noisy_input_ids[labels != -100] = x_t[labels != -100]
|
| 595 |
+
mask = (noisy_input_ids != mask_id)
|
| 596 |
+
labels[mask] = -100
|
| 597 |
+
input_ids = torch.cat([noisy_input_ids, input_ids.reshape(labels.shape)], dim=1)
|
| 598 |
+
|
| 599 |
+
complementary_noisy_input_ids = original_input_ids.clone()
|
| 600 |
+
complementary_labels = original_labels.clone()
|
| 601 |
+
|
| 602 |
+
complementary_input_ids = original_input_ids.reshape(original_input_ids.shape[0] * original_input_ids.shape[1] // self.model.bd_size, self.model.bd_size)
|
| 603 |
+
|
| 604 |
+
complementary_mask_indices = ~mask_indices
|
| 605 |
+
complementary_x_t = torch.where(complementary_mask_indices, mask_id, complementary_input_ids).reshape(labels.shape)
|
| 606 |
+
complementary_noisy_input_ids[complementary_labels != -100] = complementary_x_t[complementary_labels != -100]
|
| 607 |
+
complementary_mask = (complementary_noisy_input_ids != mask_id)
|
| 608 |
+
complementary_labels[complementary_mask] = -100
|
| 609 |
+
complementary_input_ids = torch.cat([complementary_noisy_input_ids, complementary_input_ids.reshape(complementary_labels.shape)], dim=1)
|
| 610 |
+
|
| 611 |
+
input_ids = torch.cat([input_ids, complementary_input_ids], dim=0)
|
| 612 |
+
labels = torch.cat([labels, complementary_labels], dim=0)
|
| 613 |
+
|
| 614 |
+
outputs: BaseModelOutputWithPastAndBlockCache = self.model(
|
| 615 |
+
input_ids=input_ids,
|
| 616 |
+
labels=labels,
|
| 617 |
+
attention_mask=attention_mask,
|
| 618 |
+
position_ids=position_ids,
|
| 619 |
+
past_key_values=past_key_values,
|
| 620 |
+
inputs_embeds=inputs_embeds,
|
| 621 |
+
use_cache=use_cache,
|
| 622 |
+
cache_position=cache_position,
|
| 623 |
+
update_past_key_values=update_past_key_values,
|
| 624 |
+
block_size=block_size,
|
| 625 |
+
use_block_cache=use_block_cache,
|
| 626 |
+
block_past_key_values=block_past_key_values,
|
| 627 |
+
replace_position=replace_position,
|
| 628 |
+
**kwargs,
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
hidden_states = outputs.last_hidden_state
|
| 632 |
+
if self.training:
|
| 633 |
+
hidden_states = hidden_states[:, :hidden_states.shape[1]//2, :]
|
| 634 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 635 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 636 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 637 |
+
|
| 638 |
+
loss = None
|
| 639 |
+
if labels is not None:
|
| 640 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 641 |
+
|
| 642 |
+
return CausalLMOutputWithPastAndBlockCache(
|
| 643 |
+
loss=loss,
|
| 644 |
+
logits=logits,
|
| 645 |
+
past_key_values=outputs.past_key_values,
|
| 646 |
+
hidden_states=outputs.hidden_states,
|
| 647 |
+
attentions=outputs.attentions,
|
| 648 |
+
block_past_key_values=outputs.block_past_key_values,
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
@torch.no_grad()
|
| 652 |
+
def generate(
|
| 653 |
+
self,
|
| 654 |
+
input_ids,
|
| 655 |
+
max_new_tokens,
|
| 656 |
+
mask_id=151665,
|
| 657 |
+
threshold=1,
|
| 658 |
+
small_block_size=8,
|
| 659 |
+
block_size=32,
|
| 660 |
+
stop_token=151645,
|
| 661 |
+
stopping_criteria=None,
|
| 662 |
+
top_p=0.95,
|
| 663 |
+
temperature=0,
|
| 664 |
+
use_block_cache=False,
|
| 665 |
+
**kwargs
|
| 666 |
+
):
|
| 667 |
+
num_blocks = max_new_tokens // block_size
|
| 668 |
+
original_input_length = input_ids.shape[1]
|
| 669 |
+
|
| 670 |
+
if input_ids.shape[1] > block_size:
|
| 671 |
+
output = self.forward(input_ids=input_ids[:, :(input_ids.shape[1] // block_size * block_size)], use_cache=True, update_past_key_values=True, block_size=block_size)
|
| 672 |
+
logits, past_key_values = output.logits, output.past_key_values
|
| 673 |
+
if input_ids.shape[1] % block_size == 0:
|
| 674 |
+
next_token = logits[:, -1:, :].argmax(dim=-1)
|
| 675 |
+
input_ids = torch.cat([input_ids, next_token], dim=1)
|
| 676 |
+
else:
|
| 677 |
+
past_key_values = None
|
| 678 |
+
|
| 679 |
+
num_small_blocks = block_size // small_block_size
|
| 680 |
+
|
| 681 |
+
for block_idx in range(num_blocks):
|
| 682 |
+
if stop_token in input_ids[:, original_input_length:]:
|
| 683 |
+
break
|
| 684 |
+
prompt_length = input_ids.shape[1]
|
| 685 |
+
# Initialize x_init with mask_id
|
| 686 |
+
x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long)
|
| 687 |
+
x_init = torch.cat([input_ids, x_init], dim=1)
|
| 688 |
+
|
| 689 |
+
x_t = x_init.clone()
|
| 690 |
+
block_past_key_values = None
|
| 691 |
+
while True:
|
| 692 |
+
if stop_token in x_t[:, prompt_length:]:
|
| 693 |
+
stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
|
| 694 |
+
if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
|
| 695 |
+
break
|
| 696 |
+
mask_idx = (x_t[:, -block_size:] == mask_id)
|
| 697 |
+
# Decode a complete block, update cache, and generate the next token
|
| 698 |
+
if mask_idx.sum() == 0:
|
| 699 |
+
output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=True, block_size=block_size)
|
| 700 |
+
logits, past_key_values = output.logits, output.past_key_values
|
| 701 |
+
next_token = logits[:, -1:, :].argmax(dim=-1)
|
| 702 |
+
x_t = torch.cat([x_t, next_token], dim=1)
|
| 703 |
+
break
|
| 704 |
+
for small_block_idx in range(num_small_blocks):
|
| 705 |
+
small_block_start_idx = small_block_idx * small_block_size
|
| 706 |
+
small_block_end_idx = small_block_start_idx + small_block_size
|
| 707 |
+
|
| 708 |
+
start = -block_size + small_block_start_idx
|
| 709 |
+
end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx
|
| 710 |
+
while True:
|
| 711 |
+
mask_idx = (x_t[:, -block_size:] == mask_id)
|
| 712 |
+
if mask_idx[:, start:end].sum() == 0:
|
| 713 |
+
break
|
| 714 |
+
if stop_token in x_t[:, prompt_length:]:
|
| 715 |
+
stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
|
| 716 |
+
if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
|
| 717 |
+
break
|
| 718 |
+
|
| 719 |
+
if use_block_cache:
|
| 720 |
+
if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
|
| 721 |
+
output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
|
| 722 |
+
logits, block_past_key_values = output.logits, output.block_past_key_values
|
| 723 |
+
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 724 |
+
logits = logits[:, start:end]
|
| 725 |
+
else:
|
| 726 |
+
logits = self.forward(input_ids=x_t[:,start:end], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True, block_past_key_values=block_past_key_values, replace_position=small_block_start_idx).logits
|
| 727 |
+
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 728 |
+
else:
|
| 729 |
+
logits = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False).logits
|
| 730 |
+
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 731 |
+
logits = logits[:, start:end]
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
|
| 735 |
+
# Select tokens with probability greater than threshold from p_1t
|
| 736 |
+
x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
|
| 737 |
+
x1_p = torch.where(mask_idx[:, start:end], x1_p, -torch.inf)
|
| 738 |
+
|
| 739 |
+
unmask_idx = (x1_p > threshold)
|
| 740 |
+
max_prob_idx = x1_p.argmax(dim=-1)
|
| 741 |
+
unmask_idx[torch.arange(x_1.shape[0]), max_prob_idx] = True
|
| 742 |
+
unmask_idx = unmask_idx & mask_idx[:, start:end]
|
| 743 |
+
|
| 744 |
+
x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
|
| 745 |
+
|
| 746 |
+
input_ids = x_t
|
| 747 |
+
# Truncate stop_token
|
| 748 |
+
if stop_token in input_ids[:, original_input_length:]:
|
| 749 |
+
stop_token_idx = (input_ids[:, original_input_length:] == stop_token).nonzero()[0][1]
|
| 750 |
+
input_ids = input_ids[:, :stop_token_idx+original_input_length+1]
|
| 751 |
+
return input_ids
|
| 752 |
+
|
| 753 |
+
def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
|
| 754 |
+
# Calculate probabilities
|
| 755 |
+
if temperature > 0:
|
| 756 |
+
scaled_logits = logits / temperature
|
| 757 |
+
else:
|
| 758 |
+
p_1t = torch.softmax(logits, dim=-1)
|
| 759 |
+
x_1 = p_1t.argmax(dim=-1)
|
| 760 |
+
return x_1, p_1t
|
| 761 |
+
|
| 762 |
+
probs = F.softmax(scaled_logits, dim=-1)
|
| 763 |
+
|
| 764 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
| 765 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 766 |
+
|
| 767 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 768 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 769 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 770 |
+
|
| 771 |
+
indices_to_remove = torch.zeros_like(probs, dtype=torch.bool).scatter_(
|
| 772 |
+
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
probs[indices_to_remove] = 0
|
| 776 |
+
|
| 777 |
+
# Renormalize so that the probabilities of remaining tokens sum to 1
|
| 778 |
+
# Add a small epsilon value to prevent division by zero
|
| 779 |
+
probs_sum = torch.sum(probs, dim=-1, keepdim=True)
|
| 780 |
+
normalized_probs = probs / probs_sum
|
| 781 |
+
|
| 782 |
+
p_1t = normalized_probs
|
| 783 |
+
x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)
|
| 784 |
+
|
| 785 |
+
return x_1, p_1t
|