refinedet一些pytorch,python语法学习
官方链接:
https://github.com/luuuyi/RefineDet.PyTorch
- product
- 解析voc xml
- np.hstack() np.vstack() target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
- a[::-1]
- zip 例如:for (x, l, c) in zip(sources, self.arm_loc, self.arm_conf):
- torch.max() | tensor([[6, 3, 0, ..., 6, 0, 2]]) |
best_truth_overlap, best_truth_idx = overlap.max(0, keepdim=True)
- index_fill_(dim,index,val) ||||
best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior
product
for k, f in enumerate([10, 8, 5, 3]):
print("f:=====",f)
for i, j in product(range(f), repeat=2):
print(i,j)
f:===== 3
0 0
0 1
0 2
1 0
1 1
1 2
2 0
2 1
2 2
f:===== 5
0 0
0 1
0 2
0 3
0 4
1 0
1 1
1 2
1 3
1 4
2 0
2 1
2 2
2 3
2 4
3 0
3 1
3 2
3 3
3 4
4 0
4 1
4 2
4 3
4 4
解析voc xml
根据代码,写的测试样例:
例如xml里面内容如下:voc格式
<annotation>
<folder>VOC2007</folder>
<filename>seat_190530_623.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
<flickrid>329145082</flickrid>
</source>
<owner>>
<flickrid>hiromori2</flickrid>
<name>Hiroyuki Mori</name>
</owner>>
<size>
<width>1024</width>
<height>768</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>zuoyianquandai</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>476</xmin>
<ymin>276</ymin>
<xmax>562</xmax>
<ymax>372</ymax>
</bndbox>
</object>
<object>
<name>zuoyianquandai</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>440</xmin>
<ymin>271</ymin>
<xmax>506</xmax>
<ymax>372</ymax>
</bndbox>
</object>
<object>
<name>zuoyianquandai</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>622</xmin>
<ymin>616</ymin>
<xmax>726</xmax>
<ymax>717</ymax>
</bndbox>
</object>
<object>
<name>zuoyianquandai</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>348</xmin>
<ymin>598</ymin>
<xmax>456</xmax>
<ymax>720</ymax>
</bndbox>
</object>
<object>
<name>seat</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>270</xmin>
<ymin>15</ymin>
<xmax>825</xmax>
<ymax>367</ymax>
</bndbox>
</object>
<object>
<name>seat</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>72</xmin>
<ymin>66</ymin>
<xmax>492</xmax>
<ymax>683</ymax>
</bndbox>
</object>
<object>
<name>seat</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>612</xmin>
<ymin>0</ymin>
<xmax>1024</xmax>
<ymax>704</ymax>
</bndbox>
</object>
</annotation>
代码如下:
import os
import xml.etree.ElementTree as ET
root_dir = "/data_2/project_2021/refinedet/pytorch_refinedet/data/VOCdevkit/VOC2007/Annotations/"
list_xml = os.listdir(root_dir)
for cnt, name in enumerate(list_xml):
print(cnt,name)
path_xml = root_dir + name
target = ET.parse(path_xml).getroot()
res = []
for obj in target.iter('object'):
difficult = int(obj.find('difficult').text) == 1
if difficult:
continue
name = obj.find('name').text.lower().strip()
bbox = obj.find('bndbox')
pts = ['xmin', 'ymin', 'xmax', 'ymax']
bndbox = []
for i, pt in enumerate(pts):
cur_pt = int(float((bbox.find(pt).text)) + 0.5) - 1
# scale height or width
# cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
bndbox.append(cur_pt)
# label_idx = self.class_to_ind[name]
# bndbox.append(label_idx)
# label_idx = self.class_to_ind[name]
bndbox.append(name)
res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind]
#a = 0
res里面值如下:
<class 'list'>: [[475, 275, 561, 371, 'zuoyianquandai'], [439, 270, 505, 371, 'zuoyianquandai'], [621, 615, 725, 716, 'zuoyianquandai'], [347, 597, 455, 719, 'zuoyianquandai'], [269, 14, 824, 366, 'seat'], [71, 65, 491, 682, 'seat'], [611, -1, 1023, 703, 'seat']]
np.hstack() np.vstack() target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
np.vstack():在竖直方向上堆叠
np.hstack():在水平方向上平铺
import numpy as np
arr1=np.array([1,2,3])
arr2=np.array([4,5,6])
print(np.vstack)
print (np.vstack((arr1,arr2)))
print(np.hstack)
print (np.hstack((arr1,arr2)))
打印如下:
<function vstack at 0x7ff6e333d0e0>
[[1 2 3]
[4 5 6]]
<function hstack at 0x7ff6e333d290>
[1 2 3 4 5 6]
Process finished with exit code 0
target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
boxes[5,4]
label:[5] -- >np.expand_dims(labels, axis=1) -->>>>>[5,1]
==>target[5,5]
a[::-1]
a = [1,2,3,4,5]
b = a[::-1]
print(a)
print(b)
#[1, 2, 3, 4, 5]
#[5, 4, 3, 2, 1]
zip 例如:for (x, l, c) in zip(sources, self.arm_loc, self.arm_conf):
sources, self.arm_loc, self.arm_conf都是长度相同的列表,sources是数据,arm_loc和arm_conf是conv2d之类的操作方法
for (x, l, c) in zip(sources, self.arm_loc, self.arm_conf):
arm_loc.append(l(x).permute(0, 2, 3, 1).contiguous())
arm_conf.append(c(x).permute(0, 2, 3, 1).contiguous())
torch.max() | tensor([[6, 3, 0, ..., 6, 0, 2]]) | best_truth_overlap, best_truth_idx = overlap.max(0, keepdim=True)
torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
按维度dim 返回最大值,并且返回索引。
torch.max(a,0)返回每一列中最大值的那个元素,且返回索引(返回最大元素在这一列的行索引)。返回的最大值和索引各是一个tensor,一起构成元组(Tensor, LongTensor)
torch.max(a,1)返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)
import torch
a = torch.rand(3,5)
print(a)
print("========================")
print("a.max(0)")
print(a.max(0))
print("========================")
print("a.max(1)")
print(a.max(1))
Connected to pydev debugger (build 182.4505.26)
tensor([[0.2695, 0.3127, 0.5122, 0.4659, 0.8935],
[0.8419, 0.1534, 0.4232, 0.7792, 0.4795],
[0.9919, 0.9686, 0.1972, 0.2406, 0.4112]])
========================
a.max(0)
torch.return_types.max(
values=tensor([0.9919, 0.9686, 0.5122, 0.7792, 0.8935]),
indices=tensor([2, 2, 0, 1, 0]))
========================
a.max(1)
torch.return_types.max(
values=tensor([0.8935, 0.8419, 0.9919]),
indices=tensor([4, 0, 0]))
这里我有点儿迷糊,max(0),max(1)分的不清,0代表列?1代表行?
原本shape[3,5]的tensor经过max(0)就得到[1,5]
在refinedet里面,下面的代码:
overlap = torch.rand(7,6375)
best_prior_overlap, best_prior_idx = overlap.max(1, keepdim=True)
best_truth_overlap, best_truth_idx = overlap.max(0, keepdim=True)
overlap的含义是7个groundtruth与6375个prior的交并比,所以best_prior_overlap的维度知道是什么样子的吗?代表的含义又是啥?
best_prior_overlap的shape[7,1]
best_prior_idx的shape[7,1],取值范围是[0,6375)
每个groundtruth与哪个prior的iou最大,最大的prior是多少。
best_truth_overlap的shape是[1,6375]
best_truth_idx的shape是[1,6375],取值范围是[0,7)
每个prior与哪个groundtruth的iou最大
index_fill_(dim,index,val) |||| best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior
x = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
index = torch.LongTensor([0, 2])
x.index_fill_(1, index, 8)#([[8., 2., 8.],
# [8., 5., 8.],
# [8., 8., 8.]])
refinedet代码中:
best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior
aaa = best_truth_overlap[best_prior_idx[0].type(torch.LongTensor)] ##==2? yes!
这个就有点儿意思了,首先best_truth_overlap里面存放的都是交并比0到1的值,best_truth_overlap是竖直的[6375]找的最大,即每个prior与groundtruth找的最大值。
best_prior_idx的shape[7,1],取值范围是[0,6375)。best_prior_idx是横向找到的最大值的位置。
代码best_truth_overlap.index_fill_(0, best_prior_idx, 2) 意思就是在best_prior_idx的位置上把best_truth_overlap对应位置赋值2。感觉就是best_truth_overlap[best_prior_idx]=2类似的操作。
总的来说好像就是代码注释的这句# ensure best prior