尝试解决cifar10问题
我理解这个问题和猫狗的不同,在于将2类扩展为10类,其它的地方我准备采用相同的方法。
注意事项:
1、我要用kaggle的数据集,而不是用其它的数据集;
2、最终得到的结果要以test为导向;
1、先打开jupyter,并且把数据集传到dl_machine上去。想办法读入数据
通过观察kaggle,可以发现pd的使用非常高,很大程度上是因为它对csv文件的支持非常好吧。
df
=pd.read_csv(
'trainLabels.csv',header
=
0,sep
=
',')
#filename可以直接从盘符开始,标明每一级的文件夹直到csv文件,header=0表示头部为空第一行为标题
#sep=','表示数据间分隔符是逗号
print df.head()
print df.tail()
2、能否将图片数据读入内存?
基本的思路就是遍历图片,然后根据名称去找类别。
这其实是经常会遇到的问题。
TRAIN_DIR
=
'./train/'
TEST_DIR
=
'./test/'
tmp
= df[(df.label
==
"airplane ")]
train_airplane
= [TRAIN_DIR
+
str(i)
+
'.png'
for i
in a.
id]
print(
"train_airplane",
len(train_airplane))
tmp
= df[(df.label
==
"automobile ")]
train_automobile
= [TRAIN_DIR
+
str(i)
+
'.png'
for i
in a.
id]
print(
"train_automobile",
len(train_automobile))
tmp
= df[(df.label
==
"bird ")]
train_bird
= [TRAIN_DIR
+
str(i)
+
'.png'
for i
in a.
id]
print(
"train_bird",
len(train_bird))
tmp
= df[(df.label
==
"cat")]
train_cat
= [TRAIN_DIR
+
str(i)
+
'.png'
for i
in a.
id]
print(
"train_cat",
len(train_cat))
tmp
= df[(df.label
==
"deer")]
train_deer
= [TRAIN_DIR
+
str(i)
+
'.png'
for i
in a.
id]
print(
"train_deer",
len(train_deer))
tmp
= df[(df.label
==
"dog")]
train_dog
= [TRAIN_DIR
+
str(i)
+
'.png'
for i
in a.
id]
print(
"train_dog",
len(train_dog))
tmp
= df[(df.label
==
"frog")]
train_frog
= [TRAIN_DIR
+
str(i)
+
'.png'
for i
in a.
id]
print(
"train_frog",
len(train_frog))
tmp
= df[(df.label
==
"horse")]
train_horse
= [TRAIN_DIR
+
str(i)
+
'.png'
for i
in a.
id]
print(
"train_horse",
len(train_horse))
tmp
= df[(df.label
==
"ship")]
train_ship
= [TRAIN_DIR
+
str(i)
+
'.png'
for i
in a.
id]
print(
"train_ship",
len(train_ship))
tmp
= df[(df.label
==
"truck")]
train_truck
= [TRAIN_DIR
+
str(i)
+
'.png'
for i
in a.
id]
print(
"train_truck",
len(train_truck))
test_images
= [TEST_DIR
+
str(i)
+
'.png'
for i
in
os.listdir(TEST_DIR)]
print(
"test_images",
len(test_images))
这个过程分为了a、获得文件名;b、读取文件。两个部分。CIFAR还只是10类的,还可以手工编码,如果是100位的,肯定就不能采用这种方法。
df
=pd.read_csv(
'trainLabels.csv',header
=
0,sep
=
',')
train_airplane
= [
str(i)
+
'.png'
for i
in df[(df.label
==
"airplane")].
id]
这种方法是正确、高效的,直接能够获得一个list,我希望的是能够直接包含这些文件的绝对地址。
简化的方法,当然是使用数组。但是现在我不适合手写,最好去参考比较成熟的代码。
3、看看,看看。我开始体会到为什么很多代码里面都有“看看”这个步骤,因为你在编写代码的时候只有这种方式才能确保你的代码编写是正确的。
def show_cifar10(idx)
:
airplane
= read_image(train_airplane[idx])
automobile
= read_image(train_automobile[idx])
bird
= read_image(train_bird[idx])
cat
= read_image(train_cat[idx])
deer
= read_image(train_deer[idx])
dog
= read_image(train_dog[idx])
frog
= read_image(train_frog[idx])
horse
= read_image(train_horse[idx])
ship
= read_image(train_ship[idx])
truck
= read_image(train_truck[idx])
pair
= np.concatenate((airplane, automobile,bird,cat,deer,dog,frog,horse,ship,truck), axis
=
1)
plt.figure(figsize
=(
10,
5))
plt.imshow(pair)
plt.show()
for idx
in
range(
0,
5)
:
show_cifar10(idx)

4、文件已经获得,是否已经可以塞到模型里面去??
如果要塞到模型中去,现有模式是采用直接解析目录文件的方式,为此广泛使用了软链接。基于之前获得的完全路径,这个地方其实是很好做的。需要注意的是塞进去之前,首先检验一下文件是否存在:
for filename
in train_truck[
:TESTNUM]
:
if(
os.path.exists(TRAIN_DIR
+filename))
:
os.symlink(TRAIN_DIR
+filename,
'./train2/truck/'
+filename);
5、训练过程中可能遇到的问题
现在看来,万事大吉:模型下载完成、数据也正确安置了(为此我一个文件夹一个文件夹地打开观察),下面调用之前在DogVSCat中正确运行的代码,训练一段时间后发现错误:
Unable to create link (Name already exists)
进一步修改代码,主要是文件的大小。因为我记得ResNet应该是有最小文件支持限制的,我改成了48*48,但是不行,resnet的限制应该在224,但是cifar10只有32,所以我将cifar10放大,并缩小数据集,然后是等待。
此外,还特别需要注意,文件初始化的时候这样来做:
也就是test2下面还要有一个test目录,作为预分类。
6、关于OS的总结
在目前的程序中,广泛地使用到了os来操作文件系统,应该说很有效果,包括:
os.listdir(TEST_DIR)
返回的显然是
/home
/helu
/cifar10
/test
/
203688.png
/home
/helu
/cifar10
/test
/
221824.png
/home
/helu
/cifar10
/test
/
289334.png
/home
/helu
/cifar10
/test
/
104194.png
/home
/helu
/cifar10
/test
/
30977.png
这种带后缀的完整目录里面文件的地址
os.path.exists(dirname):
os.listdir()
#不给参数默认输出当前路径下所有文件
os.listdir(
'/home/python')
#可以指定目录
简单的用来判断,一个目录下面的文件是否存在。
os.mkdir('test2/test')
创建一个新的目录,正如其名字一样。
os.symlink(TRAIN_DIR+filename, './train2/airplane/'+filename)
非常重要的,创建软连接。
此外
shutil.rmtree(dirname)
这个应该是删除一串文件的,并且进一步整合成这个函数,能够强制刷新文件目录。
def rmrf_mkdir(dirname)
:
if
os.path.exists(dirname)
:
shutil.rmtree(dirname)
os.mkdir(dirname)
7、其它一些可以被复用的东西
def show_cifar10(idx)
:
airplane
= read_image(TRAIN_DIR
+train_airplane[idx])
automobile
= read_image(TRAIN_DIR
+train_automobile[idx])
bird
= read_image(TRAIN_DIR
+train_bird[idx])
cat
= read_image(TRAIN_DIR
+train_cat[idx])
deer
= read_image(TRAIN_DIR
+train_deer[idx])
dog
= read_image(TRAIN_DIR
+train_dog[idx])
frog
= read_image(TRAIN_DIR
+train_frog[idx])
horse
= read_image(TRAIN_DIR
+train_horse[idx])
ship
= read_image(TRAIN_DIR
+train_ship[idx])
truck
= read_image(TRAIN_DIR
+train_truck[idx])
pair
= np.concatenate((airplane, automobile,bird,cat,deer,dog,frog,horse,ship,truck), axis
=
1)
plt.figure(figsize
=(
10,
5))
plt.imshow(pair)
plt.show()
for idx
in
range(
0,
5)
:
show_cifar10(idx)
用来显示已经保存到内存中数据的图片。

def CNNFeatureExtract(MODEL, image_size, lambda_func
=
None)
:
width
= image_size[
0]
#图像宽
height
= image_size[
1]
#图像高
input_tensor
= Input((height, width,
3))
x
= input_tensor
if lambda_func
:
x
= Lambda(lambda_func)(x)
base_model
= MODEL(input_tensor
=x, weights
=
'imagenet', include_top
=
False)
#这里全部使用no_top模型
model
= Model(base_model.
input, GlobalAveragePooling2D()(base_model.output))
gen
= ImageDataGenerator()
#使用了generate,并且使用的是文件夹模式
train_generator
= gen.flow_from_directory(
"train2", image_size, shuffle
=
False, batch_size
=
16)
test_generator
= gen.flow_from_directory(
"test2", image_size, shuffle
=
False, batch_size
=
16, class_mode
=
None)
train
= model.predict_generator(train_generator)
test
= model.predict_generator(test_generator)
with h5py.File(
"GoCifar10_%s.h5"
%MODEL.func_name) as h
:
h.create_dataset(
"train", data
=train)
h.create_dataset(
"test", data
=test)
h.create_dataset(
"label", data
=train_generator.classes)
强制的模型运算,帮助在dogsvscats上面进入10%,在cifar10上,我认为可以进入前20.
已经开始训练了。目前的算法虽然不流程,但是可以运行,最重要的是可控的。在这个层次上,我们可以继续前进。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 25岁的心里话
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 零经验选手,Compose 一天开发一款小游戏!
· 因为Apifox不支持离线,我果断选择了Apipost!
· 通过 API 将Deepseek响应流式内容输出到前端