Datawhale Al夏令营——siRNA药物药效预测Transformer模型搭建准备工作
数据分析¶
构建模型,我们首先要对数据进行充分的分析,通过可视化与表格的形式展现我们能够更加有效的将其用在模型中。
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
# 忽略警告
warnings.filterwarnings('ignore')
# 设置字体,避免中文字符显示问题
plt.rcParams['font.sans-serif'] = ['SimHei'] # 使用黑体
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 加载数据
train_data = pd.read_csv('train_data.csv')
# 指定需要处理的列
columns = ['siRNA_antisense_seq', 'modified_siRNA_sense_seq_list', 'modified_siRNA_antisense_seq_list', 'gene_target_seq']
# 定义函数处理字符长度统计
def get_length(x, is_modified=False):
if pd.isna(x):
return 0
if is_modified:
return len(x.split(' '))
return len(x)
# 统计字符数
lengths = {}
for col in columns:
is_modified = 'modified' in col
lengths[col] = train_data[col].apply(lambda x: get_length(x, is_modified))
# 创建表格展示每个字符长度的频数,并绘制直方图
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()
for i, col in enumerate(columns):
length_counts = lengths[col].value_counts().sort_index()
length_counts_df = pd.DataFrame(length_counts).reset_index()
length_counts_df.columns = [f'{col} 长度', '频数']
print(f"{col} 各字符长度的频数:")
print(length_counts_df)
sns.histplot(lengths[col], bins=30, kde=True, ax=axes[i], edgecolor='black')
axes[i].set_title(f'{col} 字符长度分布')
axes[i].set_xlabel('字符长度')
axes[i].set_ylabel('频数')
plt.tight_layout()
plt.show()
siRNA_antisense_seq 各字符长度的频数: siRNA_antisense_seq 长度 频数 0 23 66 modified_siRNA_sense_seq_list 各字符长度的频数: modified_siRNA_sense_seq_list 长度 频数 0 22 66 modified_siRNA_antisense_seq_list 各字符长度的频数: modified_siRNA_antisense_seq_list 长度 频数 0 23 66 gene_target_seq 各字符长度的频数: gene_target_seq 长度 频数 0 1661 12 1 2226 54
# 指定需要统计频次的列
frequency_columns = ['gene_target_species', 'cell_line_donor', 'siRNA_concentration', 'concentration_unit', 'Transfection_method', 'Duration_after_transfection_h']
# 统计并展示每个变量的频次
for col in frequency_columns:
# 统计变量出现的次数
frequency_counts = train_data[col].value_counts()
# 创建频次统计的表格
frequency_df = pd.DataFrame(frequency_counts).reset_index()
frequency_df.columns = [col, '次数']
print(f"{col} 变量出现的次数:")
print(frequency_df)
# 绘制条形图
plt.figure(figsize=(10, 6))
sns.barplot(x=frequency_df[col], y=frequency_df['次数'])
plt.title(f'{col} 变量出现次数')
plt.xlabel(col)
plt.ylabel('次数')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
gene_target_species 变量出现的次数: gene_target_species 次数 0 Homo sapiens 54 1 Mus musculus 12
cell_line_donor 变量出现的次数: cell_line_donor 次数 0 Hep3B Cells 54 1 Primary Mouse Hepatocytes 12
siRNA_concentration 变量出现的次数: siRNA_concentration 次数 0 10.0 60 1 0.1 6
concentration_unit 变量出现的次数: concentration_unit 次数 0 nM 66
Transfection_method 变量出现的次数: Transfection_method 次数 0 Lipofectamine 66
Duration_after_transfection_h 变量出现的次数: Duration_after_transfection_h 次数 0 24 66
深度学习模型中张量的变化探索¶
import os # 文件操作
import torch # 深度学习框架
import random # 随机数生成
import numpy as np # 数值计算
import pandas as pd # 数据处理
import math # 导入math模块
import torch.nn as nn # 神经网络模块
import torch.optim as optim # 优化器模块
from tqdm import tqdm # 进度条显示
from rich import print # 美化打印输出
from collections import Counter # 计数器工具
from torch.utils.data import Dataset, DataLoader # 数据集和数据加载器
from sklearn.model_selection import train_test_split # 数据集划分
from sklearn.metrics import precision_score, recall_score, mean_absolute_error # 模型评估指标
class GenomicTokenizer:
def __init__(self, ngram=3, stride=1):
self.ngram = ngram
self.stride = stride
def tokenize(self, t):
t = t.upper()
if ' ' in t:
tokens = t.split()
if self.ngram > 1:
tokens = [''.join(tokens[i:i+self.ngram]) for i in range(0, len(tokens), self.stride)]
else:
if self.ngram == 1:
tokens = list(t)
else:
tokens = [t[i:i+self.ngram] for i in range(0, len(t), self.stride)]
tokens = [tok.ljust(self.ngram, '0') for tok in tokens]
return tokens
class GenomicVocab:
def __init__(self, itos):
self.itos = itos
self.stoi = {v: k for k, v in enumerate(self.itos)}
@classmethod
def create(cls, tokens, max_vocab, min_freq):
from collections import Counter
freq = Counter(tokens)
itos = ['<pad>'] + [o for o, c in freq.most_common(max_vocab - 1) if c >= min_freq]
return cls(itos)
class SiRNADataset(Dataset):
def __init__(self, df, columns, vocab_lists, tokenizer_lists, max_len, is_test=False):
self.df = df
self.columns = columns
self.vocab_lists = vocab_lists
self.tokenizer_lists = tokenizer_lists
self.max_len = max_len
self.is_test = is_test
def __len__(self):
return len(self.df)
def tokenize_and_encode(self, seq, vocab, tokenizer, max_len):
tokens = tokenizer.tokenize(seq)
token_indices = [vocab.stoi.get(token, vocab.stoi['<pad>']) for token in tokens]
if len(token_indices) < max_len:
token_indices += [vocab.stoi['<pad>']] * (max_len - len(token_indices))
else:
token_indices = token_indices[:max_len]
return torch.tensor(token_indices, dtype=torch.long)
def __getitem__(self, idx):
row = self.df.iloc[idx]
seqs_combined = []
for col, vocabs, tokenizers in zip(self.columns, self.vocab_lists, self.tokenizer_lists):
for vocab, tokenizer in zip(vocabs, tokenizers):
seqs_combined.append(self.tokenize_and_encode(row[col], vocab, tokenizer, self.max_len))
if self.is_test:
return seqs_combined
else:
target = torch.tensor(row['mRNA_remaining_pct'], dtype=torch.float)
return seqs_combined, target
# 设置参数
bs = 8
# 选择设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 加载数据
train_data = pd.read_csv('train_data.csv')
# 指定需要处理的列
columns = ['siRNA_antisense_seq', 'modified_siRNA_sense_seq_list', 'modified_siRNA_antisense_seq_list']
# columns = ['siRNA_antisense_seq', 'modified_siRNA_sense_seq_list', 'modified_siRNA_antisense_seq_list', 'gene_target_seq']
# 删除包含空值的行
train_data.dropna(subset=columns + ['mRNA_remaining_pct'], inplace=True)
# dropna(): 这是pandas DataFrame的一个方法,用于删除含有缺失值的行。如果DataFrame中某一行在指定的列中有缺失值(NaN),那么这一行将被删除。
# subset: 这是dropna()方法的一个参数,用于指定只考虑DataFrame中的某些列来查找缺失值。如果这些列中有缺失值,则对应的行将被删除。dropna()方法就会检查这三个列是否有缺失值
# inplace=True: 这个参数告诉dropna()方法直接在原始的train_data DataFrame上进行修改,而不是返回一个新的DataFrame。
# 将数据分为训练集和验证集
train_data, val_data = train_test_split(train_data, test_size=0.1, random_state=42)
# 初始化分词器
tokenizer_list = [GenomicTokenizer(ngram=i) for i in range(1, 4)]
# tokenizer_list = [GenomicTokenizer(ngram=3)]
# 计算每个列的最大长度
max_len = max(max(len(seq.split()) if ' ' in seq else len(tokenizer.tokenize(seq)) for seq in train_data[col]) for tokenizer in tokenizer_list for col in columns)
# 创建词汇表
vocab_lists = []
for col in columns:
vocab_list = []
for i, tokenizer in enumerate(tokenizer_list):
all_tokens = []
for seq in train_data[col]:
all_tokens.extend(tokenizer.tokenize(seq))
vocab = GenomicVocab.create(all_tokens, max_vocab=10000, min_freq=1)
vocab_list.append(vocab)
vocab_lists.append(vocab_list)
# 创建数据集
train_dataset = SiRNADataset(train_data, columns, vocab_lists, [tokenizer_list]*len(columns), max_len, is_test=False)
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
检查样本的分词和编码结果¶
- 不采用任何方式处理索引
# 检查样本的分词和编码结果
for col in columns:
sample_seq = train_data[col].iloc[1]
for tokenizer in tokenizer_list:
tokens = tokenizer.tokenize(sample_seq)
token_indices = [vocab_lists[columns.index(col)][tokenizer_list.index(tokenizer)].stoi.get(token, vocab_lists[columns.index(col)][tokenizer_list.index(tokenizer)].stoi['<pad>']) for token in tokens]
print(f"Column: {col}, Tokenizer: {tokenizer.ngram}, Tokens: {tokens[:10]}, Indices: {token_indices[:10]}, Total tokens: {len(tokens)}")
Column: siRNA_antisense_seq, Tokenizer: 1, Tokens: ['A', 'G', 'U', 'C', 'U', 'U', 'U', 'G', 'C', 'U'], Indices: [2, 3, 1, 4, 1, 1, 1, 3, 4, 1], Total tokens: 23
Column: siRNA_antisense_seq, Tokenizer: 2, Tokens: ['AG', 'GU', 'UC', 'CU', 'UU', 'UU', 'UG', 'GC', 'CU', 'UG'], Indices: [7, 12, 9, 10, 1, 1, 2, 15, 10, 2], Total tokens: 23
Column: siRNA_antisense_seq, Tokenizer: 3, Tokens: ['AGU', 'GUC', 'UCU', 'CUU', 'UUU', 'UUG', 'UGC', 'GCU', 'CUG', 'UGG'], Indices: [23, 59, 32, 22, 1, 2, 51, 56, 20, 34], Total tokens: 23
Column: modified_siRNA_sense_seq_list, Tokenizer: 1, Tokens: ['A', 'A', 'A', 'U', 'U', 'G', 'UF', 'C', 'UF', 'CF'], Indices: [1, 1, 1, 2, 2, 4, 5, 3, 5, 9], Total tokens: 22
Column: modified_siRNA_sense_seq_list, Tokenizer: 2, Tokens: ['AA', 'AA', 'AU', 'UU', 'UG', 'GUF', 'UFC', 'CUF', 'UFCF', 'CFCF'], Indices: [1, 1, 3, 9, 5, 33, 23, 24, 32, 59], Total tokens: 22
Column: modified_siRNA_sense_seq_list, Tokenizer: 3, Tokens: ['AAA', 'AAU', 'AUU', 'UUG', 'UGUF', 'GUFC', 'UFCUF', 'CUFCF', 'UFCFCF', 'CFCFA'], Indices: [3, 30, 24, 15, 109, 110, 72, 73, 111, 112], Total tokens: 22
Column: modified_siRNA_antisense_seq_list, Tokenizer: 1, Tokens: ['A', 'GF', 'U', 'C', 'U', 'UF', 'U', 'GF', 'CF', 'U'], Indices: [2, 7, 1, 3, 1, 5, 1, 7, 8, 1], Total tokens: 23
Column: modified_siRNA_antisense_seq_list, Tokenizer: 2, Tokens: ['AGF', 'GFU', 'UC', 'CU', 'UUF', 'UFU', 'UGF', 'GFCF', 'CFU', 'UG'], Indices: [19, 45, 9, 11, 8, 14, 29, 55, 23, 2], Total tokens: 23
Column: modified_siRNA_antisense_seq_list, Tokenizer: 3, Tokens: ['AGFU', 'GFUC', 'UCU', 'CUUF', 'UUFU', 'UFUGF', 'UGFCF', 'GFCFU', 'CFUG', 'UGG'], Indices: [81, 182, 49, 28, 8, 29, 183, 134, 82, 50], Total tokens: 23
- 当前列名(col)
- 分词器的n-gram长度(tokenizer.ngram)
- 前10个tokens(tokens[:10])
- 前10个tokens的索引(token_indices[:10])
- 总tokens数量(len(tokens))
- Tokens用前缀标识区分不同的 n-gram 分词器生成的 token
# class GenomicTokenizer:
# def __init__(self, ngram=3, stride=1):
# self.ngram = ngram
# self.stride = stride
# self.prefix = f'ng{ngram}_'
# def tokenize(self, t):
# t = t.upper()
# if ' ' in t:
# tokens = t.split()
# if self.ngram > 1:
# tokens = [' '.join(tokens[i:i+self.ngram]) for i in range(0, len(tokens), self.stride)]
# tokens = [token.replace(' ', '') for token in tokens]
# else:
# if self.ngram == 1:
# tokens = list(t)
# else:
# tokens = [t[i:i+self.ngram] for i in range(0, len(t), self.stride)]
# tokens = [tok.ljust(self.ngram, '0') for tok in tokens]
# tokens = [self.prefix + token for token in tokens]
# return tokens
# class SiRNADataset(Dataset):
# def __init__(self, df, columns, vocab_lists, tokenizer_lists, max_len_lists, is_test=False):
# self.df = df
# self.columns = columns
# self.vocab_lists = vocab_lists
# self.tokenizer_lists = tokenizer_lists
# self.max_len_lists = max_len_lists
# self.is_test = is_test
# def __len__(self):
# return len(self.df)
# def tokenize_and_encode(self, seq, vocab, tokenizer, max_len):
# tokens = tokenizer.tokenize(seq) # 修改了这里,因为有' ' 的情况被GenomicTokenizer纳入
# token_indices = [vocab.stoi.get(token, vocab.stoi['<pad>']) for token in tokens]
# if len(token_indices) < max_len:
# token_indices += [vocab.stoi['<pad>']] * (max_len - len(token_indices))
# else:
# token_indices = token_indices[:max_len]
# return torch.tensor(token_indices, dtype=torch.long)
# def __getitem__(self, idx):
# row = self.df.iloc[idx]
# seqs_combined = []
# for col, vocabs, tokenizers, max_lens in zip(self.columns, self.vocab_lists, self.tokenizer_lists, self.max_len_lists):
# for vocab, tokenizer, max_len in zip(vocabs, tokenizers, max_lens):
# seqs_combined.append(self.tokenize_and_encode(row[col], vocab, tokenizer, max_len))
# if self.is_test:
# return seqs_combined
# else:
# target = torch.tensor(row['mRNA_remaining_pct'], dtype=torch.float)
# return seqs_combined, target
# # 检查样本的分词和编码结果
# for col in columns:
# sample_seq = train_data[col].iloc[1]
# for tokenizer, max_len in zip(tokenizer_list, max_len_list[columns.index(col)]):
# tokens = tokenizer.tokenize(sample_seq)
# token_indices = [vocab_lists[columns.index(col)][tokenizer_list.index(tokenizer)].stoi.get(token, vocab_lists[columns.index(col)][tokenizer_list.index(tokenizer)].stoi['<pad>']) for token in tokens]
# print(f"Column: {col}, Tokenizer: {tokenizer.ngram}, Tokens: {tokens[:10]}, Indices: {token_indices[:10]}, Total tokens: {len(tokens)}")
查看样本词汇表¶
vocab_lists
[[<__main__.GenomicVocab at 0x2a839153e50>, <__main__.GenomicVocab at 0x2a839153010>, <__main__.GenomicVocab at 0x2a83929c070>], [<__main__.GenomicVocab at 0x2a83929f070>, <__main__.GenomicVocab at 0x2a83929f010>, <__main__.GenomicVocab at 0x2a83929d0c0>], [<__main__.GenomicVocab at 0x2a83929f1c0>, <__main__.GenomicVocab at 0x2a83929f430>, <__main__.GenomicVocab at 0x2a83929f340>]]
# 遍历 vocab_lists 并打印每个词汇表中的部分内容
for col_idx, vocab_list in enumerate(vocab_lists):
print(f"Column {col_idx}:")
for ngram_idx, vocab in enumerate(vocab_list):
print(f" N-gram {ngram_idx + 1}:")
print(f" Total tokens: {len(vocab.itos)}")
print(f" First 10 tokens and their indices:")
for i, token in enumerate(vocab.itos[:10]):
print(f" {token}: {vocab.stoi[token]}")
Column 0:
N-gram 1:
Total tokens: 5
First 10 tokens and their indices:
<pad>: 0
U: 1
A: 2
G: 3
C: 4
N-gram 2:
Total tokens: 21
First 10 tokens and their indices:
<pad>: 0
UU: 1
UG: 2
CA: 3
AU: 4
GA: 5
AA: 6
AG: 7
UA: 8
UC: 9
N-gram 3:
Total tokens: 78
First 10 tokens and their indices:
<pad>: 0
UUU: 1
UUG: 2
UGA: 3
GAU: 4
ACA: 5
UGU: 6
GUU: 7
UCA: 8
AAU: 9
Column 1:
N-gram 1:
Total tokens: 10
First 10 tokens and their indices:
<pad>: 0
A: 1
U: 2
C: 3
G: 4
UF: 5
AF: 6
L96: 7
GF: 8
CF: 9
N-gram 2:
Total tokens: 67
First 10 tokens and their indices:
<pad>: 0
AA: 1
CA: 2
AU: 3
CU: 4
UG: 5
UC: 6
UA: 7
L96: 8
UU: 9
N-gram 3:
Total tokens: 313
First 10 tokens and their indices:
<pad>: 0
L96: 1
AL96: 2
AAA: 3
CAA: 4
AUC: 5
ACA: 6
UGU: 7
AAC: 8
UCA: 9
Column 2:
N-gram 1:
Total tokens: 9
First 10 tokens and their indices:
<pad>: 0
U: 1
A: 2
C: 3
G: 4
UF: 5
AF: 6
GF: 7
CF: 8
N-gram 2:
Total tokens: 67
First 10 tokens and their indices:
<pad>: 0
UU: 1
UG: 2
CA: 3
GA: 4
AU: 5
AA: 6
AG: 7
UUF: 8
UC: 9
N-gram 3:
Total tokens: 357
First 10 tokens and their indices:
<pad>: 0
UUU: 1
A: 2
U: 3
UUG: 4
UGA: 5
GGA: 6
UGU: 7
UUFU: 8
ACA: 9
# 提取词汇表大小
vocab_sizes = [len(vocab.itos) for vocab_list in vocab_lists for vocab in vocab_list]
print(vocab_sizes)
len(vocab.itos)
[5, 21, 78, 10, 67, 313, 9, 67, 357]
357
- 外层循环:for vocab_list in vocab_lists
这个循环遍历 vocab_lists 中的每个 vocab_list
- 内层循环:for vocab in vocab_list
这个循环遍历每个 vocab_list 中的每个 vocab 对象
vocab_lists 是一个包含多个列的列表
- 顶层列表 vocab_lists:每个子列表代表一个列,每个列包含多个 N-gram 词汇表
- 子列表:每个列包含多个 vocab 对象,这些对象是 N-gram 词汇表
- 词汇表 vocab:包含两个主要属性
- itos: 一个包含词汇的列表(从索引到字符串的映射)
- stoi: 一个包含词汇和对应索引的字典(从字符串到索引的映射)
vocab_lists = [
[vocab_col0_ngram1, vocab_col0_ngram2, vocab_col0_ngram3],
[vocab_col1_ngram1, vocab_col1_ngram2, vocab_col1_ngram3],
[vocab_col2_ngram1, vocab_col2_ngram2, vocab_col2_ngram3],
# 更多列...
]
vocab_colX_ngramY 代表第 X 列的第 Y 个 N-gram 词汇表
判定基本输入的形状¶
# 计算每个列的最大长度
max_len_list = []
for col in columns:
max_len = max(len(seq.split()) if ' ' in seq else len(GenomicTokenizer(ngram=3).tokenize(seq)) for seq in train_data[col])
max_len_list.append(max_len)
print(f"max_len_list: {max_len_list}")
max_len_list: [23, 22, 23]
这里使用ngram=3可以得到max_len_list中出现17911,主要是因为在GenomicTokenizer中我进行了如果长度不足继续填补的设定,因此与ngram=1得值一致。
for inputs, target in train_loader:
print(f"inputs列表包括的张量数量: {len(inputs)}") # inputs列表包括的张量数量
print(f"列表1张量的形状: {inputs[0].shape}") # 列表1张量的形状
# 打印inputs的整体形状
if isinstance(inputs, list):
# 如果inputs是一个列表
print(f"inputs是一个列表,形状为: {[input.shape for input in inputs]}")
else:
# 如果inputs是一个张量
print(f"inputs的形状为: {inputs.shape}")
break
inputs列表包括的张量数量: 9
列表1张量的形状: torch.Size([8, 23])
inputs是一个列表,形状为: [torch.Size([8, 23]), torch.Size([8, 23]), torch.Size([8, 23]), torch.Size([8, 23]), torch.Size([8, 23]), torch.Size([8, 23]), torch.Size([8, 23]), torch.Size([8, 23]), torch.Size([8, 23])]
通过train_loader我们可以查看一个batch内的输入情况,这里的inputs就是深度学习模型最原始的输入。
我们首先发现inputs包含12个张量,它们分别对应的是参数columns使用的4个特征,每个特征使用3种不同的ngram进行分词,最终组成12个元素。
每个元素的尺寸都是64*17911,64代表batch的大小,17911为max_len_list中最大的值。
这里的17911我们可以在数据分析部分中发现,即为gene_target_seq中最长的长度。
for inputs, target in train_loader:
# 查看每个张量的输出
print(f"第1个张量中的第2行/个样本: {inputs[0][1]}")
print(f"第6个张量中的第7行/个样本: {inputs[5][6]}")
# 打印第一个输入的非零元素的坐标
non_zero_indices1 = (inputs[0][1] != 0).nonzero(as_tuple=True)
print(f"第1个张量第2个样本的非零元素的坐标: {non_zero_indices1}")
non_zero_indices2 = (inputs[5][6] != 0).nonzero(as_tuple=True)
print(f"第6个张量第7个样本的非零元素的坐标: {non_zero_indices2}")
print(f"target的形状: {target.shape}") # 输出target的形状
break
第1个张量中的第2行/个样本: tensor([1, 1, 4, 4, 1, 1, 1, 2, 3, 2, 2, 3, 2, 1, 1, 2, 1, 2, 2, 2, 1, 4, 2])
第6个张量中的第7行/个样本: tensor([ 3, 30, 5, 9, 98, 91, 92, 99, 62, 76, 156, 46, 56, 21, 58, 38, 4, 3, 3, 12, 2, 1, 0])
第1个张量第2个样本的非零元素的坐标: (tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]),)
第6个张量第7个样本的非零元素的坐标: (tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]),)
target的形状: torch.Size([8])
我们可以从inputs[0][1]看到每一行数据的siRNA_antisense_seq被向量化后的情况,这个例子中我们发现前面的23位是非零数,表示其序列编码后每一位的唯一标识;而后面都是0,这是因为深度学习模型的输入需要每个样本的长度一致,因此我们需要事先算出一个所有序列编码后的最大长度,然后补0。
搭建的Transformer模型张量变化情况一览¶
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=max_len):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class SiRNAModel(nn.Module):
def __init__(self, vocab_sizes, max_len, embed_dim=200, nhead=8, num_encoder_layers=6, dim_feedforward=512, dropout=0.1):
super(SiRNAModel, self).__init__()
# 初始化多个嵌入层
self.embeddings = nn.ModuleList([nn.Embedding(vocab_size, embed_dim, padding_idx=0) for vocab_size in vocab_sizes])
# 初始化位置编码层
self.position_encoder = PositionalEncoding(embed_dim, dropout, max_len=max_len)
# 初始化Transformer编码器
self.transformer_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=embed_dim, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout),
num_layers=num_encoder_layers
)
# 初始化全连接层
self.fc = nn.Linear(embed_dim * len(vocab_sizes), 1) # 调整输入维度为 embed_dim * vocab_sizes 的长度
# 初始化Dropout层
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 将输入序列传入相应的嵌入层
embedded = [self.embeddings[i](seq) for i, seq in enumerate(x)]
outputs = []
# 对每个嵌入的序列进行处理
for embed in embedded:
embed = self.position_encoder(embed)
x = self.transformer_encoder(embed.transpose(0, 1)).transpose(0, 1) # Transpose for transformer input
x = self.dropout(x[:, 0, :]) # Use the output of the first token (like CLS token in BERT)
outputs.append(x)
# 将所有序列的输出拼接起来
x = torch.cat(outputs, dim=1)
# 传入全连接层
x = self.fc(x)
# 返回结果
return x.squeeze()
# 提取词汇表大小
vocab_sizes = [len(vocab.itos) for vocab_list in vocab_lists for vocab in vocab_list]
# 计算最大序列长度
max_len = max(max(len(seq.split()) if ' ' in seq else len(tokenizer.tokenize(seq)) for seq in train_data[col]) for tokenizer in tokenizer_list for col in columns)
# 初始化模型
model = SiRNAModel(vocab_sizes, max_len)
# 示例数据加载和验证过程
for inputs, targets in train_loader:
inputs = [input_seq for input_seq in inputs] # 将输入移动到设备
targets = targets # 将目标值移动到设备
# 模型前向传播
outputs = model(inputs)
print(f"模型输出: {outputs}")
break
模型输出: tensor([ 0.6839, 0.5042, 0.1742, 0.9946, 0.4056, 0.9351, 1.0420, -0.0524], grad_fn=<SqueezeBackward0>)
for inputs, targets in train_loader: # tqdm用于在长循环中显示进度条,它可以帮助用户了解程序的执行进度
print(f"inputs: {inputs}")
print(f"inputs: {inputs[0]}")
inputs = [input_seq.to(device) for input_seq in inputs] # 将输入移动到设备
print(f"inputs: {inputs}")
print(f"inputs: {inputs[0]}")
targets = targets.to(device) # 将目标值移动到设备
inputs: [tensor([[1, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2], [1, 2, 4, 1, 4, 1, 4, 2, 2, 4, 4, 2, 4, 4, 1, 3, 4, 1, 1, 3, 1, 3, 2], [2, 3, 4, 1, 3, 1, 3, 3, 2, 1, 1, 1, 1, 2, 4, 2, 2, 2, 4, 4, 2, 1, 1], [1, 1, 3, 1, 1, 4, 1, 2, 2, 3, 3, 2, 2, 2, 2, 3, 3, 4, 1, 3, 4, 4, 2], [1, 1, 1, 4, 2, 2, 2, 4, 3, 3, 3, 4, 4, 1, 4, 1, 1, 4, 4, 1, 4, 2, 3], [1, 1, 2, 2, 1, 2, 4, 4, 2, 1, 4, 4, 2, 3, 4, 2, 1, 4, 2, 3, 3, 2, 1], [1, 1, 1, 4, 2, 2, 1, 1, 1, 3, 1, 4, 1, 1, 4, 3, 2, 1, 3, 2, 4, 2, 1], [2, 1, 2, 3, 1, 1, 4, 2, 4, 2, 2, 2, 2, 1, 2, 2, 3, 2, 1, 4, 4, 1, 1]]), tensor([[ 9, 10, 2, 12, 1, 2, 12, 8, 7, 15, 13, 10, 8, 7, 5, 11, 3, 7, 12, 2, 5, 6, 16], [ 8, 11, 10, 9, 10, 9, 3, 6, 11, 13, 3, 11, 13, 10, 2, 15, 10, 1, 2, 12, 2, 5, 16], [ 7, 15, 10, 2, 12, 2, 14, 5, 4, 1, 1, 1, 8, 11, 3, 6, 6, 11, 13, 3, 4, 1, 17], [ 1, 2, 12, 1, 9, 10, 8, 6, 7, 14, 5, 6, 6, 6, 7, 14, 15, 10, 2, 15, 13, 3, 16], [ 1, 1, 9, 3, 6, 6, 11, 20, 14, 14, 15, 13, 10, 9, 10, 1, 9, 13, 10, 9, 3, 7, 19], [ 1, 8, 6, 4, 8, 11, 13, 3, 4, 9, 13, 3, 7, 15, 3, 4, 9, 3, 7, 14, 5, 4, 17], [ 1, 1, 9, 3, 6, 4, 1, 1, 2, 12, 9, 10, 1, 9, 20, 5, 4, 2, 5, 11, 3, 4, 17], [ 4, 8, 7, 12, 1, 9, 3, 11, 3, 6, 6, 6, 4, 8, 6, 7, 5, 4, 9, 13, 10, 1, 17]]), tensor([[32, 20, 6, 7, 2, 6, 52, 16, 30, 41, 10, 47, 16, 14, 35, 5, 31, 23, 43, 3, 33, 68, 29], [39, 26, 42, 32, 42, 8, 15, 50, 45, 27, 12, 45, 10, 20, 51, 56, 22, 2, 6, 43, 3, 60, 29], [30, 56, 20, 6, 43, 34, 11, 4, 24, 1, 1, 44, 39, 5, 15, 21, 50, 45, 27, 38, 24, 55, 37], [ 2, 6, 7, 18, 32, 47, 25, 17, 40, 11, 33, 21, 21, 17, 40, 62, 56, 20, 51, 41, 27, 54, 29], [ 1, 18, 8, 15, 21, 50, 73, 67, 63, 62, 41, 10, 42, 32, 22, 18, 36, 10, 42, 8, 31, 65, 58], [44, 25, 9, 13, 39, 45, 27, 38, 19, 36, 27, 31, 30, 46, 38, 19, 8, 31, 40, 11, 4, 64, 37], [ 1, 18, 8, 15, 9, 24, 1, 2, 6, 59, 32, 22, 18, 70, 77, 4, 28, 3, 35, 5, 38, 64, 37], [13, 16, 23, 7, 18, 8, 12, 5, 15, 21, 21, 9, 13, 25, 17, 14, 4, 19, 36, 10, 22, 55, 37]]), tensor([[3, 1, 3, 2, 4, 2, 9, 2, 6, 8, 8, 3, 2, 1, 3, 1, 1, 3, 1, 4, 1, 7, 0], [1, 3, 1, 1, 4, 3, 6, 4, 8, 5, 8, 4, 2, 2, 4, 1, 4, 1, 4, 2, 1, 7, 0], [2, 4, 4, 2, 2, 2, 8, 2, 6, 6, 6, 1, 2, 3, 3, 1, 3, 1, 4, 3, 2, 7, 0], [4, 3, 1, 4, 3, 3, 5, 2, 5, 5, 9, 3, 2, 2, 1, 4, 1, 1, 3, 1, 1, 7, 0], [4, 1, 4, 4, 1, 1, 8, 1, 8, 8, 9, 3, 3, 4, 2, 2, 2, 4, 1, 1, 1, 7, 0], [3, 3, 2, 4, 1, 2, 8, 3, 5, 8, 8, 1, 2, 4, 4, 2, 1, 2, 2, 1, 1, 7, 0], [4, 2, 3, 1, 2, 3, 8, 1, 6, 8, 6, 3, 1, 1, 1, 2, 2, 4, 1, 1, 1, 7, 0], [4, 4, 1, 2, 3, 2, 5, 1, 5, 5, 5, 2, 4, 2, 4, 1, 1, 3, 2, 1, 2, 7, 0]]), tensor([[ 2, 13, 4, 5, 12, 47, 35, 39, 45, 25, 50, 4, 7, 13, 2, 1, 13, 2, 11, 10, 14, 8, 0], [13, 2, 1, 11, 17, 37, 31, 53, 46, 26, 49, 12, 9, 5, 10, 11, 10, 11, 12, 7, 14, 8, 0], [ 5, 16, 12, 9, 9, 19, 44, 39, 36, 36, 29, 3, 6, 15, 2, 13, 2, 11, 17, 4, 18, 8, 0], [17, 2, 11, 17, 15, 24, 27, 38, 28, 32, 63, 4, 9, 7, 11, 10, 1, 13, 2, 1, 14, 8, 0], [10, 11, 16, 10, 1, 34, 21, 34, 25, 55, 63, 15, 61, 12, 9, 9, 5, 10, 1, 1, 14, 8, 0], [15, 4, 5, 10, 3, 19, 50, 24, 26, 25, 21, 3, 5, 16, 12, 7, 3, 9, 7, 1, 14, 8, 0], [12, 6, 2, 3, 6, 65, 21, 41, 45, 62, 56, 2, 1, 1, 3, 9, 5, 10, 1, 1, 14, 8, 0], [16, 10, 3, 6, 4, 38, 54, 22, 28, 28, 27, 5, 12, 5, 10, 1, 13, 4, 7, 3, 18, 8, 0]]), tensor([[ 44, 13, 33, 7, 265, 183, 174, 266, 165, 166, 167, 10, 54, 6, 4, 8, 6, 34, 32, 48, 2, 1, 0], [ 6, 4, 14, 52, 136, 137, 138, 139, 82, 140, 120, 43, 15, 11, 47, 32, 47, 31, 37, 26, 2, 1, 0], [ 51, 59, 43, 46, 293, 294, 295, 296, 192, 297, 298, 5, 21, 38, 44, 6, 34, 52, 40, 49, 25, 1, 0], [ 53, 34, 52, 66, 218, 219, 141, 88, 142, 143, 144, 28, 19, 50, 32, 20, 8, 6, 4, 12, 2, 1, 0], [ 47, 23, 17, 20, 245, 102, 171, 75, 62, 246, 247, 69, 172, 43, 46, 15, 11, 20, 3, 12, 2, 1, 0], [ 36, 33, 11, 22, 78, 121, 122, 204, 123, 124, 79, 42, 51, 59, 37, 16, 24, 19, 35, 12, 2, 1, 0], [ 39, 9, 29, 5, 224, 225, 89, 90, 147, 226, 117, 4, 3, 30, 24, 15, 11, 20, 3, 12, 2, 1, 0], [ 17, 22, 5, 18, 157, 232, 134, 70, 158, 233, 159, 7, 27, 11, 20, 8, 13, 10, 16, 160, 25, 1, 0]]), tensor([[1, 8, 1, 4, 1, 5, 4, 5, 6, 4, 3, 3, 1, 6, 4, 6, 3, 2, 4, 1, 4, 2, 2], [1, 6, 3, 1, 3, 5, 3, 6, 6, 3, 3, 2, 3, 8, 1, 7, 3, 1, 1, 4, 1, 4, 2], [2, 7, 3, 1, 4, 5, 4, 7, 6, 1, 1, 1, 1, 6, 3, 6, 2, 2, 3, 3, 2, 1, 1], [1, 5, 4, 1, 1, 8, 1, 6, 6, 4, 4, 2, 2, 6, 2, 7, 4, 3, 1, 4, 3, 3, 2], [1, 5, 1, 3, 2, 6, 2, 8, 7, 4, 4, 3, 3, 5, 3, 5, 1, 3, 3, 1, 3, 2, 4], [1, 5, 2, 2, 1, 6, 3, 8, 6, 1, 3, 3, 2, 7, 3, 6, 1, 3, 2, 4, 4, 2, 1], [1, 5, 1, 3, 2, 6, 1, 5, 5, 4, 1, 3, 1, 5, 3, 7, 2, 1, 4, 2, 3, 2, 1], [2, 5, 2, 4, 1, 5, 3, 6, 8, 2, 2, 2, 2, 5, 2, 6, 4, 2, 1, 3, 3, 1, 1]]), tensor([[31, 23, 2, 10, 8, 18, 39, 58, 34, 22, 12, 11, 16, 34, 40, 36, 3, 7, 10, 2, 4, 6, 27], [16, 36, 11, 9, 38, 30, 24, 53, 36, 12, 3, 13, 41, 23, 29, 46, 11, 1, 2, 10, 2, 4, 27], [19, 46, 11, 2, 39, 18, 42, 57, 20, 1, 1, 1, 16, 36, 24, 25, 6, 13, 12, 3, 5, 1, 32], [ 8, 18, 10, 1, 31, 23, 16, 53, 34, 21, 4, 6, 35, 25, 19, 44, 22, 11, 2, 22, 12, 3, 27], [ 8, 14, 9, 3, 35, 25, 28, 64, 44, 21, 22, 12, 38, 30, 38, 14, 9, 12, 11, 9, 3, 7, 49], [ 8, 33, 6, 5, 16, 36, 41, 48, 20, 9, 12, 3, 19, 46, 24, 20, 9, 3, 7, 21, 4, 5, 32], [ 8, 14, 9, 3, 35, 20, 8, 56, 18, 10, 9, 11, 8, 30, 61, 17, 5, 2, 4, 13, 3, 5, 32], [26, 33, 7, 10, 8, 30, 24, 54, 37, 6, 6, 6, 26, 33, 35, 34, 4, 5, 9, 12, 11, 1, 32]]), tensor([[ 78, 82, 7, 90, 38, 76, 235, 171, 151, 32, 10, 160, 60, 161, 237, 119, 25, 33, 34, 5, 12, 131, 2], [ 67, 204, 48, 155, 156, 47, 205, 273, 102, 16, 35, 206, 69, 157, 207, 154, 122, 4, 7, 34, 5, 59, 2], [ 63, 154, 40, 106, 234, 240, 241, 174, 170, 1, 1, 343, 67, 116, 146, 117, 96, 91, 16, 39, 11, 31, 3], [ 38, 62, 13, 208, 78, 209, 274, 275, 158, 6, 12, 159, 93, 210, 211, 212, 200, 40, 121, 32, 16, 19, 2], [ 8, 73, 15, 115, 93, 147, 299, 300, 301, 197, 32, 137, 156, 173, 163, 73, 36, 10, 48, 15, 25, 97, 55], [ 66, 88, 23, 259, 67, 260, 142, 261, 68, 36, 16, 190, 63, 191, 111, 68, 15, 25, 52, 6, 14, 89, 3], [ 8, 73, 15, 115, 94, 279, 193, 194, 62, 280, 49, 28, 44, 281, 282, 42, 18, 5, 41, 9, 39, 89, 3], [ 92, 107, 33, 90, 44, 47, 169, 123, 104, 46, 46, 61, 92, 227, 99, 139, 14, 37, 36, 10, 122, 31, 3]])]
inputs: tensor([[1, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2], [1, 2, 4, 1, 4, 1, 4, 2, 2, 4, 4, 2, 4, 4, 1, 3, 4, 1, 1, 3, 1, 3, 2], [2, 3, 4, 1, 3, 1, 3, 3, 2, 1, 1, 1, 1, 2, 4, 2, 2, 2, 4, 4, 2, 1, 1], [1, 1, 3, 1, 1, 4, 1, 2, 2, 3, 3, 2, 2, 2, 2, 3, 3, 4, 1, 3, 4, 4, 2], [1, 1, 1, 4, 2, 2, 2, 4, 3, 3, 3, 4, 4, 1, 4, 1, 1, 4, 4, 1, 4, 2, 3], [1, 1, 2, 2, 1, 2, 4, 4, 2, 1, 4, 4, 2, 3, 4, 2, 1, 4, 2, 3, 3, 2, 1], [1, 1, 1, 4, 2, 2, 1, 1, 1, 3, 1, 4, 1, 1, 4, 3, 2, 1, 3, 2, 4, 2, 1], [2, 1, 2, 3, 1, 1, 4, 2, 4, 2, 2, 2, 2, 1, 2, 2, 3, 2, 1, 4, 4, 1, 1]])
inputs: [tensor([[1, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2], [1, 2, 4, 1, 4, 1, 4, 2, 2, 4, 4, 2, 4, 4, 1, 3, 4, 1, 1, 3, 1, 3, 2], [2, 3, 4, 1, 3, 1, 3, 3, 2, 1, 1, 1, 1, 2, 4, 2, 2, 2, 4, 4, 2, 1, 1], [1, 1, 3, 1, 1, 4, 1, 2, 2, 3, 3, 2, 2, 2, 2, 3, 3, 4, 1, 3, 4, 4, 2], [1, 1, 1, 4, 2, 2, 2, 4, 3, 3, 3, 4, 4, 1, 4, 1, 1, 4, 4, 1, 4, 2, 3], [1, 1, 2, 2, 1, 2, 4, 4, 2, 1, 4, 4, 2, 3, 4, 2, 1, 4, 2, 3, 3, 2, 1], [1, 1, 1, 4, 2, 2, 1, 1, 1, 3, 1, 4, 1, 1, 4, 3, 2, 1, 3, 2, 4, 2, 1], [2, 1, 2, 3, 1, 1, 4, 2, 4, 2, 2, 2, 2, 1, 2, 2, 3, 2, 1, 4, 4, 1, 1]], device='cuda:0'), tensor([[ 9, 10, 2, 12, 1, 2, 12, 8, 7, 15, 13, 10, 8, 7, 5, 11, 3, 7, 12, 2, 5, 6, 16], [ 8, 11, 10, 9, 10, 9, 3, 6, 11, 13, 3, 11, 13, 10, 2, 15, 10, 1, 2, 12, 2, 5, 16], [ 7, 15, 10, 2, 12, 2, 14, 5, 4, 1, 1, 1, 8, 11, 3, 6, 6, 11, 13, 3, 4, 1, 17], [ 1, 2, 12, 1, 9, 10, 8, 6, 7, 14, 5, 6, 6, 6, 7, 14, 15, 10, 2, 15, 13, 3, 16], [ 1, 1, 9, 3, 6, 6, 11, 20, 14, 14, 15, 13, 10, 9, 10, 1, 9, 13, 10, 9, 3, 7, 19], [ 1, 8, 6, 4, 8, 11, 13, 3, 4, 9, 13, 3, 7, 15, 3, 4, 9, 3, 7, 14, 5, 4, 17], [ 1, 1, 9, 3, 6, 4, 1, 1, 2, 12, 9, 10, 1, 9, 20, 5, 4, 2, 5, 11, 3, 4, 17], [ 4, 8, 7, 12, 1, 9, 3, 11, 3, 6, 6, 6, 4, 8, 6, 7, 5, 4, 9, 13, 10, 1, 17]], device='cuda:0'), tensor([[32, 20, 6, 7, 2, 6, 52, 16, 30, 41, 10, 47, 16, 14, 35, 5, 31, 23, 43, 3, 33, 68, 29], [39, 26, 42, 32, 42, 8, 15, 50, 45, 27, 12, 45, 10, 20, 51, 56, 22, 2, 6, 43, 3, 60, 29], [30, 56, 20, 6, 43, 34, 11, 4, 24, 1, 1, 44, 39, 5, 15, 21, 50, 45, 27, 38, 24, 55, 37], [ 2, 6, 7, 18, 32, 47, 25, 17, 40, 11, 33, 21, 21, 17, 40, 62, 56, 20, 51, 41, 27, 54, 29], [ 1, 18, 8, 15, 21, 50, 73, 67, 63, 62, 41, 10, 42, 32, 22, 18, 36, 10, 42, 8, 31, 65, 58], [44, 25, 9, 13, 39, 45, 27, 38, 19, 36, 27, 31, 30, 46, 38, 19, 8, 31, 40, 11, 4, 64, 37], [ 1, 18, 8, 15, 9, 24, 1, 2, 6, 59, 32, 22, 18, 70, 77, 4, 28, 3, 35, 5, 38, 64, 37], [13, 16, 23, 7, 18, 8, 12, 5, 15, 21, 21, 9, 13, 25, 17, 14, 4, 19, 36, 10, 22, 55, 37]], device='cuda:0'), tensor([[3, 1, 3, 2, 4, 2, 9, 2, 6, 8, 8, 3, 2, 1, 3, 1, 1, 3, 1, 4, 1, 7, 0], [1, 3, 1, 1, 4, 3, 6, 4, 8, 5, 8, 4, 2, 2, 4, 1, 4, 1, 4, 2, 1, 7, 0], [2, 4, 4, 2, 2, 2, 8, 2, 6, 6, 6, 1, 2, 3, 3, 1, 3, 1, 4, 3, 2, 7, 0], [4, 3, 1, 4, 3, 3, 5, 2, 5, 5, 9, 3, 2, 2, 1, 4, 1, 1, 3, 1, 1, 7, 0], [4, 1, 4, 4, 1, 1, 8, 1, 8, 8, 9, 3, 3, 4, 2, 2, 2, 4, 1, 1, 1, 7, 0], [3, 3, 2, 4, 1, 2, 8, 3, 5, 8, 8, 1, 2, 4, 4, 2, 1, 2, 2, 1, 1, 7, 0], [4, 2, 3, 1, 2, 3, 8, 1, 6, 8, 6, 3, 1, 1, 1, 2, 2, 4, 1, 1, 1, 7, 0], [4, 4, 1, 2, 3, 2, 5, 1, 5, 5, 5, 2, 4, 2, 4, 1, 1, 3, 2, 1, 2, 7, 0]], device='cuda:0'), tensor([[ 2, 13, 4, 5, 12, 47, 35, 39, 45, 25, 50, 4, 7, 13, 2, 1, 13, 2, 11, 10, 14, 8, 0], [13, 2, 1, 11, 17, 37, 31, 53, 46, 26, 49, 12, 9, 5, 10, 11, 10, 11, 12, 7, 14, 8, 0], [ 5, 16, 12, 9, 9, 19, 44, 39, 36, 36, 29, 3, 6, 15, 2, 13, 2, 11, 17, 4, 18, 8, 0], [17, 2, 11, 17, 15, 24, 27, 38, 28, 32, 63, 4, 9, 7, 11, 10, 1, 13, 2, 1, 14, 8, 0], [10, 11, 16, 10, 1, 34, 21, 34, 25, 55, 63, 15, 61, 12, 9, 9, 5, 10, 1, 1, 14, 8, 0], [15, 4, 5, 10, 3, 19, 50, 24, 26, 25, 21, 3, 5, 16, 12, 7, 3, 9, 7, 1, 14, 8, 0], [12, 6, 2, 3, 6, 65, 21, 41, 45, 62, 56, 2, 1, 1, 3, 9, 5, 10, 1, 1, 14, 8, 0], [16, 10, 3, 6, 4, 38, 54, 22, 28, 28, 27, 5, 12, 5, 10, 1, 13, 4, 7, 3, 18, 8, 0]], device='cuda:0'), tensor([[ 44, 13, 33, 7, 265, 183, 174, 266, 165, 166, 167, 10, 54, 6, 4, 8, 6, 34, 32, 48, 2, 1, 0], [ 6, 4, 14, 52, 136, 137, 138, 139, 82, 140, 120, 43, 15, 11, 47, 32, 47, 31, 37, 26, 2, 1, 0], [ 51, 59, 43, 46, 293, 294, 295, 296, 192, 297, 298, 5, 21, 38, 44, 6, 34, 52, 40, 49, 25, 1, 0], [ 53, 34, 52, 66, 218, 219, 141, 88, 142, 143, 144, 28, 19, 50, 32, 20, 8, 6, 4, 12, 2, 1, 0], [ 47, 23, 17, 20, 245, 102, 171, 75, 62, 246, 247, 69, 172, 43, 46, 15, 11, 20, 3, 12, 2, 1, 0], [ 36, 33, 11, 22, 78, 121, 122, 204, 123, 124, 79, 42, 51, 59, 37, 16, 24, 19, 35, 12, 2, 1, 0], [ 39, 9, 29, 5, 224, 225, 89, 90, 147, 226, 117, 4, 3, 30, 24, 15, 11, 20, 3, 12, 2, 1, 0], [ 17, 22, 5, 18, 157, 232, 134, 70, 158, 233, 159, 7, 27, 11, 20, 8, 13, 10, 16, 160, 25, 1, 0]], device='cuda:0'), tensor([[1, 8, 1, 4, 1, 5, 4, 5, 6, 4, 3, 3, 1, 6, 4, 6, 3, 2, 4, 1, 4, 2, 2], [1, 6, 3, 1, 3, 5, 3, 6, 6, 3, 3, 2, 3, 8, 1, 7, 3, 1, 1, 4, 1, 4, 2], [2, 7, 3, 1, 4, 5, 4, 7, 6, 1, 1, 1, 1, 6, 3, 6, 2, 2, 3, 3, 2, 1, 1], [1, 5, 4, 1, 1, 8, 1, 6, 6, 4, 4, 2, 2, 6, 2, 7, 4, 3, 1, 4, 3, 3, 2], [1, 5, 1, 3, 2, 6, 2, 8, 7, 4, 4, 3, 3, 5, 3, 5, 1, 3, 3, 1, 3, 2, 4], [1, 5, 2, 2, 1, 6, 3, 8, 6, 1, 3, 3, 2, 7, 3, 6, 1, 3, 2, 4, 4, 2, 1], [1, 5, 1, 3, 2, 6, 1, 5, 5, 4, 1, 3, 1, 5, 3, 7, 2, 1, 4, 2, 3, 2, 1], [2, 5, 2, 4, 1, 5, 3, 6, 8, 2, 2, 2, 2, 5, 2, 6, 4, 2, 1, 3, 3, 1, 1]], device='cuda:0'), tensor([[31, 23, 2, 10, 8, 18, 39, 58, 34, 22, 12, 11, 16, 34, 40, 36, 3, 7, 10, 2, 4, 6, 27], [16, 36, 11, 9, 38, 30, 24, 53, 36, 12, 3, 13, 41, 23, 29, 46, 11, 1, 2, 10, 2, 4, 27], [19, 46, 11, 2, 39, 18, 42, 57, 20, 1, 1, 1, 16, 36, 24, 25, 6, 13, 12, 3, 5, 1, 32], [ 8, 18, 10, 1, 31, 23, 16, 53, 34, 21, 4, 6, 35, 25, 19, 44, 22, 11, 2, 22, 12, 3, 27], [ 8, 14, 9, 3, 35, 25, 28, 64, 44, 21, 22, 12, 38, 30, 38, 14, 9, 12, 11, 9, 3, 7, 49], [ 8, 33, 6, 5, 16, 36, 41, 48, 20, 9, 12, 3, 19, 46, 24, 20, 9, 3, 7, 21, 4, 5, 32], [ 8, 14, 9, 3, 35, 20, 8, 56, 18, 10, 9, 11, 8, 30, 61, 17, 5, 2, 4, 13, 3, 5, 32], [26, 33, 7, 10, 8, 30, 24, 54, 37, 6, 6, 6, 26, 33, 35, 34, 4, 5, 9, 12, 11, 1, 32]], device='cuda:0'), tensor([[ 78, 82, 7, 90, 38, 76, 235, 171, 151, 32, 10, 160, 60, 161, 237, 119, 25, 33, 34, 5, 12, 131, 2], [ 67, 204, 48, 155, 156, 47, 205, 273, 102, 16, 35, 206, 69, 157, 207, 154, 122, 4, 7, 34, 5, 59, 2], [ 63, 154, 40, 106, 234, 240, 241, 174, 170, 1, 1, 343, 67, 116, 146, 117, 96, 91, 16, 39, 11, 31, 3], [ 38, 62, 13, 208, 78, 209, 274, 275, 158, 6, 12, 159, 93, 210, 211, 212, 200, 40, 121, 32, 16, 19, 2], [ 8, 73, 15, 115, 93, 147, 299, 300, 301, 197, 32, 137, 156, 173, 163, 73, 36, 10, 48, 15, 25, 97, 55], [ 66, 88, 23, 259, 67, 260, 142, 261, 68, 36, 16, 190, 63, 191, 111, 68, 15, 25, 52, 6, 14, 89, 3], [ 8, 73, 15, 115, 94, 279, 193, 194, 62, 280, 49, 28, 44, 281, 282, 42, 18, 5, 41, 9, 39, 89, 3], [ 92, 107, 33, 90, 44, 47, 169, 123, 104, 46, 46, 61, 92, 227, 99, 139, 14, 37, 36, 10, 122, 31, 3]], device='cuda:0')]
inputs: tensor([[1, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2], [1, 2, 4, 1, 4, 1, 4, 2, 2, 4, 4, 2, 4, 4, 1, 3, 4, 1, 1, 3, 1, 3, 2], [2, 3, 4, 1, 3, 1, 3, 3, 2, 1, 1, 1, 1, 2, 4, 2, 2, 2, 4, 4, 2, 1, 1], [1, 1, 3, 1, 1, 4, 1, 2, 2, 3, 3, 2, 2, 2, 2, 3, 3, 4, 1, 3, 4, 4, 2], [1, 1, 1, 4, 2, 2, 2, 4, 3, 3, 3, 4, 4, 1, 4, 1, 1, 4, 4, 1, 4, 2, 3], [1, 1, 2, 2, 1, 2, 4, 4, 2, 1, 4, 4, 2, 3, 4, 2, 1, 4, 2, 3, 3, 2, 1], [1, 1, 1, 4, 2, 2, 1, 1, 1, 3, 1, 4, 1, 1, 4, 3, 2, 1, 3, 2, 4, 2, 1], [2, 1, 2, 3, 1, 1, 4, 2, 4, 2, 2, 2, 2, 1, 2, 2, 3, 2, 1, 4, 4, 1, 1]], device='cuda:0')
inputs: [tensor([[2, 2, 1, 3, 4, 2, 2, 3, 3, 2, 2, 4, 2, 4, 1, 2, 2, 3, 3, 2, 2, 3, 2], [1, 2, 2, 3, 2, 1, 2, 4, 1, 3, 2, 1, 3, 3, 4, 2, 4, 2, 3, 3, 4, 4, 2], [1, 3, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2, 2, 3, 4, 2], [1, 4, 2, 4, 1, 1, 4, 4, 2, 2, 1, 1, 2, 4, 1, 4, 3, 3, 1, 1, 1, 1, 1], [1, 4, 4, 2, 2, 4, 2, 2, 4, 1, 3, 1, 2, 2, 1, 4, 1, 1, 2, 1, 1, 4, 1], [1, 2, 3, 4, 2, 1, 4, 2, 3, 3, 2, 1, 2, 1, 2, 3, 4, 1, 3, 1, 3, 3, 2], [1, 2, 1, 1, 4, 1, 3, 1, 4, 4, 4, 2, 2, 2, 2, 1, 3, 4, 2, 2, 3, 3, 2], [3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2, 2, 3, 4, 2, 4, 1, 4]]), tensor([[ 6, 4, 2, 15, 3, 6, 7, 14, 5, 6, 11, 3, 11, 10, 8, 6, 7, 14, 5, 6, 7, 5, 16], [ 8, 6, 7, 5, 4, 8, 11, 10, 2, 5, 4, 2, 14, 15, 3, 11, 3, 7, 14, 15, 13, 3, 16], [ 2, 5, 7, 12, 1, 1, 2, 5, 4, 9, 3, 11, 13, 10, 9, 3, 4, 8, 6, 7, 15, 3, 16], [ 9, 3, 11, 10, 1, 9, 13, 3, 6, 4, 1, 8, 11, 10, 9, 20, 14, 12, 1, 1, 1, 1, 17], [ 9, 13, 3, 6, 11, 3, 6, 11, 10, 2, 12, 8, 6, 4, 9, 10, 1, 8, 4, 1, 9, 10, 17], [ 8, 7, 15, 3, 4, 9, 3, 7, 14, 5, 4, 8, 4, 8, 7, 15, 10, 2, 12, 2, 14, 5, 16], [ 8, 4, 1, 9, 10, 2, 12, 9, 13, 13, 3, 6, 6, 6, 4, 2, 15, 3, 6, 7, 14, 5, 16], [12, 1, 1, 2, 5, 4, 9, 3, 11, 13, 10, 9, 3, 4, 8, 6, 7, 15, 3, 11, 10, 9, 18]]), tensor([[ 9, 28, 51, 46, 15, 17, 40, 11, 33, 50, 5, 12, 26, 47, 25, 17, 40, 11, 33, 17, 14, 60, 29], [25, 17, 14, 4, 13, 39, 26, 20, 3, 4, 28, 34, 62, 46, 12, 5, 31, 40, 62, 41, 27, 54, 29], [ 3, 49, 23, 7, 1, 2, 3, 4, 19, 8, 12, 45, 10, 42, 8, 38, 13, 25, 17, 30, 46, 54, 29], [ 8, 12, 26, 22, 18, 36, 27, 15, 9, 24, 44, 39, 26, 42, 70, 67, 57, 7, 1, 1, 1, 55, 37], [36, 27, 15, 50, 5, 15, 50, 26, 20, 6, 52, 25, 9, 19, 32, 22, 44, 48, 24, 18, 32, 69, 37], [16, 30, 46, 38, 19, 8, 31, 40, 11, 4, 13, 48, 13, 16, 30, 56, 20, 6, 43, 34, 11, 60, 29], [48, 24, 18, 32, 20, 6, 59, 36, 71, 27, 15, 21, 21, 9, 28, 51, 46, 15, 17, 40, 11, 60, 29], [ 7, 1, 2, 3, 4, 19, 8, 12, 45, 10, 42, 8, 38, 13, 25, 17, 30, 46, 12, 26, 42, 61, 53]]), tensor([[2, 2, 3, 3, 2, 2, 6, 4, 5, 8, 5, 2, 3, 3, 2, 2, 4, 3, 1, 2, 2, 7, 0], [4, 3, 3, 2, 4, 2, 8, 3, 9, 6, 5, 3, 1, 4, 2, 1, 2, 3, 2, 2, 1, 7, 0], [3, 2, 2, 1, 2, 4, 6, 4, 8, 5, 8, 1, 2, 3, 1, 1, 1, 3, 2, 3, 1, 7, 0], [1, 1, 1, 3, 3, 4, 6, 4, 5, 6, 6, 2, 2, 4, 4, 1, 1, 4, 2, 4, 1, 7, 0], [1, 1, 2, 1, 1, 4, 6, 2, 5, 6, 9, 1, 4, 2, 2, 4, 2, 2, 4, 4, 1, 7, 0], [3, 1, 3, 1, 4, 3, 5, 1, 5, 6, 5, 3, 3, 2, 4, 1, 2, 4, 3, 2, 1, 7, 0], [3, 2, 2, 4, 3, 1, 5, 2, 5, 5, 8, 4, 4, 1, 3, 1, 4, 1, 1, 2, 1, 7, 0], [4, 2, 4, 3, 2, 2, 6, 2, 8, 6, 8, 4, 2, 4, 1, 2, 3, 1, 1, 1, 3, 7, 0]]), tensor([[ 9, 6, 15, 4, 9, 39, 31, 33, 26, 46, 27, 6, 15, 4, 9, 5, 17, 2, 3, 9, 18, 8, 0], [17, 15, 4, 5, 12, 19, 50, 57, 20, 52, 23, 2, 11, 12, 7, 3, 6, 4, 9, 7, 14, 8, 0], [ 4, 9, 7, 3, 5, 58, 31, 53, 46, 26, 21, 3, 6, 2, 1, 1, 13, 4, 6, 2, 14, 8, 0], [ 1, 1, 13, 15, 61, 58, 31, 33, 43, 36, 40, 9, 5, 16, 10, 1, 11, 12, 5, 10, 14, 8, 0], [ 1, 3, 7, 1, 11, 58, 40, 38, 43, 42, 30, 11, 12, 9, 5, 12, 9, 5, 16, 10, 14, 8, 0], [ 2, 13, 2, 11, 17, 24, 54, 22, 43, 52, 23, 15, 4, 5, 10, 3, 5, 17, 4, 7, 14, 8, 0], [ 4, 9, 5, 17, 2, 22, 27, 38, 28, 26, 49, 16, 10, 13, 2, 11, 10, 1, 3, 7, 14, 8, 0], [12, 5, 17, 4, 9, 39, 40, 19, 62, 45, 49, 12, 5, 10, 3, 6, 2, 1, 1, 13, 64, 8, 0]]), tensor([[ 56, 21, 36, 28, 118, 291, 67, 149, 150, 292, 186, 21, 36, 28, 15, 41, 53, 29, 24, 93, 25, 1, 0], [ 66, 36, 33, 7, 104, 121, 211, 86, 129, 131, 81, 34, 31, 37, 16, 5, 18, 28, 19, 26, 2, 1, 0], [ 28, 19, 16, 42, 252, 94, 138, 139, 82, 173, 79, 5, 9, 4, 3, 8, 13, 55, 9, 65, 2, 1, 0], [ 3, 8, 68, 69, 154, 94, 67, 95, 155, 96, 97, 15, 51, 17, 20, 14, 31, 27, 11, 48, 2, 1, 0], [ 30, 45, 35, 14, 272, 273, 168, 184, 162, 100, 60, 31, 43, 15, 7, 43, 15, 51, 17, 48, 2, 1, 0], [ 44, 6, 34, 52, 114, 74, 134, 217, 87, 131, 135, 36, 33, 11, 22, 42, 41, 40, 10, 26, 2, 1, 0], [ 28, 15, 41, 53, 299, 300, 141, 88, 103, 140, 301, 17, 61, 6, 34, 32, 20, 30, 45, 26, 2, 1, 0], [ 27, 41, 40, 28, 118, 77, 199, 200, 201, 119, 120, 27, 11, 22, 5, 9, 4, 3, 8, 202, 203, 1, 0]]), tensor([[2, 6, 1, 4, 3, 6, 2, 7, 7, 2, 2, 3, 2, 8, 1, 6, 2, 4, 4, 2, 2, 4, 2], [1, 6, 2, 4, 2, 5, 2, 8, 5, 4, 2, 1, 4, 7, 3, 6, 3, 2, 4, 4, 3, 3, 2], [1, 7, 2, 4, 1, 5, 1, 7, 6, 1, 3, 2, 3, 8, 1, 8, 2, 1, 2, 2, 4, 3, 2], [1, 8, 2, 3, 1, 5, 3, 8, 6, 2, 1, 1, 2, 8, 1, 8, 4, 4, 1, 1, 1, 1, 1], [1, 8, 3, 2, 2, 8, 2, 6, 8, 1, 4, 1, 2, 6, 1, 8, 1, 1, 2, 1, 1, 3, 1], [1, 6, 4, 3, 2, 5, 3, 6, 7, 4, 2, 1, 2, 5, 2, 7, 3, 1, 4, 1, 4, 4, 2], [1, 6, 1, 1, 3, 5, 4, 5, 8, 3, 3, 2, 2, 6, 2, 5, 4, 3, 2, 2, 4, 4, 2], [4, 5, 1, 1, 4, 6, 1, 8, 6, 3, 3, 1, 3, 6, 1, 6, 2, 4, 3, 2, 3, 1, 3]]), tensor([[35, 20, 2, 22, 24, 25, 19, 65, 17, 6, 13, 3, 28, 23, 16, 25, 7, 21, 4, 6, 7, 4, 27], [16, 25, 7, 4, 26, 33, 28, 52, 18, 4, 5, 2, 42, 46, 24, 36, 3, 7, 21, 22, 12, 3, 27], [29, 17, 7, 10, 8, 14, 29, 57, 20, 9, 3, 13, 41, 23, 31, 37, 5, 15, 6, 7, 22, 3, 27], [31, 37, 13, 11, 8, 30, 41, 48, 25, 5, 1, 15, 28, 23, 31, 63, 21, 10, 1, 1, 1, 1, 32], [31, 47, 3, 6, 28, 37, 35, 54, 23, 2, 10, 15, 35, 20, 31, 23, 1, 15, 5, 1, 9, 11, 32], [16, 34, 22, 3, 26, 30, 24, 50, 44, 4, 5, 15, 26, 33, 19, 46, 11, 2, 10, 2, 21, 4, 27], [16, 20, 1, 9, 38, 18, 39, 66, 47, 12, 3, 6, 35, 25, 26, 18, 22, 3, 6, 7, 21, 4, 27], [39, 14, 1, 2, 40, 20, 31, 48, 36, 12, 11, 9, 24, 20, 16, 25, 7, 22, 3, 13, 11, 9, 43]]), tensor([[ 94, 118, 121, 340, 146, 210, 341, 342, 175, 96, 9, 217, 43, 209, 51, 87, 52, 6, 12, 20, 26, 59, 2], [ 51, 87, 26, 101, 92, 268, 149, 185, 70, 14, 18, 72, 269, 191, 57, 119, 25, 52, 197, 32, 16, 19, 2], [ 21, 236, 33, 90, 8, 29, 196, 174, 68, 15, 35, 206, 69, 120, 74, 302, 22, 75, 20, 24, 27, 19, 2], [ 74, 165, 113, 28, 44, 85, 142, 166, 54, 11, 126, 53, 43, 120, 221, 222, 223, 13, 1, 1, 1, 31, 3], [136, 321, 56, 231, 83, 150, 322, 323, 82, 7, 100, 249, 94, 186, 78, 133, 126, 124, 11, 45, 49, 144, 3], [ 60, 151, 27, 152, 80, 47, 201, 202, 153, 14, 22, 203, 92, 108, 63, 154, 40, 7, 34, 50, 6, 59, 2], [ 71, 170, 45, 155, 98, 76, 344, 345, 346, 16, 56, 159, 93, 347, 77, 145, 27, 56, 20, 52, 6, 59, 2], [ 58, 79, 4, 110, 64, 186, 187, 188, 102, 10, 48, 189, 111, 112, 51, 87, 24, 27, 35, 113, 48, 65, 17]])]
inputs: tensor([[2, 2, 1, 3, 4, 2, 2, 3, 3, 2, 2, 4, 2, 4, 1, 2, 2, 3, 3, 2, 2, 3, 2], [1, 2, 2, 3, 2, 1, 2, 4, 1, 3, 2, 1, 3, 3, 4, 2, 4, 2, 3, 3, 4, 4, 2], [1, 3, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2, 2, 3, 4, 2], [1, 4, 2, 4, 1, 1, 4, 4, 2, 2, 1, 1, 2, 4, 1, 4, 3, 3, 1, 1, 1, 1, 1], [1, 4, 4, 2, 2, 4, 2, 2, 4, 1, 3, 1, 2, 2, 1, 4, 1, 1, 2, 1, 1, 4, 1], [1, 2, 3, 4, 2, 1, 4, 2, 3, 3, 2, 1, 2, 1, 2, 3, 4, 1, 3, 1, 3, 3, 2], [1, 2, 1, 1, 4, 1, 3, 1, 4, 4, 4, 2, 2, 2, 2, 1, 3, 4, 2, 2, 3, 3, 2], [3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2, 2, 3, 4, 2, 4, 1, 4]])
inputs: [tensor([[2, 2, 1, 3, 4, 2, 2, 3, 3, 2, 2, 4, 2, 4, 1, 2, 2, 3, 3, 2, 2, 3, 2], [1, 2, 2, 3, 2, 1, 2, 4, 1, 3, 2, 1, 3, 3, 4, 2, 4, 2, 3, 3, 4, 4, 2], [1, 3, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2, 2, 3, 4, 2], [1, 4, 2, 4, 1, 1, 4, 4, 2, 2, 1, 1, 2, 4, 1, 4, 3, 3, 1, 1, 1, 1, 1], [1, 4, 4, 2, 2, 4, 2, 2, 4, 1, 3, 1, 2, 2, 1, 4, 1, 1, 2, 1, 1, 4, 1], [1, 2, 3, 4, 2, 1, 4, 2, 3, 3, 2, 1, 2, 1, 2, 3, 4, 1, 3, 1, 3, 3, 2], [1, 2, 1, 1, 4, 1, 3, 1, 4, 4, 4, 2, 2, 2, 2, 1, 3, 4, 2, 2, 3, 3, 2], [3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2, 2, 3, 4, 2, 4, 1, 4]], device='cuda:0'), tensor([[ 6, 4, 2, 15, 3, 6, 7, 14, 5, 6, 11, 3, 11, 10, 8, 6, 7, 14, 5, 6, 7, 5, 16], [ 8, 6, 7, 5, 4, 8, 11, 10, 2, 5, 4, 2, 14, 15, 3, 11, 3, 7, 14, 15, 13, 3, 16], [ 2, 5, 7, 12, 1, 1, 2, 5, 4, 9, 3, 11, 13, 10, 9, 3, 4, 8, 6, 7, 15, 3, 16], [ 9, 3, 11, 10, 1, 9, 13, 3, 6, 4, 1, 8, 11, 10, 9, 20, 14, 12, 1, 1, 1, 1, 17], [ 9, 13, 3, 6, 11, 3, 6, 11, 10, 2, 12, 8, 6, 4, 9, 10, 1, 8, 4, 1, 9, 10, 17], [ 8, 7, 15, 3, 4, 9, 3, 7, 14, 5, 4, 8, 4, 8, 7, 15, 10, 2, 12, 2, 14, 5, 16], [ 8, 4, 1, 9, 10, 2, 12, 9, 13, 13, 3, 6, 6, 6, 4, 2, 15, 3, 6, 7, 14, 5, 16], [12, 1, 1, 2, 5, 4, 9, 3, 11, 13, 10, 9, 3, 4, 8, 6, 7, 15, 3, 11, 10, 9, 18]], device='cuda:0'), tensor([[ 9, 28, 51, 46, 15, 17, 40, 11, 33, 50, 5, 12, 26, 47, 25, 17, 40, 11, 33, 17, 14, 60, 29], [25, 17, 14, 4, 13, 39, 26, 20, 3, 4, 28, 34, 62, 46, 12, 5, 31, 40, 62, 41, 27, 54, 29], [ 3, 49, 23, 7, 1, 2, 3, 4, 19, 8, 12, 45, 10, 42, 8, 38, 13, 25, 17, 30, 46, 54, 29], [ 8, 12, 26, 22, 18, 36, 27, 15, 9, 24, 44, 39, 26, 42, 70, 67, 57, 7, 1, 1, 1, 55, 37], [36, 27, 15, 50, 5, 15, 50, 26, 20, 6, 52, 25, 9, 19, 32, 22, 44, 48, 24, 18, 32, 69, 37], [16, 30, 46, 38, 19, 8, 31, 40, 11, 4, 13, 48, 13, 16, 30, 56, 20, 6, 43, 34, 11, 60, 29], [48, 24, 18, 32, 20, 6, 59, 36, 71, 27, 15, 21, 21, 9, 28, 51, 46, 15, 17, 40, 11, 60, 29], [ 7, 1, 2, 3, 4, 19, 8, 12, 45, 10, 42, 8, 38, 13, 25, 17, 30, 46, 12, 26, 42, 61, 53]], device='cuda:0'), tensor([[2, 2, 3, 3, 2, 2, 6, 4, 5, 8, 5, 2, 3, 3, 2, 2, 4, 3, 1, 2, 2, 7, 0], [4, 3, 3, 2, 4, 2, 8, 3, 9, 6, 5, 3, 1, 4, 2, 1, 2, 3, 2, 2, 1, 7, 0], [3, 2, 2, 1, 2, 4, 6, 4, 8, 5, 8, 1, 2, 3, 1, 1, 1, 3, 2, 3, 1, 7, 0], [1, 1, 1, 3, 3, 4, 6, 4, 5, 6, 6, 2, 2, 4, 4, 1, 1, 4, 2, 4, 1, 7, 0], [1, 1, 2, 1, 1, 4, 6, 2, 5, 6, 9, 1, 4, 2, 2, 4, 2, 2, 4, 4, 1, 7, 0], [3, 1, 3, 1, 4, 3, 5, 1, 5, 6, 5, 3, 3, 2, 4, 1, 2, 4, 3, 2, 1, 7, 0], [3, 2, 2, 4, 3, 1, 5, 2, 5, 5, 8, 4, 4, 1, 3, 1, 4, 1, 1, 2, 1, 7, 0], [4, 2, 4, 3, 2, 2, 6, 2, 8, 6, 8, 4, 2, 4, 1, 2, 3, 1, 1, 1, 3, 7, 0]], device='cuda:0'), tensor([[ 9, 6, 15, 4, 9, 39, 31, 33, 26, 46, 27, 6, 15, 4, 9, 5, 17, 2, 3, 9, 18, 8, 0], [17, 15, 4, 5, 12, 19, 50, 57, 20, 52, 23, 2, 11, 12, 7, 3, 6, 4, 9, 7, 14, 8, 0], [ 4, 9, 7, 3, 5, 58, 31, 53, 46, 26, 21, 3, 6, 2, 1, 1, 13, 4, 6, 2, 14, 8, 0], [ 1, 1, 13, 15, 61, 58, 31, 33, 43, 36, 40, 9, 5, 16, 10, 1, 11, 12, 5, 10, 14, 8, 0], [ 1, 3, 7, 1, 11, 58, 40, 38, 43, 42, 30, 11, 12, 9, 5, 12, 9, 5, 16, 10, 14, 8, 0], [ 2, 13, 2, 11, 17, 24, 54, 22, 43, 52, 23, 15, 4, 5, 10, 3, 5, 17, 4, 7, 14, 8, 0], [ 4, 9, 5, 17, 2, 22, 27, 38, 28, 26, 49, 16, 10, 13, 2, 11, 10, 1, 3, 7, 14, 8, 0], [12, 5, 17, 4, 9, 39, 40, 19, 62, 45, 49, 12, 5, 10, 3, 6, 2, 1, 1, 13, 64, 8, 0]], device='cuda:0'), tensor([[ 56, 21, 36, 28, 118, 291, 67, 149, 150, 292, 186, 21, 36, 28, 15, 41, 53, 29, 24, 93, 25, 1, 0], [ 66, 36, 33, 7, 104, 121, 211, 86, 129, 131, 81, 34, 31, 37, 16, 5, 18, 28, 19, 26, 2, 1, 0], [ 28, 19, 16, 42, 252, 94, 138, 139, 82, 173, 79, 5, 9, 4, 3, 8, 13, 55, 9, 65, 2, 1, 0], [ 3, 8, 68, 69, 154, 94, 67, 95, 155, 96, 97, 15, 51, 17, 20, 14, 31, 27, 11, 48, 2, 1, 0], [ 30, 45, 35, 14, 272, 273, 168, 184, 162, 100, 60, 31, 43, 15, 7, 43, 15, 51, 17, 48, 2, 1, 0], [ 44, 6, 34, 52, 114, 74, 134, 217, 87, 131, 135, 36, 33, 11, 22, 42, 41, 40, 10, 26, 2, 1, 0], [ 28, 15, 41, 53, 299, 300, 141, 88, 103, 140, 301, 17, 61, 6, 34, 32, 20, 30, 45, 26, 2, 1, 0], [ 27, 41, 40, 28, 118, 77, 199, 200, 201, 119, 120, 27, 11, 22, 5, 9, 4, 3, 8, 202, 203, 1, 0]], device='cuda:0'), tensor([[2, 6, 1, 4, 3, 6, 2, 7, 7, 2, 2, 3, 2, 8, 1, 6, 2, 4, 4, 2, 2, 4, 2], [1, 6, 2, 4, 2, 5, 2, 8, 5, 4, 2, 1, 4, 7, 3, 6, 3, 2, 4, 4, 3, 3, 2], [1, 7, 2, 4, 1, 5, 1, 7, 6, 1, 3, 2, 3, 8, 1, 8, 2, 1, 2, 2, 4, 3, 2], [1, 8, 2, 3, 1, 5, 3, 8, 6, 2, 1, 1, 2, 8, 1, 8, 4, 4, 1, 1, 1, 1, 1], [1, 8, 3, 2, 2, 8, 2, 6, 8, 1, 4, 1, 2, 6, 1, 8, 1, 1, 2, 1, 1, 3, 1], [1, 6, 4, 3, 2, 5, 3, 6, 7, 4, 2, 1, 2, 5, 2, 7, 3, 1, 4, 1, 4, 4, 2], [1, 6, 1, 1, 3, 5, 4, 5, 8, 3, 3, 2, 2, 6, 2, 5, 4, 3, 2, 2, 4, 4, 2], [4, 5, 1, 1, 4, 6, 1, 8, 6, 3, 3, 1, 3, 6, 1, 6, 2, 4, 3, 2, 3, 1, 3]], device='cuda:0'), tensor([[35, 20, 2, 22, 24, 25, 19, 65, 17, 6, 13, 3, 28, 23, 16, 25, 7, 21, 4, 6, 7, 4, 27], [16, 25, 7, 4, 26, 33, 28, 52, 18, 4, 5, 2, 42, 46, 24, 36, 3, 7, 21, 22, 12, 3, 27], [29, 17, 7, 10, 8, 14, 29, 57, 20, 9, 3, 13, 41, 23, 31, 37, 5, 15, 6, 7, 22, 3, 27], [31, 37, 13, 11, 8, 30, 41, 48, 25, 5, 1, 15, 28, 23, 31, 63, 21, 10, 1, 1, 1, 1, 32], [31, 47, 3, 6, 28, 37, 35, 54, 23, 2, 10, 15, 35, 20, 31, 23, 1, 15, 5, 1, 9, 11, 32], [16, 34, 22, 3, 26, 30, 24, 50, 44, 4, 5, 15, 26, 33, 19, 46, 11, 2, 10, 2, 21, 4, 27], [16, 20, 1, 9, 38, 18, 39, 66, 47, 12, 3, 6, 35, 25, 26, 18, 22, 3, 6, 7, 21, 4, 27], [39, 14, 1, 2, 40, 20, 31, 48, 36, 12, 11, 9, 24, 20, 16, 25, 7, 22, 3, 13, 11, 9, 43]], device='cuda:0'), tensor([[ 94, 118, 121, 340, 146, 210, 341, 342, 175, 96, 9, 217, 43, 209, 51, 87, 52, 6, 12, 20, 26, 59, 2], [ 51, 87, 26, 101, 92, 268, 149, 185, 70, 14, 18, 72, 269, 191, 57, 119, 25, 52, 197, 32, 16, 19, 2], [ 21, 236, 33, 90, 8, 29, 196, 174, 68, 15, 35, 206, 69, 120, 74, 302, 22, 75, 20, 24, 27, 19, 2], [ 74, 165, 113, 28, 44, 85, 142, 166, 54, 11, 126, 53, 43, 120, 221, 222, 223, 13, 1, 1, 1, 31, 3], [136, 321, 56, 231, 83, 150, 322, 323, 82, 7, 100, 249, 94, 186, 78, 133, 126, 124, 11, 45, 49, 144, 3], [ 60, 151, 27, 152, 80, 47, 201, 202, 153, 14, 22, 203, 92, 108, 63, 154, 40, 7, 34, 50, 6, 59, 2], [ 71, 170, 45, 155, 98, 76, 344, 345, 346, 16, 56, 159, 93, 347, 77, 145, 27, 56, 20, 52, 6, 59, 2], [ 58, 79, 4, 110, 64, 186, 187, 188, 102, 10, 48, 189, 111, 112, 51, 87, 24, 27, 35, 113, 48, 65, 17]], device='cuda:0')]
inputs: tensor([[2, 2, 1, 3, 4, 2, 2, 3, 3, 2, 2, 4, 2, 4, 1, 2, 2, 3, 3, 2, 2, 3, 2], [1, 2, 2, 3, 2, 1, 2, 4, 1, 3, 2, 1, 3, 3, 4, 2, 4, 2, 3, 3, 4, 4, 2], [1, 3, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2, 2, 3, 4, 2], [1, 4, 2, 4, 1, 1, 4, 4, 2, 2, 1, 1, 2, 4, 1, 4, 3, 3, 1, 1, 1, 1, 1], [1, 4, 4, 2, 2, 4, 2, 2, 4, 1, 3, 1, 2, 2, 1, 4, 1, 1, 2, 1, 1, 4, 1], [1, 2, 3, 4, 2, 1, 4, 2, 3, 3, 2, 1, 2, 1, 2, 3, 4, 1, 3, 1, 3, 3, 2], [1, 2, 1, 1, 4, 1, 3, 1, 4, 4, 4, 2, 2, 2, 2, 1, 3, 4, 2, 2, 3, 3, 2], [3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2, 2, 3, 4, 2, 4, 1, 4]], device='cuda:0')
inputs: [tensor([[1, 1, 3, 1, 2, 4, 1, 1, 3, 2, 4, 2, 2, 1, 3, 1, 1, 3, 3, 3, 2, 2, 1], [2, 3, 2, 2, 1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2], [1, 3, 2, 1, 3, 1, 3, 1, 2, 3, 4, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2], [1, 2, 1, 3, 1, 3, 1, 2, 3, 4, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2, 1], [1, 1, 3, 3, 2, 1, 1, 1, 2, 1, 2, 4, 2, 4, 1, 3, 3, 2, 1, 4, 4, 4, 2], [1, 3, 2, 4, 2, 2, 4, 2, 1, 3, 4, 2, 4, 2, 2, 4, 4, 1, 4, 4, 2, 4, 4], [2, 4, 1, 1, 3, 3, 2, 1, 2, 3, 1, 1, 3, 3, 1, 1, 3, 4, 2, 1, 1, 3, 1], [1, 1, 1, 4, 1, 2, 2, 4, 1, 1, 4, 2, 3, 3, 2, 3, 1, 1, 3, 2, 1, 3, 1]]), tensor([[ 1, 2, 12, 8, 11, 10, 1, 2, 5, 11, 3, 6, 4, 2, 12, 1, 2, 14, 14, 5, 6, 4, 17], [ 7, 5, 6, 4, 9, 13, 10, 2, 12, 1, 2, 12, 8, 7, 15, 13, 10, 8, 7, 5, 11, 3, 16], [ 2, 5, 4, 2, 12, 2, 12, 8, 7, 15, 13, 10, 1, 1, 2, 5, 7, 12, 1, 1, 2, 5, 16], [ 8, 4, 2, 12, 2, 12, 8, 7, 15, 13, 10, 1, 1, 2, 5, 7, 12, 1, 1, 2, 5, 4, 17], [ 1, 2, 14, 5, 4, 1, 1, 8, 4, 8, 11, 3, 11, 10, 2, 14, 5, 4, 9, 13, 13, 3, 16], [ 2, 5, 11, 3, 6, 11, 3, 4, 2, 15, 3, 11, 3, 6, 11, 13, 10, 9, 13, 3, 11, 13, 18], [11, 10, 1, 2, 14, 5, 4, 8, 7, 12, 1, 2, 14, 12, 1, 2, 15, 3, 4, 1, 2, 12, 17], [ 1, 1, 9, 10, 8, 6, 11, 10, 1, 9, 3, 7, 14, 5, 7, 12, 1, 2, 5, 4, 2, 12, 17]]), tensor([[ 2, 6, 52, 39, 26, 22, 2, 3, 35, 5, 15, 9, 28, 6, 7, 2, 34, 63, 11, 33, 9, 64, 37], [14, 33, 9, 19, 36, 10, 20, 6, 7, 2, 6, 52, 16, 30, 41, 10, 47, 16, 14, 35, 5, 54, 29], [ 3, 4, 28, 6, 43, 6, 52, 16, 30, 41, 10, 22, 1, 2, 3, 49, 23, 7, 1, 2, 3, 60, 29], [48, 28, 6, 43, 6, 52, 16, 30, 41, 10, 22, 1, 2, 3, 49, 23, 7, 1, 2, 3, 4, 64, 37], [ 2, 34, 11, 4, 24, 1, 44, 48, 13, 39, 5, 12, 26, 20, 34, 11, 4, 19, 36, 71, 27, 54, 29], [ 3, 35, 5, 15, 50, 5, 38, 28, 51, 46, 12, 5, 15, 50, 45, 10, 42, 36, 27, 12, 45, 76, 53], [26, 22, 2, 34, 11, 4, 13, 16, 23, 7, 2, 34, 57, 7, 2, 51, 46, 38, 24, 2, 6, 74, 37], [ 1, 18, 32, 47, 25, 50, 26, 22, 18, 8, 31, 40, 11, 49, 23, 7, 2, 3, 4, 28, 6, 74, 37]]), tensor([[2, 3, 3, 3, 1, 1, 9, 1, 5, 5, 8, 2, 3, 1, 1, 4, 2, 1, 3, 1, 1, 7, 0], [2, 3, 2, 1, 4, 4, 9, 2, 6, 9, 6, 1, 3, 1, 4, 4, 1, 2, 2, 3, 2, 7, 0], [1, 1, 1, 3, 2, 3, 6, 1, 6, 8, 8, 3, 2, 1, 3, 1, 3, 1, 2, 3, 1, 7, 0], [3, 1, 1, 1, 3, 2, 9, 1, 6, 6, 8, 4, 3, 2, 1, 3, 1, 3, 1, 2, 1, 7, 0], [4, 4, 1, 2, 3, 3, 6, 4, 5, 8, 5, 1, 2, 1, 1, 1, 2, 3, 3, 1, 1, 7, 0], [2, 4, 4, 1, 4, 4, 5, 2, 8, 5, 8, 3, 1, 2, 4, 2, 2, 4, 2, 3, 1, 7, 0], [1, 1, 2, 4, 3, 1, 6, 3, 9, 6, 6, 3, 2, 1, 2, 3, 3, 1, 1, 4, 2, 7, 0], [1, 2, 3, 1, 1, 3, 5, 3, 9, 5, 8, 1, 1, 4, 2, 2, 1, 4, 1, 1, 1, 7, 0]]), tensor([[ 6, 15, 15, 2, 1, 48, 30, 22, 28, 26, 44, 6, 2, 1, 11, 12, 7, 13, 2, 1, 14, 8, 0], [ 6, 4, 7, 11, 16, 66, 35, 39, 42, 20, 29, 13, 2, 11, 16, 10, 3, 9, 6, 4, 18, 8, 0], [ 1, 1, 13, 4, 6, 37, 29, 41, 45, 25, 50, 4, 7, 13, 2, 13, 2, 3, 6, 2, 14, 8, 0], [ 2, 1, 1, 13, 4, 47, 30, 41, 36, 45, 49, 17, 4, 7, 13, 2, 13, 2, 3, 7, 14, 8, 0], [16, 10, 3, 6, 15, 37, 31, 33, 26, 46, 54, 3, 7, 1, 1, 3, 6, 15, 2, 1, 14, 8, 0], [ 5, 16, 10, 11, 16, 33, 27, 19, 46, 26, 50, 2, 3, 5, 12, 9, 5, 12, 6, 2, 14, 8, 0], [ 1, 3, 5, 17, 2, 41, 56, 57, 20, 36, 56, 4, 7, 3, 6, 15, 2, 1, 11, 12, 18, 8, 0], [ 3, 6, 2, 1, 13, 24, 23, 57, 51, 26, 21, 1, 11, 12, 9, 7, 11, 10, 1, 1, 14, 8, 0]]), tensor([[ 21, 58, 38, 4, 177, 178, 101, 70, 103, 179, 180, 9, 4, 14, 31, 37, 54, 6, 4, 12, 2, 1, 0], [ 18, 10, 50, 23, 253, 254, 174, 255, 175, 71, 176, 6, 34, 23, 17, 22, 24, 56, 18, 49, 25, 1, 0], [ 3, 8, 13, 55, 64, 57, 80, 90, 165, 166, 167, 10, 54, 6, 44, 6, 29, 5, 9, 65, 2, 1, 0], [ 4, 3, 8, 13, 307, 170, 308, 191, 309, 119, 310, 40, 10, 54, 6, 44, 6, 29, 45, 26, 2, 1, 0], [ 17, 22, 5, 21, 148, 137, 67, 149, 150, 227, 228, 45, 35, 3, 30, 5, 21, 38, 4, 12, 2, 1, 0], [ 51, 17, 47, 23, 206, 207, 208, 127, 82, 209, 210, 29, 42, 7, 43, 15, 7, 39, 9, 65, 2, 1, 0], [ 30, 42, 41, 53, 274, 275, 276, 86, 115, 116, 277, 10, 16, 5, 21, 38, 4, 14, 31, 278, 25, 1, 0], [ 5, 9, 4, 8, 260, 261, 182, 262, 263, 173, 264, 14, 31, 43, 19, 50, 32, 20, 3, 12, 2, 1, 0]]), tensor([[1, 5, 4, 1, 2, 8, 1, 5, 7, 2, 3, 2, 2, 5, 4, 5, 1, 4, 4, 4, 2, 2, 1], [2, 7, 2, 2, 1, 8, 3, 5, 7, 1, 1, 4, 1, 6, 4, 8, 3, 1, 2, 4, 2, 3, 2], [1, 7, 2, 1, 4, 5, 4, 5, 6, 4, 3, 3, 1, 5, 1, 7, 2, 4, 1, 1, 1, 4, 2], [1, 6, 1, 4, 1, 7, 1, 6, 7, 3, 3, 1, 1, 5, 4, 6, 4, 1, 1, 1, 4, 2, 1], [1, 5, 4, 4, 2, 5, 1, 5, 6, 1, 2, 3, 2, 8, 1, 7, 4, 2, 1, 3, 3, 3, 2], [1, 7, 2, 3, 2, 6, 3, 6, 5, 4, 3, 2, 3, 6, 2, 8, 3, 1, 3, 3, 2, 3, 3], [2, 8, 1, 1, 4, 7, 2, 5, 6, 4, 1, 1, 4, 7, 1, 5, 4, 3, 2, 1, 1, 4, 1], [1, 5, 1, 3, 1, 6, 2, 8, 5, 1, 3, 2, 4, 7, 2, 7, 1, 1, 4, 2, 1, 4, 1]]), tensor([[ 8, 18, 10, 15, 28, 23, 8, 59, 17, 13, 3, 6, 26, 18, 39, 14, 2, 21, 21, 4, 6, 5, 32], [19, 17, 6, 5, 31, 47, 38, 59, 45, 1, 2, 10, 16, 34, 62, 47, 11, 15, 7, 4, 13, 3, 27], [29, 17, 5, 2, 39, 18, 39, 58, 34, 22, 12, 11, 8, 14, 29, 17, 7, 10, 1, 1, 2, 4, 27], [16, 20, 2, 10, 29, 45, 16, 50, 46, 12, 11, 1, 8, 18, 40, 34, 10, 1, 1, 2, 4, 5, 32], [ 8, 18, 21, 4, 26, 14, 8, 58, 20, 15, 13, 3, 28, 23, 29, 44, 4, 5, 9, 12, 12, 3, 27], [29, 17, 13, 3, 35, 36, 24, 60, 18, 22, 3, 13, 24, 25, 28, 47, 11, 9, 12, 3, 13, 12, 43], [28, 23, 1, 2, 42, 17, 26, 58, 34, 10, 1, 2, 42, 45, 8, 18, 22, 3, 5, 1, 2, 10, 32], [ 8, 14, 9, 11, 16, 25, 28, 52, 14, 9, 3, 7, 42, 17, 19, 45, 1, 2, 4, 5, 2, 10, 32]]), tensor([[ 38, 62, 100, 53, 43, 143, 176, 244, 84, 9, 56, 61, 77, 76, 58, 127, 50, 245, 6, 12, 23, 89, 3], [ 30, 175, 23, 243, 136, 303, 304, 305, 130, 4, 7, 306, 60, 307, 308, 105, 141, 86, 26, 41, 9, 19, 2], [ 21, 42, 18, 106, 234, 76, 235, 171, 151, 32, 10, 28, 8, 29, 21, 236, 33, 13, 1, 4, 5, 59, 2], [ 71, 118, 7, 247, 352, 353, 177, 354, 140, 10, 122, 178, 38, 172, 355, 128, 13, 1, 4, 5, 14, 89, 3], [ 38, 214, 6, 101, 125, 215, 283, 284, 95, 216, 9, 217, 43, 157, 218, 153, 14, 37, 36, 219, 16, 19, 2], [ 21, 84, 9, 115, 262, 116, 263, 264, 145, 27, 35, 195, 146, 147, 148, 105, 48, 36, 16, 35, 91, 265, 17], [ 43, 133, 4, 72, 129, 132, 324, 171, 128, 13, 4, 72, 325, 326, 38, 145, 27, 39, 11, 4, 7, 248, 3], [ 8, 73, 49, 160, 51, 147, 149, 312, 73, 15, 25, 313, 129, 179, 81, 130, 4, 5, 14, 18, 7, 248, 3]])]
inputs: tensor([[1, 1, 3, 1, 2, 4, 1, 1, 3, 2, 4, 2, 2, 1, 3, 1, 1, 3, 3, 3, 2, 2, 1], [2, 3, 2, 2, 1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2], [1, 3, 2, 1, 3, 1, 3, 1, 2, 3, 4, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2], [1, 2, 1, 3, 1, 3, 1, 2, 3, 4, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2, 1], [1, 1, 3, 3, 2, 1, 1, 1, 2, 1, 2, 4, 2, 4, 1, 3, 3, 2, 1, 4, 4, 4, 2], [1, 3, 2, 4, 2, 2, 4, 2, 1, 3, 4, 2, 4, 2, 2, 4, 4, 1, 4, 4, 2, 4, 4], [2, 4, 1, 1, 3, 3, 2, 1, 2, 3, 1, 1, 3, 3, 1, 1, 3, 4, 2, 1, 1, 3, 1], [1, 1, 1, 4, 1, 2, 2, 4, 1, 1, 4, 2, 3, 3, 2, 3, 1, 1, 3, 2, 1, 3, 1]])
inputs: [tensor([[1, 1, 3, 1, 2, 4, 1, 1, 3, 2, 4, 2, 2, 1, 3, 1, 1, 3, 3, 3, 2, 2, 1], [2, 3, 2, 2, 1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2], [1, 3, 2, 1, 3, 1, 3, 1, 2, 3, 4, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2], [1, 2, 1, 3, 1, 3, 1, 2, 3, 4, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2, 1], [1, 1, 3, 3, 2, 1, 1, 1, 2, 1, 2, 4, 2, 4, 1, 3, 3, 2, 1, 4, 4, 4, 2], [1, 3, 2, 4, 2, 2, 4, 2, 1, 3, 4, 2, 4, 2, 2, 4, 4, 1, 4, 4, 2, 4, 4], [2, 4, 1, 1, 3, 3, 2, 1, 2, 3, 1, 1, 3, 3, 1, 1, 3, 4, 2, 1, 1, 3, 1], [1, 1, 1, 4, 1, 2, 2, 4, 1, 1, 4, 2, 3, 3, 2, 3, 1, 1, 3, 2, 1, 3, 1]], device='cuda:0'), tensor([[ 1, 2, 12, 8, 11, 10, 1, 2, 5, 11, 3, 6, 4, 2, 12, 1, 2, 14, 14, 5, 6, 4, 17], [ 7, 5, 6, 4, 9, 13, 10, 2, 12, 1, 2, 12, 8, 7, 15, 13, 10, 8, 7, 5, 11, 3, 16], [ 2, 5, 4, 2, 12, 2, 12, 8, 7, 15, 13, 10, 1, 1, 2, 5, 7, 12, 1, 1, 2, 5, 16], [ 8, 4, 2, 12, 2, 12, 8, 7, 15, 13, 10, 1, 1, 2, 5, 7, 12, 1, 1, 2, 5, 4, 17], [ 1, 2, 14, 5, 4, 1, 1, 8, 4, 8, 11, 3, 11, 10, 2, 14, 5, 4, 9, 13, 13, 3, 16], [ 2, 5, 11, 3, 6, 11, 3, 4, 2, 15, 3, 11, 3, 6, 11, 13, 10, 9, 13, 3, 11, 13, 18], [11, 10, 1, 2, 14, 5, 4, 8, 7, 12, 1, 2, 14, 12, 1, 2, 15, 3, 4, 1, 2, 12, 17], [ 1, 1, 9, 10, 8, 6, 11, 10, 1, 9, 3, 7, 14, 5, 7, 12, 1, 2, 5, 4, 2, 12, 17]], device='cuda:0'), tensor([[ 2, 6, 52, 39, 26, 22, 2, 3, 35, 5, 15, 9, 28, 6, 7, 2, 34, 63, 11, 33, 9, 64, 37], [14, 33, 9, 19, 36, 10, 20, 6, 7, 2, 6, 52, 16, 30, 41, 10, 47, 16, 14, 35, 5, 54, 29], [ 3, 4, 28, 6, 43, 6, 52, 16, 30, 41, 10, 22, 1, 2, 3, 49, 23, 7, 1, 2, 3, 60, 29], [48, 28, 6, 43, 6, 52, 16, 30, 41, 10, 22, 1, 2, 3, 49, 23, 7, 1, 2, 3, 4, 64, 37], [ 2, 34, 11, 4, 24, 1, 44, 48, 13, 39, 5, 12, 26, 20, 34, 11, 4, 19, 36, 71, 27, 54, 29], [ 3, 35, 5, 15, 50, 5, 38, 28, 51, 46, 12, 5, 15, 50, 45, 10, 42, 36, 27, 12, 45, 76, 53], [26, 22, 2, 34, 11, 4, 13, 16, 23, 7, 2, 34, 57, 7, 2, 51, 46, 38, 24, 2, 6, 74, 37], [ 1, 18, 32, 47, 25, 50, 26, 22, 18, 8, 31, 40, 11, 49, 23, 7, 2, 3, 4, 28, 6, 74, 37]], device='cuda:0'), tensor([[2, 3, 3, 3, 1, 1, 9, 1, 5, 5, 8, 2, 3, 1, 1, 4, 2, 1, 3, 1, 1, 7, 0], [2, 3, 2, 1, 4, 4, 9, 2, 6, 9, 6, 1, 3, 1, 4, 4, 1, 2, 2, 3, 2, 7, 0], [1, 1, 1, 3, 2, 3, 6, 1, 6, 8, 8, 3, 2, 1, 3, 1, 3, 1, 2, 3, 1, 7, 0], [3, 1, 1, 1, 3, 2, 9, 1, 6, 6, 8, 4, 3, 2, 1, 3, 1, 3, 1, 2, 1, 7, 0], [4, 4, 1, 2, 3, 3, 6, 4, 5, 8, 5, 1, 2, 1, 1, 1, 2, 3, 3, 1, 1, 7, 0], [2, 4, 4, 1, 4, 4, 5, 2, 8, 5, 8, 3, 1, 2, 4, 2, 2, 4, 2, 3, 1, 7, 0], [1, 1, 2, 4, 3, 1, 6, 3, 9, 6, 6, 3, 2, 1, 2, 3, 3, 1, 1, 4, 2, 7, 0], [1, 2, 3, 1, 1, 3, 5, 3, 9, 5, 8, 1, 1, 4, 2, 2, 1, 4, 1, 1, 1, 7, 0]], device='cuda:0'), tensor([[ 6, 15, 15, 2, 1, 48, 30, 22, 28, 26, 44, 6, 2, 1, 11, 12, 7, 13, 2, 1, 14, 8, 0], [ 6, 4, 7, 11, 16, 66, 35, 39, 42, 20, 29, 13, 2, 11, 16, 10, 3, 9, 6, 4, 18, 8, 0], [ 1, 1, 13, 4, 6, 37, 29, 41, 45, 25, 50, 4, 7, 13, 2, 13, 2, 3, 6, 2, 14, 8, 0], [ 2, 1, 1, 13, 4, 47, 30, 41, 36, 45, 49, 17, 4, 7, 13, 2, 13, 2, 3, 7, 14, 8, 0], [16, 10, 3, 6, 15, 37, 31, 33, 26, 46, 54, 3, 7, 1, 1, 3, 6, 15, 2, 1, 14, 8, 0], [ 5, 16, 10, 11, 16, 33, 27, 19, 46, 26, 50, 2, 3, 5, 12, 9, 5, 12, 6, 2, 14, 8, 0], [ 1, 3, 5, 17, 2, 41, 56, 57, 20, 36, 56, 4, 7, 3, 6, 15, 2, 1, 11, 12, 18, 8, 0], [ 3, 6, 2, 1, 13, 24, 23, 57, 51, 26, 21, 1, 11, 12, 9, 7, 11, 10, 1, 1, 14, 8, 0]], device='cuda:0'), tensor([[ 21, 58, 38, 4, 177, 178, 101, 70, 103, 179, 180, 9, 4, 14, 31, 37, 54, 6, 4, 12, 2, 1, 0], [ 18, 10, 50, 23, 253, 254, 174, 255, 175, 71, 176, 6, 34, 23, 17, 22, 24, 56, 18, 49, 25, 1, 0], [ 3, 8, 13, 55, 64, 57, 80, 90, 165, 166, 167, 10, 54, 6, 44, 6, 29, 5, 9, 65, 2, 1, 0], [ 4, 3, 8, 13, 307, 170, 308, 191, 309, 119, 310, 40, 10, 54, 6, 44, 6, 29, 45, 26, 2, 1, 0], [ 17, 22, 5, 21, 148, 137, 67, 149, 150, 227, 228, 45, 35, 3, 30, 5, 21, 38, 4, 12, 2, 1, 0], [ 51, 17, 47, 23, 206, 207, 208, 127, 82, 209, 210, 29, 42, 7, 43, 15, 7, 39, 9, 65, 2, 1, 0], [ 30, 42, 41, 53, 274, 275, 276, 86, 115, 116, 277, 10, 16, 5, 21, 38, 4, 14, 31, 278, 25, 1, 0], [ 5, 9, 4, 8, 260, 261, 182, 262, 263, 173, 264, 14, 31, 43, 19, 50, 32, 20, 3, 12, 2, 1, 0]], device='cuda:0'), tensor([[1, 5, 4, 1, 2, 8, 1, 5, 7, 2, 3, 2, 2, 5, 4, 5, 1, 4, 4, 4, 2, 2, 1], [2, 7, 2, 2, 1, 8, 3, 5, 7, 1, 1, 4, 1, 6, 4, 8, 3, 1, 2, 4, 2, 3, 2], [1, 7, 2, 1, 4, 5, 4, 5, 6, 4, 3, 3, 1, 5, 1, 7, 2, 4, 1, 1, 1, 4, 2], [1, 6, 1, 4, 1, 7, 1, 6, 7, 3, 3, 1, 1, 5, 4, 6, 4, 1, 1, 1, 4, 2, 1], [1, 5, 4, 4, 2, 5, 1, 5, 6, 1, 2, 3, 2, 8, 1, 7, 4, 2, 1, 3, 3, 3, 2], [1, 7, 2, 3, 2, 6, 3, 6, 5, 4, 3, 2, 3, 6, 2, 8, 3, 1, 3, 3, 2, 3, 3], [2, 8, 1, 1, 4, 7, 2, 5, 6, 4, 1, 1, 4, 7, 1, 5, 4, 3, 2, 1, 1, 4, 1], [1, 5, 1, 3, 1, 6, 2, 8, 5, 1, 3, 2, 4, 7, 2, 7, 1, 1, 4, 2, 1, 4, 1]], device='cuda:0'), tensor([[ 8, 18, 10, 15, 28, 23, 8, 59, 17, 13, 3, 6, 26, 18, 39, 14, 2, 21, 21, 4, 6, 5, 32], [19, 17, 6, 5, 31, 47, 38, 59, 45, 1, 2, 10, 16, 34, 62, 47, 11, 15, 7, 4, 13, 3, 27], [29, 17, 5, 2, 39, 18, 39, 58, 34, 22, 12, 11, 8, 14, 29, 17, 7, 10, 1, 1, 2, 4, 27], [16, 20, 2, 10, 29, 45, 16, 50, 46, 12, 11, 1, 8, 18, 40, 34, 10, 1, 1, 2, 4, 5, 32], [ 8, 18, 21, 4, 26, 14, 8, 58, 20, 15, 13, 3, 28, 23, 29, 44, 4, 5, 9, 12, 12, 3, 27], [29, 17, 13, 3, 35, 36, 24, 60, 18, 22, 3, 13, 24, 25, 28, 47, 11, 9, 12, 3, 13, 12, 43], [28, 23, 1, 2, 42, 17, 26, 58, 34, 10, 1, 2, 42, 45, 8, 18, 22, 3, 5, 1, 2, 10, 32], [ 8, 14, 9, 11, 16, 25, 28, 52, 14, 9, 3, 7, 42, 17, 19, 45, 1, 2, 4, 5, 2, 10, 32]], device='cuda:0'), tensor([[ 38, 62, 100, 53, 43, 143, 176, 244, 84, 9, 56, 61, 77, 76, 58, 127, 50, 245, 6, 12, 23, 89, 3], [ 30, 175, 23, 243, 136, 303, 304, 305, 130, 4, 7, 306, 60, 307, 308, 105, 141, 86, 26, 41, 9, 19, 2], [ 21, 42, 18, 106, 234, 76, 235, 171, 151, 32, 10, 28, 8, 29, 21, 236, 33, 13, 1, 4, 5, 59, 2], [ 71, 118, 7, 247, 352, 353, 177, 354, 140, 10, 122, 178, 38, 172, 355, 128, 13, 1, 4, 5, 14, 89, 3], [ 38, 214, 6, 101, 125, 215, 283, 284, 95, 216, 9, 217, 43, 157, 218, 153, 14, 37, 36, 219, 16, 19, 2], [ 21, 84, 9, 115, 262, 116, 263, 264, 145, 27, 35, 195, 146, 147, 148, 105, 48, 36, 16, 35, 91, 265, 17], [ 43, 133, 4, 72, 129, 132, 324, 171, 128, 13, 4, 72, 325, 326, 38, 145, 27, 39, 11, 4, 7, 248, 3], [ 8, 73, 49, 160, 51, 147, 149, 312, 73, 15, 25, 313, 129, 179, 81, 130, 4, 5, 14, 18, 7, 248, 3]], device='cuda:0')]
inputs: tensor([[1, 1, 3, 1, 2, 4, 1, 1, 3, 2, 4, 2, 2, 1, 3, 1, 1, 3, 3, 3, 2, 2, 1], [2, 3, 2, 2, 1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2], [1, 3, 2, 1, 3, 1, 3, 1, 2, 3, 4, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2], [1, 2, 1, 3, 1, 3, 1, 2, 3, 4, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2, 1], [1, 1, 3, 3, 2, 1, 1, 1, 2, 1, 2, 4, 2, 4, 1, 3, 3, 2, 1, 4, 4, 4, 2], [1, 3, 2, 4, 2, 2, 4, 2, 1, 3, 4, 2, 4, 2, 2, 4, 4, 1, 4, 4, 2, 4, 4], [2, 4, 1, 1, 3, 3, 2, 1, 2, 3, 1, 1, 3, 3, 1, 1, 3, 4, 2, 1, 1, 3, 1], [1, 1, 1, 4, 1, 2, 2, 4, 1, 1, 4, 2, 3, 3, 2, 3, 1, 1, 3, 2, 1, 3, 1]], device='cuda:0')
inputs: [tensor([[2, 2, 3, 2, 1, 1, 2, 2, 1, 2, 4, 4, 2, 1, 4, 4, 2, 3, 4, 2, 1, 4, 2], [1, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3, 2, 4, 2], [1, 2, 2, 1, 4, 1, 1, 1, 3, 3, 1, 3, 1, 1, 4, 1, 2, 2, 3, 3, 2, 2, 2], [1, 2, 4, 2, 4, 1, 3, 3, 2, 1, 4, 4, 4, 2, 3, 3, 2, 1, 3, 1, 3, 2, 4], [1, 2, 4, 4, 1, 1, 4, 2, 4, 2, 2, 3, 3, 1, 4, 1, 3, 2, 3, 2, 1, 1, 4], [1, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3, 2, 4], [1, 2, 2, 2, 4, 3, 3, 3, 4, 4, 1, 4, 1, 1, 4, 4, 1, 4, 2, 3, 2, 2, 3], [2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3]]), tensor([[ 6, 7, 5, 4, 1, 8, 6, 4, 8, 11, 13, 3, 4, 9, 13, 3, 7, 15, 3, 4, 9, 3, 16], [ 1, 8, 7, 5, 11, 3, 7, 12, 2, 5, 6, 6, 4, 2, 5, 4, 8, 4, 2, 5, 11, 3, 16], [ 8, 6, 4, 9, 10, 1, 1, 2, 14, 12, 2, 12, 1, 9, 10, 8, 6, 7, 14, 5, 6, 6, 16], [ 8, 11, 3, 11, 10, 2, 14, 5, 4, 9, 13, 13, 3, 7, 14, 5, 4, 2, 12, 2, 5, 11, 18], [ 8, 11, 13, 10, 1, 9, 3, 11, 3, 6, 7, 14, 12, 9, 10, 2, 5, 7, 5, 4, 1, 9, 18], [ 9, 10, 8, 7, 5, 11, 3, 7, 12, 2, 5, 6, 6, 4, 2, 5, 4, 8, 4, 2, 5, 11, 18], [ 8, 6, 6, 11, 20, 14, 14, 15, 13, 10, 9, 10, 1, 9, 13, 10, 9, 3, 7, 5, 6, 7, 19], [ 7, 15, 13, 10, 8, 7, 5, 11, 3, 7, 12, 2, 5, 6, 6, 4, 2, 5, 4, 8, 4, 2, 19]]), tensor([[17, 14, 4, 24, 44, 25, 9, 13, 39, 45, 27, 38, 19, 36, 27, 31, 30, 46, 38, 19, 8, 54, 29], [44, 16, 14, 35, 5, 31, 23, 43, 3, 33, 21, 9, 28, 3, 4, 13, 48, 28, 3, 35, 5, 54, 29], [25, 9, 19, 32, 22, 1, 2, 34, 57, 43, 6, 7, 18, 32, 47, 25, 17, 40, 11, 33, 21, 68, 29], [39, 5, 12, 26, 20, 34, 11, 4, 19, 36, 71, 27, 31, 40, 11, 4, 28, 6, 43, 3, 35, 66, 53], [39, 45, 10, 22, 18, 8, 12, 5, 15, 17, 40, 57, 59, 32, 20, 3, 49, 14, 4, 24, 18, 61, 53], [32, 47, 16, 14, 35, 5, 31, 23, 43, 3, 33, 21, 9, 28, 3, 4, 13, 48, 28, 3, 35, 66, 53], [25, 21, 50, 73, 67, 63, 62, 41, 10, 42, 32, 22, 18, 36, 10, 42, 8, 31, 14, 33, 17, 65, 58], [30, 41, 10, 47, 16, 14, 35, 5, 31, 23, 43, 3, 33, 21, 9, 28, 3, 4, 13, 48, 28, 72, 58]]), tensor([[1, 2, 4, 3, 2, 4, 8, 1, 5, 8, 8, 2, 1, 2, 2, 1, 1, 2, 3, 2, 2, 7, 0], [2, 3, 1, 2, 1, 2, 9, 1, 5, 5, 5, 3, 1, 3, 2, 4, 2, 3, 2, 1, 1, 7, 0], [2, 3, 3, 2, 2, 1, 8, 1, 6, 9, 6, 3, 3, 1, 1, 1, 4, 1, 2, 2, 1, 7, 0], [3, 1, 3, 1, 2, 3, 9, 2, 8, 8, 8, 1, 2, 3, 3, 1, 4, 2, 4, 2, 1, 7, 0], [1, 2, 3, 2, 3, 1, 8, 1, 9, 9, 5, 2, 4, 2, 4, 1, 1, 4, 4, 2, 1, 7, 0], [3, 1, 2, 1, 2, 3, 6, 2, 5, 5, 9, 1, 3, 2, 4, 2, 3, 2, 1, 4, 1, 7, 0], [2, 3, 2, 4, 1, 4, 8, 1, 6, 8, 6, 4, 4, 3, 3, 3, 4, 2, 2, 2, 1, 7, 0], [2, 1, 2, 3, 1, 2, 5, 2, 9, 6, 9, 2, 4, 2, 3, 2, 1, 4, 4, 3, 2, 7, 0]]), tensor([[ 3, 5, 17, 4, 5, 53, 21, 22, 26, 25, 44, 7, 3, 9, 7, 1, 3, 6, 4, 9, 18, 8, 0], [ 6, 2, 3, 7, 3, 47, 30, 22, 28, 28, 23, 2, 13, 4, 5, 12, 6, 4, 7, 1, 14, 8, 0], [ 6, 15, 4, 9, 7, 34, 21, 41, 42, 20, 56, 15, 2, 1, 1, 11, 10, 3, 9, 7, 14, 8, 0], [ 2, 13, 2, 3, 6, 57, 35, 19, 25, 25, 21, 3, 6, 15, 2, 11, 12, 5, 12, 7, 14, 8, 0], [ 3, 6, 4, 6, 2, 34, 21, 48, 59, 51, 27, 5, 12, 5, 10, 1, 11, 16, 12, 7, 14, 8, 0], [ 2, 3, 7, 3, 6, 37, 40, 38, 28, 32, 30, 13, 4, 5, 12, 6, 4, 7, 11, 10, 14, 8, 0], [ 6, 4, 5, 10, 11, 53, 21, 41, 45, 62, 31, 16, 17, 15, 15, 61, 12, 9, 9, 7, 14, 8, 0], [ 7, 3, 6, 2, 3, 38, 27, 47, 20, 42, 35, 5, 12, 6, 4, 7, 11, 16, 17, 4, 18, 8, 0]]), tensor([[ 42, 41, 40, 33, 304, 190, 106, 305, 123, 306, 193, 16, 24, 19, 35, 30, 5, 18, 28, 93, 25, 1, 0], [ 9, 29, 45, 16, 169, 170, 101, 70, 158, 153, 81, 44, 13, 33, 7, 39, 18, 10, 35, 12, 2, 1, 0], [ 21, 36, 28, 19, 152, 102, 89, 125, 175, 302, 303, 38, 4, 3, 14, 32, 22, 24, 19, 26, 2, 1, 0], [ 44, 6, 29, 5, 248, 249, 250, 99, 251, 124, 79, 5, 21, 38, 34, 31, 27, 7, 37, 26, 2, 1, 0], [ 5, 18, 55, 9, 98, 102, 289, 189, 290, 185, 159, 7, 27, 11, 20, 14, 23, 59, 37, 26, 2, 1, 0], [ 29, 45, 16, 5, 64, 242, 168, 88, 142, 243, 244, 13, 33, 7, 39, 18, 10, 50, 32, 48, 2, 1, 0], [ 18, 33, 11, 47, 161, 190, 89, 90, 147, 285, 286, 63, 66, 58, 69, 172, 43, 46, 19, 26, 2, 1, 0], [ 16, 5, 9, 29, 220, 221, 222, 84, 145, 223, 146, 7, 39, 18, 10, 50, 23, 63, 40, 49, 25, 1, 0]]), tensor([[2, 6, 4, 2, 1, 5, 2, 6, 5, 2, 3, 3, 2, 5, 3, 8, 2, 4, 3, 2, 1, 3, 2], [1, 5, 2, 4, 2, 8, 2, 7, 5, 4, 2, 2, 2, 5, 4, 6, 1, 2, 1, 4, 2, 3, 2], [1, 6, 2, 1, 3, 5, 1, 5, 7, 4, 1, 4, 1, 5, 3, 5, 2, 2, 4, 4, 2, 2, 2], [1, 6, 3, 2, 3, 5, 4, 7, 6, 1, 3, 3, 3, 6, 4, 7, 2, 1, 4, 1, 4, 2, 3], [1, 6, 3, 3, 1, 5, 3, 6, 8, 2, 2, 4, 4, 5, 3, 5, 4, 2, 4, 2, 1, 1, 3], [1, 8, 1, 2, 4, 6, 3, 6, 7, 1, 4, 2, 2, 6, 1, 7, 2, 1, 2, 1, 4, 2, 3], [1, 6, 2, 2, 3, 7, 4, 7, 8, 3, 1, 3, 1, 5, 3, 8, 1, 3, 2, 4, 2, 2, 4], [2, 7, 3, 3, 1, 6, 4, 6, 8, 2, 4, 1, 4, 6, 2, 6, 1, 4, 2, 1, 2, 1, 4]]), tensor([[35, 34, 4, 5, 8, 33, 35, 60, 33, 13, 12, 3, 26, 30, 41, 37, 7, 22, 3, 5, 9, 3, 27], [ 8, 33, 7, 4, 28, 37, 19, 51, 18, 4, 6, 6, 26, 18, 40, 20, 15, 5, 2, 4, 13, 3, 27], [16, 25, 5, 9, 38, 14, 8, 59, 44, 10, 2, 10, 8, 30, 38, 33, 6, 7, 21, 4, 6, 6, 27], [16, 36, 3, 13, 38, 18, 42, 57, 20, 9, 12, 12, 24, 34, 42, 17, 5, 2, 10, 2, 4, 13, 43], [16, 36, 12, 11, 8, 30, 24, 54, 37, 6, 7, 21, 39, 30, 38, 18, 4, 7, 4, 5, 1, 9, 43], [31, 23, 15, 7, 40, 36, 24, 50, 45, 2, 4, 6, 35, 20, 29, 17, 5, 15, 5, 2, 4, 13, 43], [16, 25, 6, 13, 61, 44, 42, 55, 47, 11, 9, 11, 8, 30, 41, 23, 9, 3, 7, 4, 6, 7, 49], [19, 46, 12, 11, 16, 34, 40, 54, 37, 7, 10, 2, 40, 25, 35, 20, 2, 4, 5, 15, 5, 2, 49]]), tensor([[ 99, 139, 14, 349, 66, 227, 350, 351, 253, 91, 16, 152, 80, 85, 199, 162, 24, 27, 39, 37, 15, 19, 2], [ 66, 107, 26, 297, 83, 239, 180, 298, 70, 12, 46, 61, 77, 172, 64, 95, 124, 18, 5, 41, 9, 19, 2], [ 51, 54, 37, 155, 163, 215, 176, 348, 252, 34, 7, 90, 44, 173, 138, 88, 20, 52, 6, 12, 46, 131, 2], [ 67, 119, 35, 220, 98, 240, 241, 174, 68, 36, 219, 168, 233, 242, 129, 42, 18, 7, 34, 5, 41, 109, 17], [ 67, 102, 10, 28, 44, 47, 169, 123, 104, 20, 52, 251, 339, 173, 98, 70, 114, 26, 14, 11, 45, 65, 17], [ 78, 228, 86, 294, 237, 116, 201, 238, 295, 5, 12, 159, 94, 296, 21, 42, 22, 124, 18, 5, 41, 109, 17], [ 51, 117, 96, 331, 332, 333, 334, 335, 105, 48, 49, 28, 44, 85, 69, 336, 15, 25, 26, 12, 20, 97, 55], [ 63, 140, 10, 160, 60, 161, 276, 123, 162, 33, 34, 110, 277, 278, 94, 118, 5, 14, 22, 124, 18, 213, 55]])]
inputs: tensor([[2, 2, 3, 2, 1, 1, 2, 2, 1, 2, 4, 4, 2, 1, 4, 4, 2, 3, 4, 2, 1, 4, 2], [1, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3, 2, 4, 2], [1, 2, 2, 1, 4, 1, 1, 1, 3, 3, 1, 3, 1, 1, 4, 1, 2, 2, 3, 3, 2, 2, 2], [1, 2, 4, 2, 4, 1, 3, 3, 2, 1, 4, 4, 4, 2, 3, 3, 2, 1, 3, 1, 3, 2, 4], [1, 2, 4, 4, 1, 1, 4, 2, 4, 2, 2, 3, 3, 1, 4, 1, 3, 2, 3, 2, 1, 1, 4], [1, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3, 2, 4], [1, 2, 2, 2, 4, 3, 3, 3, 4, 4, 1, 4, 1, 1, 4, 4, 1, 4, 2, 3, 2, 2, 3], [2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3]])
inputs: [tensor([[2, 2, 3, 2, 1, 1, 2, 2, 1, 2, 4, 4, 2, 1, 4, 4, 2, 3, 4, 2, 1, 4, 2], [1, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3, 2, 4, 2], [1, 2, 2, 1, 4, 1, 1, 1, 3, 3, 1, 3, 1, 1, 4, 1, 2, 2, 3, 3, 2, 2, 2], [1, 2, 4, 2, 4, 1, 3, 3, 2, 1, 4, 4, 4, 2, 3, 3, 2, 1, 3, 1, 3, 2, 4], [1, 2, 4, 4, 1, 1, 4, 2, 4, 2, 2, 3, 3, 1, 4, 1, 3, 2, 3, 2, 1, 1, 4], [1, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3, 2, 4], [1, 2, 2, 2, 4, 3, 3, 3, 4, 4, 1, 4, 1, 1, 4, 4, 1, 4, 2, 3, 2, 2, 3], [2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3]], device='cuda:0'), tensor([[ 6, 7, 5, 4, 1, 8, 6, 4, 8, 11, 13, 3, 4, 9, 13, 3, 7, 15, 3, 4, 9, 3, 16], [ 1, 8, 7, 5, 11, 3, 7, 12, 2, 5, 6, 6, 4, 2, 5, 4, 8, 4, 2, 5, 11, 3, 16], [ 8, 6, 4, 9, 10, 1, 1, 2, 14, 12, 2, 12, 1, 9, 10, 8, 6, 7, 14, 5, 6, 6, 16], [ 8, 11, 3, 11, 10, 2, 14, 5, 4, 9, 13, 13, 3, 7, 14, 5, 4, 2, 12, 2, 5, 11, 18], [ 8, 11, 13, 10, 1, 9, 3, 11, 3, 6, 7, 14, 12, 9, 10, 2, 5, 7, 5, 4, 1, 9, 18], [ 9, 10, 8, 7, 5, 11, 3, 7, 12, 2, 5, 6, 6, 4, 2, 5, 4, 8, 4, 2, 5, 11, 18], [ 8, 6, 6, 11, 20, 14, 14, 15, 13, 10, 9, 10, 1, 9, 13, 10, 9, 3, 7, 5, 6, 7, 19], [ 7, 15, 13, 10, 8, 7, 5, 11, 3, 7, 12, 2, 5, 6, 6, 4, 2, 5, 4, 8, 4, 2, 19]], device='cuda:0'), tensor([[17, 14, 4, 24, 44, 25, 9, 13, 39, 45, 27, 38, 19, 36, 27, 31, 30, 46, 38, 19, 8, 54, 29], [44, 16, 14, 35, 5, 31, 23, 43, 3, 33, 21, 9, 28, 3, 4, 13, 48, 28, 3, 35, 5, 54, 29], [25, 9, 19, 32, 22, 1, 2, 34, 57, 43, 6, 7, 18, 32, 47, 25, 17, 40, 11, 33, 21, 68, 29], [39, 5, 12, 26, 20, 34, 11, 4, 19, 36, 71, 27, 31, 40, 11, 4, 28, 6, 43, 3, 35, 66, 53], [39, 45, 10, 22, 18, 8, 12, 5, 15, 17, 40, 57, 59, 32, 20, 3, 49, 14, 4, 24, 18, 61, 53], [32, 47, 16, 14, 35, 5, 31, 23, 43, 3, 33, 21, 9, 28, 3, 4, 13, 48, 28, 3, 35, 66, 53], [25, 21, 50, 73, 67, 63, 62, 41, 10, 42, 32, 22, 18, 36, 10, 42, 8, 31, 14, 33, 17, 65, 58], [30, 41, 10, 47, 16, 14, 35, 5, 31, 23, 43, 3, 33, 21, 9, 28, 3, 4, 13, 48, 28, 72, 58]], device='cuda:0'), tensor([[1, 2, 4, 3, 2, 4, 8, 1, 5, 8, 8, 2, 1, 2, 2, 1, 1, 2, 3, 2, 2, 7, 0], [2, 3, 1, 2, 1, 2, 9, 1, 5, 5, 5, 3, 1, 3, 2, 4, 2, 3, 2, 1, 1, 7, 0], [2, 3, 3, 2, 2, 1, 8, 1, 6, 9, 6, 3, 3, 1, 1, 1, 4, 1, 2, 2, 1, 7, 0], [3, 1, 3, 1, 2, 3, 9, 2, 8, 8, 8, 1, 2, 3, 3, 1, 4, 2, 4, 2, 1, 7, 0], [1, 2, 3, 2, 3, 1, 8, 1, 9, 9, 5, 2, 4, 2, 4, 1, 1, 4, 4, 2, 1, 7, 0], [3, 1, 2, 1, 2, 3, 6, 2, 5, 5, 9, 1, 3, 2, 4, 2, 3, 2, 1, 4, 1, 7, 0], [2, 3, 2, 4, 1, 4, 8, 1, 6, 8, 6, 4, 4, 3, 3, 3, 4, 2, 2, 2, 1, 7, 0], [2, 1, 2, 3, 1, 2, 5, 2, 9, 6, 9, 2, 4, 2, 3, 2, 1, 4, 4, 3, 2, 7, 0]], device='cuda:0'), tensor([[ 3, 5, 17, 4, 5, 53, 21, 22, 26, 25, 44, 7, 3, 9, 7, 1, 3, 6, 4, 9, 18, 8, 0], [ 6, 2, 3, 7, 3, 47, 30, 22, 28, 28, 23, 2, 13, 4, 5, 12, 6, 4, 7, 1, 14, 8, 0], [ 6, 15, 4, 9, 7, 34, 21, 41, 42, 20, 56, 15, 2, 1, 1, 11, 10, 3, 9, 7, 14, 8, 0], [ 2, 13, 2, 3, 6, 57, 35, 19, 25, 25, 21, 3, 6, 15, 2, 11, 12, 5, 12, 7, 14, 8, 0], [ 3, 6, 4, 6, 2, 34, 21, 48, 59, 51, 27, 5, 12, 5, 10, 1, 11, 16, 12, 7, 14, 8, 0], [ 2, 3, 7, 3, 6, 37, 40, 38, 28, 32, 30, 13, 4, 5, 12, 6, 4, 7, 11, 10, 14, 8, 0], [ 6, 4, 5, 10, 11, 53, 21, 41, 45, 62, 31, 16, 17, 15, 15, 61, 12, 9, 9, 7, 14, 8, 0], [ 7, 3, 6, 2, 3, 38, 27, 47, 20, 42, 35, 5, 12, 6, 4, 7, 11, 16, 17, 4, 18, 8, 0]], device='cuda:0'), tensor([[ 42, 41, 40, 33, 304, 190, 106, 305, 123, 306, 193, 16, 24, 19, 35, 30, 5, 18, 28, 93, 25, 1, 0], [ 9, 29, 45, 16, 169, 170, 101, 70, 158, 153, 81, 44, 13, 33, 7, 39, 18, 10, 35, 12, 2, 1, 0], [ 21, 36, 28, 19, 152, 102, 89, 125, 175, 302, 303, 38, 4, 3, 14, 32, 22, 24, 19, 26, 2, 1, 0], [ 44, 6, 29, 5, 248, 249, 250, 99, 251, 124, 79, 5, 21, 38, 34, 31, 27, 7, 37, 26, 2, 1, 0], [ 5, 18, 55, 9, 98, 102, 289, 189, 290, 185, 159, 7, 27, 11, 20, 14, 23, 59, 37, 26, 2, 1, 0], [ 29, 45, 16, 5, 64, 242, 168, 88, 142, 243, 244, 13, 33, 7, 39, 18, 10, 50, 32, 48, 2, 1, 0], [ 18, 33, 11, 47, 161, 190, 89, 90, 147, 285, 286, 63, 66, 58, 69, 172, 43, 46, 19, 26, 2, 1, 0], [ 16, 5, 9, 29, 220, 221, 222, 84, 145, 223, 146, 7, 39, 18, 10, 50, 23, 63, 40, 49, 25, 1, 0]], device='cuda:0'), tensor([[2, 6, 4, 2, 1, 5, 2, 6, 5, 2, 3, 3, 2, 5, 3, 8, 2, 4, 3, 2, 1, 3, 2], [1, 5, 2, 4, 2, 8, 2, 7, 5, 4, 2, 2, 2, 5, 4, 6, 1, 2, 1, 4, 2, 3, 2], [1, 6, 2, 1, 3, 5, 1, 5, 7, 4, 1, 4, 1, 5, 3, 5, 2, 2, 4, 4, 2, 2, 2], [1, 6, 3, 2, 3, 5, 4, 7, 6, 1, 3, 3, 3, 6, 4, 7, 2, 1, 4, 1, 4, 2, 3], [1, 6, 3, 3, 1, 5, 3, 6, 8, 2, 2, 4, 4, 5, 3, 5, 4, 2, 4, 2, 1, 1, 3], [1, 8, 1, 2, 4, 6, 3, 6, 7, 1, 4, 2, 2, 6, 1, 7, 2, 1, 2, 1, 4, 2, 3], [1, 6, 2, 2, 3, 7, 4, 7, 8, 3, 1, 3, 1, 5, 3, 8, 1, 3, 2, 4, 2, 2, 4], [2, 7, 3, 3, 1, 6, 4, 6, 8, 2, 4, 1, 4, 6, 2, 6, 1, 4, 2, 1, 2, 1, 4]], device='cuda:0'), tensor([[35, 34, 4, 5, 8, 33, 35, 60, 33, 13, 12, 3, 26, 30, 41, 37, 7, 22, 3, 5, 9, 3, 27], [ 8, 33, 7, 4, 28, 37, 19, 51, 18, 4, 6, 6, 26, 18, 40, 20, 15, 5, 2, 4, 13, 3, 27], [16, 25, 5, 9, 38, 14, 8, 59, 44, 10, 2, 10, 8, 30, 38, 33, 6, 7, 21, 4, 6, 6, 27], [16, 36, 3, 13, 38, 18, 42, 57, 20, 9, 12, 12, 24, 34, 42, 17, 5, 2, 10, 2, 4, 13, 43], [16, 36, 12, 11, 8, 30, 24, 54, 37, 6, 7, 21, 39, 30, 38, 18, 4, 7, 4, 5, 1, 9, 43], [31, 23, 15, 7, 40, 36, 24, 50, 45, 2, 4, 6, 35, 20, 29, 17, 5, 15, 5, 2, 4, 13, 43], [16, 25, 6, 13, 61, 44, 42, 55, 47, 11, 9, 11, 8, 30, 41, 23, 9, 3, 7, 4, 6, 7, 49], [19, 46, 12, 11, 16, 34, 40, 54, 37, 7, 10, 2, 40, 25, 35, 20, 2, 4, 5, 15, 5, 2, 49]], device='cuda:0'), tensor([[ 99, 139, 14, 349, 66, 227, 350, 351, 253, 91, 16, 152, 80, 85, 199, 162, 24, 27, 39, 37, 15, 19, 2], [ 66, 107, 26, 297, 83, 239, 180, 298, 70, 12, 46, 61, 77, 172, 64, 95, 124, 18, 5, 41, 9, 19, 2], [ 51, 54, 37, 155, 163, 215, 176, 348, 252, 34, 7, 90, 44, 173, 138, 88, 20, 52, 6, 12, 46, 131, 2], [ 67, 119, 35, 220, 98, 240, 241, 174, 68, 36, 219, 168, 233, 242, 129, 42, 18, 7, 34, 5, 41, 109, 17], [ 67, 102, 10, 28, 44, 47, 169, 123, 104, 20, 52, 251, 339, 173, 98, 70, 114, 26, 14, 11, 45, 65, 17], [ 78, 228, 86, 294, 237, 116, 201, 238, 295, 5, 12, 159, 94, 296, 21, 42, 22, 124, 18, 5, 41, 109, 17], [ 51, 117, 96, 331, 332, 333, 334, 335, 105, 48, 49, 28, 44, 85, 69, 336, 15, 25, 26, 12, 20, 97, 55], [ 63, 140, 10, 160, 60, 161, 276, 123, 162, 33, 34, 110, 277, 278, 94, 118, 5, 14, 22, 124, 18, 213, 55]], device='cuda:0')]
inputs: tensor([[2, 2, 3, 2, 1, 1, 2, 2, 1, 2, 4, 4, 2, 1, 4, 4, 2, 3, 4, 2, 1, 4, 2], [1, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3, 2, 4, 2], [1, 2, 2, 1, 4, 1, 1, 1, 3, 3, 1, 3, 1, 1, 4, 1, 2, 2, 3, 3, 2, 2, 2], [1, 2, 4, 2, 4, 1, 3, 3, 2, 1, 4, 4, 4, 2, 3, 3, 2, 1, 3, 1, 3, 2, 4], [1, 2, 4, 4, 1, 1, 4, 2, 4, 2, 2, 3, 3, 1, 4, 1, 3, 2, 3, 2, 1, 1, 4], [1, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3, 2, 4], [1, 2, 2, 2, 4, 3, 3, 3, 4, 4, 1, 4, 1, 1, 4, 4, 1, 4, 2, 3, 2, 2, 3], [2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3]], device='cuda:0')
inputs: [tensor([[1, 1, 2, 1, 2, 4, 1, 4, 1, 4, 1, 3, 4, 4, 2, 2, 2, 1, 4, 1, 3, 4, 1], [1, 2, 3, 2, 2, 1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4], [2, 1, 3, 1, 2, 4, 2, 4, 1, 2, 4, 1, 3, 2, 1, 2, 1, 2, 3, 1, 1, 4, 2], [1, 1, 4, 4, 1, 1, 1, 2, 3, 2, 2, 3, 2, 1, 1, 2, 1, 2, 2, 2, 1, 4, 2], [1, 1, 1, 3, 3, 2, 1, 2, 3, 1, 1, 3, 3, 1, 1, 3, 4, 2, 1, 1, 3, 1, 1], [1, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2, 2, 3, 4, 2, 4], [2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3, 2, 4, 2, 1, 4], [2, 3, 1, 4, 1, 1, 1, 3, 4, 1, 3, 3, 2, 3, 2, 4, 2, 2, 1, 1, 1, 1, 1]]), tensor([[ 1, 8, 4, 8, 11, 10, 9, 10, 9, 10, 2, 15, 13, 3, 6, 6, 4, 9, 10, 2, 15, 10, 17], [ 8, 7, 5, 6, 4, 9, 13, 10, 2, 12, 1, 2, 12, 8, 7, 15, 13, 10, 8, 7, 5, 11, 18], [ 4, 2, 12, 8, 11, 3, 11, 10, 8, 11, 10, 2, 5, 4, 8, 4, 8, 7, 12, 1, 9, 3, 16], [ 1, 9, 13, 10, 1, 1, 8, 7, 5, 6, 7, 5, 4, 1, 8, 4, 8, 6, 6, 4, 9, 3, 16], [ 1, 1, 2, 14, 5, 4, 8, 7, 12, 1, 2, 14, 12, 1, 2, 15, 3, 4, 1, 2, 12, 1, 17], [ 8, 7, 12, 1, 1, 2, 5, 4, 9, 3, 11, 13, 10, 9, 3, 4, 8, 6, 7, 15, 3, 11, 18], [ 7, 5, 11, 3, 7, 12, 2, 5, 6, 6, 4, 2, 5, 4, 8, 4, 2, 5, 11, 3, 4, 9, 18], [ 7, 12, 9, 10, 1, 1, 2, 15, 10, 2, 14, 5, 7, 5, 11, 3, 6, 4, 1, 1, 1, 1, 17]]), tensor([[44, 48, 13, 39, 26, 42, 32, 42, 32, 20, 51, 41, 27, 15, 21, 9, 19, 32, 20, 51, 56, 69, 37], [16, 14, 33, 9, 19, 36, 10, 20, 6, 7, 2, 6, 52, 16, 30, 41, 10, 47, 16, 14, 35, 66, 53], [28, 6, 52, 39, 5, 12, 26, 47, 39, 26, 20, 3, 4, 13, 48, 13, 16, 23, 7, 18, 8, 54, 29], [18, 36, 10, 22, 1, 44, 16, 14, 33, 17, 14, 4, 24, 44, 48, 13, 25, 21, 9, 19, 8, 54, 29], [ 1, 2, 34, 11, 4, 13, 16, 23, 7, 2, 34, 57, 7, 2, 51, 46, 38, 24, 2, 6, 7, 55, 37], [16, 23, 7, 1, 2, 3, 4, 19, 8, 12, 45, 10, 42, 8, 38, 13, 25, 17, 30, 46, 12, 66, 53], [14, 35, 5, 31, 23, 43, 3, 33, 21, 9, 28, 3, 4, 13, 48, 28, 3, 35, 5, 38, 19, 61, 53], [23, 59, 32, 22, 1, 2, 51, 56, 20, 34, 11, 49, 14, 35, 5, 15, 9, 24, 1, 1, 1, 55, 37]]), tensor([[3, 1, 4, 1, 2, 2, 5, 4, 8, 9, 6, 4, 1, 4, 1, 4, 2, 1, 2, 1, 1, 7, 0], [3, 2, 1, 4, 4, 3, 5, 1, 9, 6, 6, 3, 1, 4, 4, 1, 2, 2, 3, 2, 1, 7, 0], [1, 1, 3, 2, 1, 2, 6, 2, 9, 6, 8, 2, 1, 4, 2, 4, 2, 1, 3, 1, 2, 7, 0], [1, 2, 2, 2, 1, 2, 6, 1, 5, 9, 5, 2, 3, 2, 1, 1, 1, 4, 4, 1, 1, 7, 0], [3, 1, 1, 2, 4, 3, 6, 1, 9, 9, 6, 1, 3, 2, 1, 2, 3, 3, 1, 1, 1, 7, 0], [4, 3, 2, 2, 1, 2, 8, 1, 8, 8, 5, 4, 1, 2, 3, 1, 1, 1, 3, 2, 1, 7, 0], [2, 4, 2, 3, 1, 2, 6, 2, 9, 6, 5, 2, 2, 3, 1, 3, 2, 4, 2, 3, 2, 7, 0], [1, 1, 1, 2, 2, 4, 5, 3, 5, 9, 9, 1, 4, 3, 1, 1, 1, 4, 1, 3, 2, 7, 0]]), tensor([[ 2, 11, 10, 3, 9, 38, 60, 53, 55, 20, 31, 10, 11, 10, 11, 12, 7, 3, 7, 1, 14, 8, 0], [ 4, 7, 11, 16, 17, 24, 54, 48, 20, 36, 56, 2, 11, 16, 10, 3, 9, 6, 4, 7, 14, 8, 0], [ 1, 13, 4, 7, 3, 39, 40, 47, 20, 45, 44, 7, 11, 12, 5, 12, 7, 13, 2, 3, 18, 8, 0], [ 3, 9, 9, 7, 3, 39, 29, 22, 32, 51, 27, 6, 4, 7, 1, 1, 11, 16, 10, 1, 14, 8, 0], [ 2, 1, 3, 5, 17, 37, 29, 48, 59, 20, 29, 13, 4, 7, 3, 6, 15, 2, 1, 1, 14, 8, 0], [17, 4, 9, 7, 3, 19, 21, 34, 25, 46, 60, 10, 3, 6, 2, 1, 1, 13, 4, 7, 14, 8, 0], [ 5, 12, 6, 2, 3, 39, 40, 47, 20, 52, 27, 9, 6, 2, 13, 4, 5, 12, 6, 4, 18, 8, 0], [ 1, 1, 3, 9, 5, 33, 23, 24, 32, 59, 30, 11, 17, 2, 1, 1, 11, 10, 13, 4, 18, 8, 0]]), tensor([[ 34, 32, 22, 24, 212, 213, 214, 215, 216, 132, 133, 47, 32, 47, 31, 37, 16, 45, 35, 12, 2, 1, 0], [ 10, 50, 23, 63, 114, 74, 197, 198, 115, 116, 117, 34, 23, 17, 22, 24, 56, 18, 10, 26, 2, 1, 0], [ 8, 13, 10, 16, 83, 77, 128, 84, 311, 312, 193, 50, 31, 27, 7, 37, 54, 6, 29, 160, 25, 1, 0], [ 24, 46, 19, 16, 83, 269, 270, 107, 271, 185, 186, 18, 10, 35, 3, 14, 23, 17, 20, 12, 2, 1, 0], [ 4, 30, 42, 41, 136, 57, 151, 189, 284, 71, 176, 13, 10, 16, 5, 21, 38, 4, 3, 12, 2, 1, 0], [ 40, 28, 19, 16, 78, 105, 171, 75, 287, 288, 181, 22, 5, 9, 4, 3, 8, 13, 10, 26, 2, 1, 0], [ 7, 39, 9, 29, 83, 77, 128, 84, 129, 85, 130, 56, 9, 44, 13, 33, 7, 39, 18, 49, 25, 1, 0], [ 3, 30, 24, 15, 109, 110, 72, 73, 111, 112, 60, 52, 53, 4, 3, 14, 32, 61, 13, 49, 25, 1, 0]]), tensor([[1, 5, 2, 1, 2, 8, 1, 8, 5, 3, 1, 4, 3, 8, 2, 6, 2, 1, 3, 1, 4, 3, 1], [1, 6, 4, 2, 2, 5, 3, 8, 5, 4, 1, 1, 4, 5, 2, 7, 3, 3, 1, 2, 4, 2, 3], [2, 5, 4, 1, 2, 8, 2, 8, 5, 2, 3, 1, 4, 6, 1, 6, 1, 2, 4, 1, 1, 3, 2], [1, 5, 3, 3, 1, 5, 1, 6, 7, 2, 2, 4, 2, 5, 1, 6, 1, 2, 2, 2, 1, 3, 2], [1, 5, 1, 4, 4, 6, 1, 6, 7, 1, 1, 4, 4, 5, 1, 7, 3, 2, 1, 1, 4, 1, 1], [1, 6, 4, 1, 1, 5, 4, 6, 5, 3, 2, 3, 3, 5, 3, 6, 1, 2, 2, 4, 3, 2, 3], [2, 7, 2, 3, 2, 7, 1, 7, 6, 2, 2, 1, 4, 6, 1, 6, 1, 4, 2, 3, 2, 1, 3], [2, 7, 1, 3, 1, 5, 1, 7, 8, 1, 4, 4, 2, 7, 2, 8, 2, 2, 1, 1, 1, 1, 1]]), tensor([[ 8, 33, 5, 15, 28, 23, 31, 52, 30, 11, 2, 22, 41, 37, 35, 25, 5, 9, 11, 2, 22, 11, 32], [16, 34, 4, 6, 26, 30, 41, 52, 18, 10, 1, 2, 39, 33, 19, 46, 12, 11, 15, 7, 4, 13, 43], [26, 18, 10, 15, 28, 37, 28, 52, 33, 13, 11, 2, 40, 20, 16, 20, 15, 7, 10, 1, 9, 3, 27], [ 8, 30, 12, 11, 8, 14, 16, 50, 17, 6, 7, 4, 26, 14, 16, 20, 15, 6, 6, 5, 9, 3, 27], [ 8, 14, 2, 21, 40, 20, 16, 50, 45, 1, 2, 21, 39, 14, 29, 46, 3, 5, 1, 2, 10, 1, 32], [16, 34, 10, 1, 8, 18, 40, 60, 30, 3, 13, 12, 38, 30, 24, 20, 15, 6, 7, 22, 3, 13, 43], [19, 17, 13, 3, 19, 45, 29, 57, 25, 6, 5, 2, 40, 20, 16, 20, 2, 4, 13, 3, 5, 9, 43], [19, 45, 9, 11, 8, 14, 29, 55, 23, 2, 21, 4, 19, 17, 28, 37, 6, 5, 1, 1, 1, 1, 32]]), tensor([[ 66, 270, 22, 53, 43, 120, 271, 272, 198, 40, 121, 192, 199, 150, 93, 54, 37, 49, 40, 121, 200, 144, 3], [ 60, 139, 12, 61, 80, 85, 257, 185, 62, 13, 4, 106, 258, 108, 63, 140, 10, 141, 86, 26, 41, 109, 17], [ 77, 62, 100, 53, 83, 229, 149, 356, 253, 113, 40, 110, 64, 112, 71, 95, 86, 33, 13, 45, 15, 19, 2], [ 44, 317, 10, 28, 8, 250, 177, 318, 175, 20, 26, 101, 125, 250, 71, 95, 75, 46, 23, 37, 15, 19, 2], [ 8, 127, 50, 329, 64, 112, 177, 238, 130, 4, 50, 251, 58, 29, 207, 330, 39, 11, 4, 7, 13, 31, 3], [ 60, 128, 13, 178, 38, 172, 337, 338, 246, 35, 91, 137, 156, 47, 111, 95, 75, 20, 24, 27, 35, 109, 17], [ 30, 84, 9, 190, 81, 266, 196, 267, 117, 23, 18, 110, 64, 112, 71, 118, 5, 41, 9, 39, 37, 65, 17], [ 81, 182, 49, 28, 8, 29, 183, 134, 82, 50, 6, 103, 30, 135, 83, 104, 23, 11, 1, 1, 1, 31, 3]])]
inputs: tensor([[1, 1, 2, 1, 2, 4, 1, 4, 1, 4, 1, 3, 4, 4, 2, 2, 2, 1, 4, 1, 3, 4, 1], [1, 2, 3, 2, 2, 1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4], [2, 1, 3, 1, 2, 4, 2, 4, 1, 2, 4, 1, 3, 2, 1, 2, 1, 2, 3, 1, 1, 4, 2], [1, 1, 4, 4, 1, 1, 1, 2, 3, 2, 2, 3, 2, 1, 1, 2, 1, 2, 2, 2, 1, 4, 2], [1, 1, 1, 3, 3, 2, 1, 2, 3, 1, 1, 3, 3, 1, 1, 3, 4, 2, 1, 1, 3, 1, 1], [1, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2, 2, 3, 4, 2, 4], [2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3, 2, 4, 2, 1, 4], [2, 3, 1, 4, 1, 1, 1, 3, 4, 1, 3, 3, 2, 3, 2, 4, 2, 2, 1, 1, 1, 1, 1]])
inputs: [tensor([[1, 1, 2, 1, 2, 4, 1, 4, 1, 4, 1, 3, 4, 4, 2, 2, 2, 1, 4, 1, 3, 4, 1], [1, 2, 3, 2, 2, 1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4], [2, 1, 3, 1, 2, 4, 2, 4, 1, 2, 4, 1, 3, 2, 1, 2, 1, 2, 3, 1, 1, 4, 2], [1, 1, 4, 4, 1, 1, 1, 2, 3, 2, 2, 3, 2, 1, 1, 2, 1, 2, 2, 2, 1, 4, 2], [1, 1, 1, 3, 3, 2, 1, 2, 3, 1, 1, 3, 3, 1, 1, 3, 4, 2, 1, 1, 3, 1, 1], [1, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2, 2, 3, 4, 2, 4], [2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3, 2, 4, 2, 1, 4], [2, 3, 1, 4, 1, 1, 1, 3, 4, 1, 3, 3, 2, 3, 2, 4, 2, 2, 1, 1, 1, 1, 1]], device='cuda:0'), tensor([[ 1, 8, 4, 8, 11, 10, 9, 10, 9, 10, 2, 15, 13, 3, 6, 6, 4, 9, 10, 2, 15, 10, 17], [ 8, 7, 5, 6, 4, 9, 13, 10, 2, 12, 1, 2, 12, 8, 7, 15, 13, 10, 8, 7, 5, 11, 18], [ 4, 2, 12, 8, 11, 3, 11, 10, 8, 11, 10, 2, 5, 4, 8, 4, 8, 7, 12, 1, 9, 3, 16], [ 1, 9, 13, 10, 1, 1, 8, 7, 5, 6, 7, 5, 4, 1, 8, 4, 8, 6, 6, 4, 9, 3, 16], [ 1, 1, 2, 14, 5, 4, 8, 7, 12, 1, 2, 14, 12, 1, 2, 15, 3, 4, 1, 2, 12, 1, 17], [ 8, 7, 12, 1, 1, 2, 5, 4, 9, 3, 11, 13, 10, 9, 3, 4, 8, 6, 7, 15, 3, 11, 18], [ 7, 5, 11, 3, 7, 12, 2, 5, 6, 6, 4, 2, 5, 4, 8, 4, 2, 5, 11, 3, 4, 9, 18], [ 7, 12, 9, 10, 1, 1, 2, 15, 10, 2, 14, 5, 7, 5, 11, 3, 6, 4, 1, 1, 1, 1, 17]], device='cuda:0'), tensor([[44, 48, 13, 39, 26, 42, 32, 42, 32, 20, 51, 41, 27, 15, 21, 9, 19, 32, 20, 51, 56, 69, 37], [16, 14, 33, 9, 19, 36, 10, 20, 6, 7, 2, 6, 52, 16, 30, 41, 10, 47, 16, 14, 35, 66, 53], [28, 6, 52, 39, 5, 12, 26, 47, 39, 26, 20, 3, 4, 13, 48, 13, 16, 23, 7, 18, 8, 54, 29], [18, 36, 10, 22, 1, 44, 16, 14, 33, 17, 14, 4, 24, 44, 48, 13, 25, 21, 9, 19, 8, 54, 29], [ 1, 2, 34, 11, 4, 13, 16, 23, 7, 2, 34, 57, 7, 2, 51, 46, 38, 24, 2, 6, 7, 55, 37], [16, 23, 7, 1, 2, 3, 4, 19, 8, 12, 45, 10, 42, 8, 38, 13, 25, 17, 30, 46, 12, 66, 53], [14, 35, 5, 31, 23, 43, 3, 33, 21, 9, 28, 3, 4, 13, 48, 28, 3, 35, 5, 38, 19, 61, 53], [23, 59, 32, 22, 1, 2, 51, 56, 20, 34, 11, 49, 14, 35, 5, 15, 9, 24, 1, 1, 1, 55, 37]], device='cuda:0'), tensor([[3, 1, 4, 1, 2, 2, 5, 4, 8, 9, 6, 4, 1, 4, 1, 4, 2, 1, 2, 1, 1, 7, 0], [3, 2, 1, 4, 4, 3, 5, 1, 9, 6, 6, 3, 1, 4, 4, 1, 2, 2, 3, 2, 1, 7, 0], [1, 1, 3, 2, 1, 2, 6, 2, 9, 6, 8, 2, 1, 4, 2, 4, 2, 1, 3, 1, 2, 7, 0], [1, 2, 2, 2, 1, 2, 6, 1, 5, 9, 5, 2, 3, 2, 1, 1, 1, 4, 4, 1, 1, 7, 0], [3, 1, 1, 2, 4, 3, 6, 1, 9, 9, 6, 1, 3, 2, 1, 2, 3, 3, 1, 1, 1, 7, 0], [4, 3, 2, 2, 1, 2, 8, 1, 8, 8, 5, 4, 1, 2, 3, 1, 1, 1, 3, 2, 1, 7, 0], [2, 4, 2, 3, 1, 2, 6, 2, 9, 6, 5, 2, 2, 3, 1, 3, 2, 4, 2, 3, 2, 7, 0], [1, 1, 1, 2, 2, 4, 5, 3, 5, 9, 9, 1, 4, 3, 1, 1, 1, 4, 1, 3, 2, 7, 0]], device='cuda:0'), tensor([[ 2, 11, 10, 3, 9, 38, 60, 53, 55, 20, 31, 10, 11, 10, 11, 12, 7, 3, 7, 1, 14, 8, 0], [ 4, 7, 11, 16, 17, 24, 54, 48, 20, 36, 56, 2, 11, 16, 10, 3, 9, 6, 4, 7, 14, 8, 0], [ 1, 13, 4, 7, 3, 39, 40, 47, 20, 45, 44, 7, 11, 12, 5, 12, 7, 13, 2, 3, 18, 8, 0], [ 3, 9, 9, 7, 3, 39, 29, 22, 32, 51, 27, 6, 4, 7, 1, 1, 11, 16, 10, 1, 14, 8, 0], [ 2, 1, 3, 5, 17, 37, 29, 48, 59, 20, 29, 13, 4, 7, 3, 6, 15, 2, 1, 1, 14, 8, 0], [17, 4, 9, 7, 3, 19, 21, 34, 25, 46, 60, 10, 3, 6, 2, 1, 1, 13, 4, 7, 14, 8, 0], [ 5, 12, 6, 2, 3, 39, 40, 47, 20, 52, 27, 9, 6, 2, 13, 4, 5, 12, 6, 4, 18, 8, 0], [ 1, 1, 3, 9, 5, 33, 23, 24, 32, 59, 30, 11, 17, 2, 1, 1, 11, 10, 13, 4, 18, 8, 0]], device='cuda:0'), tensor([[ 34, 32, 22, 24, 212, 213, 214, 215, 216, 132, 133, 47, 32, 47, 31, 37, 16, 45, 35, 12, 2, 1, 0], [ 10, 50, 23, 63, 114, 74, 197, 198, 115, 116, 117, 34, 23, 17, 22, 24, 56, 18, 10, 26, 2, 1, 0], [ 8, 13, 10, 16, 83, 77, 128, 84, 311, 312, 193, 50, 31, 27, 7, 37, 54, 6, 29, 160, 25, 1, 0], [ 24, 46, 19, 16, 83, 269, 270, 107, 271, 185, 186, 18, 10, 35, 3, 14, 23, 17, 20, 12, 2, 1, 0], [ 4, 30, 42, 41, 136, 57, 151, 189, 284, 71, 176, 13, 10, 16, 5, 21, 38, 4, 3, 12, 2, 1, 0], [ 40, 28, 19, 16, 78, 105, 171, 75, 287, 288, 181, 22, 5, 9, 4, 3, 8, 13, 10, 26, 2, 1, 0], [ 7, 39, 9, 29, 83, 77, 128, 84, 129, 85, 130, 56, 9, 44, 13, 33, 7, 39, 18, 49, 25, 1, 0], [ 3, 30, 24, 15, 109, 110, 72, 73, 111, 112, 60, 52, 53, 4, 3, 14, 32, 61, 13, 49, 25, 1, 0]], device='cuda:0'), tensor([[1, 5, 2, 1, 2, 8, 1, 8, 5, 3, 1, 4, 3, 8, 2, 6, 2, 1, 3, 1, 4, 3, 1], [1, 6, 4, 2, 2, 5, 3, 8, 5, 4, 1, 1, 4, 5, 2, 7, 3, 3, 1, 2, 4, 2, 3], [2, 5, 4, 1, 2, 8, 2, 8, 5, 2, 3, 1, 4, 6, 1, 6, 1, 2, 4, 1, 1, 3, 2], [1, 5, 3, 3, 1, 5, 1, 6, 7, 2, 2, 4, 2, 5, 1, 6, 1, 2, 2, 2, 1, 3, 2], [1, 5, 1, 4, 4, 6, 1, 6, 7, 1, 1, 4, 4, 5, 1, 7, 3, 2, 1, 1, 4, 1, 1], [1, 6, 4, 1, 1, 5, 4, 6, 5, 3, 2, 3, 3, 5, 3, 6, 1, 2, 2, 4, 3, 2, 3], [2, 7, 2, 3, 2, 7, 1, 7, 6, 2, 2, 1, 4, 6, 1, 6, 1, 4, 2, 3, 2, 1, 3], [2, 7, 1, 3, 1, 5, 1, 7, 8, 1, 4, 4, 2, 7, 2, 8, 2, 2, 1, 1, 1, 1, 1]], device='cuda:0'), tensor([[ 8, 33, 5, 15, 28, 23, 31, 52, 30, 11, 2, 22, 41, 37, 35, 25, 5, 9, 11, 2, 22, 11, 32], [16, 34, 4, 6, 26, 30, 41, 52, 18, 10, 1, 2, 39, 33, 19, 46, 12, 11, 15, 7, 4, 13, 43], [26, 18, 10, 15, 28, 37, 28, 52, 33, 13, 11, 2, 40, 20, 16, 20, 15, 7, 10, 1, 9, 3, 27], [ 8, 30, 12, 11, 8, 14, 16, 50, 17, 6, 7, 4, 26, 14, 16, 20, 15, 6, 6, 5, 9, 3, 27], [ 8, 14, 2, 21, 40, 20, 16, 50, 45, 1, 2, 21, 39, 14, 29, 46, 3, 5, 1, 2, 10, 1, 32], [16, 34, 10, 1, 8, 18, 40, 60, 30, 3, 13, 12, 38, 30, 24, 20, 15, 6, 7, 22, 3, 13, 43], [19, 17, 13, 3, 19, 45, 29, 57, 25, 6, 5, 2, 40, 20, 16, 20, 2, 4, 13, 3, 5, 9, 43], [19, 45, 9, 11, 8, 14, 29, 55, 23, 2, 21, 4, 19, 17, 28, 37, 6, 5, 1, 1, 1, 1, 32]], device='cuda:0'), tensor([[ 66, 270, 22, 53, 43, 120, 271, 272, 198, 40, 121, 192, 199, 150, 93, 54, 37, 49, 40, 121, 200, 144, 3], [ 60, 139, 12, 61, 80, 85, 257, 185, 62, 13, 4, 106, 258, 108, 63, 140, 10, 141, 86, 26, 41, 109, 17], [ 77, 62, 100, 53, 83, 229, 149, 356, 253, 113, 40, 110, 64, 112, 71, 95, 86, 33, 13, 45, 15, 19, 2], [ 44, 317, 10, 28, 8, 250, 177, 318, 175, 20, 26, 101, 125, 250, 71, 95, 75, 46, 23, 37, 15, 19, 2], [ 8, 127, 50, 329, 64, 112, 177, 238, 130, 4, 50, 251, 58, 29, 207, 330, 39, 11, 4, 7, 13, 31, 3], [ 60, 128, 13, 178, 38, 172, 337, 338, 246, 35, 91, 137, 156, 47, 111, 95, 75, 20, 24, 27, 35, 109, 17], [ 30, 84, 9, 190, 81, 266, 196, 267, 117, 23, 18, 110, 64, 112, 71, 118, 5, 41, 9, 39, 37, 65, 17], [ 81, 182, 49, 28, 8, 29, 183, 134, 82, 50, 6, 103, 30, 135, 83, 104, 23, 11, 1, 1, 1, 31, 3]], device='cuda:0')]
inputs: tensor([[1, 1, 2, 1, 2, 4, 1, 4, 1, 4, 1, 3, 4, 4, 2, 2, 2, 1, 4, 1, 3, 4, 1], [1, 2, 3, 2, 2, 1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4], [2, 1, 3, 1, 2, 4, 2, 4, 1, 2, 4, 1, 3, 2, 1, 2, 1, 2, 3, 1, 1, 4, 2], [1, 1, 4, 4, 1, 1, 1, 2, 3, 2, 2, 3, 2, 1, 1, 2, 1, 2, 2, 2, 1, 4, 2], [1, 1, 1, 3, 3, 2, 1, 2, 3, 1, 1, 3, 3, 1, 1, 3, 4, 2, 1, 1, 3, 1, 1], [1, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2, 2, 3, 4, 2, 4], [2, 3, 2, 4, 2, 3, 1, 3, 2, 2, 2, 1, 3, 2, 1, 2, 1, 3, 2, 4, 2, 1, 4], [2, 3, 1, 4, 1, 1, 1, 3, 4, 1, 3, 3, 2, 3, 2, 4, 2, 2, 1, 1, 1, 1, 1]], device='cuda:0')
inputs: [tensor([[1, 3, 2, 1, 2, 1, 2, 3, 4, 1, 3, 1, 3, 3, 2, 1, 1, 1, 1, 2, 4, 2, 2], [2, 3, 3, 4, 2, 1, 3, 1, 1, 4, 2, 3, 1, 3, 2, 2, 3, 3, 2, 3, 4, 4, 2], [2, 3, 1, 4, 1, 1, 1, 3, 4, 1, 3, 3, 2, 3, 2, 4, 2, 2, 1, 1, 1, 1, 1], [1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2], [1, 4, 2, 4, 1, 1, 4, 4, 2, 2, 1, 1, 2, 4, 1, 4, 3, 3, 1, 1, 1, 1, 1], [1, 1, 2, 2, 4, 1, 1, 4, 2, 3, 3, 2, 3, 1, 1, 3, 2, 1, 3, 1, 1, 1, 1], [2, 2, 2, 1, 3, 4, 2, 2, 3, 3, 2, 2, 4, 2, 4, 1, 2, 2, 3, 3, 2, 2, 3], [1, 2, 2, 1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3]]), tensor([[ 2, 5, 4, 8, 4, 8, 7, 15, 10, 2, 12, 2, 14, 5, 4, 1, 1, 1, 8, 11, 3, 6, 16], [ 7, 14, 15, 3, 4, 2, 12, 1, 9, 3, 7, 12, 2, 5, 6, 7, 14, 5, 7, 15, 13, 3, 16], [ 7, 12, 9, 10, 1, 1, 2, 15, 10, 2, 14, 5, 7, 5, 11, 3, 6, 4, 1, 1, 1, 1, 17], [ 9, 13, 10, 2, 12, 1, 2, 12, 8, 7, 15, 13, 10, 8, 7, 5, 11, 3, 7, 12, 2, 5, 16], [ 9, 3, 11, 10, 1, 9, 13, 3, 6, 4, 1, 8, 11, 10, 9, 20, 14, 12, 1, 1, 1, 1, 17], [ 1, 8, 6, 11, 10, 1, 9, 3, 7, 14, 5, 7, 12, 1, 2, 5, 4, 2, 12, 1, 1, 1, 17], [ 6, 6, 4, 2, 15, 3, 6, 7, 14, 5, 6, 11, 3, 11, 10, 8, 6, 7, 14, 5, 6, 7, 19], [ 8, 6, 4, 9, 13, 10, 2, 12, 1, 2, 12, 8, 7, 15, 13, 10, 8, 7, 5, 11, 3, 7, 19]]), tensor([[ 3, 4, 13, 48, 13, 16, 30, 56, 20, 6, 43, 34, 11, 4, 24, 1, 1, 44, 39, 5, 15, 68, 29], [40, 62, 46, 38, 28, 6, 7, 18, 8, 31, 23, 43, 3, 33, 17, 40, 11, 49, 30, 41, 27, 54, 29], [23, 59, 32, 22, 1, 2, 51, 56, 20, 34, 11, 49, 14, 35, 5, 15, 9, 24, 1, 1, 1, 55, 37], [36, 10, 20, 6, 7, 2, 6, 52, 16, 30, 41, 10, 47, 16, 14, 35, 5, 31, 23, 43, 3, 60, 29], [ 8, 12, 26, 22, 18, 36, 27, 15, 9, 24, 44, 39, 26, 42, 70, 67, 57, 7, 1, 1, 1, 55, 37], [44, 25, 50, 26, 22, 18, 8, 31, 40, 11, 49, 23, 7, 2, 3, 4, 28, 6, 7, 1, 1, 55, 37], [21, 9, 28, 51, 46, 15, 17, 40, 11, 33, 50, 5, 12, 26, 47, 25, 17, 40, 11, 33, 17, 65, 58], [25, 9, 19, 36, 10, 20, 6, 7, 2, 6, 52, 16, 30, 41, 10, 47, 16, 14, 35, 5, 31, 65, 58]]), tensor([[4, 2, 1, 1, 1, 1, 5, 3, 9, 6, 9, 1, 4, 3, 2, 1, 2, 1, 2, 3, 1, 7, 0], [4, 3, 2, 3, 3, 2, 5, 3, 6, 9, 5, 4, 1, 1, 3, 1, 2, 4, 3, 3, 2, 7, 0], [1, 1, 1, 2, 2, 4, 5, 3, 5, 9, 9, 1, 4, 3, 1, 1, 1, 4, 1, 3, 2, 7, 0], [1, 3, 2, 4, 2, 3, 5, 1, 8, 8, 9, 2, 1, 3, 1, 1, 3, 1, 4, 4, 1, 7, 0], [1, 1, 1, 3, 3, 4, 6, 4, 5, 6, 6, 2, 2, 4, 4, 1, 1, 4, 2, 4, 1, 7, 0], [1, 1, 3, 1, 2, 3, 6, 1, 9, 5, 9, 3, 2, 4, 1, 1, 4, 2, 2, 1, 1, 7, 0], [2, 3, 3, 2, 2, 1, 8, 2, 8, 5, 5, 3, 3, 2, 2, 4, 3, 1, 2, 2, 2, 7, 0], [4, 2, 3, 2, 1, 4, 8, 3, 5, 6, 9, 1, 1, 3, 1, 4, 4, 1, 2, 2, 1, 7, 0]]), tensor([[12, 7, 1, 1, 1, 22, 23, 57, 20, 42, 30, 11, 17, 4, 7, 3, 7, 3, 6, 2, 14, 8, 0], [17, 4, 6, 15, 4, 38, 23, 37, 42, 51, 60, 10, 1, 13, 2, 3, 5, 17, 15, 4, 18, 8, 0], [ 1, 1, 3, 9, 5, 33, 23, 24, 32, 59, 30, 11, 17, 2, 1, 1, 11, 10, 13, 4, 18, 8, 0], [13, 4, 5, 12, 6, 24, 54, 34, 25, 55, 35, 7, 13, 2, 1, 13, 2, 11, 16, 10, 14, 8, 0], [ 1, 1, 13, 15, 61, 58, 31, 33, 43, 36, 40, 9, 5, 16, 10, 1, 11, 12, 5, 10, 14, 8, 0], [ 1, 13, 2, 3, 6, 37, 29, 48, 51, 32, 63, 4, 5, 10, 1, 11, 12, 9, 7, 1, 14, 8, 0], [ 6, 15, 4, 9, 7, 34, 44, 19, 46, 28, 23, 15, 4, 9, 5, 17, 2, 3, 9, 9, 18, 8, 0], [12, 6, 4, 7, 11, 53, 50, 24, 43, 42, 30, 1, 13, 2, 11, 16, 10, 3, 9, 7, 14, 8, 0]]), tensor([[ 37, 35, 3, 3, 187, 188, 182, 86, 145, 100, 60, 52, 40, 10, 16, 45, 16, 5, 9, 65, 2, 1, 0], [ 40, 55, 21, 36, 157, 256, 257, 258, 126, 259, 181, 20, 8, 6, 29, 42, 41, 66, 36, 49, 25, 1, 0], [ 3, 30, 24, 15, 109, 110, 72, 73, 111, 112, 60, 52, 53, 4, 3, 14, 32, 61, 13, 49, 25, 1, 0], [ 13, 33, 7, 39, 113, 74, 195, 75, 62, 76, 196, 54, 6, 4, 8, 6, 34, 23, 17, 48, 2, 1, 0], [ 3, 8, 68, 69, 154, 94, 67, 95, 155, 96, 97, 15, 51, 17, 20, 14, 31, 27, 11, 48, 2, 1, 0], [ 8, 6, 29, 5, 64, 57, 151, 229, 230, 143, 144, 33, 11, 20, 14, 31, 43, 19, 35, 12, 2, 1, 0], [ 21, 36, 28, 19, 152, 91, 92, 127, 231, 153, 135, 36, 28, 15, 41, 53, 29, 24, 46, 93, 25, 1, 0], [ 39, 18, 10, 50, 161, 234, 122, 235, 162, 100, 236, 8, 6, 34, 23, 17, 22, 24, 19, 26, 2, 1, 0]]), tensor([[1, 7, 2, 1, 2, 5, 2, 7, 8, 1, 4, 1, 4, 7, 2, 5, 1, 1, 1, 2, 3, 2, 2], [2, 7, 4, 3, 2, 5, 4, 5, 5, 3, 2, 4, 1, 7, 2, 6, 4, 4, 2, 4, 3, 3, 2], [2, 7, 1, 3, 1, 5, 1, 7, 8, 1, 4, 4, 2, 7, 2, 8, 2, 2, 1, 1, 1, 1, 1], [1, 8, 3, 1, 4, 5, 1, 7, 5, 2, 4, 3, 3, 5, 2, 7, 2, 3, 2, 4, 1, 4, 2], [1, 8, 2, 3, 1, 5, 3, 8, 6, 2, 1, 1, 2, 8, 1, 8, 4, 4, 1, 1, 1, 1, 1], [1, 5, 2, 2, 3, 5, 1, 8, 6, 4, 4, 2, 4, 5, 1, 7, 2, 1, 4, 1, 1, 1, 1], [2, 6, 2, 1, 4, 8, 2, 6, 7, 4, 2, 2, 3, 6, 3, 5, 2, 2, 4, 4, 2, 2, 4], [1, 6, 2, 1, 3, 8, 1, 7, 5, 1, 4, 1, 2, 7, 3, 8, 1, 2, 4, 2, 3, 2, 4]]), tensor([[29, 17, 5, 15, 26, 33, 19, 55, 23, 2, 10, 2, 42, 17, 26, 14, 1, 1, 15, 13, 3, 6, 27], [19, 44, 22, 3, 26, 18, 39, 56, 30, 3, 7, 10, 29, 17, 35, 34, 21, 4, 7, 22, 12, 3, 27], [19, 45, 9, 11, 8, 14, 29, 55, 23, 2, 21, 4, 19, 17, 28, 37, 6, 5, 1, 1, 1, 1, 32], [31, 47, 11, 2, 39, 14, 29, 51, 33, 7, 22, 12, 38, 33, 19, 17, 13, 3, 7, 10, 2, 4, 27], [31, 37, 13, 11, 8, 30, 41, 48, 25, 5, 1, 15, 28, 23, 31, 63, 21, 10, 1, 1, 1, 1, 32], [ 8, 33, 6, 13, 38, 14, 31, 48, 34, 21, 4, 7, 39, 14, 29, 17, 5, 2, 10, 1, 1, 1, 32], [35, 25, 5, 2, 62, 37, 35, 50, 44, 4, 6, 13, 24, 36, 38, 33, 6, 7, 21, 4, 6, 7, 49], [16, 25, 5, 9, 41, 23, 29, 51, 14, 2, 10, 15, 19, 46, 41, 23, 15, 7, 4, 13, 3, 7, 49]]), tensor([[ 21, 42, 22, 203, 92, 108, 327, 134, 82, 7, 34, 72, 129, 132, 125, 79, 1, 126, 216, 9, 56, 131, 2], [211, 212, 27, 152, 77, 76, 309, 310, 246, 25, 33, 247, 21, 311, 99, 158, 6, 114, 24, 32, 16, 19, 2], [ 81, 182, 49, 28, 8, 29, 183, 134, 82, 50, 6, 103, 30, 135, 83, 104, 23, 11, 1, 1, 1, 31, 3], [136, 105, 40, 106, 58, 29, 184, 256, 107, 24, 32, 137, 138, 108, 30, 84, 9, 25, 33, 34, 5, 59, 2], [ 74, 165, 113, 28, 44, 85, 142, 166, 54, 11, 126, 53, 43, 120, 221, 222, 223, 13, 1, 1, 1, 31, 3], [ 66, 88, 96, 220, 163, 285, 187, 286, 158, 6, 114, 287, 58, 29, 21, 42, 18, 7, 13, 1, 1, 31, 3], [ 93, 54, 18, 288, 289, 150, 290, 202, 153, 12, 96, 195, 57, 164, 138, 88, 20, 52, 6, 12, 20, 97, 55], [ 51, 54, 37, 291, 69, 157, 184, 181, 127, 7, 100, 292, 63, 293, 69, 228, 86, 26, 41, 9, 25, 97, 55]])]
inputs: tensor([[1, 3, 2, 1, 2, 1, 2, 3, 4, 1, 3, 1, 3, 3, 2, 1, 1, 1, 1, 2, 4, 2, 2], [2, 3, 3, 4, 2, 1, 3, 1, 1, 4, 2, 3, 1, 3, 2, 2, 3, 3, 2, 3, 4, 4, 2], [2, 3, 1, 4, 1, 1, 1, 3, 4, 1, 3, 3, 2, 3, 2, 4, 2, 2, 1, 1, 1, 1, 1], [1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2], [1, 4, 2, 4, 1, 1, 4, 4, 2, 2, 1, 1, 2, 4, 1, 4, 3, 3, 1, 1, 1, 1, 1], [1, 1, 2, 2, 4, 1, 1, 4, 2, 3, 3, 2, 3, 1, 1, 3, 2, 1, 3, 1, 1, 1, 1], [2, 2, 2, 1, 3, 4, 2, 2, 3, 3, 2, 2, 4, 2, 4, 1, 2, 2, 3, 3, 2, 2, 3], [1, 2, 2, 1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3]])
inputs: [tensor([[1, 3, 2, 1, 2, 1, 2, 3, 4, 1, 3, 1, 3, 3, 2, 1, 1, 1, 1, 2, 4, 2, 2], [2, 3, 3, 4, 2, 1, 3, 1, 1, 4, 2, 3, 1, 3, 2, 2, 3, 3, 2, 3, 4, 4, 2], [2, 3, 1, 4, 1, 1, 1, 3, 4, 1, 3, 3, 2, 3, 2, 4, 2, 2, 1, 1, 1, 1, 1], [1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2], [1, 4, 2, 4, 1, 1, 4, 4, 2, 2, 1, 1, 2, 4, 1, 4, 3, 3, 1, 1, 1, 1, 1], [1, 1, 2, 2, 4, 1, 1, 4, 2, 3, 3, 2, 3, 1, 1, 3, 2, 1, 3, 1, 1, 1, 1], [2, 2, 2, 1, 3, 4, 2, 2, 3, 3, 2, 2, 4, 2, 4, 1, 2, 2, 3, 3, 2, 2, 3], [1, 2, 2, 1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3]], device='cuda:0'), tensor([[ 2, 5, 4, 8, 4, 8, 7, 15, 10, 2, 12, 2, 14, 5, 4, 1, 1, 1, 8, 11, 3, 6, 16], [ 7, 14, 15, 3, 4, 2, 12, 1, 9, 3, 7, 12, 2, 5, 6, 7, 14, 5, 7, 15, 13, 3, 16], [ 7, 12, 9, 10, 1, 1, 2, 15, 10, 2, 14, 5, 7, 5, 11, 3, 6, 4, 1, 1, 1, 1, 17], [ 9, 13, 10, 2, 12, 1, 2, 12, 8, 7, 15, 13, 10, 8, 7, 5, 11, 3, 7, 12, 2, 5, 16], [ 9, 3, 11, 10, 1, 9, 13, 3, 6, 4, 1, 8, 11, 10, 9, 20, 14, 12, 1, 1, 1, 1, 17], [ 1, 8, 6, 11, 10, 1, 9, 3, 7, 14, 5, 7, 12, 1, 2, 5, 4, 2, 12, 1, 1, 1, 17], [ 6, 6, 4, 2, 15, 3, 6, 7, 14, 5, 6, 11, 3, 11, 10, 8, 6, 7, 14, 5, 6, 7, 19], [ 8, 6, 4, 9, 13, 10, 2, 12, 1, 2, 12, 8, 7, 15, 13, 10, 8, 7, 5, 11, 3, 7, 19]], device='cuda:0'), tensor([[ 3, 4, 13, 48, 13, 16, 30, 56, 20, 6, 43, 34, 11, 4, 24, 1, 1, 44, 39, 5, 15, 68, 29], [40, 62, 46, 38, 28, 6, 7, 18, 8, 31, 23, 43, 3, 33, 17, 40, 11, 49, 30, 41, 27, 54, 29], [23, 59, 32, 22, 1, 2, 51, 56, 20, 34, 11, 49, 14, 35, 5, 15, 9, 24, 1, 1, 1, 55, 37], [36, 10, 20, 6, 7, 2, 6, 52, 16, 30, 41, 10, 47, 16, 14, 35, 5, 31, 23, 43, 3, 60, 29], [ 8, 12, 26, 22, 18, 36, 27, 15, 9, 24, 44, 39, 26, 42, 70, 67, 57, 7, 1, 1, 1, 55, 37], [44, 25, 50, 26, 22, 18, 8, 31, 40, 11, 49, 23, 7, 2, 3, 4, 28, 6, 7, 1, 1, 55, 37], [21, 9, 28, 51, 46, 15, 17, 40, 11, 33, 50, 5, 12, 26, 47, 25, 17, 40, 11, 33, 17, 65, 58], [25, 9, 19, 36, 10, 20, 6, 7, 2, 6, 52, 16, 30, 41, 10, 47, 16, 14, 35, 5, 31, 65, 58]], device='cuda:0'), tensor([[4, 2, 1, 1, 1, 1, 5, 3, 9, 6, 9, 1, 4, 3, 2, 1, 2, 1, 2, 3, 1, 7, 0], [4, 3, 2, 3, 3, 2, 5, 3, 6, 9, 5, 4, 1, 1, 3, 1, 2, 4, 3, 3, 2, 7, 0], [1, 1, 1, 2, 2, 4, 5, 3, 5, 9, 9, 1, 4, 3, 1, 1, 1, 4, 1, 3, 2, 7, 0], [1, 3, 2, 4, 2, 3, 5, 1, 8, 8, 9, 2, 1, 3, 1, 1, 3, 1, 4, 4, 1, 7, 0], [1, 1, 1, 3, 3, 4, 6, 4, 5, 6, 6, 2, 2, 4, 4, 1, 1, 4, 2, 4, 1, 7, 0], [1, 1, 3, 1, 2, 3, 6, 1, 9, 5, 9, 3, 2, 4, 1, 1, 4, 2, 2, 1, 1, 7, 0], [2, 3, 3, 2, 2, 1, 8, 2, 8, 5, 5, 3, 3, 2, 2, 4, 3, 1, 2, 2, 2, 7, 0], [4, 2, 3, 2, 1, 4, 8, 3, 5, 6, 9, 1, 1, 3, 1, 4, 4, 1, 2, 2, 1, 7, 0]], device='cuda:0'), tensor([[12, 7, 1, 1, 1, 22, 23, 57, 20, 42, 30, 11, 17, 4, 7, 3, 7, 3, 6, 2, 14, 8, 0], [17, 4, 6, 15, 4, 38, 23, 37, 42, 51, 60, 10, 1, 13, 2, 3, 5, 17, 15, 4, 18, 8, 0], [ 1, 1, 3, 9, 5, 33, 23, 24, 32, 59, 30, 11, 17, 2, 1, 1, 11, 10, 13, 4, 18, 8, 0], [13, 4, 5, 12, 6, 24, 54, 34, 25, 55, 35, 7, 13, 2, 1, 13, 2, 11, 16, 10, 14, 8, 0], [ 1, 1, 13, 15, 61, 58, 31, 33, 43, 36, 40, 9, 5, 16, 10, 1, 11, 12, 5, 10, 14, 8, 0], [ 1, 13, 2, 3, 6, 37, 29, 48, 51, 32, 63, 4, 5, 10, 1, 11, 12, 9, 7, 1, 14, 8, 0], [ 6, 15, 4, 9, 7, 34, 44, 19, 46, 28, 23, 15, 4, 9, 5, 17, 2, 3, 9, 9, 18, 8, 0], [12, 6, 4, 7, 11, 53, 50, 24, 43, 42, 30, 1, 13, 2, 11, 16, 10, 3, 9, 7, 14, 8, 0]], device='cuda:0'), tensor([[ 37, 35, 3, 3, 187, 188, 182, 86, 145, 100, 60, 52, 40, 10, 16, 45, 16, 5, 9, 65, 2, 1, 0], [ 40, 55, 21, 36, 157, 256, 257, 258, 126, 259, 181, 20, 8, 6, 29, 42, 41, 66, 36, 49, 25, 1, 0], [ 3, 30, 24, 15, 109, 110, 72, 73, 111, 112, 60, 52, 53, 4, 3, 14, 32, 61, 13, 49, 25, 1, 0], [ 13, 33, 7, 39, 113, 74, 195, 75, 62, 76, 196, 54, 6, 4, 8, 6, 34, 23, 17, 48, 2, 1, 0], [ 3, 8, 68, 69, 154, 94, 67, 95, 155, 96, 97, 15, 51, 17, 20, 14, 31, 27, 11, 48, 2, 1, 0], [ 8, 6, 29, 5, 64, 57, 151, 229, 230, 143, 144, 33, 11, 20, 14, 31, 43, 19, 35, 12, 2, 1, 0], [ 21, 36, 28, 19, 152, 91, 92, 127, 231, 153, 135, 36, 28, 15, 41, 53, 29, 24, 46, 93, 25, 1, 0], [ 39, 18, 10, 50, 161, 234, 122, 235, 162, 100, 236, 8, 6, 34, 23, 17, 22, 24, 19, 26, 2, 1, 0]], device='cuda:0'), tensor([[1, 7, 2, 1, 2, 5, 2, 7, 8, 1, 4, 1, 4, 7, 2, 5, 1, 1, 1, 2, 3, 2, 2], [2, 7, 4, 3, 2, 5, 4, 5, 5, 3, 2, 4, 1, 7, 2, 6, 4, 4, 2, 4, 3, 3, 2], [2, 7, 1, 3, 1, 5, 1, 7, 8, 1, 4, 4, 2, 7, 2, 8, 2, 2, 1, 1, 1, 1, 1], [1, 8, 3, 1, 4, 5, 1, 7, 5, 2, 4, 3, 3, 5, 2, 7, 2, 3, 2, 4, 1, 4, 2], [1, 8, 2, 3, 1, 5, 3, 8, 6, 2, 1, 1, 2, 8, 1, 8, 4, 4, 1, 1, 1, 1, 1], [1, 5, 2, 2, 3, 5, 1, 8, 6, 4, 4, 2, 4, 5, 1, 7, 2, 1, 4, 1, 1, 1, 1], [2, 6, 2, 1, 4, 8, 2, 6, 7, 4, 2, 2, 3, 6, 3, 5, 2, 2, 4, 4, 2, 2, 4], [1, 6, 2, 1, 3, 8, 1, 7, 5, 1, 4, 1, 2, 7, 3, 8, 1, 2, 4, 2, 3, 2, 4]], device='cuda:0'), tensor([[29, 17, 5, 15, 26, 33, 19, 55, 23, 2, 10, 2, 42, 17, 26, 14, 1, 1, 15, 13, 3, 6, 27], [19, 44, 22, 3, 26, 18, 39, 56, 30, 3, 7, 10, 29, 17, 35, 34, 21, 4, 7, 22, 12, 3, 27], [19, 45, 9, 11, 8, 14, 29, 55, 23, 2, 21, 4, 19, 17, 28, 37, 6, 5, 1, 1, 1, 1, 32], [31, 47, 11, 2, 39, 14, 29, 51, 33, 7, 22, 12, 38, 33, 19, 17, 13, 3, 7, 10, 2, 4, 27], [31, 37, 13, 11, 8, 30, 41, 48, 25, 5, 1, 15, 28, 23, 31, 63, 21, 10, 1, 1, 1, 1, 32], [ 8, 33, 6, 13, 38, 14, 31, 48, 34, 21, 4, 7, 39, 14, 29, 17, 5, 2, 10, 1, 1, 1, 32], [35, 25, 5, 2, 62, 37, 35, 50, 44, 4, 6, 13, 24, 36, 38, 33, 6, 7, 21, 4, 6, 7, 49], [16, 25, 5, 9, 41, 23, 29, 51, 14, 2, 10, 15, 19, 46, 41, 23, 15, 7, 4, 13, 3, 7, 49]], device='cuda:0'), tensor([[ 21, 42, 22, 203, 92, 108, 327, 134, 82, 7, 34, 72, 129, 132, 125, 79, 1, 126, 216, 9, 56, 131, 2], [211, 212, 27, 152, 77, 76, 309, 310, 246, 25, 33, 247, 21, 311, 99, 158, 6, 114, 24, 32, 16, 19, 2], [ 81, 182, 49, 28, 8, 29, 183, 134, 82, 50, 6, 103, 30, 135, 83, 104, 23, 11, 1, 1, 1, 31, 3], [136, 105, 40, 106, 58, 29, 184, 256, 107, 24, 32, 137, 138, 108, 30, 84, 9, 25, 33, 34, 5, 59, 2], [ 74, 165, 113, 28, 44, 85, 142, 166, 54, 11, 126, 53, 43, 120, 221, 222, 223, 13, 1, 1, 1, 31, 3], [ 66, 88, 96, 220, 163, 285, 187, 286, 158, 6, 114, 287, 58, 29, 21, 42, 18, 7, 13, 1, 1, 31, 3], [ 93, 54, 18, 288, 289, 150, 290, 202, 153, 12, 96, 195, 57, 164, 138, 88, 20, 52, 6, 12, 20, 97, 55], [ 51, 54, 37, 291, 69, 157, 184, 181, 127, 7, 100, 292, 63, 293, 69, 228, 86, 26, 41, 9, 25, 97, 55]], device='cuda:0')]
inputs: tensor([[1, 3, 2, 1, 2, 1, 2, 3, 4, 1, 3, 1, 3, 3, 2, 1, 1, 1, 1, 2, 4, 2, 2], [2, 3, 3, 4, 2, 1, 3, 1, 1, 4, 2, 3, 1, 3, 2, 2, 3, 3, 2, 3, 4, 4, 2], [2, 3, 1, 4, 1, 1, 1, 3, 4, 1, 3, 3, 2, 3, 2, 4, 2, 2, 1, 1, 1, 1, 1], [1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3, 1, 3, 2], [1, 4, 2, 4, 1, 1, 4, 4, 2, 2, 1, 1, 2, 4, 1, 4, 3, 3, 1, 1, 1, 1, 1], [1, 1, 2, 2, 4, 1, 1, 4, 2, 3, 3, 2, 3, 1, 1, 3, 2, 1, 3, 1, 1, 1, 1], [2, 2, 2, 1, 3, 4, 2, 2, 3, 3, 2, 2, 4, 2, 4, 1, 2, 2, 3, 3, 2, 2, 3], [1, 2, 2, 1, 4, 4, 1, 3, 1, 1, 3, 1, 2, 3, 4, 4, 1, 2, 3, 2, 4, 2, 3]], device='cuda:0')
inputs: [tensor([[1, 2, 1, 1, 1, 4, 2, 4, 2, 4, 1, 2, 2, 4, 4, 2, 3, 1, 1, 3, 2, 2, 3], [2, 3, 1, 1, 4, 2, 4, 2, 2, 2, 2, 1, 2, 2, 3, 2, 1, 4, 4, 1, 1, 1, 3], [1, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2], [1, 1, 2, 3, 4, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1], [1, 4, 2, 4, 2, 2, 3, 3, 1, 4, 1, 3, 2, 3, 2, 1, 1, 4, 4, 2, 1, 1, 4], [1, 1, 1, 1, 3, 3, 3, 2, 2, 2, 3, 4, 4, 2, 4, 1, 3, 2, 1, 1, 1, 1, 4], [1, 1, 3, 1, 2, 4, 1, 1, 3, 2, 4, 2, 2, 1, 3, 1, 1, 3, 3, 3, 2, 2, 1], [1, 1, 3, 3, 2, 3, 2, 4, 2, 2, 1, 1, 1, 1, 1, 3, 3, 1, 3, 1, 1, 1, 1]]), tensor([[ 8, 4, 1, 1, 9, 3, 11, 3, 11, 10, 8, 6, 11, 13, 3, 7, 12, 1, 2, 5, 6, 7, 19], [ 7, 12, 1, 9, 3, 11, 3, 6, 6, 6, 4, 8, 6, 7, 5, 4, 9, 13, 10, 1, 1, 2, 19], [ 9, 10, 1, 1, 2, 5, 7, 12, 1, 1, 2, 5, 4, 9, 3, 11, 13, 10, 9, 3, 4, 8, 16], [ 1, 8, 7, 15, 13, 10, 1, 1, 2, 5, 7, 12, 1, 1, 2, 5, 4, 9, 3, 11, 13, 10, 17], [ 9, 3, 11, 3, 6, 7, 14, 12, 9, 10, 2, 5, 7, 5, 4, 1, 9, 13, 3, 4, 1, 9, 18], [ 1, 1, 1, 2, 14, 14, 5, 6, 6, 7, 15, 13, 3, 11, 10, 2, 5, 4, 1, 1, 1, 9, 18], [ 1, 2, 12, 8, 11, 10, 1, 2, 5, 11, 3, 6, 4, 2, 12, 1, 2, 14, 14, 5, 6, 4, 17], [ 1, 2, 14, 5, 7, 5, 11, 3, 6, 4, 1, 1, 1, 1, 2, 14, 12, 2, 12, 1, 1, 1, 17]]), tensor([[48, 24, 1, 18, 8, 12, 5, 12, 26, 47, 25, 50, 45, 27, 31, 23, 7, 2, 3, 33, 17, 65, 58], [23, 7, 18, 8, 12, 5, 15, 21, 21, 9, 13, 25, 17, 14, 4, 19, 36, 10, 22, 1, 2, 72, 58], [32, 22, 1, 2, 3, 49, 23, 7, 1, 2, 3, 4, 19, 8, 12, 45, 10, 42, 8, 38, 13, 75, 29], [44, 16, 30, 41, 10, 22, 1, 2, 3, 49, 23, 7, 1, 2, 3, 4, 19, 8, 12, 45, 10, 69, 37], [ 8, 12, 5, 15, 17, 40, 57, 59, 32, 20, 3, 49, 14, 4, 24, 18, 36, 27, 38, 24, 18, 61, 53], [ 1, 1, 2, 34, 63, 11, 33, 21, 17, 30, 41, 27, 12, 26, 20, 3, 4, 24, 1, 1, 18, 61, 53], [ 2, 6, 52, 39, 26, 22, 2, 3, 35, 5, 15, 9, 28, 6, 7, 2, 34, 63, 11, 33, 9, 64, 37], [ 2, 34, 11, 49, 14, 35, 5, 15, 9, 24, 1, 1, 1, 2, 34, 57, 43, 6, 7, 1, 1, 55, 37]]), tensor([[2, 3, 1, 1, 3, 2, 8, 4, 5, 5, 6, 4, 2, 4, 2, 4, 1, 1, 1, 2, 1, 7, 0], [1, 1, 4, 4, 1, 2, 9, 2, 5, 6, 5, 2, 2, 2, 4, 2, 4, 1, 1, 3, 2, 7, 0], [2, 4, 1, 4, 4, 2, 8, 1, 5, 9, 6, 1, 1, 3, 2, 3, 1, 1, 1, 4, 1, 7, 0], [4, 2, 4, 1, 2, 3, 6, 1, 6, 9, 5, 3, 1, 1, 1, 4, 4, 3, 2, 1, 1, 7, 0], [1, 2, 4, 4, 1, 1, 5, 3, 5, 9, 6, 4, 1, 3, 3, 2, 2, 4, 2, 4, 1, 7, 0], [1, 1, 1, 2, 3, 1, 8, 2, 8, 8, 9, 2, 2, 2, 3, 3, 3, 1, 1, 1, 1, 7, 0], [2, 3, 3, 3, 1, 1, 9, 1, 5, 5, 8, 2, 3, 1, 1, 4, 2, 1, 3, 1, 1, 7, 0], [1, 1, 3, 1, 3, 3, 6, 1, 6, 6, 6, 2, 2, 4, 2, 3, 2, 3, 3, 1, 1, 7, 0]]), tensor([[ 6, 2, 1, 13, 4, 19, 49, 33, 28, 43, 31, 12, 5, 12, 5, 10, 1, 1, 3, 7, 14, 8, 0], [ 1, 11, 16, 10, 3, 47, 35, 38, 43, 52, 27, 9, 9, 5, 12, 5, 10, 1, 13, 4, 18, 8, 0], [ 5, 10, 11, 16, 12, 19, 21, 22, 32, 20, 29, 1, 13, 4, 6, 2, 1, 1, 11, 10, 14, 8, 0], [12, 5, 10, 3, 6, 37, 29, 41, 42, 51, 23, 2, 1, 1, 11, 16, 17, 4, 7, 1, 14, 8, 0], [ 3, 5, 16, 10, 1, 22, 23, 24, 32, 20, 31, 10, 13, 15, 4, 9, 5, 12, 5, 10, 14, 8, 0], [ 1, 1, 3, 6, 2, 34, 44, 19, 25, 55, 35, 9, 9, 6, 15, 15, 2, 1, 1, 1, 14, 8, 0], [ 6, 15, 15, 2, 1, 48, 30, 22, 28, 26, 44, 6, 2, 1, 11, 12, 7, 13, 2, 1, 14, 8, 0], [ 1, 13, 2, 13, 15, 37, 29, 41, 36, 36, 40, 9, 5, 12, 6, 4, 6, 15, 2, 1, 14, 8, 0]]), tensor([[ 9, 4, 8, 13, 237, 163, 164, 238, 239, 240, 241, 27, 7, 27, 11, 20, 3, 30, 45, 26, 2, 1, 0], [ 14, 23, 17, 22, 169, 183, 268, 184, 87, 85, 130, 46, 15, 7, 27, 11, 20, 8, 13, 49, 25, 1, 0], [ 11, 47, 23, 59, 104, 105, 106, 107, 108, 71, 194, 8, 13, 55, 9, 4, 3, 14, 32, 48, 2, 1, 0], [ 27, 11, 22, 5, 64, 57, 80, 125, 126, 205, 81, 4, 3, 14, 23, 63, 40, 10, 35, 12, 2, 1, 0], [ 42, 51, 17, 20, 187, 188, 72, 73, 108, 132, 133, 61, 68, 36, 28, 15, 7, 27, 11, 48, 2, 1, 0], [ 3, 30, 5, 9, 98, 91, 92, 99, 62, 76, 156, 46, 56, 21, 58, 38, 4, 3, 3, 12, 2, 1, 0], [ 21, 58, 38, 4, 177, 178, 101, 70, 103, 179, 180, 9, 4, 14, 31, 37, 54, 6, 4, 12, 2, 1, 0], [ 8, 6, 44, 68, 148, 57, 80, 191, 192, 96, 97, 15, 7, 39, 18, 55, 21, 38, 4, 12, 2, 1, 0]]), tensor([[1, 6, 1, 1, 1, 8, 2, 8, 6, 3, 1, 2, 2, 8, 3, 6, 4, 1, 1, 4, 2, 2, 4], [2, 7, 1, 1, 3, 6, 3, 6, 6, 2, 2, 1, 2, 6, 4, 6, 1, 3, 3, 1, 1, 1, 4], [1, 8, 1, 1, 1, 7, 2, 7, 5, 1, 1, 4, 2, 5, 3, 6, 3, 3, 1, 3, 2, 1, 2], [1, 5, 2, 4, 3, 8, 1, 5, 5, 4, 2, 4, 1, 5, 1, 7, 2, 1, 3, 2, 3, 3, 1], [1, 8, 2, 3, 2, 6, 4, 7, 5, 3, 1, 4, 2, 7, 2, 5, 1, 3, 3, 2, 1, 1, 3], [1, 5, 1, 1, 4, 7, 4, 6, 6, 2, 4, 3, 3, 6, 3, 5, 4, 2, 1, 1, 1, 1, 3], [1, 5, 4, 1, 2, 8, 1, 5, 7, 2, 3, 2, 2, 5, 4, 5, 1, 4, 4, 4, 2, 2, 1], [1, 5, 4, 4, 2, 7, 2, 8, 6, 2, 1, 1, 1, 5, 1, 7, 4, 1, 4, 1, 1, 1, 1]]), tensor([[16, 20, 1, 1, 31, 37, 28, 48, 36, 11, 15, 6, 28, 47, 24, 34, 10, 1, 2, 4, 6, 7, 49], [19, 45, 1, 9, 24, 36, 24, 53, 25, 6, 5, 15, 35, 34, 40, 20, 9, 12, 11, 1, 1, 2, 49], [31, 23, 1, 1, 29, 17, 19, 51, 14, 1, 2, 4, 26, 30, 24, 36, 12, 11, 9, 3, 5, 15, 27], [ 8, 33, 7, 22, 41, 23, 8, 56, 18, 4, 7, 10, 8, 14, 29, 17, 5, 9, 3, 13, 12, 11, 32], [31, 37, 13, 3, 35, 34, 42, 51, 30, 11, 2, 4, 19, 17, 26, 14, 9, 12, 3, 5, 1, 9, 43], [ 8, 14, 1, 2, 42, 44, 40, 53, 25, 7, 22, 12, 24, 36, 38, 18, 4, 5, 1, 1, 1, 9, 43], [ 8, 18, 10, 15, 28, 23, 8, 59, 17, 13, 3, 6, 26, 18, 39, 14, 2, 21, 21, 4, 6, 5, 32], [ 8, 18, 21, 4, 19, 17, 28, 48, 25, 5, 1, 1, 8, 14, 29, 44, 10, 2, 10, 1, 1, 1, 32]]), tensor([[ 71, 170, 1, 208, 74, 229, 230, 188, 204, 141, 75, 231, 148, 232, 233, 128, 13, 4, 5, 12, 20, 97, 55], [ 81, 130, 45, 189, 57, 116, 205, 167, 117, 23, 22, 249, 99, 161, 64, 68, 36, 10, 122, 1, 4, 213, 55], [ 78, 133, 1, 254, 21, 179, 180, 181, 79, 4, 5, 101, 80, 47, 57, 102, 10, 48, 15, 39, 22, 255, 2], [ 66, 107, 24, 192, 69, 143, 193, 194, 70, 114, 33, 90, 8, 29, 21, 42, 37, 15, 35, 91, 10, 144, 3], [ 74, 165, 9, 115, 99, 242, 319, 320, 198, 40, 5, 103, 30, 132, 125, 73, 36, 16, 39, 11, 45, 65, 17], [ 8, 79, 4, 72, 224, 225, 226, 167, 87, 24, 32, 168, 57, 164, 98, 70, 14, 11, 1, 1, 45, 65, 17], [ 38, 62, 100, 53, 43, 143, 176, 244, 84, 9, 56, 61, 77, 76, 58, 127, 50, 245, 6, 12, 23, 89, 3], [ 38, 214, 6, 103, 30, 135, 230, 166, 54, 11, 1, 178, 8, 29, 218, 252, 34, 7, 13, 1, 1, 31, 3]])]
inputs: tensor([[1, 2, 1, 1, 1, 4, 2, 4, 2, 4, 1, 2, 2, 4, 4, 2, 3, 1, 1, 3, 2, 2, 3], [2, 3, 1, 1, 4, 2, 4, 2, 2, 2, 2, 1, 2, 2, 3, 2, 1, 4, 4, 1, 1, 1, 3], [1, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2], [1, 1, 2, 3, 4, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1], [1, 4, 2, 4, 2, 2, 3, 3, 1, 4, 1, 3, 2, 3, 2, 1, 1, 4, 4, 2, 1, 1, 4], [1, 1, 1, 1, 3, 3, 3, 2, 2, 2, 3, 4, 4, 2, 4, 1, 3, 2, 1, 1, 1, 1, 4], [1, 1, 3, 1, 2, 4, 1, 1, 3, 2, 4, 2, 2, 1, 3, 1, 1, 3, 3, 3, 2, 2, 1], [1, 1, 3, 3, 2, 3, 2, 4, 2, 2, 1, 1, 1, 1, 1, 3, 3, 1, 3, 1, 1, 1, 1]])
inputs: [tensor([[1, 2, 1, 1, 1, 4, 2, 4, 2, 4, 1, 2, 2, 4, 4, 2, 3, 1, 1, 3, 2, 2, 3], [2, 3, 1, 1, 4, 2, 4, 2, 2, 2, 2, 1, 2, 2, 3, 2, 1, 4, 4, 1, 1, 1, 3], [1, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2], [1, 1, 2, 3, 4, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1], [1, 4, 2, 4, 2, 2, 3, 3, 1, 4, 1, 3, 2, 3, 2, 1, 1, 4, 4, 2, 1, 1, 4], [1, 1, 1, 1, 3, 3, 3, 2, 2, 2, 3, 4, 4, 2, 4, 1, 3, 2, 1, 1, 1, 1, 4], [1, 1, 3, 1, 2, 4, 1, 1, 3, 2, 4, 2, 2, 1, 3, 1, 1, 3, 3, 3, 2, 2, 1], [1, 1, 3, 3, 2, 3, 2, 4, 2, 2, 1, 1, 1, 1, 1, 3, 3, 1, 3, 1, 1, 1, 1]], device='cuda:0'), tensor([[ 8, 4, 1, 1, 9, 3, 11, 3, 11, 10, 8, 6, 11, 13, 3, 7, 12, 1, 2, 5, 6, 7, 19], [ 7, 12, 1, 9, 3, 11, 3, 6, 6, 6, 4, 8, 6, 7, 5, 4, 9, 13, 10, 1, 1, 2, 19], [ 9, 10, 1, 1, 2, 5, 7, 12, 1, 1, 2, 5, 4, 9, 3, 11, 13, 10, 9, 3, 4, 8, 16], [ 1, 8, 7, 15, 13, 10, 1, 1, 2, 5, 7, 12, 1, 1, 2, 5, 4, 9, 3, 11, 13, 10, 17], [ 9, 3, 11, 3, 6, 7, 14, 12, 9, 10, 2, 5, 7, 5, 4, 1, 9, 13, 3, 4, 1, 9, 18], [ 1, 1, 1, 2, 14, 14, 5, 6, 6, 7, 15, 13, 3, 11, 10, 2, 5, 4, 1, 1, 1, 9, 18], [ 1, 2, 12, 8, 11, 10, 1, 2, 5, 11, 3, 6, 4, 2, 12, 1, 2, 14, 14, 5, 6, 4, 17], [ 1, 2, 14, 5, 7, 5, 11, 3, 6, 4, 1, 1, 1, 1, 2, 14, 12, 2, 12, 1, 1, 1, 17]], device='cuda:0'), tensor([[48, 24, 1, 18, 8, 12, 5, 12, 26, 47, 25, 50, 45, 27, 31, 23, 7, 2, 3, 33, 17, 65, 58], [23, 7, 18, 8, 12, 5, 15, 21, 21, 9, 13, 25, 17, 14, 4, 19, 36, 10, 22, 1, 2, 72, 58], [32, 22, 1, 2, 3, 49, 23, 7, 1, 2, 3, 4, 19, 8, 12, 45, 10, 42, 8, 38, 13, 75, 29], [44, 16, 30, 41, 10, 22, 1, 2, 3, 49, 23, 7, 1, 2, 3, 4, 19, 8, 12, 45, 10, 69, 37], [ 8, 12, 5, 15, 17, 40, 57, 59, 32, 20, 3, 49, 14, 4, 24, 18, 36, 27, 38, 24, 18, 61, 53], [ 1, 1, 2, 34, 63, 11, 33, 21, 17, 30, 41, 27, 12, 26, 20, 3, 4, 24, 1, 1, 18, 61, 53], [ 2, 6, 52, 39, 26, 22, 2, 3, 35, 5, 15, 9, 28, 6, 7, 2, 34, 63, 11, 33, 9, 64, 37], [ 2, 34, 11, 49, 14, 35, 5, 15, 9, 24, 1, 1, 1, 2, 34, 57, 43, 6, 7, 1, 1, 55, 37]], device='cuda:0'), tensor([[2, 3, 1, 1, 3, 2, 8, 4, 5, 5, 6, 4, 2, 4, 2, 4, 1, 1, 1, 2, 1, 7, 0], [1, 1, 4, 4, 1, 2, 9, 2, 5, 6, 5, 2, 2, 2, 4, 2, 4, 1, 1, 3, 2, 7, 0], [2, 4, 1, 4, 4, 2, 8, 1, 5, 9, 6, 1, 1, 3, 2, 3, 1, 1, 1, 4, 1, 7, 0], [4, 2, 4, 1, 2, 3, 6, 1, 6, 9, 5, 3, 1, 1, 1, 4, 4, 3, 2, 1, 1, 7, 0], [1, 2, 4, 4, 1, 1, 5, 3, 5, 9, 6, 4, 1, 3, 3, 2, 2, 4, 2, 4, 1, 7, 0], [1, 1, 1, 2, 3, 1, 8, 2, 8, 8, 9, 2, 2, 2, 3, 3, 3, 1, 1, 1, 1, 7, 0], [2, 3, 3, 3, 1, 1, 9, 1, 5, 5, 8, 2, 3, 1, 1, 4, 2, 1, 3, 1, 1, 7, 0], [1, 1, 3, 1, 3, 3, 6, 1, 6, 6, 6, 2, 2, 4, 2, 3, 2, 3, 3, 1, 1, 7, 0]], device='cuda:0'), tensor([[ 6, 2, 1, 13, 4, 19, 49, 33, 28, 43, 31, 12, 5, 12, 5, 10, 1, 1, 3, 7, 14, 8, 0], [ 1, 11, 16, 10, 3, 47, 35, 38, 43, 52, 27, 9, 9, 5, 12, 5, 10, 1, 13, 4, 18, 8, 0], [ 5, 10, 11, 16, 12, 19, 21, 22, 32, 20, 29, 1, 13, 4, 6, 2, 1, 1, 11, 10, 14, 8, 0], [12, 5, 10, 3, 6, 37, 29, 41, 42, 51, 23, 2, 1, 1, 11, 16, 17, 4, 7, 1, 14, 8, 0], [ 3, 5, 16, 10, 1, 22, 23, 24, 32, 20, 31, 10, 13, 15, 4, 9, 5, 12, 5, 10, 14, 8, 0], [ 1, 1, 3, 6, 2, 34, 44, 19, 25, 55, 35, 9, 9, 6, 15, 15, 2, 1, 1, 1, 14, 8, 0], [ 6, 15, 15, 2, 1, 48, 30, 22, 28, 26, 44, 6, 2, 1, 11, 12, 7, 13, 2, 1, 14, 8, 0], [ 1, 13, 2, 13, 15, 37, 29, 41, 36, 36, 40, 9, 5, 12, 6, 4, 6, 15, 2, 1, 14, 8, 0]], device='cuda:0'), tensor([[ 9, 4, 8, 13, 237, 163, 164, 238, 239, 240, 241, 27, 7, 27, 11, 20, 3, 30, 45, 26, 2, 1, 0], [ 14, 23, 17, 22, 169, 183, 268, 184, 87, 85, 130, 46, 15, 7, 27, 11, 20, 8, 13, 49, 25, 1, 0], [ 11, 47, 23, 59, 104, 105, 106, 107, 108, 71, 194, 8, 13, 55, 9, 4, 3, 14, 32, 48, 2, 1, 0], [ 27, 11, 22, 5, 64, 57, 80, 125, 126, 205, 81, 4, 3, 14, 23, 63, 40, 10, 35, 12, 2, 1, 0], [ 42, 51, 17, 20, 187, 188, 72, 73, 108, 132, 133, 61, 68, 36, 28, 15, 7, 27, 11, 48, 2, 1, 0], [ 3, 30, 5, 9, 98, 91, 92, 99, 62, 76, 156, 46, 56, 21, 58, 38, 4, 3, 3, 12, 2, 1, 0], [ 21, 58, 38, 4, 177, 178, 101, 70, 103, 179, 180, 9, 4, 14, 31, 37, 54, 6, 4, 12, 2, 1, 0], [ 8, 6, 44, 68, 148, 57, 80, 191, 192, 96, 97, 15, 7, 39, 18, 55, 21, 38, 4, 12, 2, 1, 0]], device='cuda:0'), tensor([[1, 6, 1, 1, 1, 8, 2, 8, 6, 3, 1, 2, 2, 8, 3, 6, 4, 1, 1, 4, 2, 2, 4], [2, 7, 1, 1, 3, 6, 3, 6, 6, 2, 2, 1, 2, 6, 4, 6, 1, 3, 3, 1, 1, 1, 4], [1, 8, 1, 1, 1, 7, 2, 7, 5, 1, 1, 4, 2, 5, 3, 6, 3, 3, 1, 3, 2, 1, 2], [1, 5, 2, 4, 3, 8, 1, 5, 5, 4, 2, 4, 1, 5, 1, 7, 2, 1, 3, 2, 3, 3, 1], [1, 8, 2, 3, 2, 6, 4, 7, 5, 3, 1, 4, 2, 7, 2, 5, 1, 3, 3, 2, 1, 1, 3], [1, 5, 1, 1, 4, 7, 4, 6, 6, 2, 4, 3, 3, 6, 3, 5, 4, 2, 1, 1, 1, 1, 3], [1, 5, 4, 1, 2, 8, 1, 5, 7, 2, 3, 2, 2, 5, 4, 5, 1, 4, 4, 4, 2, 2, 1], [1, 5, 4, 4, 2, 7, 2, 8, 6, 2, 1, 1, 1, 5, 1, 7, 4, 1, 4, 1, 1, 1, 1]], device='cuda:0'), tensor([[16, 20, 1, 1, 31, 37, 28, 48, 36, 11, 15, 6, 28, 47, 24, 34, 10, 1, 2, 4, 6, 7, 49], [19, 45, 1, 9, 24, 36, 24, 53, 25, 6, 5, 15, 35, 34, 40, 20, 9, 12, 11, 1, 1, 2, 49], [31, 23, 1, 1, 29, 17, 19, 51, 14, 1, 2, 4, 26, 30, 24, 36, 12, 11, 9, 3, 5, 15, 27], [ 8, 33, 7, 22, 41, 23, 8, 56, 18, 4, 7, 10, 8, 14, 29, 17, 5, 9, 3, 13, 12, 11, 32], [31, 37, 13, 3, 35, 34, 42, 51, 30, 11, 2, 4, 19, 17, 26, 14, 9, 12, 3, 5, 1, 9, 43], [ 8, 14, 1, 2, 42, 44, 40, 53, 25, 7, 22, 12, 24, 36, 38, 18, 4, 5, 1, 1, 1, 9, 43], [ 8, 18, 10, 15, 28, 23, 8, 59, 17, 13, 3, 6, 26, 18, 39, 14, 2, 21, 21, 4, 6, 5, 32], [ 8, 18, 21, 4, 19, 17, 28, 48, 25, 5, 1, 1, 8, 14, 29, 44, 10, 2, 10, 1, 1, 1, 32]], device='cuda:0'), tensor([[ 71, 170, 1, 208, 74, 229, 230, 188, 204, 141, 75, 231, 148, 232, 233, 128, 13, 4, 5, 12, 20, 97, 55], [ 81, 130, 45, 189, 57, 116, 205, 167, 117, 23, 22, 249, 99, 161, 64, 68, 36, 10, 122, 1, 4, 213, 55], [ 78, 133, 1, 254, 21, 179, 180, 181, 79, 4, 5, 101, 80, 47, 57, 102, 10, 48, 15, 39, 22, 255, 2], [ 66, 107, 24, 192, 69, 143, 193, 194, 70, 114, 33, 90, 8, 29, 21, 42, 37, 15, 35, 91, 10, 144, 3], [ 74, 165, 9, 115, 99, 242, 319, 320, 198, 40, 5, 103, 30, 132, 125, 73, 36, 16, 39, 11, 45, 65, 17], [ 8, 79, 4, 72, 224, 225, 226, 167, 87, 24, 32, 168, 57, 164, 98, 70, 14, 11, 1, 1, 45, 65, 17], [ 38, 62, 100, 53, 43, 143, 176, 244, 84, 9, 56, 61, 77, 76, 58, 127, 50, 245, 6, 12, 23, 89, 3], [ 38, 214, 6, 103, 30, 135, 230, 166, 54, 11, 1, 178, 8, 29, 218, 252, 34, 7, 13, 1, 1, 31, 3]], device='cuda:0')]
inputs: tensor([[1, 2, 1, 1, 1, 4, 2, 4, 2, 4, 1, 2, 2, 4, 4, 2, 3, 1, 1, 3, 2, 2, 3], [2, 3, 1, 1, 4, 2, 4, 2, 2, 2, 2, 1, 2, 2, 3, 2, 1, 4, 4, 1, 1, 1, 3], [1, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1, 4, 2, 1, 2], [1, 1, 2, 3, 4, 4, 1, 1, 1, 3, 2, 3, 1, 1, 1, 3, 2, 1, 4, 2, 4, 4, 1], [1, 4, 2, 4, 2, 2, 3, 3, 1, 4, 1, 3, 2, 3, 2, 1, 1, 4, 4, 2, 1, 1, 4], [1, 1, 1, 1, 3, 3, 3, 2, 2, 2, 3, 4, 4, 2, 4, 1, 3, 2, 1, 1, 1, 1, 4], [1, 1, 3, 1, 2, 4, 1, 1, 3, 2, 4, 2, 2, 1, 3, 1, 1, 3, 3, 3, 2, 2, 1], [1, 1, 3, 3, 2, 3, 2, 4, 2, 2, 1, 1, 1, 1, 1, 3, 3, 1, 3, 1, 1, 1, 1]], device='cuda:0')
inputs: [tensor([[1, 2, 4, 2, 2, 3, 2, 1, 1, 2, 2, 1, 2, 4, 4, 2, 1, 4, 4, 2, 3, 4, 2], [1, 1, 1, 1, 3, 3, 3, 2, 2, 2, 3, 4, 4, 2, 4, 1, 3, 2, 1, 1, 1, 1, 4], [1, 1, 1, 2, 2, 1, 4, 2, 4, 2, 3, 2, 1, 4, 2, 3, 2, 1, 2, 2, 2, 2, 2]]), tensor([[ 8, 11, 3, 6, 7, 5, 4, 1, 8, 6, 4, 8, 11, 13, 3, 4, 9, 13, 3, 7, 15, 3, 16], [ 1, 1, 1, 2, 14, 14, 5, 6, 6, 7, 15, 13, 3, 11, 10, 2, 5, 4, 1, 1, 1, 9, 18], [ 1, 1, 8, 6, 4, 9, 3, 11, 3, 7, 5, 4, 9, 3, 7, 5, 4, 8, 6, 6, 6, 6, 16]]), tensor([[39, 5, 15, 17, 14, 4, 24, 44, 25, 9, 13, 39, 45, 27, 38, 19, 36, 27, 31, 30, 46, 54, 29], [ 1, 1, 2, 34, 63, 11, 33, 21, 17, 30, 41, 27, 12, 26, 20, 3, 4, 24, 1, 1, 18, 61, 53], [ 1, 44, 25, 9, 19, 8, 12, 5, 31, 14, 4, 19, 8, 31, 14, 4, 13, 25, 21, 21, 21, 68, 29]]), tensor([[3, 2, 4, 4, 1, 2, 8, 4, 5, 6, 5, 2, 1, 1, 2, 3, 2, 2, 4, 2, 1, 7, 0], [1, 1, 1, 2, 3, 1, 8, 2, 8, 8, 9, 2, 2, 2, 3, 3, 3, 1, 1, 1, 1, 7, 0], [2, 2, 2, 1, 2, 3, 5, 4, 6, 5, 9, 2, 4, 2, 4, 1, 2, 2, 1, 1, 1, 7, 0]]), tensor([[ 4, 5, 16, 10, 3, 19, 49, 33, 43, 52, 27, 7, 1, 3, 6, 4, 9, 5, 12, 7, 14, 8, 0], [ 1, 1, 3, 6, 2, 34, 44, 19, 25, 55, 35, 9, 9, 6, 15, 15, 2, 1, 1, 1, 14, 8, 0], [ 9, 9, 7, 3, 6, 24, 60, 58, 52, 32, 35, 5, 12, 5, 10, 3, 9, 7, 1, 1, 14, 8, 0]]), tensor([[ 33, 51, 17, 22, 78, 163, 164, 95, 87, 85, 267, 35, 30, 5, 18, 28, 15, 7, 37, 26, 2, 1, 0], [ 3, 30, 5, 9, 98, 91, 92, 99, 62, 76, 156, 46, 56, 21, 58, 38, 4, 3, 3, 12, 2, 1, 0], [ 46, 19, 16, 5, 113, 279, 280, 281, 282, 283, 146, 7, 27, 11, 22, 24, 19, 35, 3, 12, 2, 1, 0]]), tensor([[1, 6, 3, 2, 2, 7, 2, 5, 5, 2, 2, 1, 2, 8, 3, 6, 1, 3, 3, 2, 4, 3, 2], [1, 5, 1, 1, 4, 7, 4, 6, 6, 2, 4, 3, 3, 6, 3, 5, 4, 2, 1, 1, 1, 1, 3], [1, 5, 1, 2, 2, 5, 3, 6, 8, 2, 4, 2, 1, 8, 2, 7, 2, 1, 2, 2, 2, 2, 2]]), tensor([[16, 36, 3, 6, 19, 17, 26, 56, 33, 6, 5, 15, 28, 47, 24, 20, 9, 12, 3, 7, 22, 3, 27], [ 8, 14, 1, 2, 42, 44, 40, 53, 25, 7, 22, 12, 24, 36, 38, 18, 4, 5, 1, 1, 1, 9, 43], [ 8, 14, 15, 6, 26, 30, 24, 54, 37, 7, 4, 5, 31, 37, 19, 17, 5, 15, 6, 6, 6, 6, 27]]), tensor([[ 67, 119, 56, 314, 30, 132, 315, 316, 88, 23, 22, 53, 148, 232, 111, 68, 36, 16, 25, 24, 27, 19, 2], [ 8, 79, 4, 72, 224, 225, 226, 167, 87, 24, 32, 168, 57, 164, 98, 70, 14, 11, 1, 1, 45, 65, 17], [ 8, 328, 75, 61, 80, 47, 169, 123, 162, 26, 14, 243, 74, 239, 30, 42, 22, 75, 46, 46, 46, 131, 2]])]
inputs: tensor([[1, 2, 4, 2, 2, 3, 2, 1, 1, 2, 2, 1, 2, 4, 4, 2, 1, 4, 4, 2, 3, 4, 2], [1, 1, 1, 1, 3, 3, 3, 2, 2, 2, 3, 4, 4, 2, 4, 1, 3, 2, 1, 1, 1, 1, 4], [1, 1, 1, 2, 2, 1, 4, 2, 4, 2, 3, 2, 1, 4, 2, 3, 2, 1, 2, 2, 2, 2, 2]])
inputs: [tensor([[1, 2, 4, 2, 2, 3, 2, 1, 1, 2, 2, 1, 2, 4, 4, 2, 1, 4, 4, 2, 3, 4, 2], [1, 1, 1, 1, 3, 3, 3, 2, 2, 2, 3, 4, 4, 2, 4, 1, 3, 2, 1, 1, 1, 1, 4], [1, 1, 1, 2, 2, 1, 4, 2, 4, 2, 3, 2, 1, 4, 2, 3, 2, 1, 2, 2, 2, 2, 2]], device='cuda:0'), tensor([[ 8, 11, 3, 6, 7, 5, 4, 1, 8, 6, 4, 8, 11, 13, 3, 4, 9, 13, 3, 7, 15, 3, 16], [ 1, 1, 1, 2, 14, 14, 5, 6, 6, 7, 15, 13, 3, 11, 10, 2, 5, 4, 1, 1, 1, 9, 18], [ 1, 1, 8, 6, 4, 9, 3, 11, 3, 7, 5, 4, 9, 3, 7, 5, 4, 8, 6, 6, 6, 6, 16]], device='cuda:0'), tensor([[39, 5, 15, 17, 14, 4, 24, 44, 25, 9, 13, 39, 45, 27, 38, 19, 36, 27, 31, 30, 46, 54, 29], [ 1, 1, 2, 34, 63, 11, 33, 21, 17, 30, 41, 27, 12, 26, 20, 3, 4, 24, 1, 1, 18, 61, 53], [ 1, 44, 25, 9, 19, 8, 12, 5, 31, 14, 4, 19, 8, 31, 14, 4, 13, 25, 21, 21, 21, 68, 29]], device='cuda:0'), tensor([[3, 2, 4, 4, 1, 2, 8, 4, 5, 6, 5, 2, 1, 1, 2, 3, 2, 2, 4, 2, 1, 7, 0], [1, 1, 1, 2, 3, 1, 8, 2, 8, 8, 9, 2, 2, 2, 3, 3, 3, 1, 1, 1, 1, 7, 0], [2, 2, 2, 1, 2, 3, 5, 4, 6, 5, 9, 2, 4, 2, 4, 1, 2, 2, 1, 1, 1, 7, 0]], device='cuda:0'), tensor([[ 4, 5, 16, 10, 3, 19, 49, 33, 43, 52, 27, 7, 1, 3, 6, 4, 9, 5, 12, 7, 14, 8, 0], [ 1, 1, 3, 6, 2, 34, 44, 19, 25, 55, 35, 9, 9, 6, 15, 15, 2, 1, 1, 1, 14, 8, 0], [ 9, 9, 7, 3, 6, 24, 60, 58, 52, 32, 35, 5, 12, 5, 10, 3, 9, 7, 1, 1, 14, 8, 0]], device='cuda:0'), tensor([[ 33, 51, 17, 22, 78, 163, 164, 95, 87, 85, 267, 35, 30, 5, 18, 28, 15, 7, 37, 26, 2, 1, 0], [ 3, 30, 5, 9, 98, 91, 92, 99, 62, 76, 156, 46, 56, 21, 58, 38, 4, 3, 3, 12, 2, 1, 0], [ 46, 19, 16, 5, 113, 279, 280, 281, 282, 283, 146, 7, 27, 11, 22, 24, 19, 35, 3, 12, 2, 1, 0]], device='cuda:0'), tensor([[1, 6, 3, 2, 2, 7, 2, 5, 5, 2, 2, 1, 2, 8, 3, 6, 1, 3, 3, 2, 4, 3, 2], [1, 5, 1, 1, 4, 7, 4, 6, 6, 2, 4, 3, 3, 6, 3, 5, 4, 2, 1, 1, 1, 1, 3], [1, 5, 1, 2, 2, 5, 3, 6, 8, 2, 4, 2, 1, 8, 2, 7, 2, 1, 2, 2, 2, 2, 2]], device='cuda:0'), tensor([[16, 36, 3, 6, 19, 17, 26, 56, 33, 6, 5, 15, 28, 47, 24, 20, 9, 12, 3, 7, 22, 3, 27], [ 8, 14, 1, 2, 42, 44, 40, 53, 25, 7, 22, 12, 24, 36, 38, 18, 4, 5, 1, 1, 1, 9, 43], [ 8, 14, 15, 6, 26, 30, 24, 54, 37, 7, 4, 5, 31, 37, 19, 17, 5, 15, 6, 6, 6, 6, 27]], device='cuda:0'), tensor([[ 67, 119, 56, 314, 30, 132, 315, 316, 88, 23, 22, 53, 148, 232, 111, 68, 36, 16, 25, 24, 27, 19, 2], [ 8, 79, 4, 72, 224, 225, 226, 167, 87, 24, 32, 168, 57, 164, 98, 70, 14, 11, 1, 1, 45, 65, 17], [ 8, 328, 75, 61, 80, 47, 169, 123, 162, 26, 14, 243, 74, 239, 30, 42, 22, 75, 46, 46, 46, 131, 2]], device='cuda:0')]
inputs: tensor([[1, 2, 4, 2, 2, 3, 2, 1, 1, 2, 2, 1, 2, 4, 4, 2, 1, 4, 4, 2, 3, 4, 2], [1, 1, 1, 1, 3, 3, 3, 2, 2, 2, 3, 4, 4, 2, 4, 1, 3, 2, 1, 1, 1, 1, 4], [1, 1, 1, 2, 2, 1, 4, 2, 4, 2, 3, 2, 1, 4, 2, 3, 2, 1, 2, 2, 2, 2, 2]], device='cuda:0')
inputs 是一个包含 12 个张量的列表,每个张量形状为 [batch_size, max_len]
targets 是目标值张量
1、嵌入层¶
在模型中,首先第一步将得到的索引进行了embedding,token的embedding是将离散的符号(如单词、字符、或基因序列片段)映射到连续的向量空间的过程。这个过程通过将高维的稀疏表示(如独热编码)转换为低维的密集向量表示,使得相似的符号在向量空间中距离更近。
embedded = [self.embedding(seq) for seq in x]
嵌入层的输出是一个列表,包含 12 个张量,每个张量的形状为 [batch_size, max_len, embed_dim]
2、位置编码层¶
每个嵌入的序列会通过位置编码层,添加位置信息
for embed in embedded:
embed = self.position_encoder(embed)
位置编码层的输出形状保持不变,为 [batch_size, max_len, embed_dim]
3、Transformer 编码器¶
通过 Transformer 编码器,为了适应 Transformer 编码器的输入,需要对张量进行转置。
x = self.transformer_encoder(embed.transpose(0, 1)).transpose(0, 1)
形状从 [batch_size, max_len, embed_dim] 变为 [max_len, batch_size, embed_dim],然后 Transformer 编码器处理后再转置回来
4、Dropout 层¶
只保留每个序列的第一个 token 的输出(类似于 BERT 中的 [CLS] token)
x = self.dropout(x[:, 0, :])
形状为 [batch_size, embed_dim]
5、全连接层¶
所有序列的输出拼接起来,通过全连接层进行预测
x = torch.cat(outputs, dim=1)
x = self.fc(x)
输入形状为 [batch_size, embed_dim * len(tokenizer_list)],输出形状为 [batch_size, 1]