python无任何导入包实现3*3卷积操作

纯python计算卷积

# -*- coding: utf-8 -*-
# @Time: 2022/11/09
# @Author: 
# @file: 3.py
'''
试题3:
考点:矩阵和深度学习基础
描述:卷积是神经网络中的基本单元,实现3*3卷积,步长固定为1,padding=1,输出特征维度和输入相同,不考虑偏置。(不能使用numpy等第三方库)
    例如:输入特征x=[[-1,0,1,2],[-2,1,2,0],[1,2,0,-1],[1,0,-1,2]],权重w=[[-1,0,2],[1,0,-1],[0,2,1]]
    输出y为 [[-3,2,2,1],[3,3,4,-1],[2,6,2,2],[4,1,-6,-1]]
自测结果:测试用例通过
'''

class Conv3x3():
    def mm(self, A, B):
        # 读取输入特征x的行长度和权重矩阵w列长度
        row_len = len(A)
        column_len = len(B)
        # padding
        pad_mat = [[0] * (row_len+2) for i in range(row_len+2)]
        for col in range(len(pad_mat)):
            for lin in range(len(pad_mat[0])):
                if  0<col<len(pad_mat)-1 and  0<lin<len(pad_mat)-1:
                    pad_mat[col][lin] = A[col-1][lin-1]
        # 矩阵初始化
        res_mat = [[0] * row_len for i in range(row_len)]
        # 滑动次数
        skip = len(pad_mat)-column_len+1
        for lin in range(skip):
            for col in range(skip):
                lin_star = lin
                lin_end = lin + column_len
                col_star = col
                col_end = col + column_len
                # 取值
                calu_val = self.get_value(pad_mat, lin_star, lin_end, col_star, col_end)
                # 计算
                cov_1 = self.dot(calu_val,B)
                res_mat[lin][col] = cov_1
        return res_mat

    def dot(self, list_1, list_2):
        #计算点乘
        sum = 0
        for k in range(len(list_1)):
            for s in range(len(list_1[0])):
                tmp = list_1[k][s] * list_2[k][s]
                sum += tmp
        return sum

    def get_value(self, list_1, lin_star, lin_end, col_star, col_end):
        # 对list取值
        list0 = []
        for i in range(lin_star, lin_end):
            list0.append(list_1[i][col_star:col_end])
        return list0


if __name__ == '__main__':
    #测试用例
    Input_x = [[-1, 0, 1, 2], [-2, 1, 2, 0], [1, 2, 0, -1], [1, 0, -1, 2]]
    Input_w = [[-1, 0, 2], [1, 0, -1], [0, 2, 1]]
    output = Conv3x3()
    ret = output.mm(Input_x, Input_w)
    print(ret)

 

 
posted @ 2022-11-10 11:19  海_纳百川  阅读(124)  评论(0编辑  收藏  举报
本站总访问量