pytorch 和 tensorflow的 upsampling 互通代码

pytorch 实现上采样

点击查看代码
import numpy as np
import torch.nn.functional as F
import torch
from torch import nn


input = torch.arange(0, 12, dtype=torch.float32).view(2, 2, 3).transpose(1, 2)
# size 和 scale_factor只能二选一
sample_layer = nn.Upsample(scale_factor=2, mode='nearest')
print(input)
print(sample_layer(input).transpose(1, 2), sample_layer(input).transpose(1, 2).shape)


输出

点击查看代码

tensor([[[ 0.,  3.],
         [ 1.,  4.],
         [ 2.,  5.]],

        [[ 6.,  9.],
         [ 7., 10.],
         [ 8., 11.]]])
tensor([[[ 0.,  1.,  2.],
         [ 0.,  1.,  2.],
         [ 3.,  4.,  5.],
         [ 3.,  4.,  5.]],

        [[ 6.,  7.,  8.],
         [ 6.,  7.,  8.],
         [ 9., 10., 11.],
         [ 9., 10., 11.]]]) torch.Size([2, 4, 3])

Process finished with exit code 0


tensorflow 的实现

点击查看代码

import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import UpSampling1D

#


#   Arguments:
#     size: Integer. Upsampling factor.
#
#   Input shape:
#     3D tensor with shape: `(batch_size, steps, features)`.
#
#   Output shape:
#     3D tensor with shape: `(batch_size, upsampled_steps, features)`.
input_shape = (2, 2, 3)
x = np.arange(np.prod(input_shape)).reshape(input_shape)
print(x)
# [[[ 0  1  2]
#   [ 3  4  5]]
#  [[ 6  7  8]
#   [ 9 10 11]]]
y = tf.keras.layers.UpSampling1D(size=2)(x)
print(y)
# tf.Tensor(
#   [[[ 0  1  2]
#     [ 0  1  2]
#     [ 3  4  5]
#     [ 3  4  5]]
#    [[ 6  7  8]
#     [ 6  7  8]
#     [ 9 10 11]
#     [ 9 10 11]]], shape=(2, 4, 3), dtype=int64)


两者是完全等价的

posted @ 2022-09-16 22:24  筷点雪糕侠  阅读(122)  评论(0编辑  收藏  举报