Spaces:
Running
on
Zero
Running
on
Zero
| from .attention import FlashAttentionRope | |
| from .block import BlockRope | |
| from ..dinov2.layers import Mlp | |
| import torch.nn as nn | |
| from functools import partial | |
| from torch.utils.checkpoint import checkpoint | |
| import torch.nn.functional as F | |
| class TransformerDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_dim, | |
| out_dim, | |
| dec_embed_dim=512, | |
| depth=5, | |
| dec_num_heads=8, | |
| mlp_ratio=4, | |
| rope=None, | |
| need_project=True, | |
| use_checkpoint=False, | |
| ): | |
| super().__init__() | |
| self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity() | |
| self.use_checkpoint = use_checkpoint | |
| self.blocks = nn.ModuleList([ | |
| BlockRope( | |
| dim=dec_embed_dim, | |
| num_heads=dec_num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=True, | |
| proj_bias=True, | |
| ffn_bias=True, | |
| drop_path=0.0, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| act_layer=nn.GELU, | |
| ffn_layer=Mlp, | |
| init_values=None, | |
| qk_norm=False, | |
| # attn_class=MemEffAttentionRope, | |
| attn_class=FlashAttentionRope, | |
| rope=rope | |
| ) for _ in range(depth)]) | |
| self.linear_out = nn.Linear(dec_embed_dim, out_dim) | |
| def forward(self, hidden, xpos=None): | |
| hidden = self.projects(hidden) | |
| for i, blk in enumerate(self.blocks): | |
| if self.use_checkpoint and self.training: | |
| hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False) | |
| else: | |
| hidden = blk(hidden, xpos=xpos) | |
| out = self.linear_out(hidden) | |
| return out | |
| class LinearPts3d (nn.Module): | |
| """ | |
| Linear head for dust3r | |
| Each token outputs: - 16x16 3D points (+ confidence) | |
| """ | |
| def __init__(self, patch_size, dec_embed_dim, output_dim=3,): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.proj = nn.Linear(dec_embed_dim, (output_dim)*self.patch_size**2) | |
| def forward(self, decout, img_shape): | |
| H, W = img_shape | |
| tokens = decout[-1] | |
| B, S, D = tokens.shape | |
| # extract 3D points | |
| feat = self.proj(tokens) # B,S,D | |
| feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) | |
| feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W | |
| # permute + norm depth | |
| return feat.permute(0, 2, 3, 1) |