Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| BSD 3-Clause License | |
| Copyright (c) Soumith Chintala 2016, | |
| All rights reserved. | |
| Redistribution and use in source and binary forms, with or without | |
| modification, are permitted provided that the following conditions are met: | |
| * Redistributions of source code must retain the above copyright notice, this | |
| list of conditions and the following disclaimer. | |
| * Redistributions in binary form must reproduce the above copyright notice, | |
| this list of conditions and the following disclaimer in the documentation | |
| and/or other materials provided with the distribution. | |
| * Neither the name of the copyright holder nor the names of its | |
| contributors may be used to endorse or promote products derived from | |
| this software without specific prior written permission. | |
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
| DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
| FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
| DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
| SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
| CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
| OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
| OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| """ | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| __all__ = ["DeepLabV3Decoder"] | |
| class DeepLabV3Decoder(nn.Sequential): | |
| def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)): | |
| super().__init__( | |
| ASPP(in_channels, out_channels, atrous_rates), | |
| nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), | |
| # nn.BatchNorm2d(out_channels), #remove bn following https://arxiv.org/abs/2305.02310 | |
| nn.ReLU(), | |
| ) | |
| self.out_channels = out_channels | |
| def forward(self, *features): | |
| return super().forward(features[-1]) | |
| class ASPPConv(nn.Sequential): | |
| def __init__(self, in_channels, out_channels, dilation): | |
| super().__init__( | |
| nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=dilation, | |
| dilation=dilation, | |
| bias=False, | |
| ), | |
| # nn.BatchNorm2d(out_channels), #remove bn following https://arxiv.org/abs/2305.02310 | |
| nn.ReLU(), | |
| ) | |
| class ASPPSeparableConv(nn.Sequential): | |
| def __init__(self, in_channels, out_channels, dilation): | |
| super().__init__( | |
| SeparableConv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=dilation, | |
| dilation=dilation, | |
| bias=False, | |
| ), | |
| # nn.BatchNorm2d(out_channels), #remove bn following https://arxiv.org/abs/2305.02310 | |
| nn.ReLU(), | |
| ) | |
| class ASPPPooling(nn.Sequential): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), | |
| # nn.BatchNorm2d(out_channels), #remove bn following https://arxiv.org/abs/2305.02310 | |
| nn.ReLU(), | |
| ) | |
| def forward(self, x): | |
| size = x.shape[-2:] | |
| for mod in self: | |
| x = mod(x) | |
| return F.interpolate(x, size=size, mode="bilinear", align_corners=False) | |
| class ASPP(nn.Module): | |
| def __init__(self, in_channels, out_channels, atrous_rates, separable=False): | |
| super(ASPP, self).__init__() | |
| modules = [] | |
| modules.append( | |
| nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, 1, bias=False), | |
| # nn.BatchNorm2d(out_channels), #remove bn following https://arxiv.org/abs/2305.02310 | |
| nn.ReLU(), | |
| ) | |
| ) | |
| rate1, rate2, rate3 = tuple(atrous_rates) | |
| ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv | |
| modules.append(ASPPConvModule(in_channels, out_channels, rate1)) | |
| modules.append(ASPPConvModule(in_channels, out_channels, rate2)) | |
| modules.append(ASPPConvModule(in_channels, out_channels, rate3)) | |
| modules.append(ASPPPooling(in_channels, out_channels)) | |
| self.convs = nn.ModuleList(modules) | |
| self.project = nn.Sequential( | |
| nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False), | |
| # nn.BatchNorm2d(out_channels), #remove bn following https://arxiv.org/abs/2305.02310 | |
| nn.ReLU(), | |
| # nn.Dropout(0.5), #remove dropout | |
| ) | |
| def forward(self, x): | |
| res = [] | |
| for conv in self.convs: | |
| res.append(conv(x)) | |
| res = torch.cat(res, dim=1) | |
| return self.project(res) | |
| class SeparableConv2d(nn.Sequential): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| bias=True, | |
| ): | |
| dephtwise_conv = nn.Conv2d( | |
| in_channels, | |
| in_channels, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=in_channels, | |
| bias=False, | |
| ) | |
| pointwise_conv = nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=1, | |
| bias=bias, | |
| ) | |
| super().__init__(dephtwise_conv, pointwise_conv) |