import torch import torch.nn as nn from torchvision import models import torch.nn.functional as F class UpSample(nn.Sequential): def __init__(self, skip_input, output_features): super(UpSample, self).__init__() self.convA = nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1) self.leakyreluA = nn.LeakyReLU(0.2) self.convB = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1) self.leakyreluB = nn.LeakyReLU(0.2) def forward(self, x, concat_with): up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) return self.leakyreluB( self.convB( self.convA( torch.cat([up_x, concat_with], dim=1) ) ) ) class Decoder(nn.Module): def __init__(self, num_features=1664, decoder_width = 1.0): super(Decoder, self).__init__() features = int(num_features * decoder_width) self.conv2 = nn.Conv2d(num_features, features, kernel_size=1, stride=1, padding=0) self.up1 = UpSample(skip_input=features//1 + 256, output_features=features//2) self.up2 = UpSample(skip_input=features//2 + 128, output_features=features//4) self.up3 = UpSample(skip_input=features//4 + 64, output_features=features//8) self.up4 = UpSample(skip_input=features//8 + 64, output_features=features//16) self.conv3 = nn.Conv2d(features//16, 1, kernel_size=3, stride=1, padding=1) def forward(self, features): x_block0, x_block1, x_block2, x_block3, x_block4 = features[3], features[4], features[6], features[8], features[12] x_d0 = self.conv2(F.relu(x_block4)) x_d1 = self.up1(x_d0, x_block3) x_d2 = self.up2(x_d1, x_block2) x_d3 = self.up3(x_d2, x_block1) x_d4 = self.up4(x_d3, x_block0) return self.conv3(x_d4) class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() self.original_model = models.densenet169( pretrained=False ) def forward(self, x): features = [x] for k, v in self.original_model.features._modules.items(): features.append( v(features[-1]) ) return features class PTModel(nn.Module): def __init__(self): super(PTModel, self).__init__() self.encoder = Encoder() self.decoder = Decoder() def forward(self, x): return self.decoder( self.encoder(x) )