WiderFace标注格式转PASCAL VOC2007标注格式

#coding=utf-8
import os
import cv2
from xml.dom.minidom import Document


def create_xml(boxes_dict,target_xml_dir):
    file_name = boxes_dict["filename"]
    fname = file_name.split('.')[0]
    boxes = boxes_dict["boxes"]
    doc = Document()
    annotation = doc.createElement('annotation')
    doc.appendChild(annotation)

    folder = doc.createElement('folder')
    folder.appendChild(doc.createTextNode('widerface'))
    annotation.appendChild(folder)

    filename = doc.createElement('filename')
    filename.appendChild(doc.createTextNode(file_name))
    annotation.appendChild(filename)

    source = doc.createElement('source')
    database = doc.createElement('database')
    database.appendChild(doc.createTextNode('baidu'))
    source.appendChild(database)
    annotation.appendChild(source)

    size = doc.createElement('size')
    width = doc.createElement('width')
    width.appendChild(doc.createTextNode(str(300)))
    size.appendChild(width)
    height = doc.createElement('height')
    height.appendChild(doc.createTextNode(str(300)))
    size.appendChild(height)
    depth = doc.createElement('depth')
    depth.appendChild(doc.createTextNode(str(3)))
    size.appendChild(depth)
    annotation.appendChild(size)

    segmented = doc.createElement('segmented')
    segmented.appendChild(doc.createTextNode(str(0)))
    annotation.appendChild(segmented)

    # write the coordinates of the b-box
    for b_box in boxes:
        #print b_box
        if(b_box[0]<0):
            b_box[0] = 0
        if (b_box[1] < 0):
            b_box[1] = 0

        object = doc.createElement('object')
        name = doc.createElement('name')
        name.appendChild(doc.createTextNode('face'))
        #name.appendChild(doc.createTextNode(x[0]))
        object.appendChild(name)

        difficult = doc.createElement('difficult')
        difficult.appendChild(doc.createTextNode('0'))
        object.appendChild(difficult)

        truncated = doc.createElement('truncated')
        truncated.appendChild(doc.createTextNode('0'))
        object.appendChild(truncated)

        pose = doc.createElement('pose')
        pose.appendChild(doc.createTextNode('undefined'))
        object.appendChild(pose)

        bndbox = doc.createElement('bndbox')
        xmin = doc.createElement('xmin')
        xmin.appendChild(doc.createTextNode(str(b_box[0])))
        bndbox.appendChild(xmin)
        object.appendChild(bndbox)
        ymin = doc.createElement('ymin')
        ymin.appendChild(doc.createTextNode(str(b_box[1])))
        bndbox.appendChild(ymin)
        xmax = doc.createElement('xmax')
        xmax.appendChild(doc.createTextNode(str(b_box[0]+b_box[2])))
        bndbox.appendChild(xmax)
        ymax = doc.createElement('ymax')
        ymax.appendChild(doc.createTextNode(str(b_box[1]+b_box[3])))
        bndbox.appendChild(ymax)
        annotation.appendChild(object)

    xml_name = fname+'.xml'
    target_xml_path = os.path.join(target_xml_dir,xml_name)
    with open(target_xml_path, 'wb') as f:
        f.write(doc.toprettyxml(indent='\t', encoding='utf-8'))



def draw_and_save(image_list,src_img_dir = None, tar_img_dir = None):
    name_list = os.path.join(tar_img_dir,"val.txt")
    with open(name_list,'a') as fw:
        for item in image_list:
            sub_path = item["path"]
            path_seg = sub_path.split("/")
            path = os.path.join(src_img_dir,sub_path)
            boxes = item["boxes"]
            img = cv2.imread(path)
            height,width,channel = img.shape
            box_num = 0
            target_size = 300
            boxes_dict = {}
            boxes_dict["filename"] = path_seg[1]
            new_boxes = []
            for box in boxes:
                new_box = []
                ord = box.split(" ")
                x, y, w, h = int(ord[0]),int(ord[1]),int(ord[2]),int(ord[3])
                wh = width
                if width > height:
                    wh = height
                img = img[0:wh,0:wh]
                if x+w > wh or y+h > wh :    #过滤掉超出图片范围的人脸
                    print "Face has been out of picture"
                    continue

                scale = float(target_size)/wh   #缩放比
                x_new = int(x*scale)
                y_new = int(y*scale)
                w_new = int(w*scale)
                h_new = int(h*scale)

                if w_new*h_new < 64:    # 过滤面积小于64像素平方的框,因为第一个用于检测的特征图的stride=8
                    print "Box: (width: {} height: {}) is too small".format(w_new,h_new)
                    continue

                img = cv2.resize(img,(target_size,target_size))  #缩放到300×300
                new_box.append(x_new)
                new_box.append(y_new)
                new_box.append(w_new)
                new_box.append(h_new)
                cv2.rectangle(img,(x_new,y_new),(x_new+w_new,y_new+h_new),(0,255,0), 1)
                print new_box
                box_num+=1
                new_boxes.append(new_box)
            boxes_dict["boxes"] = new_boxes
            if box_num == 0:
                continue

            img_tar_dir = os.path.join(tar_img_dir,"JPEGImages")
            if not os.path.exists(img_tar_dir):
                os.mkdir(img_tar_dir)
            tar_path = os.path.join(img_tar_dir,path_seg[1])
            cv2.imwrite(tar_path,img)


            xml_tar_dir = os.path.join(tar_img_dir, "Annotations")
            if not os.path.exists(xml_tar_dir):
                os.mkdir(xml_tar_dir)
            create_xml(boxes_dict,xml_tar_dir)
            fw.write(path_seg[1].split('.')[0]+'\n')
            fw.flush()





def parse(label_file_path, src_img_dir, tar_img_dir):
    fr = open(label_file_path,'r')
    image_list = []
    line = fr.readline().rstrip()
    while line:
        mdict = {}
        path = line
        mdict["path"] = path
        num = fr.readline().rstrip()
        boxes_list = []
        for n in range(int(num)):
            box = fr.readline().rstrip()
            boxes_list.append(box)
        mdict["boxes"]=boxes_list
        image_list.append(mdict)
        line = fr.readline().rstrip()
    draw_and_save(image_list,src_img_dir,tar_img_dir)


if __name__=="__main__":
    file_path = "/projects/DSOD/wider_face/datasets/wider_face_split/wider_face_val_bbx_gt.txt"
    source_img_dir = "/projects/DSOD/wider_face/datasets/val/images"
    target_img_dir = "/projects/DSOD/wider_face/datasets/drew"
    if not os.path.exists(target_img_dir):
        os.mkdir(target_img_dir)
    parse(file_path,source_img_dir,target_img_dir)

 

posted @ 2018-10-25 11:20  HOU_JUN  阅读(1439)  评论(0编辑  收藏  举报