读sru代码
1、
def read_corpus(path, eos="</s>"): data = [ ] with open(path) as fin: for line in fin: data += line.split() + [ eos ] return data
来看一下这一段代码运行后产生的数据会是什么样子的
data = [ ] eos="</s>" path = '/home/lai/下载/txt' with open(path) as fin: for line in fin: data += line.split() + [ eos ] print(data)
这里的txt文件如下
no it was n't black monday but while the new york stock exchange did n't fall apart friday as the dow jones industrial average plunged N points most of it in the final hour it barely managed to stay this side of chaos some circuit breakers installed after the october N crash failed their first test traders say unable to cool the selling panic in both stocks and futures
结果:
['no', 'it', 'was', "n't", 'black', 'monday', '</s>', 'but', 'while', 'the', 'new', 'york', 'stock', 'exchange', 'did', "n't", 'fall', 'apart', 'friday', 'as', 'the', 'dow', 'jones', 'industrial', 'average', 'plunged', 'N', 'points', 'most', 'of', 'it', 'in', 'the', 'final', 'hour', 'it', 'barely', 'managed', 'to', 'stay', 'this', 'side', 'of', 'chaos', '</s>', 'some', 'circuit', 'breakers', 'installed', 'after', 'the', 'october', 'N', 'crash', 'failed', 'their', 'first', 'test', 'traders', 'say', 'unable', 'to', 'cool', 'the', 'selling', 'panic', 'in', 'both', 'stocks', 'and', 'futures', '</s>']
输出的是单个单词组成的序列,每一行的结尾以</s>结尾
2.
class EmbeddingLayer(nn.Module):#为语料中每一个单词对应的其相应的词向量 def __init__(self, n_d, words, fix_emb=False): super(EmbeddingLayer, self).__init__() word2id = {} for w in words: if w not in word2id: word2id[w] = len(word2id)#把文本映射到数字上。 self.word2id = word2id self.n_V, self.n_d = len(word2id), n_d#n_V应该是指词库大小,n_d指hidden state size self.embedding = nn.Embedding(self.n_V, n_d)#赋予每个单词相应的词向量 def forward(self, x): return self.embedding(x) def map_to_ids(self, text):#映射 return np.asarray([self.word2id[x] for x in text], dtype='int64' )
我构造了一个可以运行的简易程序进行理解
import numpy as np data = [ ("me gusta comer en la cafeteria".split(), "SPANISH"), ("Give it to me".split(), "ENGLISH"), ("No creo que sea una buena idea".split(), "SPANISH"), ("No it is not a good idea to get lost at sea".split(), "ENGLISH") ] test_data = [("Yo creo que si".split(), "SPANISH"), ("it is lost on me".split(), "ENGLISH")] #将文字映射到数字 word_to_ix = {} for sent, _ in data + test_data: for word in sent: if word not in word_to_ix: word_to_ix[word] = len(word_to_ix) print(word_to_ix) text={'creo': 10, 'idea': 15, 'a': 18} 把一个句子sentence通过word_to_ix转换成数字化序列. print(np.asarray([word_to_ix[x] for x in text], dtype='int64')) print(text)
结果:
{'Give': 6, 'lost': 21, 'No': 9, 'cafeteria': 5, 'comer': 2, 'en': 3, 'at': 22, 'not': 17, 'good': 19, 'to': 8, 'una': 13, 'Yo': 23, 'me': 0, 'a': 18, 'on': 25, 'creo': 10, 'get': 20, 'it': 7, 'idea': 15, 'buena': 14, 'is': 16, 'si': 24, 'que': 11, 'la': 4, 'gusta': 1, 'sea': 12} [15 10 18] {'idea': 15, 'creo': 10, 'a': 18}
所以这一部分先将文字映射到数字,然后把一个句子sentence通过word_to_ix转换成数字化序列.
关于读入数据的总结
用代码中定义的类读入自己的数据
import time import random import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable def read_corpus(path, eos="</s>"): data = [ ] with open(path) as fin: for line in fin: data += line.split() + [ eos ] return data def create_batches(data_text, map_to_ids, batch_size): data_ids = map_to_ids(data_text) print(data_ids) N = len(data_ids) L = ((N-1) // batch_size) * batch_size x = np.copy(data_ids[:L].reshape(batch_size,-1).T) y = np.copy(data_ids[1:L+1].reshape(batch_size,-1).T) x, y = torch.from_numpy(x), torch.from_numpy(y) x, y = x.contiguous(), y.contiguous() return x,y class EmbeddingLayer(nn.Module):#为语料中每一个单词对应的其相应的词向量 def __init__(self, n_d, words, fix_emb=False): super(EmbeddingLayer, self).__init__() word2id = {} for w in words: if w not in word2id: word2id[w] = len(word2id)#把文本映射到数字上。 self.word2id = word2id self.n_V, self.n_d = len(word2id), n_d#n_V应该是指词库大小,n_d指hidden state size self.embedding = nn.Embedding(self.n_V, n_d)#赋予每个单词相应的词向量 def forward(self, x): return self.embedding(x) def map_to_ids(self, text):#映射 return np.asarray([self.word2id[x] for x in text], dtype='int64' ) train = read_corpus('/home/lai/下载/train.txt') print(train) model = EmbeddingLayer(10,train) print(model) map_to_ids = model.map_to_ids print(map_to_ids) train = create_batches(train, map_to_ids, batch_size=45) print(train) print(model.embedding.weight)
结果
['no', 'it', 'was', "n't", 'black', 'monday', '</s>', 'but', 'while', 'the', 'new', 'york', 'stock', 'exchange', 'did', "n't", 'fall', 'apart', 'friday', 'as', 'the', 'dow', 'jones', 'industrial', 'average', 'plunged', 'N', 'points', 'most', 'of', 'it', 'in', 'the', 'final', 'hour', 'it', 'barely', 'managed', 'to', 'stay', 'this', 'side', 'of', 'chaos', '</s>', 'some', 'circuit', 'breakers', 'installed', 'after', 'the', 'october', 'N', 'crash', 'failed', 'their', 'first', 'test', 'traders', 'say', 'unable', 'to', 'cool', 'the', 'selling', 'panic', 'in', 'both', 'stocks', 'and', 'futures', '</s>'] EmbeddingLayer ( (embedding): Embedding(59, 10) ) <bound method EmbeddingLayer.map_to_ids of EmbeddingLayer ( (embedding): Embedding(59, 10) )> [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 3 15 16 17 18 9 19 20 21 22 23 24 25 26 27 1 28 9 29 30 1 31 32 33 34 35 36 27 37 6 38 39 40 41 42 9 43 24 44 45 46 47 48 49 50 51 33 52 9 53 54 28 55 56 57 58 6] ( Columns 0 to 12 0 1 2 3 4 5 6 7 8 9 10 11 12 Columns 13 to 25 13 14 3 15 16 17 18 9 19 20 21 22 23 Columns 26 to 38 24 25 26 27 1 28 9 29 30 1 31 32 33 Columns 39 to 44 34 35 36 27 37 6 [torch.LongTensor of size 1x45] , Columns 0 to 12 1 2 3 4 5 6 7 8 9 10 11 12 13 Columns 13 to 25 14 3 15 16 17 18 9 19 20 21 22 23 24 Columns 26 to 38 25 26 27 1 28 9 29 30 1 31 32 33 34 Columns 39 to 44 35 36 27 37 6 38 [torch.LongTensor of size 1x45] ) Parameter containing: 0.4376 -1.1509 -0.1407 -0.6956 -0.7292 -0.1944 0.8925 0.0688 -0.0560 2.5919 -0.7855 -0.0448 -0.8069 -1.4774 0.2366 0.3967 -0.0706 -0.4602 1.0099 -0.0734 -1.7748 -0.5265 0.4334 -0.7525 -0.0537 0.3966 -1.1800 0.2774 -2.2269 -0.4814 -0.9325 1.7541 0.6094 -0.1564 0.8379 -0.4577 -1.3616 -2.1115 -0.7025 -0.6662 1.0896 -0.1558 -1.1896 -0.0955 -2.7685 0.9485 1.1311 -1.1454 -0.4689 1.0410 1.2227 1.8617 0.9243 -0.3036 0.2639 -0.6933 -0.4147 -0.4482 2.7447 0.0573 1.0230 0.0484 -1.0139 -0.4291 0.6560 0.6911 -1.2519 0.9809 0.5843 0.2033 -0.1128 -0.2149 1.2092 1.5636 -0.6737 1.0226 1.0155 -0.6230 -2.1714 -0.0226 0.1947 1.0509 0.8694 1.5002 -0.3447 -0.2618 1.3267 0.0795 0.5041 -0.9763 1.0146 0.9310 -1.2894 1.3288 -0.4146 0.1909 -0.3760 1.6011 0.7943 0.6290 -0.2122 -1.4665 1.4775 0.5200 1.2882 -0.4101 0.4479 0.4447 -0.9597 1.7938 0.8239 0.5278 -0.0036 0.8840 0.1069 0.2539 -0.7887 0.1271 0.8512 0.3766 -0.5573 0.6985 1.0623 -1.3442 1.0792 0.4055 0.3625 1.7664 -0.3776 0.0266 -0.2160 0.6872 1.6154 -0.5749 2.6781 1.1730 -0.9687 -1.2116 -0.9464 0.5248 0.0916 0.3761 -1.0593 -0.6794 1.6780 -0.2040 0.8541 -0.0384 1.5180 0.6114 -0.0321 0.5364 0.3896 -0.4864 -1.0080 -1.0698 0.1935 0.3896 -0.5745 -0.0273 1.6301 -0.2652 -0.5325 -0.9380 0.3457 -2.0038 -0.0775 -0.7555 -0.8524 -0.9321 0.0364 -0.4582 -0.3213 -0.9254 -1.0728 -0.1355 0.0993 -0.3186 2.3914 -1.5035 0.0652 0.7371 0.9628 1.1530 -0.4044 -0.7131 -0.8299 1.6627 -0.8451 -1.0463 -0.3744 0.6010 -2.4774 1.6569 -0.5589 -0.6512 -1.3728 -1.7573 1.1402 1.6838 0.2883 -1.3225 1.2454 0.4222 -0.5544 -1.5851 1.7119 1.3759 1.2300 -0.0676 0.6371 1.4258 -0.0222 1.2869 0.8767 -0.2959 -0.5973 -2.6143 -0.4366 0.9691 0.3215 0.6463 0.4688 0.4125 0.1800 0.0441 0.0375 0.4195 1.5675 0.7011 0.5407 1.4961 -1.5759 -1.7088 -0.5991 1.2169 0.9620 -1.7427 -0.0108 -0.3502 -0.0906 0.1109 -0.4118 1.0876 0.8098 -0.8063 -0.2878 0.8896 -0.6304 0.0683 0.6119 0.4786 0.6667 0.5702 -1.0531 0.4991 0.0538 1.1451 -0.7958 -0.0557 1.3344 1.7192 -1.9320 2.1928 -0.1014 0.6543 -0.1026 -0.6506 -0.2592 0.0537 -1.0320 1.9222 -0.6615 0.8046 -0.7667 -0.6775 -0.4904 0.6054 0.2837 -1.2075 0.6694 -0.7456 -0.9112 0.0961 0.3517 -0.6020 -0.9233 0.8343 0.0364 -0.5247 -1.4859 -0.8458 0.1642 0.2666 -2.9028 0.5945 0.0080 0.2036 1.9158 0.4553 1.9948 -0.1500 -1.9221 -0.2734 0.7872 0.1108 -0.1790 -0.0549 0.8124 0.1027 -0.8605 2.0634 -1.1081 0.3951 0.6214 0.1754 0.4764 0.9175 -0.3207 -0.3007 0.3095 1.4426 -0.6971 -1.1740 0.7263 0.0415 -0.4804 0.2983 0.9156 0.6196 -0.0862 -0.6351 -2.7732 1.2055 0.8422 -1.9189 1.4048 -0.8839 0.0811 -1.1528 -0.5930 1.2625 0.5828 -0.8534 0.5789 -1.8812 1.2968 1.1347 -1.3243 0.5715 -0.3339 0.5853 0.1010 1.2207 1.0524 -1.5834 -2.1429 0.7626 1.6698 0.7554 -1.0038 1.6710 -0.6395 -0.3707 0.3491 0.0697 0.2043 0.2882 1.3192 -2.2766 1.1236 -0.3770 -0.4992 0.3957 -1.0027 0.7676 1.3439 1.1695 -0.0786 0.0372 0.1163 -0.4600 -1.2990 -0.6624 0.6378 0.4357 -0.2231 0.8826 0.7718 0.6312 -0.9322 0.7925 1.0265 -0.9309 0.3586 -0.2663 0.7529 -0.8931 0.3230 1.0597 0.0599 0.3668 0.2117 -0.3740 -1.2131 -0.7596 -0.1819 0.4357 3.0936 0.7486 -0.7667 -0.3219 -0.3511 -0.6781 0.8756 1.2539 0.7989 0.6129 0.3743 0.6551 0.8160 -0.3391 -0.4200 0.0984 0.0863 -1.1544 0.6204 -0.6724 0.2659 0.5388 0.4748 0.5738 -0.8648 0.3691 -0.3480 -0.1510 0.8260 0.6924 0.0053 -0.6213 0.2044 0.7698 0.7638 0.3532 0.7197 0.9445 -1.0761 0.0882 0.5684 0.4562 -1.0330 -1.0507 -1.1679 0.0608 1.3512 0.2507 0.1740 -0.1574 -0.0552 0.6377 1.3845 1.3252 2.5621 -0.5241 0.4334 -0.5092 0.1271 -1.3832 0.7112 0.1932 -0.1659 0.2740 -0.6393 -0.2937 -0.2887 -0.7221 -1.1947 -1.0431 1.1029 -1.1171 -0.2033 -0.5364 -0.4530 -2.4491 -1.2100 -1.5732 0.4191 -2.8109 0.3529 -0.7417 0.1667 -0.0072 0.8795 -0.1538 0.5413 1.1036 -0.5249 -0.8432 0.0563 -0.2998 -0.4226 0.6448 -0.4215 0.4342 -0.6593 -0.2078 1.4768 1.1829 0.8084 -2.0024 2.1950 0.8189 0.4104 0.4159 -1.1775 -2.3510 -0.5108 -2.5914 -0.5550 0.7188 -0.2978 0.1422 -0.0790 -1.6337 -0.4799 -0.9623 -0.9411 0.8321 -1.6386 -0.7785 -0.3109 0.5793 0.5437 0.3324 -0.9796 1.4794 0.0364 0.6472 0.7203 1.5878 0.0685 1.5637 -0.4545 -2.2541 0.5353 0.1305 1.3973 -1.2065 -0.5373 1.3352 0.0670 -0.6708 -0.4448 0.1797 -0.6935 1.4199 0.2560 0.3542 -1.0556 -1.1745 -0.3048 1.7749 -0.5777 -0.7029 0.9634 -0.9982 1.1929 1.5102 0.7618 -0.3569 0.1294 -1.6825 -0.8473 -0.7886 0.3286 -0.2387 -0.4245 -0.3130 0.2273 -1.0860 -0.7929 -1.0838 0.1994 -0.4874 0.6568 0.1065 1.8086 0.2142 -1.1657 -0.2313 [torch.FloatTensor of size 59x10]
我把这个过程的中间结果全都打印出来,便于理解,对于model.embedding.weight,这个embedding层的weight应该是指每个单词所对应的向量
3.
def init_weights(self):
val_range = (3.0/self.n_d)**0.5 for p in self.parameters(): if p.dim() > 1: # matrix p.data.uniform_(-val_range, val_range) else: p.data.zero_()
p.data.uniform_(-val_range, val_range)和p.data.zero_()
我自己构造了一个模型用以探究其功能
import time import random import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable def read_corpus(path, eos="</s>"): data = [ ] with open(path) as fin: for line in fin: data += line.split() + [ eos ] return data def create_batches(data_text, map_to_ids, batch_size): data_ids = map_to_ids(data_text) print(data_ids) N = len(data_ids) L = ((N-1) // batch_size) * batch_size x = np.copy(data_ids[:L].reshape(batch_size,-1).T) y = np.copy(data_ids[1:L+1].reshape(batch_size,-1).T) x, y = torch.from_numpy(x), torch.from_numpy(y) x, y = x.contiguous(), y.contiguous() return x,y class EmbeddingLayer(nn.Module):#为语料中每一个单词对应的其相应的词向量 def __init__(self, n_d, words, fix_emb=False): super(EmbeddingLayer, self).__init__() word2id = {} for w in words: if w not in word2id: word2id[w] = len(word2id)#把文本映射到数字上。 self.word2id = word2id self.n_V, self.n_d = len(word2id), n_d#n_V应该是指词库大小,n_d指hidden state size self.embedding = nn.Embedding(self.n_V, n_d)#赋予每个单词相应的词向量 def forward(self, x): return self.embedding(x) def map_to_ids(self, text):#映射 return np.asarray([self.word2id[x] for x in text], dtype='int64' ) train = read_corpus('/home/lai/下载/train.txt') print(train) model = EmbeddingLayer(10,train) for param in model.parameters(): print(param.data.uniform_(0,2)) print(param.data)
结果:
['no', 'it', 'was', "n't", 'black', 'monday', '</s>', 'but', 'while', 'the', 'new', 'york', 'stock', 'exchange', 'did', "n't", 'fall', 'apart', 'friday', 'as', 'the', 'dow', 'jones', 'industrial', 'average', 'plunged', 'N', 'points', 'most', 'of', 'it', 'in', 'the', 'final', 'hour', 'it', 'barely', 'managed', 'to', 'stay', 'this', 'side', 'of', 'chaos', '</s>', 'some', 'circuit', 'breakers', 'installed', 'after', 'the', 'october', 'N', 'crash', 'failed', 'their', 'first', 'test', 'traders', 'say', 'unable', 'to', 'cool', 'the', 'selling', 'panic', 'in', 'both', 'stocks', 'and', 'futures', '</s>'] 1.4317 0.6596 0.0516 1.0376 0.1926 1.2600 0.0494 0.8796 1.9962 1.2159 0.2419 0.6704 0.1465 1.6639 1.5062 1.6871 0.7300 1.6097 0.6998 1.1892 0.8882 0.7436 0.7304 0.6540 1.0289 0.7935 1.9055 1.5515 1.2066 1.7531 1.1168 1.8315 0.7545 1.8267 0.9284 0.4486 0.5175 0.0532 0.8085 1.3437 0.2860 0.2907 0.8077 1.9553 1.2979 1.1078 0.0623 1.8027 1.8158 0.0852 1.0238 0.3384 0.5703 1.5060 1.0183 0.2247 0.2230 0.7064 0.3984 1.6884 1.1680 1.5321 0.9316 1.9031 0.5216 0.8028 0.8465 0.5166 1.5459 0.2865 0.6001 1.1145 1.6196 1.7692 1.7195 1.3123 0.4399 0.4006 1.2029 1.6420 1.9466 1.9689 0.8811 0.2398 1.3328 0.5307 1.6048 0.9328 1.6946 0.5598 1.9595 0.3396 1.4121 0.1757 0.3677 0.5584 1.9388 1.2118 1.3966 1.4618 1.2004 0.8745 0.4966 1.5487 0.7805 1.0708 1.8857 0.1973 1.1339 1.0490 0.4731 0.2265 1.0293 0.7514 1.3949 1.5742 0.0032 1.0001 1.6449 1.4519 0.2014 0.0456 1.2669 1.2988 0.9432 1.0757 0.6428 1.3084 0.7477 0.3753 0.1086 0.1842 1.3811 1.4472 0.6998 0.0028 1.8839 1.0238 1.6243 1.3262 0.6383 1.4817 0.2363 1.7802 1.2998 1.8367 1.9967 0.5028 0.0819 1.4886 0.2979 0.3566 0.5144 0.6787 0.8583 0.9256 0.8171 0.0482 0.6638 1.3788 0.4180 1.5806 1.0489 0.6587 1.6041 1.0644 1.9635 1.4030 1.5242 1.9292 1.7177 1.0168 1.4879 1.5941 0.6318 0.4966 1.9573 1.0276 1.8955 0.9595 1.3229 0.5519 0.0796 1.0840 0.2204 0.7510 0.6440 0.7307 1.0064 1.0647 0.5325 1.1621 1.0669 1.2276 0.2488 1.6607 1.6797 1.7445 0.7051 0.0290 1.9457 0.8071 1.9667 1.5591 1.6706 1.8955 0.2541 1.2218 0.5843 1.8493 0.8763 0.2127 0.5883 0.9636 1.9839 0.5030 0.8972 0.3293 1.1231 0.8687 1.3803 0.9248 1.3445 0.1882 1.3226 1.9621 1.0377 1.7566 1.6686 1.6855 1.9552 0.1764 0.6670 1.5401 0.4913 0.8954 0.3951 0.8991 1.5485 0.6603 0.5025 1.1702 1.8270 0.9304 0.4637 1.4306 0.5506 0.3712 0.0122 0.4379 0.2657 0.0599 1.8354 0.2358 1.7581 0.3380 0.9558 1.7275 0.5202 1.3801 0.7791 1.4060 0.6530 1.8742 0.5895 0.7742 1.7748 1.7141 1.2038 0.2918 1.0312 1.9371 0.8345 0.4569 0.0447 0.2415 1.3479 0.9809 0.0566 1.0656 0.3313 0.4801 0.3357 1.4143 0.6487 0.7692 1.0398 1.1538 0.8307 0.8231 1.4774 0.1299 1.1836 0.2659 1.4413 0.4059 0.2428 1.0973 0.5491 0.2169 1.8733 0.7073 0.6730 1.7413 1.1705 1.7082 1.0175 1.2589 1.9080 0.7648 1.0761 1.1880 1.5441 1.9458 0.5513 1.5324 1.3756 0.3201 1.6600 0.7143 1.8071 1.2422 1.5758 1.5677 1.5796 1.0328 0.3856 0.3648 0.5017 1.2543 1.8749 1.9269 0.2120 0.3971 0.4451 0.7651 0.6793 0.1512 1.7845 0.1911 1.2950 0.9356 1.0757 0.7603 0.6917 0.2891 1.3327 1.1102 0.3153 1.7074 0.9031 1.8973 1.6392 0.3516 0.4412 1.4444 1.4032 0.1110 1.1379 0.2283 0.4678 1.3409 0.6576 0.5351 1.2108 1.7777 0.5716 1.9060 1.4147 1.4487 0.9546 0.9840 0.3020 1.7696 0.9677 1.1206 1.5639 0.0437 0.1485 0.1437 1.0374 0.8910 1.7921 1.1207 0.4798 0.5863 0.0112 0.7735 0.8233 0.8936 1.1980 1.6834 0.5779 0.7173 1.5803 1.6196 0.1642 1.6706 1.9906 1.4089 0.2140 0.6833 1.6710 0.4645 0.0886 1.6945 0.8467 1.3290 1.7448 0.5405 1.2914 1.5487 0.8509 1.8434 1.3398 0.3215 0.5732 1.5421 1.5103 0.2807 1.4965 0.5448 1.0851 0.6836 1.4491 0.4040 1.8560 1.2288 1.4055 0.7298 0.6319 0.9501 0.5320 1.2168 0.0031 1.8810 1.5128 0.4442 1.3887 1.5603 0.5936 1.9980 1.4988 0.5884 1.9388 1.8275 0.1833 1.3767 1.2934 0.6319 0.2711 0.0854 0.7103 0.8877 1.9997 0.2341 0.7163 1.8445 1.4777 0.0532 1.1966 1.1512 1.8602 0.0552 1.7778 0.4180 1.0675 1.0646 1.6946 1.9979 1.4076 0.1683 0.6894 1.0616 1.8683 0.3648 0.9496 0.4799 1.5983 0.8257 1.5951 0.7438 0.4807 1.7440 1.1139 1.5855 0.3561 0.5960 0.6389 1.7573 1.3262 1.5965 0.1100 1.0414 0.1697 1.8125 0.8135 0.1712 0.8863 0.5336 0.4490 0.1233 0.0136 1.3416 0.2668 0.2091 0.8900 0.3823 1.3197 1.4936 1.3607 0.6022 0.9031 0.7420 0.5538 1.5407 1.1918 0.5104 1.7564 0.1658 0.4650 0.4523 1.3443 1.5691 1.0239 0.5898 0.8882 0.1892 1.0721 1.6908 1.0479 1.9074 0.3732 1.8763 1.5337 0.2918 1.9343 1.6055 0.0709 0.9326 0.6884 1.6136 1.1970 1.0819 0.3358 0.0234 0.4381 1.2239 1.1829 1.1254 1.4076 0.4704 0.1724 0.5579 0.1318 0.5537 0.2435 0.8490 0.7200 1.5814 0.2753 0.4727 0.5446 1.7038 0.8742 1.2662 1.3187 0.5939 1.2068 0.3514 0.6184 1.6217 1.0503 1.0958 1.9824 0.6737 0.3009 0.7889 1.8378 1.7559 0.6418 1.8355 0.7340 0.7232 0.6433 0.0288 1.3672 0.6466 0.3574 1.0760 [torch.FloatTensor of size 59x10] 1.4317 0.6596 0.0516 1.0376 0.1926 1.2600 0.0494 0.8796 1.9962 1.2159 0.2419 0.6704 0.1465 1.6639 1.5062 1.6871 0.7300 1.6097 0.6998 1.1892 0.8882 0.7436 0.7304 0.6540 1.0289 0.7935 1.9055 1.5515 1.2066 1.7531 1.1168 1.8315 0.7545 1.8267 0.9284 0.4486 0.5175 0.0532 0.8085 1.3437 0.2860 0.2907 0.8077 1.9553 1.2979 1.1078 0.0623 1.8027 1.8158 0.0852 1.0238 0.3384 0.5703 1.5060 1.0183 0.2247 0.2230 0.7064 0.3984 1.6884 1.1680 1.5321 0.9316 1.9031 0.5216 0.8028 0.8465 0.5166 1.5459 0.2865 0.6001 1.1145 1.6196 1.7692 1.7195 1.3123 0.4399 0.4006 1.2029 1.6420 1.9466 1.9689 0.8811 0.2398 1.3328 0.5307 1.6048 0.9328 1.6946 0.5598 1.9595 0.3396 1.4121 0.1757 0.3677 0.5584 1.9388 1.2118 1.3966 1.4618 1.2004 0.8745 0.4966 1.5487 0.7805 1.0708 1.8857 0.1973 1.1339 1.0490 0.4731 0.2265 1.0293 0.7514 1.3949 1.5742 0.0032 1.0001 1.6449 1.4519 0.2014 0.0456 1.2669 1.2988 0.9432 1.0757 0.6428 1.3084 0.7477 0.3753 0.1086 0.1842 1.3811 1.4472 0.6998 0.0028 1.8839 1.0238 1.6243 1.3262 0.6383 1.4817 0.2363 1.7802 1.2998 1.8367 1.9967 0.5028 0.0819 1.4886 0.2979 0.3566 0.5144 0.6787 0.8583 0.9256 0.8171 0.0482 0.6638 1.3788 0.4180 1.5806 1.0489 0.6587 1.6041 1.0644 1.9635 1.4030 1.5242 1.9292 1.7177 1.0168 1.4879 1.5941 0.6318 0.4966 1.9573 1.0276 1.8955 0.9595 1.3229 0.5519 0.0796 1.0840 0.2204 0.7510 0.6440 0.7307 1.0064 1.0647 0.5325 1.1621 1.0669 1.2276 0.2488 1.6607 1.6797 1.7445 0.7051 0.0290 1.9457 0.8071 1.9667 1.5591 1.6706 1.8955 0.2541 1.2218 0.5843 1.8493 0.8763 0.2127 0.5883 0.9636 1.9839 0.5030 0.8972 0.3293 1.1231 0.8687 1.3803 0.9248 1.3445 0.1882 1.3226 1.9621 1.0377 1.7566 1.6686 1.6855 1.9552 0.1764 0.6670 1.5401 0.4913 0.8954 0.3951 0.8991 1.5485 0.6603 0.5025 1.1702 1.8270 0.9304 0.4637 1.4306 0.5506 0.3712 0.0122 0.4379 0.2657 0.0599 1.8354 0.2358 1.7581 0.3380 0.9558 1.7275 0.5202 1.3801 0.7791 1.4060 0.6530 1.8742 0.5895 0.7742 1.7748 1.7141 1.2038 0.2918 1.0312 1.9371 0.8345 0.4569 0.0447 0.2415 1.3479 0.9809 0.0566 1.0656 0.3313 0.4801 0.3357 1.4143 0.6487 0.7692 1.0398 1.1538 0.8307 0.8231 1.4774 0.1299 1.1836 0.2659 1.4413 0.4059 0.2428 1.0973 0.5491 0.2169 1.8733 0.7073 0.6730 1.7413 1.1705 1.7082 1.0175 1.2589 1.9080 0.7648 1.0761 1.1880 1.5441 1.9458 0.5513 1.5324 1.3756 0.3201 1.6600 0.7143 1.8071 1.2422 1.5758 1.5677 1.5796 1.0328 0.3856 0.3648 0.5017 1.2543 1.8749 1.9269 0.2120 0.3971 0.4451 0.7651 0.6793 0.1512 1.7845 0.1911 1.2950 0.9356 1.0757 0.7603 0.6917 0.2891 1.3327 1.1102 0.3153 1.7074 0.9031 1.8973 1.6392 0.3516 0.4412 1.4444 1.4032 0.1110 1.1379 0.2283 0.4678 1.3409 0.6576 0.5351 1.2108 1.7777 0.5716 1.9060 1.4147 1.4487 0.9546 0.9840 0.3020 1.7696 0.9677 1.1206 1.5639 0.0437 0.1485 0.1437 1.0374 0.8910 1.7921 1.1207 0.4798 0.5863 0.0112 0.7735 0.8233 0.8936 1.1980 1.6834 0.5779 0.7173 1.5803 1.6196 0.1642 1.6706 1.9906 1.4089 0.2140 0.6833 1.6710 0.4645 0.0886 1.6945 0.8467 1.3290 1.7448 0.5405 1.2914 1.5487 0.8509 1.8434 1.3398 0.3215 0.5732 1.5421 1.5103 0.2807 1.4965 0.5448 1.0851 0.6836 1.4491 0.4040 1.8560 1.2288 1.4055 0.7298 0.6319 0.9501 0.5320 1.2168 0.0031 1.8810 1.5128 0.4442 1.3887 1.5603 0.5936 1.9980 1.4988 0.5884 1.9388 1.8275 0.1833 1.3767 1.2934 0.6319 0.2711 0.0854 0.7103 0.8877 1.9997 0.2341 0.7163 1.8445 1.4777 0.0532 1.1966 1.1512 1.8602 0.0552 1.7778 0.4180 1.0675 1.0646 1.6946 1.9979 1.4076 0.1683 0.6894 1.0616 1.8683 0.3648 0.9496 0.4799 1.5983 0.8257 1.5951 0.7438 0.4807 1.7440 1.1139 1.5855 0.3561 0.5960 0.6389 1.7573 1.3262 1.5965 0.1100 1.0414 0.1697 1.8125 0.8135 0.1712 0.8863 0.5336 0.4490 0.1233 0.0136 1.3416 0.2668 0.2091 0.8900 0.3823 1.3197 1.4936 1.3607 0.6022 0.9031 0.7420 0.5538 1.5407 1.1918 0.5104 1.7564 0.1658 0.4650 0.4523 1.3443 1.5691 1.0239 0.5898 0.8882 0.1892 1.0721 1.6908 1.0479 1.9074 0.3732 1.8763 1.5337 0.2918 1.9343 1.6055 0.0709 0.9326 0.6884 1.6136 1.1970 1.0819 0.3358 0.0234 0.4381 1.2239 1.1829 1.1254 1.4076 0.4704 0.1724 0.5579 0.1318 0.5537 0.2435 0.8490 0.7200 1.5814 0.2753 0.4727 0.5446 1.7038 0.8742 1.2662 1.3187 0.5939 1.2068 0.3514 0.6184 1.6217 1.0503 1.0958 1.9824 0.6737 0.3009 0.7889 1.8378 1.7559 0.6418 1.8355 0.7340 0.7232 0.6433 0.0288 1.3672 0.6466 0.3574 1.0760 [torch.FloatTensor of size 59x10]
param.data.uniform_(-1,1)改变则得到的tensor里面的值随之改变,model.parameter()生成的是基于模型参数的迭代器
在这里记录一个我刚观察到的知识,param.dim()输出tensor的维度信息,维度与torch.FloatTensor of size 5x1x2x2有关,size为5x1x2x2是4维,size为5x1x2是3维以此类推,而Conv2d的这些size是由(Conv2d的前两个参数分别代表input image channel, output channel)输入图像的维度(RGB为3,灰度图像是1),输出的图像的维度(即filter的个数),还有kernel_size决定的。
而输出结果中的维度信息为1的tendor,是卷积得到的结果
4、
def init_hidden(self, batch_size): weight = next(self.parameters()).data zeros = Variable(weight.new(self.depth, batch_size, self.n_d).zero_()) if self.args.lstm: return (zeros, zeros) else: return zeros
关于weight = next(self.parameters()).data
看看基于上面那个模型得到的结果
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 6, 2,2) self.conv2 = nn.Conv2d(1, 5, 2,1) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) model=Model() print(model) print(('next')) x = next(model.parameters()).data print(x)
结果
Model ( (conv1): Conv2d(1, 6, kernel_size=(2, 2), stride=(2, 2)) (conv2): Conv2d(1, 5, kernel_size=(2, 2), stride=(1, 1)) ) next (0 ,0 ,.,.) = 0.2855 -0.0303 0.1428 -0.4025 (1 ,0 ,.,.) = -0.0901 0.2736 -0.1527 -0.2854 (2 ,0 ,.,.) = 0.2193 -0.3886 -0.4652 0.2307 (3 ,0 ,.,.) = 0.1918 0.4587 -0.0480 -0.0636 (4 ,0 ,.,.) = 0.4017 -0.4123 0.3016 -0.2714 (5 ,0 ,.,.) = 0.2053 0.1252 -0.2365 -0.3651 [torch.FloatTensor of size 6x1x2x2]
输出的是模型参数中的第0个模型参数的数据。