数据集train和test分类脚本,以nyu数据集为例

# coding: UTF-8
# Change the input of function classify to change the scale of train set and test set
# Use this script in $project/tools
import os


def find_last(string, str):
    last_position = -1
    while True:
        position=string.find(str, last_position+1)
        if position == -1:
            return last_position
        last_position = position

filePath = '../data/images/'
pathDir = os.listdir(filePath)
pathDir.sort()
# print pathDir
sceneSum = len(pathDir)

sceneList = []
for allDir in pathDir:
    index = find_last(allDir, '_')
    scene = allDir[0:index]
    if scene not in sceneList:
        sceneList.append(scene)
sceneList.sort()
sceneNum = len(sceneList)
print("NYU data sets have %d scenes, they are: " % sceneNum)
print(sceneList)


def classify(train_test=0.6):
    eachScene = []
    for i in range(0, sceneNum):
        temp = 0
        sceneIndex = sceneList[i]
        for j in range(0, sceneSum):
            if sceneIndex in pathDir[j]:
                completeDir = filePath + pathDir[j]
                temp = temp + len(os.listdir(completeDir))
        eachScene.append(temp)
    print ('Each scenes has images:')
    print eachScene, 'Total in', sum(eachScene), 'images'

    txtTrain = open('../data/train.txt', 'w')
    txtTest = open('../data/test.txt', 'w')
    trainNum = 0
    testNum = 0
    for i in range(0, sceneNum):
        classifyTrain = int(train_test*eachScene[i])
        temp = 0
        sceneIndex = sceneList[i]
        for j in range(0, sceneSum):
            if sceneIndex in pathDir[j]:
                completeDir = filePath + pathDir[j]
                eachSceneSum = len(os.listdir(completeDir))
                for k in range(0, eachSceneSum):
                    if temp < classifyTrain:
                        # print pathDir[j] + '/' + os.listdir(completeDir)[k]
                        writeLine = pathDir[j] + '/' + os.listdir(completeDir)[k] + '\n'
                        txtTrain.write(writeLine)
                        temp = temp + 1
                        trainNum = trainNum + 1
                    else:
                        writeLine = pathDir[j] + '/' + os.listdir(completeDir)[k] + '\n'
                        txtTest.write(writeLine)
                        temp = temp + 1
                        testNum = testNum + 1
    txtTrain.close()
    txtTest.close()
    print 'The sum images of train set is', trainNum
    print 'The sum images of test set is', testNum

classify(0.6)

本文为原创,转载需注明!

posted on 2017-09-24 16:06  萝卜丶爱  阅读(386)  评论(0编辑  收藏  举报

导航