

def read_corpus(path, eos="</s>"):
    data = [ ]
    with open(path) as fin:
        for line in fin:
            data += line.split() + [ eos ]
    return data


data = [ ]
path = '/home/lai/下载/txt'
with open(path) as fin:
    for line in fin:
        data += line.split() + [ eos ]


no it was n't black monday 
 but while the new york stock exchange did n't fall apart friday as the dow jones industrial average plunged N points most of it in the final hour it barely managed to stay this side of chaos 
 some circuit breakers installed after the october N crash failed their first test traders say unable to cool the selling panic in both stocks and futures 


['no', 'it', 'was', "n't", 'black', 'monday', '</s>', 'but', 'while', 'the', 'new', 'york', 'stock', 'exchange', 'did', "n't", 'fall', 'apart', 'friday', 'as', 'the', 'dow', 'jones', 'industrial', 'average', 'plunged', 'N', 'points', 'most', 'of', 'it', 'in', 'the', 'final', 'hour', 'it', 'barely', 'managed', 'to', 'stay', 'this', 'side', 'of', 'chaos', '</s>', 'some', 'circuit', 'breakers', 'installed', 'after', 'the', 'october', 'N', 'crash', 'failed', 'their', 'first', 'test', 'traders', 'say', 'unable', 'to', 'cool', 'the', 'selling', 'panic', 'in', 'both', 'stocks', 'and', 'futures', '</s>']



class EmbeddingLayer(nn.Module):#为语料中每一个单词对应的其相应的词向量
    def __init__(self, n_d, words, fix_emb=False):
        super(EmbeddingLayer, self).__init__()
        word2id = {}
        for w in words:
            if w not in word2id:
                word2id[w] = len(word2id)#把文本映射到数字上。

        self.word2id = word2id
        self.n_V, self.n_d = len(word2id), n_d#n_V应该是指词库大小,n_d指hidden state size    
        self.embedding = nn.Embedding(self.n_V, n_d)#赋予每个单词相应的词向量

    def forward(self, x):
        return self.embedding(x)

    def map_to_ids(self, text):#映射
        return np.asarray([self.word2id[x] for x in text],




import numpy as np
data = [ ("me gusta comer en la cafeteria".split(), "SPANISH"),
         ("Give it to me".split(), "ENGLISH"),
         ("No creo que sea una buena idea".split(), "SPANISH"),
         ("No it is not a good idea to get lost at sea".split(), "ENGLISH") ]

test_data = [("Yo creo que si".split(), "SPANISH"),
              ("it is lost on me".split(), "ENGLISH")]

word_to_ix = {}
for sent, _ in data + test_data:
    for word in sent:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
text={'creo': 10, 'idea': 15, 'a': 18}
print(np.asarray([word_to_ix[x] for x in text],



{'Give': 6, 'lost': 21, 'No': 9, 'cafeteria': 5, 'comer': 2, 'en': 3, 'at': 22, 'not': 17, 'good': 19, 'to': 8, 'una': 13, 'Yo': 23, 'me': 0, 'a': 18, 'on': 25, 'creo': 10, 'get': 20, 'it': 7, 'idea': 15, 'buena': 14, 'is': 16, 'si': 24, 'que': 11, 'la': 4, 'gusta': 1, 'sea': 12}
[15 10 18]
{'idea': 15, 'creo': 10, 'a': 18}





import time
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
def read_corpus(path, eos="</s>"):
    data = [ ]
    with open(path) as fin:
        for line in fin:
            data += line.split() + [ eos ]
    return data
def create_batches(data_text, map_to_ids, batch_size):
    data_ids = map_to_ids(data_text)
    N = len(data_ids)
    L = ((N-1) // batch_size) * batch_size
    x = np.copy(data_ids[:L].reshape(batch_size,-1).T)
    y = np.copy(data_ids[1:L+1].reshape(batch_size,-1).T)
    x, y = torch.from_numpy(x), torch.from_numpy(y)
    x, y = x.contiguous(), y.contiguous()
    return x,y
class EmbeddingLayer(nn.Module):#为语料中每一个单词对应的其相应的词向量
    def __init__(self, n_d, words, fix_emb=False):
        super(EmbeddingLayer, self).__init__()
        word2id = {}
        for w in words:
            if w not in word2id:
                word2id[w] = len(word2id)#把文本映射到数字上。
        self.word2id = word2id
        self.n_V, self.n_d = len(word2id), n_d#n_V应该是指词库大小,n_d指hidden state size   
        self.embedding = nn.Embedding(self.n_V, n_d)#赋予每个单词相应的词向量
    def forward(self, x):
        return self.embedding(x)
    def map_to_ids(self, text):#映射
        return np.asarray([self.word2id[x] for x in text],
train = read_corpus('/home/lai/下载/train.txt')
model = EmbeddingLayer(10,train)
map_to_ids = model.map_to_ids
train = create_batches(train, map_to_ids, batch_size=45)



['no', 'it', 'was', "n't", 'black', 'monday', '</s>', 'but', 'while', 'the', 'new', 'york', 'stock', 'exchange', 'did', "n't", 'fall', 'apart', 'friday', 'as', 'the', 'dow', 'jones', 'industrial', 'average', 'plunged', 'N', 'points', 'most', 'of', 'it', 'in', 'the', 'final', 'hour', 'it', 'barely', 'managed', 'to', 'stay', 'this', 'side', 'of', 'chaos', '</s>', 'some', 'circuit', 'breakers', 'installed', 'after', 'the', 'october', 'N', 'crash', 'failed', 'their', 'first', 'test', 'traders', 'say', 'unable', 'to', 'cool', 'the', 'selling', 'panic', 'in', 'both', 'stocks', 'and', 'futures', '</s>']
EmbeddingLayer (
  (embedding): Embedding(59, 10)
<bound method EmbeddingLayer.map_to_ids of EmbeddingLayer (
  (embedding): Embedding(59, 10)
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14  3 15 16 17 18  9 19 20 21 22
 23 24 25 26 27  1 28  9 29 30  1 31 32 33 34 35 36 27 37  6 38 39 40 41 42
  9 43 24 44 45 46 47 48 49 50 51 33 52  9 53 54 28 55 56 57 58  6]

Columns 0 to 12 
    0     1     2     3     4     5     6     7     8     9    10    11    12

Columns 13 to 25 
   13    14     3    15    16    17    18     9    19    20    21    22    23

Columns 26 to 38 
   24    25    26    27     1    28     9    29    30     1    31    32    33

Columns 39 to 44 
   34    35    36    27    37     6
[torch.LongTensor of size 1x45]

Columns 0 to 12 
    1     2     3     4     5     6     7     8     9    10    11    12    13

Columns 13 to 25 
   14     3    15    16    17    18     9    19    20    21    22    23    24

Columns 26 to 38 
   25    26    27     1    28     9    29    30     1    31    32    33    34

Columns 39 to 44 
   35    36    27    37     6    38
[torch.LongTensor of size 1x45]
Parameter containing:
 0.4376 -1.1509 -0.1407 -0.6956 -0.7292 -0.1944  0.8925  0.0688 -0.0560  2.5919
-0.7855 -0.0448 -0.8069 -1.4774  0.2366  0.3967 -0.0706 -0.4602  1.0099 -0.0734
-1.7748 -0.5265  0.4334 -0.7525 -0.0537  0.3966 -1.1800  0.2774 -2.2269 -0.4814
-0.9325  1.7541  0.6094 -0.1564  0.8379 -0.4577 -1.3616 -2.1115 -0.7025 -0.6662
 1.0896 -0.1558 -1.1896 -0.0955 -2.7685  0.9485  1.1311 -1.1454 -0.4689  1.0410
 1.2227  1.8617  0.9243 -0.3036  0.2639 -0.6933 -0.4147 -0.4482  2.7447  0.0573
 1.0230  0.0484 -1.0139 -0.4291  0.6560  0.6911 -1.2519  0.9809  0.5843  0.2033
-0.1128 -0.2149  1.2092  1.5636 -0.6737  1.0226  1.0155 -0.6230 -2.1714 -0.0226
 0.1947  1.0509  0.8694  1.5002 -0.3447 -0.2618  1.3267  0.0795  0.5041 -0.9763
 1.0146  0.9310 -1.2894  1.3288 -0.4146  0.1909 -0.3760  1.6011  0.7943  0.6290
-0.2122 -1.4665  1.4775  0.5200  1.2882 -0.4101  0.4479  0.4447 -0.9597  1.7938
 0.8239  0.5278 -0.0036  0.8840  0.1069  0.2539 -0.7887  0.1271  0.8512  0.3766
-0.5573  0.6985  1.0623 -1.3442  1.0792  0.4055  0.3625  1.7664 -0.3776  0.0266
-0.2160  0.6872  1.6154 -0.5749  2.6781  1.1730 -0.9687 -1.2116 -0.9464  0.5248
 0.0916  0.3761 -1.0593 -0.6794  1.6780 -0.2040  0.8541 -0.0384  1.5180  0.6114
-0.0321  0.5364  0.3896 -0.4864 -1.0080 -1.0698  0.1935  0.3896 -0.5745 -0.0273
 1.6301 -0.2652 -0.5325 -0.9380  0.3457 -2.0038 -0.0775 -0.7555 -0.8524 -0.9321
 0.0364 -0.4582 -0.3213 -0.9254 -1.0728 -0.1355  0.0993 -0.3186  2.3914 -1.5035
 0.0652  0.7371  0.9628  1.1530 -0.4044 -0.7131 -0.8299  1.6627 -0.8451 -1.0463
-0.3744  0.6010 -2.4774  1.6569 -0.5589 -0.6512 -1.3728 -1.7573  1.1402  1.6838
 0.2883 -1.3225  1.2454  0.4222 -0.5544 -1.5851  1.7119  1.3759  1.2300 -0.0676
 0.6371  1.4258 -0.0222  1.2869  0.8767 -0.2959 -0.5973 -2.6143 -0.4366  0.9691
 0.3215  0.6463  0.4688  0.4125  0.1800  0.0441  0.0375  0.4195  1.5675  0.7011
 0.5407  1.4961 -1.5759 -1.7088 -0.5991  1.2169  0.9620 -1.7427 -0.0108 -0.3502
-0.0906  0.1109 -0.4118  1.0876  0.8098 -0.8063 -0.2878  0.8896 -0.6304  0.0683
 0.6119  0.4786  0.6667  0.5702 -1.0531  0.4991  0.0538  1.1451 -0.7958 -0.0557
 1.3344  1.7192 -1.9320  2.1928 -0.1014  0.6543 -0.1026 -0.6506 -0.2592  0.0537
-1.0320  1.9222 -0.6615  0.8046 -0.7667 -0.6775 -0.4904  0.6054  0.2837 -1.2075
 0.6694 -0.7456 -0.9112  0.0961  0.3517 -0.6020 -0.9233  0.8343  0.0364 -0.5247
-1.4859 -0.8458  0.1642  0.2666 -2.9028  0.5945  0.0080  0.2036  1.9158  0.4553
 1.9948 -0.1500 -1.9221 -0.2734  0.7872  0.1108 -0.1790 -0.0549  0.8124  0.1027
-0.8605  2.0634 -1.1081  0.3951  0.6214  0.1754  0.4764  0.9175 -0.3207 -0.3007
 0.3095  1.4426 -0.6971 -1.1740  0.7263  0.0415 -0.4804  0.2983  0.9156  0.6196
-0.0862 -0.6351 -2.7732  1.2055  0.8422 -1.9189  1.4048 -0.8839  0.0811 -1.1528
-0.5930  1.2625  0.5828 -0.8534  0.5789 -1.8812  1.2968  1.1347 -1.3243  0.5715
-0.3339  0.5853  0.1010  1.2207  1.0524 -1.5834 -2.1429  0.7626  1.6698  0.7554
-1.0038  1.6710 -0.6395 -0.3707  0.3491  0.0697  0.2043  0.2882  1.3192 -2.2766
 1.1236 -0.3770 -0.4992  0.3957 -1.0027  0.7676  1.3439  1.1695 -0.0786  0.0372
 0.1163 -0.4600 -1.2990 -0.6624  0.6378  0.4357 -0.2231  0.8826  0.7718  0.6312
-0.9322  0.7925  1.0265 -0.9309  0.3586 -0.2663  0.7529 -0.8931  0.3230  1.0597
 0.0599  0.3668  0.2117 -0.3740 -1.2131 -0.7596 -0.1819  0.4357  3.0936  0.7486
-0.7667 -0.3219 -0.3511 -0.6781  0.8756  1.2539  0.7989  0.6129  0.3743  0.6551
 0.8160 -0.3391 -0.4200  0.0984  0.0863 -1.1544  0.6204 -0.6724  0.2659  0.5388
 0.4748  0.5738 -0.8648  0.3691 -0.3480 -0.1510  0.8260  0.6924  0.0053 -0.6213
 0.2044  0.7698  0.7638  0.3532  0.7197  0.9445 -1.0761  0.0882  0.5684  0.4562
-1.0330 -1.0507 -1.1679  0.0608  1.3512  0.2507  0.1740 -0.1574 -0.0552  0.6377
 1.3845  1.3252  2.5621 -0.5241  0.4334 -0.5092  0.1271 -1.3832  0.7112  0.1932
-0.1659  0.2740 -0.6393 -0.2937 -0.2887 -0.7221 -1.1947 -1.0431  1.1029 -1.1171
-0.2033 -0.5364 -0.4530 -2.4491 -1.2100 -1.5732  0.4191 -2.8109  0.3529 -0.7417
 0.1667 -0.0072  0.8795 -0.1538  0.5413  1.1036 -0.5249 -0.8432  0.0563 -0.2998
-0.4226  0.6448 -0.4215  0.4342 -0.6593 -0.2078  1.4768  1.1829  0.8084 -2.0024
 2.1950  0.8189  0.4104  0.4159 -1.1775 -2.3510 -0.5108 -2.5914 -0.5550  0.7188
-0.2978  0.1422 -0.0790 -1.6337 -0.4799 -0.9623 -0.9411  0.8321 -1.6386 -0.7785
-0.3109  0.5793  0.5437  0.3324 -0.9796  1.4794  0.0364  0.6472  0.7203  1.5878
 0.0685  1.5637 -0.4545 -2.2541  0.5353  0.1305  1.3973 -1.2065 -0.5373  1.3352
 0.0670 -0.6708 -0.4448  0.1797 -0.6935  1.4199  0.2560  0.3542 -1.0556 -1.1745
-0.3048  1.7749 -0.5777 -0.7029  0.9634 -0.9982  1.1929  1.5102  0.7618 -0.3569
 0.1294 -1.6825 -0.8473 -0.7886  0.3286 -0.2387 -0.4245 -0.3130  0.2273 -1.0860
-0.7929 -1.0838  0.1994 -0.4874  0.6568  0.1065  1.8086  0.2142 -1.1657 -0.2313
[torch.FloatTensor of size 59x10]






def init_weights(self):        
val_range = (3.0/self.n_d)**0.5 for p in self.parameters(): if p.dim() > 1: # matrix p.data.uniform_(-val_range, val_range) else: p.data.zero_()

 p.data.uniform_(-val_range, val_range)和p.data.zero_()



import time
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
def read_corpus(path, eos="</s>"):
    data = [ ]
    with open(path) as fin:
        for line in fin:
            data += line.split() + [ eos ]
    return data
def create_batches(data_text, map_to_ids, batch_size):
    data_ids = map_to_ids(data_text)
    N = len(data_ids)
    L = ((N-1) // batch_size) * batch_size
    x = np.copy(data_ids[:L].reshape(batch_size,-1).T)
    y = np.copy(data_ids[1:L+1].reshape(batch_size,-1).T)
    x, y = torch.from_numpy(x), torch.from_numpy(y)
    x, y = x.contiguous(), y.contiguous()
    return x,y
class EmbeddingLayer(nn.Module):#为语料中每一个单词对应的其相应的词向量
    def __init__(self, n_d, words, fix_emb=False):
        super(EmbeddingLayer, self).__init__()
        word2id = {}
        for w in words:
            if w not in word2id:
                word2id[w] = len(word2id)#把文本映射到数字上。
        self.word2id = word2id
        self.n_V, self.n_d = len(word2id), n_d#n_V应该是指词库大小,n_d指hidden state size   
        self.embedding = nn.Embedding(self.n_V, n_d)#赋予每个单词相应的词向量
    def forward(self, x):
        return self.embedding(x)
    def map_to_ids(self, text):#映射
        return np.asarray([self.word2id[x] for x in text],
train = read_corpus('/home/lai/下载/train.txt')
model = EmbeddingLayer(10,train)
for param in model.parameters():




['no', 'it', 'was', "n't", 'black', 'monday', '</s>', 'but', 'while', 'the', 'new', 'york', 'stock', 'exchange', 'did', "n't", 'fall', 'apart', 'friday', 'as', 'the', 'dow', 'jones', 'industrial', 'average', 'plunged', 'N', 'points', 'most', 'of', 'it', 'in', 'the', 'final', 'hour', 'it', 'barely', 'managed', 'to', 'stay', 'this', 'side', 'of', 'chaos', '</s>', 'some', 'circuit', 'breakers', 'installed', 'after', 'the', 'october', 'N', 'crash', 'failed', 'their', 'first', 'test', 'traders', 'say', 'unable', 'to', 'cool', 'the', 'selling', 'panic', 'in', 'both', 'stocks', 'and', 'futures', '</s>']

 1.4317  0.6596  0.0516  1.0376  0.1926  1.2600  0.0494  0.8796  1.9962  1.2159
 0.2419  0.6704  0.1465  1.6639  1.5062  1.6871  0.7300  1.6097  0.6998  1.1892
 0.8882  0.7436  0.7304  0.6540  1.0289  0.7935  1.9055  1.5515  1.2066  1.7531
 1.1168  1.8315  0.7545  1.8267  0.9284  0.4486  0.5175  0.0532  0.8085  1.3437
 0.2860  0.2907  0.8077  1.9553  1.2979  1.1078  0.0623  1.8027  1.8158  0.0852
 1.0238  0.3384  0.5703  1.5060  1.0183  0.2247  0.2230  0.7064  0.3984  1.6884
 1.1680  1.5321  0.9316  1.9031  0.5216  0.8028  0.8465  0.5166  1.5459  0.2865
 0.6001  1.1145  1.6196  1.7692  1.7195  1.3123  0.4399  0.4006  1.2029  1.6420
 1.9466  1.9689  0.8811  0.2398  1.3328  0.5307  1.6048  0.9328  1.6946  0.5598
 1.9595  0.3396  1.4121  0.1757  0.3677  0.5584  1.9388  1.2118  1.3966  1.4618
 1.2004  0.8745  0.4966  1.5487  0.7805  1.0708  1.8857  0.1973  1.1339  1.0490
 0.4731  0.2265  1.0293  0.7514  1.3949  1.5742  0.0032  1.0001  1.6449  1.4519
 0.2014  0.0456  1.2669  1.2988  0.9432  1.0757  0.6428  1.3084  0.7477  0.3753
 0.1086  0.1842  1.3811  1.4472  0.6998  0.0028  1.8839  1.0238  1.6243  1.3262
 0.6383  1.4817  0.2363  1.7802  1.2998  1.8367  1.9967  0.5028  0.0819  1.4886
 0.2979  0.3566  0.5144  0.6787  0.8583  0.9256  0.8171  0.0482  0.6638  1.3788
 0.4180  1.5806  1.0489  0.6587  1.6041  1.0644  1.9635  1.4030  1.5242  1.9292
 1.7177  1.0168  1.4879  1.5941  0.6318  0.4966  1.9573  1.0276  1.8955  0.9595
 1.3229  0.5519  0.0796  1.0840  0.2204  0.7510  0.6440  0.7307  1.0064  1.0647
 0.5325  1.1621  1.0669  1.2276  0.2488  1.6607  1.6797  1.7445  0.7051  0.0290
 1.9457  0.8071  1.9667  1.5591  1.6706  1.8955  0.2541  1.2218  0.5843  1.8493
 0.8763  0.2127  0.5883  0.9636  1.9839  0.5030  0.8972  0.3293  1.1231  0.8687
 1.3803  0.9248  1.3445  0.1882  1.3226  1.9621  1.0377  1.7566  1.6686  1.6855
 1.9552  0.1764  0.6670  1.5401  0.4913  0.8954  0.3951  0.8991  1.5485  0.6603
 0.5025  1.1702  1.8270  0.9304  0.4637  1.4306  0.5506  0.3712  0.0122  0.4379
 0.2657  0.0599  1.8354  0.2358  1.7581  0.3380  0.9558  1.7275  0.5202  1.3801
 0.7791  1.4060  0.6530  1.8742  0.5895  0.7742  1.7748  1.7141  1.2038  0.2918
 1.0312  1.9371  0.8345  0.4569  0.0447  0.2415  1.3479  0.9809  0.0566  1.0656
 0.3313  0.4801  0.3357  1.4143  0.6487  0.7692  1.0398  1.1538  0.8307  0.8231
 1.4774  0.1299  1.1836  0.2659  1.4413  0.4059  0.2428  1.0973  0.5491  0.2169
 1.8733  0.7073  0.6730  1.7413  1.1705  1.7082  1.0175  1.2589  1.9080  0.7648
 1.0761  1.1880  1.5441  1.9458  0.5513  1.5324  1.3756  0.3201  1.6600  0.7143
 1.8071  1.2422  1.5758  1.5677  1.5796  1.0328  0.3856  0.3648  0.5017  1.2543
 1.8749  1.9269  0.2120  0.3971  0.4451  0.7651  0.6793  0.1512  1.7845  0.1911
 1.2950  0.9356  1.0757  0.7603  0.6917  0.2891  1.3327  1.1102  0.3153  1.7074
 0.9031  1.8973  1.6392  0.3516  0.4412  1.4444  1.4032  0.1110  1.1379  0.2283
 0.4678  1.3409  0.6576  0.5351  1.2108  1.7777  0.5716  1.9060  1.4147  1.4487
 0.9546  0.9840  0.3020  1.7696  0.9677  1.1206  1.5639  0.0437  0.1485  0.1437
 1.0374  0.8910  1.7921  1.1207  0.4798  0.5863  0.0112  0.7735  0.8233  0.8936
 1.1980  1.6834  0.5779  0.7173  1.5803  1.6196  0.1642  1.6706  1.9906  1.4089
 0.2140  0.6833  1.6710  0.4645  0.0886  1.6945  0.8467  1.3290  1.7448  0.5405
 1.2914  1.5487  0.8509  1.8434  1.3398  0.3215  0.5732  1.5421  1.5103  0.2807
 1.4965  0.5448  1.0851  0.6836  1.4491  0.4040  1.8560  1.2288  1.4055  0.7298
 0.6319  0.9501  0.5320  1.2168  0.0031  1.8810  1.5128  0.4442  1.3887  1.5603
 0.5936  1.9980  1.4988  0.5884  1.9388  1.8275  0.1833  1.3767  1.2934  0.6319
 0.2711  0.0854  0.7103  0.8877  1.9997  0.2341  0.7163  1.8445  1.4777  0.0532
 1.1966  1.1512  1.8602  0.0552  1.7778  0.4180  1.0675  1.0646  1.6946  1.9979
 1.4076  0.1683  0.6894  1.0616  1.8683  0.3648  0.9496  0.4799  1.5983  0.8257
 1.5951  0.7438  0.4807  1.7440  1.1139  1.5855  0.3561  0.5960  0.6389  1.7573
 1.3262  1.5965  0.1100  1.0414  0.1697  1.8125  0.8135  0.1712  0.8863  0.5336
 0.4490  0.1233  0.0136  1.3416  0.2668  0.2091  0.8900  0.3823  1.3197  1.4936
 1.3607  0.6022  0.9031  0.7420  0.5538  1.5407  1.1918  0.5104  1.7564  0.1658
 0.4650  0.4523  1.3443  1.5691  1.0239  0.5898  0.8882  0.1892  1.0721  1.6908
 1.0479  1.9074  0.3732  1.8763  1.5337  0.2918  1.9343  1.6055  0.0709  0.9326
 0.6884  1.6136  1.1970  1.0819  0.3358  0.0234  0.4381  1.2239  1.1829  1.1254
 1.4076  0.4704  0.1724  0.5579  0.1318  0.5537  0.2435  0.8490  0.7200  1.5814
 0.2753  0.4727  0.5446  1.7038  0.8742  1.2662  1.3187  0.5939  1.2068  0.3514
 0.6184  1.6217  1.0503  1.0958  1.9824  0.6737  0.3009  0.7889  1.8378  1.7559
 0.6418  1.8355  0.7340  0.7232  0.6433  0.0288  1.3672  0.6466  0.3574  1.0760
[torch.FloatTensor of size 59x10]

 1.4317  0.6596  0.0516  1.0376  0.1926  1.2600  0.0494  0.8796  1.9962  1.2159
 0.2419  0.6704  0.1465  1.6639  1.5062  1.6871  0.7300  1.6097  0.6998  1.1892
 0.8882  0.7436  0.7304  0.6540  1.0289  0.7935  1.9055  1.5515  1.2066  1.7531
 1.1168  1.8315  0.7545  1.8267  0.9284  0.4486  0.5175  0.0532  0.8085  1.3437
 0.2860  0.2907  0.8077  1.9553  1.2979  1.1078  0.0623  1.8027  1.8158  0.0852
 1.0238  0.3384  0.5703  1.5060  1.0183  0.2247  0.2230  0.7064  0.3984  1.6884
 1.1680  1.5321  0.9316  1.9031  0.5216  0.8028  0.8465  0.5166  1.5459  0.2865
 0.6001  1.1145  1.6196  1.7692  1.7195  1.3123  0.4399  0.4006  1.2029  1.6420
 1.9466  1.9689  0.8811  0.2398  1.3328  0.5307  1.6048  0.9328  1.6946  0.5598
 1.9595  0.3396  1.4121  0.1757  0.3677  0.5584  1.9388  1.2118  1.3966  1.4618
 1.2004  0.8745  0.4966  1.5487  0.7805  1.0708  1.8857  0.1973  1.1339  1.0490
 0.4731  0.2265  1.0293  0.7514  1.3949  1.5742  0.0032  1.0001  1.6449  1.4519
 0.2014  0.0456  1.2669  1.2988  0.9432  1.0757  0.6428  1.3084  0.7477  0.3753
 0.1086  0.1842  1.3811  1.4472  0.6998  0.0028  1.8839  1.0238  1.6243  1.3262
 0.6383  1.4817  0.2363  1.7802  1.2998  1.8367  1.9967  0.5028  0.0819  1.4886
 0.2979  0.3566  0.5144  0.6787  0.8583  0.9256  0.8171  0.0482  0.6638  1.3788
 0.4180  1.5806  1.0489  0.6587  1.6041  1.0644  1.9635  1.4030  1.5242  1.9292
 1.7177  1.0168  1.4879  1.5941  0.6318  0.4966  1.9573  1.0276  1.8955  0.9595
 1.3229  0.5519  0.0796  1.0840  0.2204  0.7510  0.6440  0.7307  1.0064  1.0647
 0.5325  1.1621  1.0669  1.2276  0.2488  1.6607  1.6797  1.7445  0.7051  0.0290
 1.9457  0.8071  1.9667  1.5591  1.6706  1.8955  0.2541  1.2218  0.5843  1.8493
 0.8763  0.2127  0.5883  0.9636  1.9839  0.5030  0.8972  0.3293  1.1231  0.8687
 1.3803  0.9248  1.3445  0.1882  1.3226  1.9621  1.0377  1.7566  1.6686  1.6855
 1.9552  0.1764  0.6670  1.5401  0.4913  0.8954  0.3951  0.8991  1.5485  0.6603
 0.5025  1.1702  1.8270  0.9304  0.4637  1.4306  0.5506  0.3712  0.0122  0.4379
 0.2657  0.0599  1.8354  0.2358  1.7581  0.3380  0.9558  1.7275  0.5202  1.3801
 0.7791  1.4060  0.6530  1.8742  0.5895  0.7742  1.7748  1.7141  1.2038  0.2918
 1.0312  1.9371  0.8345  0.4569  0.0447  0.2415  1.3479  0.9809  0.0566  1.0656
 0.3313  0.4801  0.3357  1.4143  0.6487  0.7692  1.0398  1.1538  0.8307  0.8231
 1.4774  0.1299  1.1836  0.2659  1.4413  0.4059  0.2428  1.0973  0.5491  0.2169
 1.8733  0.7073  0.6730  1.7413  1.1705  1.7082  1.0175  1.2589  1.9080  0.7648
 1.0761  1.1880  1.5441  1.9458  0.5513  1.5324  1.3756  0.3201  1.6600  0.7143
 1.8071  1.2422  1.5758  1.5677  1.5796  1.0328  0.3856  0.3648  0.5017  1.2543
 1.8749  1.9269  0.2120  0.3971  0.4451  0.7651  0.6793  0.1512  1.7845  0.1911
 1.2950  0.9356  1.0757  0.7603  0.6917  0.2891  1.3327  1.1102  0.3153  1.7074
 0.9031  1.8973  1.6392  0.3516  0.4412  1.4444  1.4032  0.1110  1.1379  0.2283
 0.4678  1.3409  0.6576  0.5351  1.2108  1.7777  0.5716  1.9060  1.4147  1.4487
 0.9546  0.9840  0.3020  1.7696  0.9677  1.1206  1.5639  0.0437  0.1485  0.1437
 1.0374  0.8910  1.7921  1.1207  0.4798  0.5863  0.0112  0.7735  0.8233  0.8936
 1.1980  1.6834  0.5779  0.7173  1.5803  1.6196  0.1642  1.6706  1.9906  1.4089
 0.2140  0.6833  1.6710  0.4645  0.0886  1.6945  0.8467  1.3290  1.7448  0.5405
 1.2914  1.5487  0.8509  1.8434  1.3398  0.3215  0.5732  1.5421  1.5103  0.2807
 1.4965  0.5448  1.0851  0.6836  1.4491  0.4040  1.8560  1.2288  1.4055  0.7298
 0.6319  0.9501  0.5320  1.2168  0.0031  1.8810  1.5128  0.4442  1.3887  1.5603
 0.5936  1.9980  1.4988  0.5884  1.9388  1.8275  0.1833  1.3767  1.2934  0.6319
 0.2711  0.0854  0.7103  0.8877  1.9997  0.2341  0.7163  1.8445  1.4777  0.0532
 1.1966  1.1512  1.8602  0.0552  1.7778  0.4180  1.0675  1.0646  1.6946  1.9979
 1.4076  0.1683  0.6894  1.0616  1.8683  0.3648  0.9496  0.4799  1.5983  0.8257
 1.5951  0.7438  0.4807  1.7440  1.1139  1.5855  0.3561  0.5960  0.6389  1.7573
 1.3262  1.5965  0.1100  1.0414  0.1697  1.8125  0.8135  0.1712  0.8863  0.5336
 0.4490  0.1233  0.0136  1.3416  0.2668  0.2091  0.8900  0.3823  1.3197  1.4936
 1.3607  0.6022  0.9031  0.7420  0.5538  1.5407  1.1918  0.5104  1.7564  0.1658
 0.4650  0.4523  1.3443  1.5691  1.0239  0.5898  0.8882  0.1892  1.0721  1.6908
 1.0479  1.9074  0.3732  1.8763  1.5337  0.2918  1.9343  1.6055  0.0709  0.9326
 0.6884  1.6136  1.1970  1.0819  0.3358  0.0234  0.4381  1.2239  1.1829  1.1254
 1.4076  0.4704  0.1724  0.5579  0.1318  0.5537  0.2435  0.8490  0.7200  1.5814
 0.2753  0.4727  0.5446  1.7038  0.8742  1.2662  1.3187  0.5939  1.2068  0.3514
 0.6184  1.6217  1.0503  1.0958  1.9824  0.6737  0.3009  0.7889  1.8378  1.7559
 0.6418  1.8355  0.7340  0.7232  0.6433  0.0288  1.3672  0.6466  0.3574  1.0760
[torch.FloatTensor of size 59x10]



在这里记录一个我刚观察到的知识,param.dim()输出tensor的维度信息,维度与torch.FloatTensor of size 5x1x2x2有关,size为5x1x2x2是4维,size为5x1x2是3维以此类推,而Conv2d的这些size是由(Conv2d的前两个参数分别代表input image channel, output channel)输入图像的维度(RGB为3,灰度图像是1),输出的图像的维度(即filter的个数),还有kernel_size决定的。



def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        zeros = Variable(weight.new(self.depth, batch_size, self.n_d).zero_())
        if self.args.lstm:
            return (zeros, zeros)
            return zeros

 关于weight = next(self.parameters()).data


import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 2,2)
        self.conv2 = nn.Conv2d(1, 5, 2,1)
    def forward(self, x):
       x = F.relu(self.conv1(x))
       return F.relu(self.conv2(x))
x = next(model.parameters()).data


Model (
  (conv1): Conv2d(1, 6, kernel_size=(2, 2), stride=(2, 2))
  (conv2): Conv2d(1, 5, kernel_size=(2, 2), stride=(1, 1))

(0 ,0 ,.,.) = 
  0.2855 -0.0303
  0.1428 -0.4025

(1 ,0 ,.,.) = 
 -0.0901  0.2736
 -0.1527 -0.2854

(2 ,0 ,.,.) = 
  0.2193 -0.3886
 -0.4652  0.2307

(3 ,0 ,.,.) = 
  0.1918  0.4587
 -0.0480 -0.0636

(4 ,0 ,.,.) = 
  0.4017 -0.4123
  0.3016 -0.2714

(5 ,0 ,.,.) = 
  0.2053  0.1252
 -0.2365 -0.3651
[torch.FloatTensor of size 6x1x2x2]



posted @ 2017-12-04 20:03  深度学习1  阅读(1174)  评论(0编辑  收藏  举报