AD-Copilot / modeling_ad_copilot.py
jiang-cc's picture
Upload processor
b9b4987 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Any, Callable, Optional, Union
from transformers import Qwen2_5_VLForConditionalGeneration, AutoModelForImageTextToText
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLModel,
Qwen2RMSNorm,
Qwen2_5_VLMLP,
ALL_ATTENTION_FUNCTIONS
)
from transformers.image_utils import ImageInput
from transformers.tokenization_utils import TextInput, PreTokenizedInput
from transformers.video_utils import VideoInput
from transformers.feature_extraction_utils import BatchFeature
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLConfig
from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessorKwargs
class ADCopilotConfig(Qwen2_5_VLConfig):
model_type = "ad_copilot"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.vision_config.compare_token_size = 100
self.architectures = ["ADCopilotVLForConditionalGeneration"]
self.sequence_compare = True
class ADCopilotProcessor(Qwen2_5_VLProcessor):
config_class = ADCopilotConfig
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
super().__init__(image_processor, tokenizer, video_processor, chat_template, **kwargs)
self.compare_token_size = 100 if "compare_token_size" not in kwargs else kwargs["compare_token_size"]
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
videos: VideoInput = None,
**kwargs,
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `list[str]`, `list[list[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
"""
output_kwargs = self._merge_kwargs(
Qwen2_5_VLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
image_inputs = videos_inputs = {}
if images is not None:
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
image_grid_thw = image_inputs["image_grid_thw"]
if videos is not None:
fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]
if isinstance(fps, (int, float)):
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps]
else:
raise ValueError(
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
)
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
if not isinstance(text, list):
text = [text]
text = text.copy() # below lines change text in-place
if images is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.image_token in text[i]:
num_image_tokens = image_grid_thw[index].prod() // merge_length
# text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens), 1)
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens + self.compare_token_size), 1)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.image_token)
if videos is not None:
merge_length = self.video_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.video_token in text[i]:
num_video_tokens = video_grid_thw[index].prod() // merge_length
text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.video_token)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
class OptimizedCrossAttention(nn.Module):
"""
仿照 Qwen2_5_VLVisionAttention 结构的优化 Cross Attention
"""
def __init__(self, config, is_cross_attention=True):
super().__init__()
self.config = config
self.dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.dim // self.num_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = 0.0
self.is_causal = False # cross attention 不需要因果掩码
self.is_cross_attention = is_cross_attention
if is_cross_attention:
# Cross attention: Q 来自一个序列,K、V 来自另一个序列
self.q_proj = nn.Linear(self.dim, self.dim, bias=True)
self.kv = nn.Linear(self.dim, self.dim * 2, bias=True) # 融合 K、V
else:
# Self attention: Q、K、V 来自同一个序列
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) # 融合 Q、K、V
self.proj = nn.Linear(self.dim, self.dim, bias=True)
def forward(
self,
query_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None, # 只FA2用
kv_cu_seqlens: Optional[torch.Tensor] = None,# 只FA2用
**kwargs,
) -> torch.Tensor:
# 允许 query_states [B,T,d] 或 [T,d],自动扩展 batch 维
orig_2d = False
if query_states.dim() == 2:
query_states = query_states.unsqueeze(0)
orig_2d = True
batch_size, seq_len_q, _ = query_states.shape
# Q/K/V投影
if self.is_cross_attention and key_value_states is not None:
if key_value_states.dim() == 2:
key_value_states = key_value_states.unsqueeze(0)
q = self.q_proj(query_states)
kv = self.kv(key_value_states)
seq_len_kv = kv.shape[1]
k, v = kv.reshape(batch_size, seq_len_kv, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0)
q = q.reshape(batch_size, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2)
else:
if key_value_states is None:
key_value_states = query_states
qkv = self.qkv(query_states)
q, k, v = qkv.reshape(batch_size, seq_len_q, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0)
# 选用哪个 attention kernel
attn_impl = getattr(self.config, '_attn_implementation', 'sdpa')
attn_impl = 'sdpa'
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[attn_impl]
# ========= 支持 FA2 ==========
if attn_impl == "flash_attention_2":
# Qwen2_5 之所以能支持 FA2,是因为准备了 flatten+cu_seqlens
# 这里假设 query_states/key_value_states 按 batch 维是变长的
# 检查 cu_seqlens,有就用,否则尝试自动生成
if cu_seqlens is None:
# 默认把每个batch都视为长度=seq_len_q
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len_q, step=seq_len_q, dtype=torch.int32, device=q.device)
if kv_cu_seqlens is None:
cu_seqlens_k = torch.arange(0, (batch_size + 1) * k.shape[2], step=k.shape[2], dtype=torch.int32, device=k.device)
else:
cu_seqlens_k = kv_cu_seqlens
# flatten [B, nH, T, d] -> [total_T, nH, d]
# 注意!FlashAttn2是 (total, nH, d),不是 (nH, total, d),和普通实现不一样
# 更安全的 flatten 方式
# [B, nH, T, d] -> [B, T, nH, d] -> [total_T, nH, d]
q_ = q.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim)
k_ = k.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim)
v_ = v.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim)
max_seqlen_q = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item()
attn_output, _ = attention_interface(
self,
q_,
k_,
v_,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
cu_seq_lens_q=cu_seqlens,
cu_seq_lens_k=cu_seqlens_k,
max_length_q=max_seqlen_q,
max_length_k=max_seqlen_k,
is_causal=self.is_causal,
**kwargs,
)
# 更简洁的输出重构
# [total_q, nH, d] -> [B, seq_len_q, nH, d]
attn_output = attn_output.view(batch_size, seq_len_q, self.num_heads, self.head_dim).contiguous()
else:
# 普通实现,下游实现就是 [B, nH, T, d]
attn_output, _ = attention_interface(
self,
q, k, v,
attention_mask=attention_mask,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
is_causal=self.is_causal,
**kwargs,
)
# attn_output: [B, nH, seq_q, d]
attn_output = attn_output.transpose(1, 2).contiguous() # [B, seq_q, nH, d]
attn_output = attn_output.reshape(batch_size, seq_len_q, self.dim) # [B, seq_q, D]
attn_output = self.proj(attn_output)
if orig_2d:
attn_output = attn_output.squeeze(0)
return attn_output.contiguous()
class ADCopilotCompareVisualEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.sequence_compare = getattr(config, "sequence_compare", True)
self.hidden_size = config.hidden_size
# self.token_size = 100 * (config.spatial_merge_size**2) if "compare_token_size" not in config else config.compare_token_size * (config.spatial_merge_size**2)
self.token_size = 100 if "compare_token_size" not in config else config.compare_token_size
# Encoder 部分:双向图像特征交互
# 第一个cross attention: previous attend to current
self.encoder_cross_attn1 = OptimizedCrossAttention(config, is_cross_attention=True)
# 第二个cross attention: current attend to previous
self.encoder_cross_attn2 = OptimizedCrossAttention(config, is_cross_attention=True)
self.encoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
self.encoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
self.encoder_norm3 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
self.encoder_norm4 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
self.encoder_mlp1 = Qwen2_5_VLMLP(config)
self.encoder_mlp2 = Qwen2_5_VLMLP(config)
# Decoder 部分:Query 与编码特征交互
# 可学习的 Query Embeddings
self.query_embeddings = nn.Parameter(
torch.empty(self.token_size, self.hidden_size)
)
# 只保留 Cross Attention for queries to attend to encoded features
self.decoder_cross_attn = OptimizedCrossAttention(config, is_cross_attention=True)
self.decoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
self.decoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
self.decoder_mlp = Qwen2_5_VLMLP(config)
self.compare_projector = nn.Linear(config.hidden_size, config.out_hidden_size)
def init_query_embeddings(self):
nn.init.normal_(self.query_embeddings, mean=0.0, std=0.02)
def forward(self, images_hidden_states: list) -> torch.Tensor:
"""
Args:
images_hidden_states: List of tensor, each tensor has shape [seq_len, hidden_size]
Returns:
Tensor of shape [total_images, token_size, hidden_size]
"""
if not images_hidden_states:
return torch.empty(0, self.token_size, self.hidden_size)
# 检查 query_embeddings 是否包含 NaN
if torch.isnan(self.query_embeddings).any():
print("警告:query_embeddings 包含 NaN 值")
# nn.init.normal_(self.query_embeddings, mean=0.0, std=0.02)
# 获取每个图像的序列长度
seq_lengths = [state.size(0) for state in images_hidden_states]
max_seq_len = max(seq_lengths)
batch_size = len(images_hidden_states)
device = images_hidden_states[0].device
dtype = images_hidden_states[0].dtype
# 将所有图像填充到相同长度并堆叠
padded_states = []
attention_masks = []
for state in images_hidden_states:
pad_len = max_seq_len - state.size(0)
if pad_len > 0:
# 填充序列
padded_state = F.pad(state, (0, 0, 0, pad_len), mode='constant', value=0)
# 创建注意力掩码
attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device)
attention_mask[state.size(0):] = False
else:
padded_state = state
attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device)
padded_states.append(padded_state)
attention_masks.append(attention_mask)
# [batch_size, max_seq_len, hidden_size]
batched_states = torch.stack(padded_states)
# [batch_size, max_seq_len]
attention_masks = torch.stack(attention_masks)
# 创建循环移位的状态用于对比
# 对于第一个图像,使用自身作为previous
previous_states = torch.roll(batched_states, shifts=1, dims=0)
previous_masks = torch.roll(attention_masks, shifts=1, dims=0)
if previous_states.size(0) > 1 and self.sequence_compare:
previous_states[0] = previous_states[1]
previous_masks[0] = previous_masks[1]
# Encoder: 批量处理所有图像
encoded_features = self._encoder_forward(
batched_states, # [batch_size, max_seq_len, hidden_size]
previous_states, # [batch_size, max_seq_len, hidden_size]
attention_masks, # [batch_size, max_seq_len]
previous_masks # [batch_size, max_seq_len]
)
# Decoder: 批量处理所有图像
# 扩展query_embeddings到batch维度
batch_queries = self.query_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
# [batch_size, token_size, hidden_size]
compare_visual_embeds = self._decoder_forward(
batch_queries,
encoded_features,
torch.ones(batch_size, self.token_size, dtype=torch.bool, device=device), # query掩码
attention_masks # encoded特征的掩码
)
# 记录每个batch的token数量
batch_size = compare_visual_embeds.size(0)
token_size = compare_visual_embeds.size(1)
# 将所有batch的数据拼接在一起
# [batch_size * token_size, hidden_size]
flattened_embeds = compare_visual_embeds.view(-1, compare_visual_embeds.size(-1))
merged = self.compare_projector(flattened_embeds) # [batch_size * token_size, merged_hidden_size]
merged_token_size = token_size
# [batch_size, merged_token_size, merged_hidden_size]
compare_visual_embeds = merged.view(batch_size, merged_token_size, -1)
return compare_visual_embeds # [batch_size, token_size, out_hidden_size]
def _encoder_forward(self, current_features, previous_features, current_mask=None, previous_mask=None):
"""
Encoder: 双向图像特征交互
Args:
current_features: [batch_size, seq_len, hidden_size]
previous_features: [batch_size, seq_len, hidden_size]
current_mask: [batch_size, seq_len]
previous_mask: [batch_size, seq_len]
"""
# 第一步:previous attend to current
residual = previous_features
# Layer norm
previous_normed = self.encoder_norm1(previous_features)
current_normed1 = self.encoder_norm1(current_features)
# Cross attention: previous attend to current
cross_attn_output1 = self.encoder_cross_attn1(
query_states=previous_normed,
key_value_states=current_normed1,
attention_mask=current_mask.unsqueeze(1).unsqueeze(2) if current_mask is not None else None
)
# Residual connection
previous_features = residual + cross_attn_output1
# MLP for previous features
residual = previous_features
mlp_input1 = self.encoder_norm2(previous_features)
mlp_output1 = self.encoder_mlp1(mlp_input1)
previous_features = residual + mlp_output1
# 第二步:current attend to previous (enhanced)
residual = current_features
# Layer norm
current_normed2 = self.encoder_norm3(current_features)
previous_normed2 = self.encoder_norm3(previous_features)
# Cross attention: current attend to previous
cross_attn_output2 = self.encoder_cross_attn2(
query_states=current_normed2,
key_value_states=previous_normed2,
attention_mask=previous_mask.unsqueeze(1).unsqueeze(2) if previous_mask is not None else None
)
# Residual connection
current_features = residual + cross_attn_output2
# MLP for current features
residual = current_features
mlp_input2 = self.encoder_norm4(current_features)
mlp_output2 = self.encoder_mlp2(mlp_input2)
# current_features = residual + mlp_output2
# 修改为减法
current_features = residual - mlp_output2
return current_features
def _decoder_forward(self, queries, encoded_features, query_mask=None, encoded_mask=None):
"""
Decoder: Query 与编码特征交互
Args:
queries: [batch_size, token_size, hidden_size]
encoded_features: [batch_size, seq_len, hidden_size]
query_mask: [batch_size, token_size]
encoded_mask: [batch_size, seq_len]
"""
# Cross attention: queries attend to encoded features
residual = queries
queries_normed = self.decoder_norm1(queries)
encoded_normed = self.decoder_norm1(encoded_features)
cross_attn_output = self.decoder_cross_attn(
query_states=queries_normed,
key_value_states=encoded_normed,
attention_mask=encoded_mask.unsqueeze(1).unsqueeze(2) if encoded_mask is not None else None
)
queries = residual + cross_attn_output
# MLP
residual = queries
mlp_input = self.decoder_norm2(queries)
mlp_output = self.decoder_mlp(mlp_input)
queries = residual + mlp_output
return queries # [batch_size, token_size, hidden_size]
# 先把组件继承出来方便修改
class ADCopilotVisionTransformerPretrainedModel(Qwen2_5_VisionTransformerPretrainedModel):
def __init__(self, config, *inputs, **kwargs) -> None:
super().__init__(config, *inputs, **kwargs)
self.compare_visual_encoder = ADCopilotCompareVisualEncoder(config)
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
The final hidden states of the model.
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
The temporal, height and width of feature shape of each image in LLM.
Returns:
`torch.Tensor`: hidden_states, compare_visual_embeds.
"""
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
seq_len, _ = hidden_states.size()
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
# Select dtype based on the following factors:
# - FA2 requires that cu_seqlens_q must have dtype int32
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
# See https://github.com/huggingface/transformers/pull/34852 for more information
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
position_embeddings=position_embeddings,
**kwargs,
)
split_sizes = grid_thw.prod(-1).tolist()
splited_hidden_states_before_merger = torch.split(hidden_states, split_sizes)
# [total_images, token_size, hidden_size]
compare_visual_embeds = self.compare_visual_encoder(splited_hidden_states_before_merger)
hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states, compare_visual_embeds
class ADCopilotVLModel(Qwen2_5_VLModel):
def __init__(self, config):
super().__init__(config)
self.visual = ADCopilotVisionTransformerPretrainedModel._from_config(config.vision_config)
self.compare_token_size = config.vision_config.compare_token_size
# self.learnable_image_embeddings = nn.Parameter(
# torch.randn(100, config.hidden_size) * 0.02 # 使用小的初始化值
# )
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds, compare_visual_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
# 每个图像添加了对比感知token
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
image_embeds = torch.split(image_embeds, split_sizes)
# 将图像嵌入和对比视觉嵌入拼接
enhanced_image_embeds = []
for i, embeds in enumerate(image_embeds):
# 确保 compare_visual_embeds[i] 与 embeds 在相同设备和数据类型
compare_embed = compare_visual_embeds[i].to(device=embeds.device, dtype=embeds.dtype)
enhanced_embeds = torch.cat([embeds, compare_embed], dim=0)
enhanced_image_embeds.append(enhanced_embeds)
# image_embeds = torch.cat(enhanced_image_embeds, dim=0)
return enhanced_image_embeds
def get_rope_index(self, input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]:
return self.get_rope_index_with_compare_token(input_ids, image_grid_thw, video_grid_thw, second_per_grid_ts, attention_mask)
def get_rope_index_with_compare_token(
self,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
spatial_merge_size = self.config.vision_config.spatial_merge_size
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
mrope_position_deltas = []
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_index, video_index = 0, 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for vision_index in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
second_per_grid_t = 0
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
## normalize type, send to device.
second_per_grid_t = torch.as_tensor(
second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device
)
time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if ed_image < ed_video:
# 如果当前是图片,则需要插入 compare_token_size 个图像对比的token的position
compare_t_index = t_index[-1].repeat(self.compare_token_size)
# compare_h_index = torch.arange(self.compare_token_size)
# compare_w_index = torch.arange(self.compare_token_size)
compare_h_index = compare_t_index
compare_w_index = compare_t_index
llm_pos_ids_list.append(torch.stack([compare_t_index, compare_h_index, compare_w_index]) + text_len + st_idx)
st = st + self.compare_token_size
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas
class ADCopilotVLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
config_class = ADCopilotConfig
def __init__(self, config):
super().__init__(config)
self.model = ADCopilotVLModel(config)