kilianhaefeli commited on
Commit
3a8246f
·
1 Parent(s): 55543b4

add new modleing file which contains fixed to batching

Browse files
Files changed (2) hide show
  1. modeling.py +174 -47
  2. modeling_old.py +785 -0
modeling.py CHANGED
@@ -483,16 +483,69 @@ class Fast_dLLM_QwenModel(Fast_dLLM_QwenPreTrainedModel):
483
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] if not self.training else inputs_embeds.shape[1]//2, device=inputs_embeds.device
484
  )
485
 
 
 
 
 
 
 
486
  if position_ids is None:
487
- position_ids = cache_position.unsqueeze(0)
488
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  if self.training:
490
  attention_mask = self.gen_mask(labels.shape[1], self.bd_size, labels.shape[0], self.config.num_attention_heads).to(device=inputs_embeds.device)
491
  else:
492
  if use_block_cache and block_past_key_values.get_seq_length() != 0:
493
  attention_mask = None
494
  else:
495
- attention_mask = self.eval_mask(input_ids.shape[1], block_size, past_key_values.get_seq_length() if past_key_values is not None else 0).to(device=inputs_embeds.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
  hidden_states = inputs_embeds
498
 
@@ -652,7 +705,8 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
652
  def generate(
653
  self,
654
  input_ids,
655
- max_new_tokens,
 
656
  mask_id=151665,
657
  threshold=1,
658
  small_block_size=8,
@@ -664,59 +718,86 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
664
  use_block_cache=False,
665
  **kwargs
666
  ):
 
 
 
667
  num_blocks = max_new_tokens // block_size
 
 
668
  original_input_length = input_ids.shape[1]
669
 
 
 
 
 
670
  if input_ids.shape[1] > block_size:
671
- output = self.forward(input_ids=input_ids[:, :(input_ids.shape[1] // block_size * block_size)], use_cache=True, update_past_key_values=True, block_size=block_size)
672
  logits, past_key_values = output.logits, output.past_key_values
673
  if input_ids.shape[1] % block_size == 0:
674
  next_token = logits[:, -1:, :].argmax(dim=-1)
675
  input_ids = torch.cat([input_ids, next_token], dim=1)
 
 
 
 
 
 
676
  else:
677
  past_key_values = None
678
 
679
  num_small_blocks = block_size // small_block_size
680
 
 
 
 
681
  for block_idx in range(num_blocks):
682
- if stop_token in input_ids[:, original_input_length:]:
683
- break
 
 
 
 
 
 
684
  prompt_length = input_ids.shape[1]
685
  # Initialize x_init with mask_id
686
  x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long)
687
- x_init = torch.cat([input_ids, x_init], dim=1)
 
 
 
688
 
689
  x_t = x_init.clone()
690
  block_past_key_values = None
 
691
  while True:
692
- if stop_token in x_t[:, prompt_length:]:
693
- stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
694
- if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
695
- break
696
  mask_idx = (x_t[:, -block_size:] == mask_id)
697
- # Decode a complete block, update cache, and generate the next token
 
698
  if mask_idx.sum() == 0:
699
- output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=True, block_size=block_size)
 
 
700
  logits, past_key_values = output.logits, output.past_key_values
701
  next_token = logits[:, -1:, :].argmax(dim=-1)
702
  x_t = torch.cat([x_t, next_token], dim=1)
 
 
 
703
  break
704
  for small_block_idx in range(num_small_blocks):
705
- small_block_start_idx = small_block_idx * small_block_size
706
- small_block_end_idx = small_block_start_idx + small_block_size
707
 
708
- start = -block_size + small_block_start_idx
709
  end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx
710
  while True:
711
  mask_idx = (x_t[:, -block_size:] == mask_id)
712
- if mask_idx[:, start:end].sum() == 0:
713
  break
714
- if stop_token in x_t[:, prompt_length:]:
715
- stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
716
- if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
717
- break
718
-
719
  if use_block_cache:
 
720
  if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
721
  output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
722
  logits, block_past_key_values = output.logits, output.block_past_key_values
@@ -726,28 +807,71 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
726
  logits = self.forward(input_ids=x_t[:,start:end], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True, block_past_key_values=block_past_key_values, replace_position=small_block_start_idx).logits
727
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
728
  else:
729
- logits = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False).logits
730
- logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
 
 
 
 
731
  logits = logits[:, start:end]
732
 
733
-
734
  x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
735
  # Select tokens with probability greater than threshold from p_1t
736
  x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
737
  x1_p = torch.where(mask_idx[:, start:end], x1_p, -torch.inf)
738
 
739
  unmask_idx = (x1_p > threshold)
 
740
  max_prob_idx = x1_p.argmax(dim=-1)
741
  unmask_idx[torch.arange(x_1.shape[0]), max_prob_idx] = True
742
  unmask_idx = unmask_idx & mask_idx[:, start:end]
743
 
 
 
 
 
 
 
 
 
744
  x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
745
 
 
 
 
 
746
  input_ids = x_t
747
- # Truncate stop_token
748
- if stop_token in input_ids[:, original_input_length:]:
749
- stop_token_idx = (input_ids[:, original_input_length:] == stop_token).nonzero()[0][1]
750
- input_ids = input_ids[:, :stop_token_idx+original_input_length+1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
751
  return input_ids
752
 
753
  def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
@@ -758,28 +882,31 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
758
  p_1t = torch.softmax(logits, dim=-1)
759
  x_1 = p_1t.argmax(dim=-1)
760
  return x_1, p_1t
761
-
762
- probs = F.softmax(scaled_logits, dim=-1)
763
 
764
- sorted_probs, sorted_indices = torch.sort(probs, descending=True)
765
- cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
766
 
767
- sorted_indices_to_remove = cumulative_probs > top_p
768
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
769
- sorted_indices_to_remove[..., 0] = 0
770
 
771
  indices_to_remove = torch.zeros_like(probs, dtype=torch.bool).scatter_(
772
  dim=-1, index=sorted_indices, src=sorted_indices_to_remove
773
- )
774
-
775
- probs[indices_to_remove] = 0
 
 
 
 
776
 
777
- # Renormalize so that the probabilities of remaining tokens sum to 1
778
- # Add a small epsilon value to prevent division by zero
779
- probs_sum = torch.sum(probs, dim=-1, keepdim=True)
780
- normalized_probs = probs / probs_sum
781
 
782
- p_1t = normalized_probs
783
- x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)
 
 
784
 
785
- return x_1, p_1t
 
483
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] if not self.training else inputs_embeds.shape[1]//2, device=inputs_embeds.device
484
  )
485
 
486
+ # --- keep the user/tokenizer padding mask BEFORE you overwrite attention_mask ---
487
+ padding_mask_2d = attention_mask # shape [B, KV_LEN], 1=token, 0=pad
488
+
489
+ # -------------------------
490
+ # Position ids (left padding)
491
+ # -------------------------
492
  if position_ids is None:
493
+ if (padding_mask_2d is not None) and (not self.training):
494
+ # full, per-sample positions over KV_LEN
495
+ pos_full = padding_mask_2d.long().cumsum(-1) - 1 # pads => -1
496
+ pos_full = pos_full.clamp_min(0) # pads => 0
497
+
498
+ q_len = inputs_embeds.shape[1]
499
+ kv_len = pos_full.shape[1]
500
+ if kv_len < q_len:
501
+ raise ValueError(f"attention_mask KV_LEN={kv_len} < input_len={q_len}. "
502
+ "When using cache, pass the FULL mask (past+current).")
503
+
504
+ q_start = kv_len - q_len # assumes current tokens are the last q_len positions
505
+ position_ids = pos_full[:, q_start:]
506
+ else:
507
+ # no padding mask: same positions for all batch elements
508
+ position_ids = cache_position.unsqueeze(0)
509
+
510
+ # -------------------------
511
+ # Attention mask (block-causal + padding), per sample
512
+ # -------------------------
513
  if self.training:
514
  attention_mask = self.gen_mask(labels.shape[1], self.bd_size, labels.shape[0], self.config.num_attention_heads).to(device=inputs_embeds.device)
515
  else:
516
  if use_block_cache and block_past_key_values.get_seq_length() != 0:
517
  attention_mask = None
518
  else:
519
+ # attention_mask = self.eval_mask(input_ids.shape[1], block_size, past_key_values.get_seq_length() if past_key_values is not None else 0).to(device=inputs_embeds.device)
520
+ if padding_mask_2d is None:
521
+ # fallback: original behavior (no padding)
522
+ structural = self.eval_mask(
523
+ seqlen=input_ids.shape[1],
524
+ block_size=block_size,
525
+ cache_seq_len=past_key_values.get_seq_length() if past_key_values is not None else 0,
526
+ ).to(device=inputs_embeds.device)
527
+ attention_mask = structural[None, None, :, :] # [1,1,Q,KV]
528
+ else:
529
+ pad = padding_mask_2d.to(torch.bool) # [B, KV]
530
+ B, kv_len = pad.shape
531
+ q_len = inputs_embeds.shape[1]
532
+ q_start = kv_len - q_len
533
+
534
+ # Per-sample block ids computed from *non-pad* positions
535
+ pos_full = pad.long().cumsum(-1) - 1
536
+ pos_full = pos_full.clamp_min(0)
537
+ block_full = pos_full // block_size # [B, KV]
538
+
539
+ block_q = block_full[:, q_start:] # [B, Q]
540
+ block_k = block_full # [B, KV]
541
+
542
+ structural = block_q.unsqueeze(-1) >= block_k.unsqueeze(-2) # [B, Q, KV]
543
+
544
+ # Mask keys AND queries (only valid tokens participate)
545
+ key_ok = pad[:, None, None, :] # [B,1,1,KV]
546
+ query_ok = pad[:, None, q_start:, None] # [B,1,Q,1]
547
+
548
+ attention_mask = structural[:, None, :, :] & key_ok & query_ok # [B,1,Q,KV]
549
 
550
  hidden_states = inputs_embeds
551
 
 
705
  def generate(
706
  self,
707
  input_ids,
708
+ attention_mask=None, # --- ADDED ARGUMENT ---
709
+ max_new_tokens=20, # Added default value for safety
710
  mask_id=151665,
711
  threshold=1,
712
  small_block_size=8,
 
718
  use_block_cache=False,
719
  **kwargs
720
  ):
721
+ if use_block_cache:
722
+ raise ValueError("use_block_cache=True is not supported in this generate() implementation.")
723
+ assert attention_mask is not None, "attention_mask must be provided for this generate() implementation."
724
  num_blocks = max_new_tokens // block_size
725
+ device = input_ids.device
726
+ batch_size = input_ids.size(0)
727
  original_input_length = input_ids.shape[1]
728
 
729
+ # Track which sequences in the batch are still active
730
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
731
+
732
+ # Handle prefix processing (Context Encoding)
733
  if input_ids.shape[1] > block_size:
734
+ output = self.forward(input_ids=input_ids[:, :(input_ids.shape[1] // block_size * block_size)], attention_mask=attention_mask[:, :(input_ids.shape[1] // block_size * block_size)], use_cache=True, update_past_key_values=True, block_size=block_size)
735
  logits, past_key_values = output.logits, output.past_key_values
736
  if input_ids.shape[1] % block_size == 0:
737
  next_token = logits[:, -1:, :].argmax(dim=-1)
738
  input_ids = torch.cat([input_ids, next_token], dim=1)
739
+
740
+ # Update finished status
741
+ # unfinished_sequences = unfinished_sequences & (next_token.squeeze(-1) != stop_token).long()
742
+ # Append to mask: If unfinished, append 1. If finished, append 0.
743
+ new_mask_col = unfinished_sequences.unsqueeze(1).to(dtype=attention_mask.dtype)
744
+ attention_mask = torch.cat([attention_mask, new_mask_col], dim=1)
745
  else:
746
  past_key_values = None
747
 
748
  num_small_blocks = block_size // small_block_size
749
 
750
+ iterations = torch.zeros((batch_size,), device=device)
751
+ n_generated_tokens = torch.zeros((batch_size,), device=device)
752
+ finished = torch.zeros((batch_size,), dtype=torch.bool, device=device)
753
  for block_idx in range(num_blocks):
754
+ new_tokens = input_ids[:, original_input_length:]
755
+ has_stop_now = (new_tokens == stop_token).any(dim=1)
756
+ finished |= has_stop_now
757
+
758
+ stop_tokens = torch.where(input_ids[:, original_input_length:] == stop_token, 1.0, 0.0).sum(dim=1)
759
+ if stop_tokens.min() > 0:
760
+ break # break if all sequences have generated the stop token
761
+
762
  prompt_length = input_ids.shape[1]
763
  # Initialize x_init with mask_id
764
  x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long)
765
+ x_init = torch.cat([input_ids, x_init], dim=1) # 1: [1, 32]
766
+
767
+ mask_extension = unfinished_sequences.unsqueeze(1).repeat(1, block_size - prompt_length % block_size).to(dtype=attention_mask.dtype)
768
+ curr_attention_mask = torch.cat([attention_mask, mask_extension], dim=1)
769
 
770
  x_t = x_init.clone()
771
  block_past_key_values = None
772
+
773
  while True:
774
+ # Mask idx is just the current block hwere there are masks. (Note, that first token is never a mask token, because at least one token must be present to condition on!)
 
 
 
775
  mask_idx = (x_t[:, -block_size:] == mask_id)
776
+
777
+ # If in the current 32 block there is no mask id, then we have generated 31 tokens already at max! (one token must always be et otherwise there is not conditional generation possible!)
778
  if mask_idx.sum() == 0:
779
+ # In here we predict the last token whcih corresponds to the first token in the next batch.
780
+ # Why not predict it before? If sampling it before the rest was sampled then this would violate semi-autoregressiveness.
781
+ output = self.forward(input_ids=x_t[:, -block_size:], attention_mask=curr_attention_mask, use_cache=True, past_key_values=past_key_values, update_past_key_values=True, block_size=block_size)
782
  logits, past_key_values = output.logits, output.past_key_values
783
  next_token = logits[:, -1:, :].argmax(dim=-1)
784
  x_t = torch.cat([x_t, next_token], dim=1)
785
+ curr_attention_mask = torch.cat([curr_attention_mask, unfinished_sequences.unsqueeze(1).to(curr_attention_mask.dtype)], dim=1)
786
+ iterations += (~finished).long()
787
+ n_generated_tokens += (~finished).long()
788
  break
789
  for small_block_idx in range(num_small_blocks):
790
+ small_block_start_idx = small_block_idx * small_block_size # 0
791
+ small_block_end_idx = small_block_start_idx + small_block_size # 32
792
 
793
+ start = -block_size + small_block_start_idx # -32
794
  end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx
795
  while True:
796
  mask_idx = (x_t[:, -block_size:] == mask_id)
797
+ if mask_idx[:, start:end].sum() == 0: # [-32:None] (just get 32 tokens newest of the mask ids, meaning we dont have any current mask preictions)
798
  break
 
 
 
 
 
799
  if use_block_cache:
800
+ assert False, "use_block_cache=True is not supported in this generate() implementation."
801
  if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
802
  output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
803
  logits, block_past_key_values = output.logits, output.block_past_key_values
 
807
  logits = self.forward(input_ids=x_t[:,start:end], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True, block_past_key_values=block_past_key_values, replace_position=small_block_start_idx).logits
808
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
809
  else:
810
+ # enter most recent 32 tokens as the input
811
+ # kv cache does not contain the kv for our current block yet so it gets recomputed always!
812
+ logits = self.forward(input_ids=x_t[:, -block_size:], attention_mask=curr_attention_mask, use_cache=True, past_key_values=past_key_values, update_past_key_values=False,block_size=block_size,).logits
813
+ # the logits to be sampled from are the most recent 32 tokens
814
+ # shift because of autoregressive conversion and valid by appending anything to the start since first token mask is off anyways always.
815
+ logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1) # TODO maybe prepend nan or sth
816
  logits = logits[:, start:end]
817
 
 
818
  x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
819
  # Select tokens with probability greater than threshold from p_1t
820
  x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
821
  x1_p = torch.where(mask_idx[:, start:end], x1_p, -torch.inf)
822
 
823
  unmask_idx = (x1_p > threshold)
824
+ # Ensure at least one token is unmasked in the current small block
825
  max_prob_idx = x1_p.argmax(dim=-1)
826
  unmask_idx[torch.arange(x_1.shape[0]), max_prob_idx] = True
827
  unmask_idx = unmask_idx & mask_idx[:, start:end]
828
 
829
+ # Add 1 to iterations if the sequence is not stopped AND at least one token is generated in this iteration
830
+ iterations += (~finished & unmask_idx.any(dim=1)).long()
831
+
832
+ # count number of generated tokens in this iteration if not stopped
833
+ n_generated_iter = torch.where(finished, 0, unmask_idx.sum(dim=1))
834
+ n_generated_tokens += n_generated_iter
835
+
836
+ # Only update the positions where unmask_idx is True AND the sequence if not finished TODO check this, otherwise
837
  x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
838
 
839
+ new_tokens = input_ids[:, original_input_length:]
840
+ has_stop_now = (new_tokens == stop_token).any(dim=1)
841
+ finished |= has_stop_now
842
+
843
  input_ids = x_t
844
+ attention_mask = curr_attention_mask
845
+
846
+ print("Generated iterations per sequence:", iterations)
847
+ print("Generated tokens per sequence:", n_generated_tokens)
848
+ print("Average tokens generated per iteration:", (n_generated_tokens / iterations))
849
+ # Final truncation: keep everything up to the *latest* first stop_token
850
+ new_tokens = input_ids[:, original_input_length:]
851
+ has_stop = (new_tokens == stop_token)
852
+
853
+ gen = input_ids[:, original_input_length:] # (B, T)
854
+
855
+ T = gen.size(1)
856
+
857
+ if T > 0:
858
+ device = input_ids.device
859
+ B = input_ids.size(0)
860
+
861
+ idx = torch.arange(T, device=device).unsqueeze(0).expand(B, T)
862
+ stop_mask = gen.eq(stop_token)
863
+
864
+ first_stop = torch.where(stop_mask, idx, torch.full_like(idx, T)).min(dim=1).values
865
+ has_stop = first_stop < T
866
+ keep = torch.where(has_stop, first_stop + 1, torch.full_like(first_stop, T))
867
+
868
+ pad_id = self.config.pad_token_id if getattr(self.config, "pad_token_id", None) is not None else stop_token
869
+ after = idx >= keep.unsqueeze(1)
870
+ gen = gen.clone()
871
+ gen[after] = pad_id
872
+
873
+ input_ids = torch.cat([input_ids[:, :original_input_length], gen], dim=1)
874
+
875
  return input_ids
876
 
877
  def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
 
882
  p_1t = torch.softmax(logits, dim=-1)
883
  x_1 = p_1t.argmax(dim=-1)
884
  return x_1, p_1t
885
+ probs = torch.softmax(scaled_logits, dim=-1) # [B, seq_len, vocab_size]
 
886
 
887
+ sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) # [B, seq_len, sorted(vocab_size)]
888
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # [B, seq_len, cumsum(sorted(vocab_size))]
889
 
890
+ sorted_indices_to_remove = cumulative_probs > top_p # [B, seq_len, bool(sorted(vocab_size))]
891
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() # clone the tensor to avoid in-place operation error
892
+ sorted_indices_to_remove[..., 0] = 0 # always keep at least one token
893
 
894
  indices_to_remove = torch.zeros_like(probs, dtype=torch.bool).scatter_(
895
  dim=-1, index=sorted_indices, src=sorted_indices_to_remove
896
+ ) # [B, seq_len, vocab_size], take 0 array and
897
+ # set True at the indices where sorted_indices_to_remove is True
898
+ # we index using the sorted indices in order to put the values back to their original position
899
+
900
+ # prev: probs[indices_to_remove] = 0, indices_to_remove is of the same shape as probs
901
+ # and therefore this operation just selects
902
+ probs = probs.masked_fill(indices_to_remove, 0.0)
903
 
904
+ probs_sum = probs.sum(dim=-1, keepdim=True).clamp_min(1e-12)
905
+ p_1t = probs / probs_sum
 
 
906
 
907
+ vocab = p_1t.shape[-1]
908
+ flat = p_1t.reshape(-1, vocab)
909
+ samples = torch.multinomial(flat, num_samples=1).squeeze(-1)
910
+ x_1 = samples.view(*p_1t.shape[:-1])
911
 
912
+ return x_1, p_1t
modeling_old.py ADDED
@@ -0,0 +1,785 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Union
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ from functools import partial
8
+
9
+ from transformers.activations import ACT2FN
10
+ from transformers.cache_utils import Cache, DynamicCache
11
+ from transformers.generation import GenerationMixin
12
+ from transformers.integrations import use_kernel_forward_from_hub
13
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
14
+ from transformers.modeling_layers import GradientCheckpointingLayer
15
+ from transformers.modeling_outputs import (
16
+ BaseModelOutputWithPast,
17
+ CausalLMOutputWithPast,
18
+ )
19
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
20
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
21
+ from transformers.processing_utils import Unpack
22
+ from transformers.utils import auto_docstring, can_return_tuple, logging
23
+ from .configuration import Fast_dLLM_QwenConfig
24
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
25
+ from einops import rearrange, repeat
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ @dataclass
31
+ class CausalLMOutputWithPastAndBlockCache(CausalLMOutputWithPast):
32
+ block_past_key_values: Optional[Cache] = None
33
+
34
+ @dataclass
35
+ class BaseModelOutputWithPastAndBlockCache(BaseModelOutputWithPast):
36
+ block_past_key_values: Optional[Cache] = None
37
+
38
+
39
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
40
+ def fused_flex_attention(q, k, v, mask=None):
41
+ return flex_attention(q, k, v, block_mask=mask, enable_gqa=True)
42
+
43
+ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
44
+ """
45
+ Constructs the specialized block diffusion attention mask for training
46
+ composed of three masks:
47
+ - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
48
+ - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
49
+ - **Block Causal Mask (M_BC)**: Attention to update x0
50
+
51
+ Args:
52
+ b, h: Batch and head indices (ignored for mask logic).
53
+ q_idx, kv_idx: Query and Key indices.
54
+ seq_len: Total sequence length.
55
+ block_size: Defines the block structure.
56
+
57
+ Returns:
58
+ A boolean attention mask.
59
+ """
60
+ # Indicate whether token belongs to xt or x0
61
+ x0_flag_q = (q_idx >= n)
62
+ x0_flag_kv = (kv_idx >= n)
63
+
64
+ # Compute block indices
65
+ block_q = torch.where(x0_flag_q == 1,
66
+ (q_idx - n) // block_size,
67
+ q_idx // block_size)
68
+ block_kv = torch.where(x0_flag_kv == 1,
69
+ (kv_idx - n) // block_size,
70
+ kv_idx // block_size)
71
+
72
+ # **1. Block Diagonal Mask (M_BD) **
73
+ block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
74
+
75
+ # **2. Offset Block-Causal Mask (M_OBC) **
76
+ offset_block_causal = (
77
+ (block_q > block_kv)
78
+ & (x0_flag_kv == 1)
79
+ & (x0_flag_q == 0)
80
+ )
81
+
82
+ # **3. Block-Causal Mask (M_BC) **
83
+ block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
84
+
85
+ # **4. Combine Masks **
86
+ return block_diagonal | offset_block_causal | block_causal
87
+
88
+ def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
89
+ # Compute block indices
90
+ block_q = q_idx // block_size
91
+ block_kv = kv_idx // block_size
92
+
93
+ return block_q >= block_kv
94
+
95
+ class Fast_dLLM_QwenMLP(nn.Module):
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ self.config = config
99
+ self.hidden_size = config.hidden_size
100
+ self.intermediate_size = config.intermediate_size
101
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
102
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
103
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
104
+ self.act_fn = ACT2FN[config.hidden_act]
105
+
106
+ def forward(self, x):
107
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
108
+ return down_proj
109
+
110
+
111
+ def rotate_half(x):
112
+ """Rotates half the hidden dims of the input."""
113
+ x1 = x[..., : x.shape[-1] // 2]
114
+ x2 = x[..., x.shape[-1] // 2 :]
115
+ return torch.cat((-x2, x1), dim=-1)
116
+
117
+
118
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
119
+ """Applies Rotary Position Embedding to the query and key tensors.
120
+
121
+ Args:
122
+ q (`torch.Tensor`): The query tensor.
123
+ k (`torch.Tensor`): The key tensor.
124
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
125
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
126
+ position_ids (`torch.Tensor`, *optional*):
127
+ Deprecated and unused.
128
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
129
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
130
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
131
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
132
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
133
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
134
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
135
+ Returns:
136
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
137
+ """
138
+ cos = cos.unsqueeze(unsqueeze_dim)
139
+ sin = sin.unsqueeze(unsqueeze_dim)
140
+ q_embed = (q * cos) + (rotate_half(q) * sin)
141
+ k_embed = (k * cos) + (rotate_half(k) * sin)
142
+ return q_embed, k_embed
143
+
144
+
145
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
146
+ """
147
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
148
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
149
+ """
150
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
151
+ if n_rep == 1:
152
+ return hidden_states
153
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
154
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
155
+
156
+
157
+ class Fast_dLLM_QwenAttention(nn.Module):
158
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
159
+
160
+ def __init__(self, config: Fast_dLLM_QwenConfig, layer_idx: int):
161
+ super().__init__()
162
+ self.config = config
163
+ self.layer_idx = layer_idx
164
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
165
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
166
+ self.scaling = self.head_dim**-0.5
167
+ self.attention_dropout = config.attention_dropout
168
+ self.is_causal = True
169
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
170
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
171
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
172
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
173
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
174
+
175
+ def forward(
176
+ self,
177
+ hidden_states: torch.Tensor,
178
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
179
+ attention_mask: Optional[torch.Tensor],
180
+ past_key_value: Optional[Cache] = None,
181
+ cache_position: Optional[torch.LongTensor] = None,
182
+ update_past_key_values: Optional[bool] = False,
183
+ block_past_key_values: Optional[Cache] = None,
184
+ replace_position: Optional[int] = None,
185
+ **kwargs: Unpack[FlashAttentionKwargs],
186
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
187
+ input_shape = hidden_states.shape[:-1]
188
+ hidden_shape = (*input_shape, -1, self.head_dim)
189
+
190
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
191
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
192
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
193
+
194
+ cos, sin = position_embeddings
195
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
196
+ if self.training:
197
+ #split q into two parts
198
+ q_1 = query_states[:,:,:query_states.shape[2]//2]
199
+ q_2 = query_states[:,:,query_states.shape[2]//2:]
200
+ #split k into two parts
201
+ k_1 = key_states[:,:,:key_states.shape[2]//2]
202
+ k_2 = key_states[:,:,key_states.shape[2]//2:]
203
+ q_1, k_1 = apply_rotary_pos_emb(q_1, k_1, cos, sin)
204
+ q_2, k_2 = apply_rotary_pos_emb(q_2, k_2, cos, sin)
205
+ query_states = torch.cat((q_1, q_2), dim=-2)
206
+ key_states = torch.cat((k_1, k_2), dim=-2)
207
+ else:
208
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
209
+
210
+ if block_past_key_values is not None:
211
+ if len(block_past_key_values) <= self.layer_idx:
212
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
213
+ key_states, value_states = block_past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
214
+ else:
215
+ block_cache_key_states = block_past_key_values[self.layer_idx][0]
216
+ block_cache_value_states = block_past_key_values[self.layer_idx][1]
217
+
218
+ block_cache_key_states[:, :, replace_position:replace_position+key_states.shape[2]] = key_states
219
+ block_cache_value_states[:, :, replace_position:replace_position+value_states.shape[2]] = value_states
220
+ key_states = block_cache_key_states
221
+ value_states = block_cache_value_states
222
+
223
+ if past_key_value is not None:
224
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
225
+ if update_past_key_values:
226
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
227
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
228
+ elif len(past_key_value) > self.layer_idx:
229
+ key_states = torch.cat((past_key_value[self.layer_idx][0], key_states), dim=-2)
230
+ value_states = torch.cat((past_key_value[self.layer_idx][1], value_states), dim=-2)
231
+
232
+ if self.training:
233
+ attn_output = fused_flex_attention(query_states, key_states, value_states, mask=attention_mask)
234
+ attn_output = attn_output.transpose(1, 2).contiguous()
235
+ else:
236
+ attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
237
+
238
+ attn_output, attn_weights = attention_interface(
239
+ self,
240
+ query_states,
241
+ key_states,
242
+ value_states,
243
+ attention_mask,
244
+ is_causal=False,
245
+ dropout=0.0 if not self.training else self.attention_dropout,
246
+ scaling=self.scaling,
247
+ sliding_window=self.sliding_window, # main diff with Llama
248
+ **kwargs,
249
+ )
250
+
251
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
252
+ attn_output = self.o_proj(attn_output)
253
+ return attn_output
254
+
255
+ @use_kernel_forward_from_hub("RMSNorm")
256
+ class Fast_dLLM_QwenRMSNorm(nn.Module):
257
+ def __init__(self, hidden_size, eps=1e-6):
258
+ """
259
+ Fast_dLLM_QwenRMSNorm is equivalent to T5LayerNorm
260
+ """
261
+ super().__init__()
262
+ self.weight = nn.Parameter(torch.ones(hidden_size))
263
+ self.variance_epsilon = eps
264
+
265
+ def forward(self, hidden_states):
266
+ input_dtype = hidden_states.dtype
267
+ hidden_states = hidden_states.to(torch.float32)
268
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
269
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
270
+ return self.weight * hidden_states.to(input_dtype)
271
+
272
+ def extra_repr(self):
273
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
274
+
275
+
276
+ class Fast_dLLM_QwenDecoderLayer(GradientCheckpointingLayer):
277
+ def __init__(self, config: Fast_dLLM_QwenConfig, layer_idx: int):
278
+ super().__init__()
279
+ self.hidden_size = config.hidden_size
280
+
281
+ self.self_attn = Fast_dLLM_QwenAttention(config=config, layer_idx=layer_idx)
282
+
283
+ self.mlp = Fast_dLLM_QwenMLP(config)
284
+ self.input_layernorm = Fast_dLLM_QwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
285
+ self.post_attention_layernorm = Fast_dLLM_QwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
286
+ self.attention_type = config.layer_types[layer_idx]
287
+
288
+ def forward(
289
+ self,
290
+ hidden_states: torch.Tensor,
291
+ attention_mask: Optional[torch.Tensor] = None,
292
+ position_ids: Optional[torch.LongTensor] = None,
293
+ past_key_value: Optional[Cache] = None,
294
+ use_cache: Optional[bool] = False,
295
+ cache_position: Optional[torch.LongTensor] = None,
296
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
297
+ update_past_key_values: Optional[bool] = False,
298
+ use_block_cache: Optional[bool] = False,
299
+ block_past_key_values: Optional[Cache] = None,
300
+ replace_position: Optional[int] = None,
301
+ **kwargs
302
+ ) -> tuple[torch.Tensor]:
303
+ residual = hidden_states
304
+ hidden_states = self.input_layernorm(hidden_states)
305
+ # Self Attention
306
+ hidden_states = self.self_attn(
307
+ hidden_states=hidden_states,
308
+ attention_mask=attention_mask,
309
+ position_ids=position_ids,
310
+ past_key_value=past_key_value,
311
+ use_cache=use_cache,
312
+ cache_position=cache_position,
313
+ position_embeddings=position_embeddings,
314
+ update_past_key_values=update_past_key_values,
315
+ use_block_cache=use_block_cache,
316
+ block_past_key_values=block_past_key_values,
317
+ replace_position=replace_position,
318
+ **kwargs,
319
+ )
320
+ hidden_states = residual + hidden_states
321
+
322
+ # Fully Connected
323
+ residual = hidden_states
324
+ hidden_states = self.post_attention_layernorm(hidden_states)
325
+ hidden_states = self.mlp(hidden_states)
326
+ hidden_states = residual + hidden_states
327
+ return hidden_states
328
+
329
+
330
+
331
+ class Fast_dLLM_QwenPreTrainedModel(PreTrainedModel):
332
+ config_class = Fast_dLLM_QwenConfig
333
+ base_model_prefix = "model"
334
+ supports_gradient_checkpointing = True
335
+ _no_split_modules = ["Fast_dLLM_QwenDecoderLayer"]
336
+ _skip_keys_device_placement = ["past_key_values"]
337
+ _supports_flash_attn_2 = True
338
+ _supports_sdpa = True
339
+ _supports_flex_attn = True
340
+ _supports_cache_class = True
341
+ _supports_quantized_cache = True
342
+ _supports_static_cache = True
343
+ _supports_attention_backend = True
344
+ _can_record_outputs = {
345
+ "hidden_states": Fast_dLLM_QwenDecoderLayer,
346
+ "attentions": Fast_dLLM_QwenAttention,
347
+ }
348
+
349
+ def _init_weights(self, module):
350
+ std = self.config.initializer_range
351
+ if isinstance(module, nn.Linear):
352
+ module.weight.data.normal_(mean=0.0, std=std)
353
+ if module.bias is not None:
354
+ module.bias.data.zero_()
355
+ elif isinstance(module, nn.Embedding):
356
+ module.weight.data.normal_(mean=0.0, std=std)
357
+ if module.padding_idx is not None:
358
+ module.weight.data[module.padding_idx].zero_()
359
+ elif isinstance(module, Fast_dLLM_QwenRMSNorm):
360
+ module.weight.data.fill_(1.0)
361
+
362
+
363
+ class Fast_dLLM_QwenRotaryEmbedding(nn.Module):
364
+ def __init__(self, config: Fast_dLLM_QwenConfig, device=None):
365
+ super().__init__()
366
+ # BC: "rope_type" was originally "type"
367
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
368
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
369
+ else:
370
+ self.rope_type = "default"
371
+ self.max_seq_len_cached = config.max_position_embeddings
372
+ self.original_max_seq_len = config.max_position_embeddings
373
+
374
+ self.config = config
375
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
376
+
377
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
378
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
379
+ self.original_inv_freq = self.inv_freq
380
+
381
+ @torch.no_grad()
382
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
383
+ def forward(self, x, position_ids):
384
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
385
+ position_ids_expanded = position_ids[:, None, :].float()
386
+
387
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
388
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
389
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
390
+ emb = torch.cat((freqs, freqs), dim=-1)
391
+ cos = emb.cos() * self.attention_scaling
392
+ sin = emb.sin() * self.attention_scaling
393
+
394
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
395
+
396
+
397
+
398
+ class Fast_dLLM_QwenModel(Fast_dLLM_QwenPreTrainedModel):
399
+ def __init__(self, config: Fast_dLLM_QwenConfig):
400
+ super().__init__(config)
401
+ self.padding_idx = config.pad_token_id
402
+ self.vocab_size = config.vocab_size
403
+ self.bd_size = config.bd_size
404
+
405
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
406
+ self.layers = nn.ModuleList(
407
+ [Fast_dLLM_QwenDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
408
+ )
409
+ self.norm = Fast_dLLM_QwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
410
+ self.rotary_emb = Fast_dLLM_QwenRotaryEmbedding(config=config)
411
+ self.gradient_checkpointing = True
412
+
413
+ # Initialize weights and apply final processing
414
+ self.post_init()
415
+
416
+ def get_input_embeddings(self):
417
+ return self.embed_tokens
418
+
419
+ def set_input_embeddings(self, value):
420
+ self.embed_tokens = value
421
+
422
+
423
+ def eval_mask(self, seqlen, block_size, cache_seq_len):
424
+ q_indices = torch.arange(seqlen) + cache_seq_len
425
+ k_indices = torch.arange(seqlen + cache_seq_len)
426
+ mask = eval_block_diff_mask(
427
+ q_idx=q_indices[:, None],
428
+ kv_idx=k_indices[None, :],
429
+ block_size=block_size
430
+ )
431
+ return mask
432
+
433
+ def gen_mask(self, seqlen, block_size, B, H):
434
+ mask = create_block_mask(
435
+ partial(block_diff_mask, block_size=block_size, n=seqlen),
436
+ B=B, H=H, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
437
+
438
+ return mask
439
+
440
+ def forward(
441
+ self,
442
+ input_ids: Optional[torch.LongTensor] = None,
443
+ labels: Optional[torch.LongTensor] = None,
444
+ attention_mask: Optional[torch.Tensor] = None,
445
+ position_ids: Optional[torch.LongTensor] = None,
446
+ past_key_values: Optional[Cache] = None,
447
+ inputs_embeds: Optional[torch.FloatTensor] = None,
448
+ use_cache: Optional[bool] = None,
449
+ cache_position: Optional[torch.LongTensor] = None,
450
+ update_past_key_values: Optional[bool] = False,
451
+ block_size: Optional[int] = 32,
452
+ use_block_cache: Optional[bool] = False,
453
+ block_past_key_values: Optional[Cache] = None,
454
+ replace_position: Optional[int] = None,
455
+ **kwargs
456
+ ) -> BaseModelOutputWithPast:
457
+ if (input_ids is None) ^ (inputs_embeds is not None):
458
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
459
+
460
+ if inputs_embeds is None:
461
+ inputs_embeds = self.embed_tokens(input_ids)
462
+
463
+ if use_cache and past_key_values is None:
464
+ past_key_values = DynamicCache()
465
+
466
+ if use_block_cache and block_past_key_values is None:
467
+ block_past_key_values = DynamicCache()
468
+
469
+ if cache_position is None:
470
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
471
+ if self.training:
472
+ cache_position = torch.arange(
473
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]//2, device=inputs_embeds.device
474
+ )
475
+ else:
476
+ if use_block_cache:
477
+ block_start_position = past_seen_tokens+replace_position if replace_position is not None else past_seen_tokens
478
+ cache_position = torch.arange(
479
+ block_start_position, block_start_position + inputs_embeds.shape[1], device=inputs_embeds.device
480
+ )
481
+ else:
482
+ cache_position = torch.arange(
483
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] if not self.training else inputs_embeds.shape[1]//2, device=inputs_embeds.device
484
+ )
485
+
486
+ if position_ids is None:
487
+ position_ids = cache_position.unsqueeze(0)
488
+
489
+ if self.training:
490
+ attention_mask = self.gen_mask(labels.shape[1], self.bd_size, labels.shape[0], self.config.num_attention_heads).to(device=inputs_embeds.device)
491
+ else:
492
+ if use_block_cache and block_past_key_values.get_seq_length() != 0:
493
+ attention_mask = None
494
+ else:
495
+ attention_mask = self.eval_mask(input_ids.shape[1], block_size, past_key_values.get_seq_length() if past_key_values is not None else 0).to(device=inputs_embeds.device)
496
+
497
+ hidden_states = inputs_embeds
498
+
499
+ # create position embeddings to be shared across the decoder layers
500
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
501
+
502
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
503
+ hidden_states = decoder_layer(
504
+ hidden_states,
505
+ attention_mask=attention_mask,
506
+ position_ids=position_ids,
507
+ past_key_value=past_key_values,
508
+ use_cache=use_cache,
509
+ cache_position=cache_position,
510
+ position_embeddings=position_embeddings,
511
+ update_past_key_values=update_past_key_values,
512
+ use_block_cache=use_block_cache,
513
+ block_past_key_values=block_past_key_values,
514
+ replace_position=replace_position,
515
+ **kwargs,
516
+ )
517
+
518
+ hidden_states = self.norm(hidden_states)
519
+ return BaseModelOutputWithPastAndBlockCache(
520
+ last_hidden_state=hidden_states,
521
+ past_key_values=past_key_values if use_cache else None,
522
+ block_past_key_values=block_past_key_values if use_block_cache else None,
523
+ )
524
+
525
+
526
+ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
527
+ _tied_weights_keys = ["lm_head.weight"]
528
+ _tp_plan = {"lm_head": "colwise_rep"}
529
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
530
+
531
+ def __init__(self, config):
532
+ super().__init__(config)
533
+ self.model = Fast_dLLM_QwenModel(config)
534
+ self.vocab_size = config.vocab_size
535
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
536
+
537
+ # Initialize weights and apply final processing
538
+ self.post_init()
539
+
540
+ def get_input_embeddings(self):
541
+ return self.model.embed_tokens
542
+
543
+ def set_input_embeddings(self, value):
544
+ self.model.embed_tokens = value
545
+
546
+ def get_output_embeddings(self):
547
+ return self.lm_head
548
+
549
+ def set_output_embeddings(self, new_embeddings):
550
+ self.lm_head = new_embeddings
551
+
552
+ def set_decoder(self, decoder):
553
+ self.model = decoder
554
+
555
+ def get_decoder(self):
556
+ return self.model
557
+
558
+ @can_return_tuple
559
+ def forward(
560
+ self,
561
+ input_ids: Optional[torch.LongTensor] = None,
562
+ attention_mask: Optional[torch.Tensor] = None,
563
+ position_ids: Optional[torch.LongTensor] = None,
564
+ past_key_values: Optional[Cache] = None,
565
+ inputs_embeds: Optional[torch.FloatTensor] = None,
566
+ labels: Optional[torch.LongTensor] = None,
567
+ use_cache: Optional[bool] = None,
568
+ cache_position: Optional[torch.LongTensor] = None,
569
+ logits_to_keep: Union[int, torch.Tensor] = 0,
570
+ update_past_key_values: Optional[bool] = False,
571
+ block_size: Optional[int] = 32,
572
+ use_block_cache: Optional[bool] = False,
573
+ block_past_key_values: Optional[Cache] = None,
574
+ replace_position: Optional[int] = None,
575
+ mask_id: Optional[int] = 151665,
576
+ **kwargs
577
+ ) -> CausalLMOutputWithPastAndBlockCache:
578
+
579
+ if self.training:
580
+ original_labels = labels.clone()
581
+ original_input_ids = input_ids.clone()
582
+
583
+ noisy_input_ids = input_ids.clone()
584
+
585
+ input_ids = input_ids.reshape(input_ids.shape[0] * input_ids.shape[1] // self.model.bd_size, self.model.bd_size)
586
+ b, l = input_ids.shape
587
+ t = torch.rand((b,), device=input_ids.device)
588
+ eps=1e-3
589
+ p_mask = (1 - eps) * t + eps
590
+ p_mask = p_mask[:, None].repeat(1, l)
591
+
592
+ mask_indices = torch.rand((b, l), device=input_ids.device) < p_mask
593
+ x_t = torch.where(mask_indices, mask_id, input_ids).reshape(labels.shape)
594
+ noisy_input_ids[labels != -100] = x_t[labels != -100]
595
+ mask = (noisy_input_ids != mask_id)
596
+ labels[mask] = -100
597
+ input_ids = torch.cat([noisy_input_ids, input_ids.reshape(labels.shape)], dim=1)
598
+
599
+ complementary_noisy_input_ids = original_input_ids.clone()
600
+ complementary_labels = original_labels.clone()
601
+
602
+ complementary_input_ids = original_input_ids.reshape(original_input_ids.shape[0] * original_input_ids.shape[1] // self.model.bd_size, self.model.bd_size)
603
+
604
+ complementary_mask_indices = ~mask_indices
605
+ complementary_x_t = torch.where(complementary_mask_indices, mask_id, complementary_input_ids).reshape(labels.shape)
606
+ complementary_noisy_input_ids[complementary_labels != -100] = complementary_x_t[complementary_labels != -100]
607
+ complementary_mask = (complementary_noisy_input_ids != mask_id)
608
+ complementary_labels[complementary_mask] = -100
609
+ complementary_input_ids = torch.cat([complementary_noisy_input_ids, complementary_input_ids.reshape(complementary_labels.shape)], dim=1)
610
+
611
+ input_ids = torch.cat([input_ids, complementary_input_ids], dim=0)
612
+ labels = torch.cat([labels, complementary_labels], dim=0)
613
+
614
+ outputs: BaseModelOutputWithPastAndBlockCache = self.model(
615
+ input_ids=input_ids,
616
+ labels=labels,
617
+ attention_mask=attention_mask,
618
+ position_ids=position_ids,
619
+ past_key_values=past_key_values,
620
+ inputs_embeds=inputs_embeds,
621
+ use_cache=use_cache,
622
+ cache_position=cache_position,
623
+ update_past_key_values=update_past_key_values,
624
+ block_size=block_size,
625
+ use_block_cache=use_block_cache,
626
+ block_past_key_values=block_past_key_values,
627
+ replace_position=replace_position,
628
+ **kwargs,
629
+ )
630
+
631
+ hidden_states = outputs.last_hidden_state
632
+ if self.training:
633
+ hidden_states = hidden_states[:, :hidden_states.shape[1]//2, :]
634
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
635
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
636
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
637
+
638
+ loss = None
639
+ if labels is not None:
640
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
641
+
642
+ return CausalLMOutputWithPastAndBlockCache(
643
+ loss=loss,
644
+ logits=logits,
645
+ past_key_values=outputs.past_key_values,
646
+ hidden_states=outputs.hidden_states,
647
+ attentions=outputs.attentions,
648
+ block_past_key_values=outputs.block_past_key_values,
649
+ )
650
+
651
+ @torch.no_grad()
652
+ def generate(
653
+ self,
654
+ input_ids,
655
+ max_new_tokens,
656
+ mask_id=151665,
657
+ threshold=1,
658
+ small_block_size=8,
659
+ block_size=32,
660
+ stop_token=151645,
661
+ stopping_criteria=None,
662
+ top_p=0.95,
663
+ temperature=0,
664
+ use_block_cache=False,
665
+ **kwargs
666
+ ):
667
+ num_blocks = max_new_tokens // block_size
668
+ original_input_length = input_ids.shape[1]
669
+
670
+ if input_ids.shape[1] > block_size:
671
+ output = self.forward(input_ids=input_ids[:, :(input_ids.shape[1] // block_size * block_size)], use_cache=True, update_past_key_values=True, block_size=block_size)
672
+ logits, past_key_values = output.logits, output.past_key_values
673
+ if input_ids.shape[1] % block_size == 0:
674
+ next_token = logits[:, -1:, :].argmax(dim=-1)
675
+ input_ids = torch.cat([input_ids, next_token], dim=1)
676
+ else:
677
+ past_key_values = None
678
+
679
+ num_small_blocks = block_size // small_block_size
680
+
681
+ for block_idx in range(num_blocks):
682
+ if stop_token in input_ids[:, original_input_length:]:
683
+ break
684
+ prompt_length = input_ids.shape[1]
685
+ # Initialize x_init with mask_id
686
+ x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long)
687
+ x_init = torch.cat([input_ids, x_init], dim=1)
688
+
689
+ x_t = x_init.clone()
690
+ block_past_key_values = None
691
+ while True:
692
+ if stop_token in x_t[:, prompt_length:]:
693
+ stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
694
+ if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
695
+ break
696
+ mask_idx = (x_t[:, -block_size:] == mask_id)
697
+ # Decode a complete block, update cache, and generate the next token
698
+ if mask_idx.sum() == 0:
699
+ output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=True, block_size=block_size)
700
+ logits, past_key_values = output.logits, output.past_key_values
701
+ next_token = logits[:, -1:, :].argmax(dim=-1)
702
+ x_t = torch.cat([x_t, next_token], dim=1)
703
+ break
704
+ for small_block_idx in range(num_small_blocks):
705
+ small_block_start_idx = small_block_idx * small_block_size
706
+ small_block_end_idx = small_block_start_idx + small_block_size
707
+
708
+ start = -block_size + small_block_start_idx
709
+ end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx
710
+ while True:
711
+ mask_idx = (x_t[:, -block_size:] == mask_id)
712
+ if mask_idx[:, start:end].sum() == 0:
713
+ break
714
+ if stop_token in x_t[:, prompt_length:]:
715
+ stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
716
+ if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
717
+ break
718
+
719
+ if use_block_cache:
720
+ if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
721
+ output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
722
+ logits, block_past_key_values = output.logits, output.block_past_key_values
723
+ logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
724
+ logits = logits[:, start:end]
725
+ else:
726
+ logits = self.forward(input_ids=x_t[:,start:end], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True, block_past_key_values=block_past_key_values, replace_position=small_block_start_idx).logits
727
+ logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
728
+ else:
729
+ logits = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False).logits
730
+ logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
731
+ logits = logits[:, start:end]
732
+
733
+
734
+ x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
735
+ # Select tokens with probability greater than threshold from p_1t
736
+ x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
737
+ x1_p = torch.where(mask_idx[:, start:end], x1_p, -torch.inf)
738
+
739
+ unmask_idx = (x1_p > threshold)
740
+ max_prob_idx = x1_p.argmax(dim=-1)
741
+ unmask_idx[torch.arange(x_1.shape[0]), max_prob_idx] = True
742
+ unmask_idx = unmask_idx & mask_idx[:, start:end]
743
+
744
+ x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
745
+
746
+ input_ids = x_t
747
+ # Truncate stop_token
748
+ if stop_token in input_ids[:, original_input_length:]:
749
+ stop_token_idx = (input_ids[:, original_input_length:] == stop_token).nonzero()[0][1]
750
+ input_ids = input_ids[:, :stop_token_idx+original_input_length+1]
751
+ return input_ids
752
+
753
+ def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
754
+ # Calculate probabilities
755
+ if temperature > 0:
756
+ scaled_logits = logits / temperature
757
+ else:
758
+ p_1t = torch.softmax(logits, dim=-1)
759
+ x_1 = p_1t.argmax(dim=-1)
760
+ return x_1, p_1t
761
+
762
+ probs = F.softmax(scaled_logits, dim=-1)
763
+
764
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
765
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
766
+
767
+ sorted_indices_to_remove = cumulative_probs > top_p
768
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
769
+ sorted_indices_to_remove[..., 0] = 0
770
+
771
+ indices_to_remove = torch.zeros_like(probs, dtype=torch.bool).scatter_(
772
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
773
+ )
774
+
775
+ probs[indices_to_remove] = 0
776
+
777
+ # Renormalize so that the probabilities of remaining tokens sum to 1
778
+ # Add a small epsilon value to prevent division by zero
779
+ probs_sum = torch.sum(probs, dim=-1, keepdim=True)
780
+ normalized_probs = probs / probs_sum
781
+
782
+ p_1t = normalized_probs
783
+ x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)
784
+
785
+ return x_1, p_1t