|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from typing import Any, Callable, Optional, Union |
|
|
|
|
|
from transformers import Qwen2_5_VLForConditionalGeneration, AutoModelForImageTextToText |
|
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( |
|
|
Qwen2_5_VisionTransformerPretrainedModel, |
|
|
Qwen2_5_VLModel, |
|
|
Qwen2RMSNorm, |
|
|
Qwen2_5_VLMLP, |
|
|
ALL_ATTENTION_FUNCTIONS |
|
|
) |
|
|
from transformers.image_utils import ImageInput |
|
|
from transformers.tokenization_utils import TextInput, PreTokenizedInput |
|
|
from transformers.video_utils import VideoInput |
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
|
|
|
|
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLConfig |
|
|
from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessorKwargs |
|
|
|
|
|
class ADCopilotConfig(Qwen2_5_VLConfig): |
|
|
model_type = "ad_copilot" |
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.vision_config.compare_token_size = 100 |
|
|
self.architectures = ["ADCopilotVLForConditionalGeneration"] |
|
|
self.sequence_compare = True |
|
|
|
|
|
class ADCopilotProcessor(Qwen2_5_VLProcessor): |
|
|
config_class = ADCopilotConfig |
|
|
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs): |
|
|
super().__init__(image_processor, tokenizer, video_processor, chat_template, **kwargs) |
|
|
self.compare_token_size = 100 if "compare_token_size" not in kwargs else kwargs["compare_token_size"] |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
images: ImageInput = None, |
|
|
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, |
|
|
videos: VideoInput = None, |
|
|
**kwargs, |
|
|
) -> BatchFeature: |
|
|
""" |
|
|
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` |
|
|
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode |
|
|
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to |
|
|
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. |
|
|
|
|
|
Args: |
|
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): |
|
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch |
|
|
tensor. Both channels-first and channels-last formats are supported. |
|
|
text (`str`, `list[str]`, `list[list[str]]`): |
|
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings |
|
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set |
|
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). |
|
|
videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): |
|
|
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch |
|
|
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. |
|
|
return_tensors (`str` or [`~utils.TensorType`], *optional*): |
|
|
If set, will return tensors of a particular framework. Acceptable values are: |
|
|
- `'tf'`: Return TensorFlow `tf.constant` objects. |
|
|
- `'pt'`: Return PyTorch `torch.Tensor` objects. |
|
|
- `'np'`: Return NumPy `np.ndarray` objects. |
|
|
- `'jax'`: Return JAX `jnp.ndarray` objects. |
|
|
|
|
|
Returns: |
|
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields: |
|
|
|
|
|
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. |
|
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when |
|
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not |
|
|
`None`). |
|
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. |
|
|
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. |
|
|
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. |
|
|
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. |
|
|
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. |
|
|
""" |
|
|
output_kwargs = self._merge_kwargs( |
|
|
Qwen2_5_VLProcessorKwargs, |
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
image_inputs = videos_inputs = {} |
|
|
if images is not None: |
|
|
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) |
|
|
image_grid_thw = image_inputs["image_grid_thw"] |
|
|
|
|
|
if videos is not None: |
|
|
fps = output_kwargs["videos_kwargs"].get("fps", 2.0) |
|
|
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) |
|
|
video_grid_thw = videos_inputs["video_grid_thw"] |
|
|
|
|
|
if isinstance(fps, (int, float)): |
|
|
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw) |
|
|
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): |
|
|
second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps] |
|
|
else: |
|
|
raise ValueError( |
|
|
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." |
|
|
) |
|
|
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) |
|
|
|
|
|
if not isinstance(text, list): |
|
|
text = [text] |
|
|
|
|
|
text = text.copy() |
|
|
if images is not None: |
|
|
merge_length = self.image_processor.merge_size**2 |
|
|
index = 0 |
|
|
for i in range(len(text)): |
|
|
while self.image_token in text[i]: |
|
|
num_image_tokens = image_grid_thw[index].prod() // merge_length |
|
|
|
|
|
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens + self.compare_token_size), 1) |
|
|
index += 1 |
|
|
text[i] = text[i].replace("<|placeholder|>", self.image_token) |
|
|
|
|
|
if videos is not None: |
|
|
merge_length = self.video_processor.merge_size**2 |
|
|
index = 0 |
|
|
for i in range(len(text)): |
|
|
while self.video_token in text[i]: |
|
|
num_video_tokens = video_grid_thw[index].prod() // merge_length |
|
|
text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1) |
|
|
index += 1 |
|
|
text[i] = text[i].replace("<|placeholder|>", self.video_token) |
|
|
|
|
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) |
|
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) |
|
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) |
|
|
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) |
|
|
|
|
|
if return_mm_token_type_ids: |
|
|
array_ids = np.array(text_inputs["input_ids"]) |
|
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) |
|
|
mm_token_type_ids[array_ids == self.image_token_id] = 1 |
|
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() |
|
|
|
|
|
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) |
|
|
|
|
|
|
|
|
class OptimizedCrossAttention(nn.Module): |
|
|
""" |
|
|
仿照 Qwen2_5_VLVisionAttention 结构的优化 Cross Attention |
|
|
""" |
|
|
def __init__(self, config, is_cross_attention=True): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.dim = config.hidden_size |
|
|
self.num_heads = config.num_heads |
|
|
self.head_dim = self.dim // self.num_heads |
|
|
self.scaling = self.head_dim**-0.5 |
|
|
self.attention_dropout = 0.0 |
|
|
self.is_causal = False |
|
|
self.is_cross_attention = is_cross_attention |
|
|
|
|
|
if is_cross_attention: |
|
|
|
|
|
self.q_proj = nn.Linear(self.dim, self.dim, bias=True) |
|
|
self.kv = nn.Linear(self.dim, self.dim * 2, bias=True) |
|
|
else: |
|
|
|
|
|
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) |
|
|
|
|
|
self.proj = nn.Linear(self.dim, self.dim, bias=True) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query_states: torch.Tensor, |
|
|
key_value_states: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
|
kv_cu_seqlens: Optional[torch.Tensor] = None, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
orig_2d = False |
|
|
if query_states.dim() == 2: |
|
|
query_states = query_states.unsqueeze(0) |
|
|
orig_2d = True |
|
|
|
|
|
batch_size, seq_len_q, _ = query_states.shape |
|
|
|
|
|
|
|
|
if self.is_cross_attention and key_value_states is not None: |
|
|
if key_value_states.dim() == 2: |
|
|
key_value_states = key_value_states.unsqueeze(0) |
|
|
q = self.q_proj(query_states) |
|
|
kv = self.kv(key_value_states) |
|
|
seq_len_kv = kv.shape[1] |
|
|
k, v = kv.reshape(batch_size, seq_len_kv, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0) |
|
|
q = q.reshape(batch_size, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
else: |
|
|
if key_value_states is None: |
|
|
key_value_states = query_states |
|
|
qkv = self.qkv(query_states) |
|
|
q, k, v = qkv.reshape(batch_size, seq_len_q, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0) |
|
|
|
|
|
|
|
|
attn_impl = getattr(self.config, '_attn_implementation', 'sdpa') |
|
|
attn_impl = 'sdpa' |
|
|
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[attn_impl] |
|
|
|
|
|
|
|
|
if attn_impl == "flash_attention_2": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if cu_seqlens is None: |
|
|
|
|
|
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len_q, step=seq_len_q, dtype=torch.int32, device=q.device) |
|
|
if kv_cu_seqlens is None: |
|
|
cu_seqlens_k = torch.arange(0, (batch_size + 1) * k.shape[2], step=k.shape[2], dtype=torch.int32, device=k.device) |
|
|
else: |
|
|
cu_seqlens_k = kv_cu_seqlens |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q_ = q.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim) |
|
|
k_ = k.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim) |
|
|
v_ = v.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim) |
|
|
|
|
|
max_seqlen_q = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() |
|
|
max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() |
|
|
|
|
|
attn_output, _ = attention_interface( |
|
|
self, |
|
|
q_, |
|
|
k_, |
|
|
v_, |
|
|
attention_mask=None, |
|
|
scaling=self.scaling, |
|
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
|
cu_seq_lens_q=cu_seqlens, |
|
|
cu_seq_lens_k=cu_seqlens_k, |
|
|
max_length_q=max_seqlen_q, |
|
|
max_length_k=max_seqlen_k, |
|
|
is_causal=self.is_causal, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
attn_output = attn_output.view(batch_size, seq_len_q, self.num_heads, self.head_dim).contiguous() |
|
|
else: |
|
|
|
|
|
attn_output, _ = attention_interface( |
|
|
self, |
|
|
q, k, v, |
|
|
attention_mask=attention_mask, |
|
|
scaling=self.scaling, |
|
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
|
is_causal=self.is_causal, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
|
|
|
attn_output = attn_output.reshape(batch_size, seq_len_q, self.dim) |
|
|
attn_output = self.proj(attn_output) |
|
|
if orig_2d: |
|
|
attn_output = attn_output.squeeze(0) |
|
|
return attn_output.contiguous() |
|
|
|
|
|
|
|
|
class ADCopilotCompareVisualEncoder(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.sequence_compare = getattr(config, "sequence_compare", True) |
|
|
self.hidden_size = config.hidden_size |
|
|
|
|
|
self.token_size = 100 if "compare_token_size" not in config else config.compare_token_size |
|
|
|
|
|
|
|
|
self.encoder_cross_attn1 = OptimizedCrossAttention(config, is_cross_attention=True) |
|
|
|
|
|
self.encoder_cross_attn2 = OptimizedCrossAttention(config, is_cross_attention=True) |
|
|
|
|
|
self.encoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) |
|
|
self.encoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) |
|
|
self.encoder_norm3 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) |
|
|
self.encoder_norm4 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) |
|
|
self.encoder_mlp1 = Qwen2_5_VLMLP(config) |
|
|
self.encoder_mlp2 = Qwen2_5_VLMLP(config) |
|
|
|
|
|
|
|
|
|
|
|
self.query_embeddings = nn.Parameter( |
|
|
torch.empty(self.token_size, self.hidden_size) |
|
|
) |
|
|
|
|
|
self.decoder_cross_attn = OptimizedCrossAttention(config, is_cross_attention=True) |
|
|
|
|
|
self.decoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) |
|
|
self.decoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) |
|
|
self.decoder_mlp = Qwen2_5_VLMLP(config) |
|
|
|
|
|
self.compare_projector = nn.Linear(config.hidden_size, config.out_hidden_size) |
|
|
|
|
|
def init_query_embeddings(self): |
|
|
nn.init.normal_(self.query_embeddings, mean=0.0, std=0.02) |
|
|
|
|
|
def forward(self, images_hidden_states: list) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
images_hidden_states: List of tensor, each tensor has shape [seq_len, hidden_size] |
|
|
|
|
|
Returns: |
|
|
Tensor of shape [total_images, token_size, hidden_size] |
|
|
""" |
|
|
if not images_hidden_states: |
|
|
return torch.empty(0, self.token_size, self.hidden_size) |
|
|
|
|
|
|
|
|
if torch.isnan(self.query_embeddings).any(): |
|
|
print("警告:query_embeddings 包含 NaN 值") |
|
|
|
|
|
|
|
|
|
|
|
seq_lengths = [state.size(0) for state in images_hidden_states] |
|
|
max_seq_len = max(seq_lengths) |
|
|
batch_size = len(images_hidden_states) |
|
|
device = images_hidden_states[0].device |
|
|
dtype = images_hidden_states[0].dtype |
|
|
|
|
|
|
|
|
padded_states = [] |
|
|
attention_masks = [] |
|
|
for state in images_hidden_states: |
|
|
pad_len = max_seq_len - state.size(0) |
|
|
if pad_len > 0: |
|
|
|
|
|
padded_state = F.pad(state, (0, 0, 0, pad_len), mode='constant', value=0) |
|
|
|
|
|
attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device) |
|
|
attention_mask[state.size(0):] = False |
|
|
else: |
|
|
padded_state = state |
|
|
attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device) |
|
|
padded_states.append(padded_state) |
|
|
attention_masks.append(attention_mask) |
|
|
|
|
|
|
|
|
batched_states = torch.stack(padded_states) |
|
|
|
|
|
attention_masks = torch.stack(attention_masks) |
|
|
|
|
|
|
|
|
|
|
|
previous_states = torch.roll(batched_states, shifts=1, dims=0) |
|
|
previous_masks = torch.roll(attention_masks, shifts=1, dims=0) |
|
|
|
|
|
if previous_states.size(0) > 1 and self.sequence_compare: |
|
|
previous_states[0] = previous_states[1] |
|
|
previous_masks[0] = previous_masks[1] |
|
|
|
|
|
|
|
|
encoded_features = self._encoder_forward( |
|
|
batched_states, |
|
|
previous_states, |
|
|
attention_masks, |
|
|
previous_masks |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
batch_queries = self.query_embeddings.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
|
|
|
compare_visual_embeds = self._decoder_forward( |
|
|
batch_queries, |
|
|
encoded_features, |
|
|
torch.ones(batch_size, self.token_size, dtype=torch.bool, device=device), |
|
|
attention_masks |
|
|
) |
|
|
|
|
|
|
|
|
batch_size = compare_visual_embeds.size(0) |
|
|
token_size = compare_visual_embeds.size(1) |
|
|
|
|
|
|
|
|
flattened_embeds = compare_visual_embeds.view(-1, compare_visual_embeds.size(-1)) |
|
|
merged = self.compare_projector(flattened_embeds) |
|
|
merged_token_size = token_size |
|
|
|
|
|
compare_visual_embeds = merged.view(batch_size, merged_token_size, -1) |
|
|
|
|
|
return compare_visual_embeds |
|
|
|
|
|
def _encoder_forward(self, current_features, previous_features, current_mask=None, previous_mask=None): |
|
|
""" |
|
|
Encoder: 双向图像特征交互 |
|
|
Args: |
|
|
current_features: [batch_size, seq_len, hidden_size] |
|
|
previous_features: [batch_size, seq_len, hidden_size] |
|
|
current_mask: [batch_size, seq_len] |
|
|
previous_mask: [batch_size, seq_len] |
|
|
""" |
|
|
|
|
|
residual = previous_features |
|
|
|
|
|
|
|
|
previous_normed = self.encoder_norm1(previous_features) |
|
|
current_normed1 = self.encoder_norm1(current_features) |
|
|
|
|
|
|
|
|
cross_attn_output1 = self.encoder_cross_attn1( |
|
|
query_states=previous_normed, |
|
|
key_value_states=current_normed1, |
|
|
attention_mask=current_mask.unsqueeze(1).unsqueeze(2) if current_mask is not None else None |
|
|
) |
|
|
|
|
|
|
|
|
previous_features = residual + cross_attn_output1 |
|
|
|
|
|
|
|
|
residual = previous_features |
|
|
mlp_input1 = self.encoder_norm2(previous_features) |
|
|
mlp_output1 = self.encoder_mlp1(mlp_input1) |
|
|
previous_features = residual + mlp_output1 |
|
|
|
|
|
|
|
|
residual = current_features |
|
|
|
|
|
|
|
|
current_normed2 = self.encoder_norm3(current_features) |
|
|
previous_normed2 = self.encoder_norm3(previous_features) |
|
|
|
|
|
|
|
|
cross_attn_output2 = self.encoder_cross_attn2( |
|
|
query_states=current_normed2, |
|
|
key_value_states=previous_normed2, |
|
|
attention_mask=previous_mask.unsqueeze(1).unsqueeze(2) if previous_mask is not None else None |
|
|
) |
|
|
|
|
|
|
|
|
current_features = residual + cross_attn_output2 |
|
|
|
|
|
|
|
|
residual = current_features |
|
|
mlp_input2 = self.encoder_norm4(current_features) |
|
|
mlp_output2 = self.encoder_mlp2(mlp_input2) |
|
|
|
|
|
|
|
|
current_features = residual - mlp_output2 |
|
|
return current_features |
|
|
|
|
|
def _decoder_forward(self, queries, encoded_features, query_mask=None, encoded_mask=None): |
|
|
""" |
|
|
Decoder: Query 与编码特征交互 |
|
|
Args: |
|
|
queries: [batch_size, token_size, hidden_size] |
|
|
encoded_features: [batch_size, seq_len, hidden_size] |
|
|
query_mask: [batch_size, token_size] |
|
|
encoded_mask: [batch_size, seq_len] |
|
|
""" |
|
|
|
|
|
residual = queries |
|
|
queries_normed = self.decoder_norm1(queries) |
|
|
encoded_normed = self.decoder_norm1(encoded_features) |
|
|
|
|
|
cross_attn_output = self.decoder_cross_attn( |
|
|
query_states=queries_normed, |
|
|
key_value_states=encoded_normed, |
|
|
attention_mask=encoded_mask.unsqueeze(1).unsqueeze(2) if encoded_mask is not None else None |
|
|
) |
|
|
|
|
|
queries = residual + cross_attn_output |
|
|
|
|
|
|
|
|
residual = queries |
|
|
mlp_input = self.decoder_norm2(queries) |
|
|
mlp_output = self.decoder_mlp(mlp_input) |
|
|
queries = residual + mlp_output |
|
|
|
|
|
return queries |
|
|
|
|
|
|
|
|
|
|
|
class ADCopilotVisionTransformerPretrainedModel(Qwen2_5_VisionTransformerPretrainedModel): |
|
|
def __init__(self, config, *inputs, **kwargs) -> None: |
|
|
super().__init__(config, *inputs, **kwargs) |
|
|
self.compare_visual_encoder = ADCopilotCompareVisualEncoder(config) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): |
|
|
The final hidden states of the model. |
|
|
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): |
|
|
The temporal, height and width of feature shape of each image in LLM. |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor`: hidden_states, compare_visual_embeds. |
|
|
""" |
|
|
hidden_states = self.patch_embed(hidden_states) |
|
|
rotary_pos_emb = self.rot_pos_emb(grid_thw) |
|
|
window_index, cu_window_seqlens = self.get_window_index(grid_thw) |
|
|
cu_window_seqlens = torch.tensor( |
|
|
cu_window_seqlens, |
|
|
device=hidden_states.device, |
|
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
|
|
) |
|
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) |
|
|
|
|
|
seq_len, _ = hidden_states.size() |
|
|
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) |
|
|
hidden_states = hidden_states[window_index, :, :] |
|
|
hidden_states = hidden_states.reshape(seq_len, -1) |
|
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) |
|
|
rotary_pos_emb = rotary_pos_emb[window_index, :, :] |
|
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) |
|
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
|
|
position_embeddings = (emb.cos(), emb.sin()) |
|
|
|
|
|
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( |
|
|
dim=0, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
|
|
) |
|
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
|
|
|
|
|
for layer_num, blk in enumerate(self.blocks): |
|
|
if layer_num in self.fullatt_block_indexes: |
|
|
cu_seqlens_now = cu_seqlens |
|
|
else: |
|
|
cu_seqlens_now = cu_window_seqlens |
|
|
|
|
|
hidden_states = blk( |
|
|
hidden_states, |
|
|
cu_seqlens=cu_seqlens_now, |
|
|
position_embeddings=position_embeddings, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
split_sizes = grid_thw.prod(-1).tolist() |
|
|
splited_hidden_states_before_merger = torch.split(hidden_states, split_sizes) |
|
|
|
|
|
compare_visual_embeds = self.compare_visual_encoder(splited_hidden_states_before_merger) |
|
|
|
|
|
|
|
|
hidden_states = self.merger(hidden_states) |
|
|
reverse_indices = torch.argsort(window_index) |
|
|
hidden_states = hidden_states[reverse_indices, :] |
|
|
|
|
|
return hidden_states, compare_visual_embeds |
|
|
|
|
|
class ADCopilotVLModel(Qwen2_5_VLModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.visual = ADCopilotVisionTransformerPretrainedModel._from_config(config.vision_config) |
|
|
self.compare_token_size = config.vision_config.compare_token_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): |
|
|
""" |
|
|
Encodes images into continuous embeddings that can be forwarded to the language model. |
|
|
|
|
|
Args: |
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): |
|
|
The tensors corresponding to the input images. |
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): |
|
|
The temporal, height and width of feature shape of each image in LLM. |
|
|
""" |
|
|
pixel_values = pixel_values.type(self.visual.dtype) |
|
|
image_embeds, compare_visual_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) |
|
|
|
|
|
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() |
|
|
image_embeds = torch.split(image_embeds, split_sizes) |
|
|
|
|
|
|
|
|
enhanced_image_embeds = [] |
|
|
for i, embeds in enumerate(image_embeds): |
|
|
|
|
|
compare_embed = compare_visual_embeds[i].to(device=embeds.device, dtype=embeds.dtype) |
|
|
enhanced_embeds = torch.cat([embeds, compare_embed], dim=0) |
|
|
enhanced_image_embeds.append(enhanced_embeds) |
|
|
|
|
|
|
|
|
return enhanced_image_embeds |
|
|
|
|
|
def get_rope_index(self, input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
return self.get_rope_index_with_compare_token(input_ids, image_grid_thw, video_grid_thw, second_per_grid_ts, attention_mask) |
|
|
|
|
|
def get_rope_index_with_compare_token( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
second_per_grid_ts: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
spatial_merge_size = self.config.vision_config.spatial_merge_size |
|
|
image_token_id = self.config.image_token_id |
|
|
video_token_id = self.config.video_token_id |
|
|
vision_start_token_id = self.config.vision_start_token_id |
|
|
mrope_position_deltas = [] |
|
|
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): |
|
|
total_input_ids = input_ids |
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones_like(total_input_ids) |
|
|
position_ids = torch.ones( |
|
|
3, |
|
|
input_ids.shape[0], |
|
|
input_ids.shape[1], |
|
|
dtype=input_ids.dtype, |
|
|
device=input_ids.device, |
|
|
) |
|
|
image_index, video_index = 0, 0 |
|
|
attention_mask = attention_mask.to(total_input_ids.device) |
|
|
for i, input_ids in enumerate(total_input_ids): |
|
|
input_ids = input_ids[attention_mask[i] == 1] |
|
|
image_nums, video_nums = 0, 0 |
|
|
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) |
|
|
vision_tokens = input_ids[vision_start_indices + 1] |
|
|
image_nums = (vision_tokens == image_token_id).sum() |
|
|
video_nums = (vision_tokens == video_token_id).sum() |
|
|
input_tokens = input_ids.tolist() |
|
|
llm_pos_ids_list: list = [] |
|
|
st = 0 |
|
|
remain_images, remain_videos = image_nums, video_nums |
|
|
for vision_index in range(image_nums + video_nums): |
|
|
if image_token_id in input_tokens and remain_images > 0: |
|
|
ed_image = input_tokens.index(image_token_id, st) |
|
|
else: |
|
|
ed_image = len(input_tokens) + 1 |
|
|
if video_token_id in input_tokens and remain_videos > 0: |
|
|
ed_video = input_tokens.index(video_token_id, st) |
|
|
else: |
|
|
ed_video = len(input_tokens) + 1 |
|
|
if ed_image < ed_video: |
|
|
t, h, w = ( |
|
|
image_grid_thw[image_index][0], |
|
|
image_grid_thw[image_index][1], |
|
|
image_grid_thw[image_index][2], |
|
|
) |
|
|
second_per_grid_t = 0 |
|
|
image_index += 1 |
|
|
remain_images -= 1 |
|
|
ed = ed_image |
|
|
|
|
|
else: |
|
|
t, h, w = ( |
|
|
video_grid_thw[video_index][0], |
|
|
video_grid_thw[video_index][1], |
|
|
video_grid_thw[video_index][2], |
|
|
) |
|
|
if second_per_grid_ts is not None: |
|
|
second_per_grid_t = second_per_grid_ts[video_index] |
|
|
else: |
|
|
second_per_grid_t = 1.0 |
|
|
video_index += 1 |
|
|
remain_videos -= 1 |
|
|
ed = ed_video |
|
|
llm_grid_t, llm_grid_h, llm_grid_w = ( |
|
|
t.item(), |
|
|
h.item() // spatial_merge_size, |
|
|
w.item() // spatial_merge_size, |
|
|
) |
|
|
text_len = ed - st |
|
|
|
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
|
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) |
|
|
|
|
|
range_tensor = torch.arange(llm_grid_t).view(-1, 1) |
|
|
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) |
|
|
|
|
|
|
|
|
second_per_grid_t = torch.as_tensor( |
|
|
second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device |
|
|
) |
|
|
|
|
|
time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second |
|
|
|
|
|
time_tensor_long = time_tensor.long() |
|
|
t_index = time_tensor_long.flatten() |
|
|
|
|
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() |
|
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() |
|
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) |
|
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w |
|
|
if ed_image < ed_video: |
|
|
|
|
|
compare_t_index = t_index[-1].repeat(self.compare_token_size) |
|
|
|
|
|
|
|
|
compare_h_index = compare_t_index |
|
|
compare_w_index = compare_t_index |
|
|
llm_pos_ids_list.append(torch.stack([compare_t_index, compare_h_index, compare_w_index]) + text_len + st_idx) |
|
|
st = st + self.compare_token_size |
|
|
|
|
|
if st < len(input_tokens): |
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
|
|
text_len = len(input_tokens) - st |
|
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) |
|
|
|
|
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) |
|
|
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) |
|
|
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) |
|
|
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) |
|
|
return position_ids, mrope_position_deltas |
|
|
else: |
|
|
if attention_mask is not None: |
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) |
|
|
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] |
|
|
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] |
|
|
else: |
|
|
position_ids = ( |
|
|
torch.arange(input_ids.shape[1], device=input_ids.device) |
|
|
.view(1, 1, -1) |
|
|
.expand(3, input_ids.shape[0], -1) |
|
|
) |
|
|
mrope_position_deltas = torch.zeros( |
|
|
[input_ids.shape[0], 1], |
|
|
device=input_ids.device, |
|
|
dtype=input_ids.dtype, |
|
|
) |
|
|
|
|
|
return position_ids, mrope_position_deltas |
|
|
|
|
|
class ADCopilotVLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): |
|
|
config_class = ADCopilotConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model = ADCopilotVLModel(config) |