conv_MPN

conv_MPN论文代码阅读

论文地址

config.py

import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device2 = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

input_nbr = 3
lr = 0.0005
patience = 50
start_epoch = 0
epochs = 120
print_freq = 20
interval_training = 8
save_folder = 'conv_mpn_loop_3_pretrain_2'
model_loop_time = 3
edge_feature_channels = 32
conv_mpn = True
gnn = False
pretrain = False
per_edge_classifier = not gnn and not conv_mpn
batch_size = 1 if not per_edge_classifier else 32
new_label = True

dataset.py

import numpy as np
import random
from torch.utils.data import Dataset
import os
import skimage
import cv2
from config import *


mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
class Graphdataset(Dataset):
    def __init__(self, datapath,
                 detcornerpath, phase='train',
                 full_connected_init=True, mix_gt=False):
        super(Graphdataset, self).__init__()
        self.datapath = datapath
        self.detcornerpath = detcornerpath
        self.phase = phase
        self.mix_gt = mix_gt
        self.full_connected_init = full_connected_init
        self.demo = demo
        if phase == 'train':
            datalistfile = os.path.join(datapath, 'train_list.txt')
        else:
            datalistfile = os.path.join(datapath, 'valid_list.txt')
        with open(datalistfile, 'r') as f:
            self._data_names = f.readlines()

    def __len__(self):
        return len(self._data_names)

    def get_annot(self, data_name):
        annot = np.load(os.path.join(self.datapath, 'annot', data_name+'.npy'),
                        allow_pickle=True, encoding='latin1').tolist()
        return annot
    
    def getbyname(self, name):
        for i in range(len(self._data_names)):
            if self._data_names[i][:-1] == name:
                return self.__getitem__(i)

    def __getitem__(self, idx):
        data_name = self._data_names[idx][:-1]
        rgb = skimage.img_as_float(cv2.imread(os.path.join(self.datapath, 'rgb', data_name+'.jpg')))
        annot = np.load(os.path.join(self.datapath, 'annot', data_name+'.npy'),
                        allow_pickle=True, encoding='latin1').tolist()
        corners = np.array(np.load(
            os.path.join(self.detcornerpath, data_name + '.npy'), allow_pickle=True))  # [N, 2]
        

        # full_connect_init
        if self.mix_gt:
            corners = np.array(list(annot.keys()))[:, [1,0]]
            corners += np.random.normal(0, 0, size=corners.shape)
        cornerList = []
        for corner_i in range(corners.shape[0]):
            cornerList.append((corners[corner_i][1], corners[corner_i][0]))
        edge_masks = []
        edges = []
        for edge_i in range(corners.shape[0]):
            for edge_j in range(edge_i + 1, corners.shape[0]):
                edge_mask = np.zeros((256, 256)).astype(np.double)
                loc1 = np.array(cornerList[edge_i]).astype(np.int)
                loc2 = np.array(cornerList[edge_j]).astype(np.int)
                cv2.line(edge_mask, (loc1[0], loc1[1]), (loc2[0], loc2[1]), 1.0, 3)
                edge_masks.append(edge_mask)

                edges.append([edge_i, edge_j])
                edges.append([edge_j, edge_i])


        edges = np.array(edges).T  # [2, N * N - 1]

        gt_edge = self.get_gt(corners, annot)  # shape: [2, gt_edge_num]

        raw_data = {
            'name': data_name,
            'rgb': rgb,
            'edges': edges,
            'edge_feature': edge_masks,
            'corners': corners,
            'corners_feature': None,
            'gt': gt_edge,
            'annot': annot
        }
        return self.get_data(raw_data)

    def get_gt(self, preds, annot):
        """
        :param preds: preds(x,y) == annot(y,x)
        :param annot:
        :return:
        """
        gt_edges = set()
        gt_corners = list(annot.keys())
        if self.mix_gt:
            for corner_i in range(len(gt_corners)):
                for corner_neighbor in annot[gt_corners[corner_i]]:
                    for corner_j in range(len(gt_corners)):
                        if (gt_corners[corner_j][0] - corner_neighbor[0]) ** 2 + \
                            (gt_corners[corner_j][1] - corner_neighbor[1]) ** 2 < 1:
                            gt_edges.add((corner_i, corner_j))
                            gt_edges.add((corner_j, corner_i))
                            break
            return list(gt_edges)
        gt_map = {}
        match_id_set = set()
        for gt_corner_ in gt_corners:
            dist = 7
            match_idx = -1
            for pred_i in range(preds.shape[0]):
                if pred_i in match_id_set:
                    continue
                pred = preds[pred_i]
                temp_dist = np.sqrt((pred[0] - gt_corner_[1]) ** 2 + (pred[1] - gt_corner_[0]) ** 2)
                if temp_dist < dist:
                    dist = temp_dist
                    match_idx = pred_i
            match_id_set.add(match_idx)
            gt_map[gt_corner_] = match_idx
        if new_label:
            for gt_corner_ in gt_corners:
                dist = 15
                match_idx = -1
                if gt_map[gt_corner_] == -1:
                    for pred_i in range(preds.shape[0]):
                        pred = preds[pred_i]
                        temp_dist = np.sqrt((pred[0] - gt_corner_[1]) ** 2 + (pred[1] - gt_corner_[0]) ** 2)
                        if temp_dist < dist:
                            dist = temp_dist
                            match_idx = pred_i
                    gt_map[gt_corner_] = match_idx

            for gt_corner_ in gt_corners:
                if gt_map[gt_corner_] == -1:
                    continue
                for neighbor in annot[gt_corner_]:
                    if gt_map[tuple(neighbor)] == -1:
                        target_dir = (neighbor - np.array(gt_corner_)) / np.sqrt(np.sum((neighbor - np.array(gt_corner_)) ** 2))
                        # get neighbor's neighbor with same direction
                        cos_value = 0.5
                        neighbor_good = None
                        for neighbor_v2 in annot[tuple(neighbor)]:
                            if gt_map[tuple(neighbor_v2)] == -1:
                                continue
                            test_dir = (neighbor_v2 - neighbor) / np.sqrt(np.sum((neighbor_v2 - neighbor) ** 2))
                            if np.sum(test_dir * target_dir) > cos_value:
                                cos_value = np.sum(test_dir * target_dir)
                                neighbor_good = neighbor_v2
                        if neighbor_good is not None:
                            gt_edges.add((gt_map[gt_corner_], gt_map[tuple(neighbor_good)]))
                            gt_edges.add((gt_map[tuple(neighbor_good)], gt_map[gt_corner_]))
                        #else:
                        #    # we only looke twice
                        #    cos_value = 0.7
                        #    for neighbor_v2 in annot[tuple(neighbor)]:
                        #        for neighbor_v3 in annot[tuple(neighbor_v2)]:
                        #            if gt_map[tuple(neighbor_v3)] == -1:
                        #                continue
                        #            test_dir = (neighbor_v3 - neighbor) / np.sqrt(np.sum((neighbor_v3 - neighbor) ** 2))
                        #            if np.sum(test_dir * target_dir) > cos_value:
                        #                cos_value = np.sum(test_dir * target_dir)
                        #                neighbor_good = neighbor_v3
                        #    if neighbor_good is not None:
                        #        gt_edges.add((gt_map[gt_corner_], gt_map[tuple(neighbor_good)]))
                        #        gt_edges.add((gt_map[tuple(neighbor_good)], gt_map[gt_corner_]))

                    elif gt_map[tuple(neighbor)] == gt_map[gt_corner_]:
                        continue
                    else:
                        gt_edges.add((gt_map[gt_corner_], gt_map[tuple(neighbor)]))
                        gt_edges.add((gt_map[tuple(neighbor)], gt_map[gt_corner_]))
            return list(gt_edges)

        for gt_corner_ in gt_corners:
            if gt_map[gt_corner_] == -1:
                continue
            for neighbor in annot[gt_corner_]:
                if gt_map[tuple(neighbor)] == -1:
                    continue
                if gt_map[tuple(neighbor)] == gt_map[gt_corner_]:
                    continue
                gt_edges.add((gt_map[gt_corner_], gt_map[tuple(neighbor)]))
                gt_edges.add((gt_map[tuple(neighbor)], gt_map[gt_corner_]))
        return list(gt_edges)

    def get_data(self, data):
        img = data['rgb']
        corners = data['corners']
        edge_masks = data['edge_feature']
        gt = data['gt']
        annot = data['annot']

        edges = []
        for edge_i in range(corners.shape[0]):
            for edge_j in range(edge_i + 1, corners.shape[0]):
                edges.append((edge_i, edge_j))

        edge_index = []
        for i in range(len(edges)):
            for j in range(i + 1, len(edges)):
                if edges[j][0] == edges[i][0] or edges[j][0] == edges[i][1] or \
                                edges[j][1] == edges[i][0] or edges[j][1] == edges[i][1]:
                    edge_index.append((i, j))
                    edge_index.append((j, i))
        edge_index = np.array(edge_index).T

        y = []
        for corner_i in range(corners.shape[0]):
            for corner_j in range(corner_i + 1, corners.shape[0]):
                if (corner_i, corner_j) in gt or (corner_j, corner_i) in gt:
                    y.append(1)
                else:
                    y.append(0)
        y = torch.Tensor(y).long()

        # process feature map for corners
        x = torch.Tensor(edge_masks).double()

        edge_index = torch.Tensor(edge_index).long()
        img = img.transpose((2,0,1))
        img = (img - np.array(mean)[:, np.newaxis, np.newaxis]) / np.array(std)[:, np.newaxis, np.newaxis]

        if self.per_edge_classifier:
            choice_id = random.randint(0, y.shape[0] - 1)
            return {
                "x": x[choice_id],
                "y": y[choice_id],
                "img": img,
                "pos": corners,
                "annot": annot,
                "name": data['name']
            }

        return {
            "x": x,
            "edge_index": edge_index,
            "y": y,
            "img": img,
            "pos": corners,
            "annot": annot,
            "name": data['name']
        }

    def get_neighbor(self, corner_idx, edge_index):
        neighbor_ids = set()
        for j in range(edge_index.shape[1]):
            if corner_idx == edge_index[0, j]:
                neighbor_ids.add(edge_index[1, j])
            if corner_idx == edge_index[1, j]:
                neighbor_ids.add(edge_index[0, j])

        return list(neighbor_ids)

model.py 实现gnn conv_mpn

import torch.nn as nn
from config import *
from unet import UNet
from torch.nn.parameter import Parameter
import math


class graphNetwork(nn.Module):
    def __init__(self, times, backbone, edge_feature_map_channel=32,
                 conv_mpn=False, gnn=False):
        super(graphNetwork, self).__init__()
        #super() 函数是用于调用父类(超类)
        self.edge_feature_channel = edge_feature_map_channel
        self.rgb_net = nn.Sequential(
            backbone,
            nn.Conv2d(2 * self.edge_feature_channel, self.edge_feature_channel, kernel_size=3, stride=1, padding=1)
        )
        #nn.Sequential()用来封装模块,输出会按照顺序经过这些模块
        self.gnn = gnn
        self.times = times
        self.conv_mpn = conv_mpn
        # gnn baseline
        self.vector_size = 16 * self.edge_feature_channel
        if gnn:
            vector_size = self.vector_size
            self.loop_net = nn.ModuleList([nn.Sequential(
                nn.Conv2d(2 * vector_size, 2 * vector_size, kernel_size=1, stride=1),
                # Applies a 2D convolution over an input signal composed of several input planes.
                #torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)
                nn.BatchNorm2d(2 * vector_size),
                nn.ReLU(inplace=True),
                nn.Conv2d(2 * vector_size, 2 * vector_size, kernel_size=1, stride=1),
                nn.BatchNorm2d(2 * vector_size),
                #Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) 
                #torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
                nn.ReLU(inplace=True),
                nn.Conv2d(2 * vector_size, 2 * vector_size, kernel_size=1, stride=1),
                nn.BatchNorm2d(2 * vector_size),
                nn.ReLU(inplace=True),
                nn.Conv2d(2 * vector_size, vector_size, kernel_size=1, stride=1),
                nn.BatchNorm2d(vector_size),
                nn.ReLU(inplace=True),
                nn.Conv2d(vector_size, vector_size, kernel_size=1, stride=1),
                nn.BatchNorm2d(vector_size),
                nn.ReLU(inplace=True),
                nn.Conv2d(vector_size, vector_size, kernel_size=1, stride=1),
                nn.BatchNorm2d(vector_size),
                nn.ReLU(inplace=True)
            ) for _ in range(self.times)])

        if conv_mpn:
            self.loop_net = nn.ModuleList([
                conv_mpn_model(2 * self.edge_feature_channel,
                               self.edge_feature_channel)
                for _ in range(self.times)])

        self.edge_pred_layer = nn.Sequential(
            nn.Conv2d(self.edge_feature_channel, self.edge_feature_channel, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.edge_feature_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.edge_feature_channel, 2 * self.edge_feature_channel, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(2 * self.edge_feature_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(2 * self.edge_feature_channel, 2 * self.edge_feature_channel, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(2 * self.edge_feature_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(2 * self.edge_feature_channel, 4 * self.edge_feature_channel, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(4 * self.edge_feature_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(4 * self.edge_feature_channel, 4 * self.edge_feature_channel, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(4 * self.edge_feature_channel),
            nn.ReLU(inplace=True)
        )
        self.maxpool = nn.AdaptiveAvgPool2d((2,2))
        self.fc = nn.Linear(self.vector_size, 2)
        
        for m in self.modules():
            if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                m.track_running_stats=False
        # isinstance() 函数来判断一个对象是否是一个已知的类型,类似 type()
        # track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性
        # 相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了
        # 当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。        


    def change_device(self):
        #将计算推送到设备
        self.rgb_net.to(device)
        self.loop_net.to(device2)
        self.edge_pred_layer.to(device2)
        self.fc.to(device)

    def forward(self, img, edge_masks, edge_index=None):
        if self.training is False:
            tt = math.ceil(edge_masks.shape[0] / 105)
            #ceil() 函数返回数字的上入整数
            edge_feature_init = torch.zeros((edge_masks.shape[0], self.edge_feature_channel, 64, 64)).double().to(device)
            for time in range(tt):
                if time == tt - 1:
                    edge_sub_masks = edge_masks[time * 105:, :, :]
                else:
                    edge_sub_masks = edge_masks[time * 105:(time+1) * 105, :, :]
                img_expand = img.expand(edge_sub_masks.shape[0], -1, -1, -1)
                feature_in = torch.cat((img_expand, edge_sub_masks.unsqueeze(1)), 1)
                if time == tt - 1:
                    edge_feature_init[time * 105:] = self.rgb_net(feature_in)
                else:
                    edge_feature_init[time*105:(time+1)*105] = self.rgb_net(feature_in)
                del feature_in
        else:
            img = img.expand(edge_masks.shape[0], -1, -1, -1)
            feature_in = torch.cat((img, edge_masks.unsqueeze(1)), 1)
            edge_feature_init = self.rgb_net(feature_in)
        edge_feature = edge_feature_init
        if device != device2:
            edge_feature = edge_feature.to(device2)
        if self.conv_mpn:
            for t in range(self.times):
                feature_neighbor = torch.zeros_like(edge_feature)
                #torch.zeros_like() 输出为形状和edge_feature一致的矩阵,其元素全部为0
                for edge_iter in range(edge_masks.shape[0]):
                    feature_temp = edge_feature[edge_index[1, torch.where(edge_index[0,:] == edge_iter)[0]]]
                    feature_neighbor[edge_iter] = torch.max(feature_temp, 0)[0]
                edge_feature = torch.cat((edge_feature, feature_neighbor), 1)
                #torch.cat(input,output_dim)张量拼接
                edge_feature = self.loop_net[t](edge_feature)
        if self.training is False:
            tt = math.ceil(edge_masks.shape[0] / 105)
            edge_pred = torch.zeros((edge_masks.shape[0], 4*self.edge_feature_channel, 64, 64)).double().to(device)
            for time in range(tt):
                if time == tt - 1:
                    edge_sub_feature = edge_feature[time * 105:, :, :]
                else:
                    edge_sub_feature = edge_feature[time * 105:(time+1) * 105, :, :]
                if time == tt - 1:
                    edge_pred[time * 105:] = self.edge_pred_layer(edge_sub_feature)
                else:
                    edge_pred[time*105:(time+1)*105] = self.edge_pred_layer(edge_sub_feature)
                del edge_sub_feature
        else:
            edge_pred = self.edge_pred_layer(edge_feature)
        edge_pred = self.maxpool(edge_pred)
        edge_pred = edge_pred.view((edge_masks.shape[0], self.vector_size, 1, 1))
        if self.gnn:
            for t in range(self.times):
                feature_neighbor = torch.zeros_like(edge_pred)
                for edge_iter in range(edge_masks.shape[0]):
                    feature_temp = edge_pred[edge_index[1, torch.where(edge_index[0,:] == edge_iter)[0]]]
                    feature_neighbor[edge_iter] = torch.max(feature_temp, 0)[0]
                edge_pred = torch.cat((edge_pred, feature_neighbor), 1)
                edge_pred = self.loop_net[t](edge_pred)
        edge_pred = torch.flatten(edge_pred, 1)
        if device != device2:
            edge_pred = edge_pred.to(device)
        fc = self.fc(edge_pred)
        return fc


class conv_mpn_model(nn.Module):
    def __init__(self, inchannels, out_channels):
        super(conv_mpn_model, self).__init__()
        assert inchannels >= out_channels
        self.out_channels = out_channels
        self.seq = nn.Sequential(
            nn.Conv2d(inchannels, inchannels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(inchannels, track_running_stats=True),
            nn.Conv2d(inchannels, inchannels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(inchannels, track_running_stats=True),
            nn.Conv2d(inchannels, inchannels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(inchannels, track_running_stats=True),
            nn.Conv2d(inchannels, inchannels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(inchannels, track_running_stats=True),
            nn.Conv2d(inchannels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels, track_running_stats=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels, track_running_stats=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels, track_running_stats=True)
        )

    def forward(self, x):
        return self.seq(x)

DRN.py

# 有意思的条件判断
self.layer6 = None if layers[5] == 0 else \
self._make_layer(block, channels[5], layers[5], dilation=4,
                 new_level=False)
import pdb

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
#Loads the Torch serialized object at the given URL.
#从给定网站读取网络结构
import torch

BatchNorm = nn.BatchNorm2d


# __all__ = ['DRN', 'drn26', 'drn42', 'drn58']


webroot = 'http://dl.yf.io/drn/'

model_urls = {
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'drn-c-26': webroot + 'drn_c_26-ddedf421.pth',
    'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth',
    'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth',
    'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth',
    'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth',
    'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth',
    'drn-d-105': webroot + 'drn_d_105-12b40979.pth'
}


def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=padding, bias=False, dilation=dilation)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 dilation=(1, 1), residual=True):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride,
                             padding=dilation[0], dilation=dilation[0])
        self.bn1 = BatchNorm(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes,
                             padding=dilation[1], dilation=dilation[1])
        self.bn2 = BatchNorm(planes)
        self.downsample = downsample
        self.stride = stride
        self.residual = residual

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)
        if self.residual:
            out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 dilation=(1, 1), residual=True):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=dilation[1], bias=False,
                               dilation=dilation[1])
        self.bn2 = BatchNorm(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = BatchNorm(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class DRN(nn.Module):

    def __init__(self, block, layers, num_classes=1000,
                 channels=(16, 32, 64, 128, 256, 512, 512, 512),
                 out_map=False, out_middle=False, pool_size=28, arch='D',
                 image_channels=3):
        super(DRN, self).__init__()
        self.inplanes = channels[0]
        self.out_map = out_map
        self.out_dim = channels[-1]
        self.out_middle = out_middle
        self.arch = arch

        if arch == 'C':
            self.conv1 = nn.Conv2d(image_channels, channels[0], kernel_size=7, stride=1,
                                   padding=3, bias=False)
            self.bn1 = BatchNorm(channels[0])
            self.relu = nn.ReLU(inplace=True)

            self.layer1 = self._make_layer(
                BasicBlock, channels[0], layers[0], stride=1)
            self.layer2 = self._make_layer(
                BasicBlock, channels[1], layers[1], stride=2)
        elif arch == 'D':
            self.layer0 = nn.Sequential(
                nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3,
                          bias=False),
                BatchNorm(channels[0]),
                nn.ReLU(inplace=True)
            )

            self.layer1 = self._make_conv_layers(
                channels[0], layers[0], stride=1)
            self.layer2 = self._make_conv_layers(
                channels[1], layers[1], stride=2)

        self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2)
        self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2)
        self.layer5 = self._make_layer(block, channels[4], layers[4],
                                       dilation=2, new_level=False)
        self.layer6 = None if layers[5] == 0 else \
            self._make_layer(block, channels[5], layers[5], dilation=4,
                             new_level=False)
        # 这个条件判断真有意思,学到了

        if arch == 'C':
            self.layer7 = None if layers[6] == 0 else \
                self._make_layer(BasicBlock, channels[6], layers[6], dilation=2,
                                 new_level=False, residual=False)
            self.layer8 = None if layers[7] == 0 else \
                self._make_layer(BasicBlock, channels[7], layers[7], dilation=1,
                                 new_level=False, residual=False)
        elif arch == 'D':
            self.layer7 = None if layers[6] == 0 else \
                self._make_conv_layers(channels[6], layers[6], dilation=2)
            self.layer8 = None if layers[7] == 0 else \
                self._make_conv_layers(channels[7], layers[7], dilation=1)

        if num_classes > 0:
            self.avgpool = nn.AvgPool2d(pool_size)
            self.fc = nn.Conv2d(self.out_dim, num_classes, kernel_size=1,
                                stride=1, padding=0, bias=True)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, BatchNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
                    new_level=True, residual=True):
        assert dilation == 1 or dilation % 2 == 0
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm(planes * block.expansion),
            )

        layers = list()
        layers.append(block(
            self.inplanes, planes, stride, downsample,
            dilation=(1, 1) if dilation == 1 else (
                dilation // 2 if new_level else dilation, dilation),
            residual=residual))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, residual=residual,
                                dilation=(dilation, dilation)))

        return nn.Sequential(*layers)

    def _make_conv_layers(self, channels, convs, stride=1, dilation=1):
        modules = []
        for i in range(convs):
            modules.extend([
                nn.Conv2d(self.inplanes, channels, kernel_size=3,
                          stride=stride if i == 0 else 1,
                          padding=dilation, bias=False, dilation=dilation),
                BatchNorm(channels),
                nn.ReLU(inplace=True)])
            self.inplanes = channels
        return nn.Sequential(*modules)

    def forward(self, x):
        y = list()

        if self.arch == 'C':
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
        elif self.arch == 'D':
            x = self.layer0(x)

        x = self.layer1(x)
        y.append(x)
        x = self.layer2(x)
        y.append(x)

        x = self.layer3(x)
        y.append(x)

        x = self.layer4(x)
        y.append(x)

        x = self.layer5(x)
        y.append(x)

        if self.layer6 is not None:
            x = self.layer6(x)
            y.append(x)

        if self.layer7 is not None:
            x = self.layer7(x)
            y.append(x)

        if self.layer8 is not None:
            x = self.layer8(x)
            y.append(x)

        if self.out_map:
            x = self.fc(x)
        else:
            x = self.avgpool(x)
            x = self.fc(x)
            x = x.view(x.size(0), -1)

        if self.out_middle:
            return x, y
        else:
            return x


class DRN_A(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(DRN_A, self).__init__()
        self.out_dim = 512 * block.expansion
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
                                       dilation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
                                       dilation=4)
        self.avgpool = nn.AvgPool2d(28, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, BatchNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        # for m in self.modules():
        #     if isinstance(m, nn.Conv2d):
        #         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        #     elif isinstance(m, nn.BatchNorm2d):
        #         nn.init.constant_(m.weight, 1)
        #         nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes,
                                dilation=(dilation, dilation)))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def drn_a_50(pretrained=False, **kwargs):
    model = DRN_A(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model


def drn_c_26(pretrained=False, **kwargs):
    model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', **kwargs)
    if pretrained:
        pretrained_dict = model_zoo.load_url(model_urls['drn-c-26'], model_dir='/local-scratch/fuyang/.torch/checkpoints')
        first_layer_name, first_layer_weight = next(iter(pretrained_dict.items()))
        mean_weight = torch.mean(first_layer_weight, dim=1, keepdim=True)
        new_first_layer_weight = torch.cat([first_layer_weight, mean_weight], dim=1)
        pretrained_dict.update({first_layer_name: new_first_layer_weight})
        model.load_state_dict(pretrained_dict)
    return model


def drn_c_42(pretrained=False, **kwargs):
    model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['drn-c-42']))
    return model


def drn_c_58(pretrained=False, **kwargs):
    model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['drn-c-58']))
    return model


def drn_d_22(pretrained=False, **kwargs):
    model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['drn-d-22']))
    return model


def drn_d_24(pretrained=False, **kwargs):
    model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['drn-d-24']))
    return model


def drn_d_38(pretrained=False, **kwargs):
    model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['drn-d-38']))
    return model


def drn_d_40(pretrained=False, **kwargs):
    model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['drn-d-40']))
    return model


def drn_d_54(pretrained=False, **kwargs):
    model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['drn-d-54']))
    return model


def drn_d_56(pretrained=False, **kwargs):
    model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['drn-d-56']))
    return model


def drn_d_105(pretrained=False, **kwargs):
    model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['drn-d-105']))
    return model


def drn_d_107(pretrained=False, **kwargs):
    model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 2, 2], arch='D', **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['drn-d-107']))
    return model

unet.py

import torch
import torch.nn as nn


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )


class UNet(nn.Module):
    def __init__(self, inchannels, outchannels):
        super(UNet, self).__init__()
        assert inchannels >= outchannels
        self.outchannels = outchannels
        self.dconv_down0 = double_conv(inchannels, outchannels)
        self.dconv_down1 = double_conv(outchannels, 32)
        self.dconv_down2 = double_conv(32, 64)
        self.dconv_down3 = double_conv(64, 128)
        self.dconv_down4 = double_conv(128, 128)

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.dconv_up3 = double_conv(128 + 128, 128)
        self.dconv_up2 = double_conv(128 + 64, 64)
        self.dconv_up1 = double_conv(64 + 32, 32)

        self.conv_last = nn.Conv2d(32, outchannels, 1)

    def forward(self, x):
        identity = x[:, :self.outchannels, :, :]
        x = self.dconv_down0(x)

        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)

        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)

        x = self.dconv_down4(x)

        x = self.upsample(x)
        x = torch.cat([x, conv3], dim=1)

        x = self.dconv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)

        x = self.dconv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)

        x = self.dconv_up1(x)

        out = self.conv_last(x)
        out += identity

        return out

utils.py

import os

from config import *


def ensure_folder(folder):
    if not os.path.exists(folder):
        os.makedirs(folder)


def adjust_learning_rate(optimizer, shrink_factor):
    print("\nDECAYING learning rate.")
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * shrink_factor
    print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))


class ExpoAverageMeter(object):
    # Exponential Weighted Average Meter
    def __init__(self, beta=0.9):
        self.reset()

    def reset(self):
        self.beta = 0.9
        self.val = 0
        self.avg = 0
        self.count = 0

    def update(self, val):
        self.val = val
        self.avg = self.beta * self.avg + (1 - self.beta) * self.val


def save_checkpoint(epoch, model, optimizer, val_loss, is_best):
    ensure_folder(save_folder)
    state = {'model': model,
             'optimizer': optimizer}
    filename = '{0}/checkpoint_{1}_{2:.3f}.tar'.format(save_folder, epoch, val_loss)
    torch.save(state, filename)
    # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint
    if is_best:
        torch.save(state, '{}/BEST_checkpoint.tar'.format(save_folder))

train.py

import time
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
import random
from dataset import Graphdataset
from model import graphNetwork
from utils import *
import logging
from drn import drn_c_26


def train(epoch, train_loader, model, optimizer, criterion):
    model.train()

    batch_time = ExpoAverageMeter()  # forward prop. + back prop. time
    losses = ExpoAverageMeter()  # loss (per word decoded)
    start = time.time()
    #DATA_NUM = len(train_loader)
    #shuffle_sort = list(range(DATA_NUM))
    #random.shuffle(shuffle_sort)
    model.zero_grad()
    #每个batch调用一遍zero_grad()将参数梯度置0
    for batch_i, data in enumerate(train_loader):
        # Set device options
        img = data['img'].to(device)
        # Zero gradients
        if not per_edge_classifier:
            edge_index = data['edge_index'][0].to(device)
            if (data['y'][0].shape[0] > 105):
                continue
            label = data['y'][0].to(device)
            edge_masks = data['x'][0].to(device)
            y_hat = model(img, edge_masks, edge_index)
        else:
            edge_masks = data['x'].to(device)
            y_hat = model(img, edge_masks, None)
            label = data['y'].to(device)
        loss = criterion(y_hat, label)
        if not per_edge_classifier:
            loss = loss / interval_training
        loss.backward()

        if (batch_i + 1) % interval_training == 0 or per_edge_classifier:
            optimizer.step()
            model.zero_grad()

        del img
        if not per_edge_classifier:
            del edge_index
        del label
        del edge_masks
        # Keep track of metrics
        if not per_edge_classifier:
            losses.update(loss.item() * interval_training)
        else:
            losses.update(loss.item())

        batch_time.update(time.time() - start)

        start = time.time()

        # Print status
        if batch_i % print_freq == 0:
            logging.info('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, batch_i, len(train_loader),
                                                                  batch_time=batch_time,
                                                                  loss=losses))


def valid(val_loader, model, criterion):
    model.eval()  # eval mode (no dropout or batchnorm)

    batch_time = ExpoAverageMeter()  # forward prop. + back prop. time
    losses = ExpoAverageMeter()  # loss (per word decoded)
    start = time.time()

    with torch.no_grad():
        # Batches
        for i_batch, data in enumerate(val_loader):
            img = data['img'].to(device)
            if not per_edge_classifier:
                if data['y'][0].shape[0] > 105:
                    continue
                edge_index = data['edge_index'][0].to(device)
                label = data['y'][0].to(device)
                edge_masks = data['x'][0].to(device)
                y_hat = model(img, edge_masks, edge_index)
            else:
                edge_masks = data['x'].to(device)
                y_hat = model(img, edge_masks, None)
                label = data['y'].to(device)
            loss = criterion(y_hat, label)

            losses.update(loss.item())
            batch_time.update(time.time() - start)

            start = time.time()

            # Print status
            if i_batch % print_freq == 0:
                logging.info('Validation: [{0}/{1}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(i_batch, len(val_loader),
                                                                      batch_time=batch_time,
                                                                      loss=losses))

    return losses.avg


def main():
    DATAPATH='/local-scratch/fza49/cities_dataset'
    DETCORNERPATH='/local-scratch/fza49/nnauata/building_reconstruction/geometry-primitive-detector/det_final'

    train_dataset = Graphdataset(DATAPATH, DETCORNERPATH, phase='train', mix_gt=True, per_edge=per_edge_classifier)
    train_dataset_2 = Graphdataset(DATAPATH, DETCORNERPATH, phase='train', mix_gt=False, per_edge=per_edge_classifier)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
    train_dataloader_2 = DataLoader(train_dataset_2, batch_size=batch_size, shuffle=True, num_workers=8)
    test_dataset = Graphdataset(DATAPATH, DETCORNERPATH, phase='test', per_edge=per_edge_classifier)
    test_datloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8)
    
    #backbone
    drn = drn_c_26(pretrained=True, image_channels=4)
    #取drn_c_26前三块作为backbone
    drn = nn.Sequential(*list(drn.children())[:-7])
    model = graphNetwork(model_loop_time, drn, edge_feature_map_channel=edge_feature_channels,
                         gnn=gnn, conv_mpn=conv_mpn)

    model.double()
    model = model.to(device)
    model.change_device()
    if pretrain:
        chechpoint_name = 'checkpoint_25_0.602'
        checkpoint = '{}/{}.tar'.format(save_folder, chechpoint_name)
        checkpoint = 'conv_mpn_loop_1/checkpoint_16_2.025.tar'
        print(checkpoint)
        checkpoint = torch.load(checkpoint, map_location=device)
        param = checkpoint['model'].state_dict()
        model.load_state_dict(param, strict=False)


    logging.info(model)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_loss = 100000
    epochs_since_improvement = 0
    criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.33, 1.0]).double().to(device))
    # Epochs
    for epoch in range(start_epoch, epochs):
        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 4 == 0:
            adjust_learning_rate(optimizer, 0.8)

        # One epoch's training
        if epoch % 3 != 0:
            train(epoch, train_dataloader, model, optimizer, criterion)
        else:
            train(epoch, train_dataloader_2, model, optimizer, criterion)

        # One epoch's validation
        val_loss = valid(test_datloader, model, criterion)
        logging.info('\n * LOSS - {loss:.3f}\n'.format(loss=val_loss))

        # Check if there was an improvement
        is_best = val_loss < best_loss
        best_loss = min(best_loss, val_loss)

        if not is_best:
            epochs_since_improvement += 1
            logging.info("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, model, optimizer, val_loss, is_best)


if __name__ == '__main__':
    logging.getLogger().setLevel(logging.INFO)
    main()
posted @ 2021-11-15 21:59  甫生  阅读(111)  评论(0编辑  收藏  举报