[caffe学习笔记][03][生成配置文件]
说明:
caffe通过配置文件prototxt来描述网路结构,通过Python接口来生成网路配置文件比较简单。这里生成train.prototxt和test.prototxt,分别用于训练阶段和验证阶段。
步骤:
1.生成配置文件
touch create_train_val_prototxt.py
spyder create_train_val_prototxt.py
1 # -*- coding: utf-8 -*- 2 """ 3 yuandanfei Editor 4 5 This is a temporary script file. 6 """ 7 8 from caffe import layers as L, params as P, to_proto 9 path = '/home/yuandanfei/work/caffe/mnist/' #root path 10 train_lmdb = path + 'train_lmdb' #train_lmdb path 11 test_lmdb = path + 'test_lmdb' #test_lmdb path 12 mean_file = path + 'mean.binaryproto' #mean.binaryproto path 13 train_proto = path + 'train.prototxt' #train.prototxt path 14 test_proto = path + 'test.prototxt' #test.prototxt path 15 16 def create_net(lmdb, batch_size, include_acc=False): 17 #input layer 18 data, label = L.Data(source=lmdb, backend=P.Data.LMDB, batch_size=batch_size, ntop=2, 19 transform_param=dict(crop_size=28, mean_file=mean_file, mirror=True)) 20 #conv1 layer n*c*w*h;c1=num_output;w1/h1=(w0/h0+2*pad-kernel_size)/stride+1;if stride=1 and pad=(kernel_size-1)/2, then w1/h1=w0/h0; 21 conv1 = L.Convolution(data, kernel_size=5, stride=1, pad=2, num_output=16, weight_filler=dict(type='xavier')) 22 #reul1 layer 23 relu1 = L.ReLU(conv1, in_place=True) 24 #pool1 layer n*c*w*h;c1=c0;w1/h1=(w0/h0+2*pad-kernel_size)/stride+1; 25 pool1 = L.Pooling(relu1, pool=P.Pooling.MAX, kernel_size=3, stride=2) 26 #conv2 layer 27 conv2 = L.Convolution(pool1, kernel_size=3, stride=1, pad=1, num_output=32, weight_filler=dict(type='xavier')) 28 #relu2 layer 29 relu2 = L.ReLU(conv2, in_place=True) 30 #pool2 layer 31 pool2 = L.Pooling(relu2, pool=P.Pooling.MAX, kernel_size=3, stride=2) 32 #fc3 layer 33 fc3 = L.InnerProduct(pool2, num_output=1024, weight_filler=dict(type='xavier')) 34 #relu3 layer 35 relu3 = L.ReLU(fc3, in_place=True) 36 #drop3 layer 37 drop3 = L.Dropout(relu3, in_place=True) 38 #fc4 layer 39 fc4 = L.InnerProduct(drop3, num_output=10, weight_filler=dict(type='xavier')) 40 #softmax-loss layer 41 loss = L.SoftmaxWithLoss(fc4, label) 42 #accuracy layer 43 if include_acc: #test 44 acc = L.Accuracy(fc4, label) 45 return to_proto(loss, acc) 46 else: #train 47 return to_proto(loss) 48 49 50 def write_net(): 51 #write train prototxt 52 with open(train_proto, 'w') as f: 53 f.write(str(create_net(train_lmdb, batch_size=64))) 54 55 #write test prototxt 56 with open(test_proto, 'w') as f: 57 f.write(str(create_net(test_lmdb, batch_size=32, include_acc=True))) 58 59 60 if __name__ == '__main__': 61 write_net() 62
2.绘制网络模型
touch draw_net.sh
vim draw_net.sh
1 #!/usr/bin/bash 2 3 DATA=train 4 BULID=/home/yuandanfei/caffe/python/draw_net.py 5 6 python $BULID ../out/$DATA.prototxt ../out/$DATA.png --rankdir=BT
参考资料:
https://www.cnblogs.com/denny402/p/5679037.html
https://www.cnblogs.com/denny402/p/5106764.html