矩阵乘法

819 阅读2分钟

这是我参与11月更文挑战的第7天,活动详情查看:2021最后一次更文挑战

问题描述

给定两个n * n 的矩阵A和B,求A * B。

示例:

输入:[[1,2],[3,2]],[[3,4],[2,1]]

输出:[[7,6],[13,14]]

分析问题

我们可以使用矩阵乘法规则来求解,对于n * n的矩阵A和矩阵B相乘,所得矩阵C的第i行第j列的元素可以表示为Ci,j=Ai,1* B1,j + Ai,2 * B2,j + ... + Ai,n * Bn,j , 即等于A的第i行和B的第j列对应元素的乘积之和。

image-20211104150541839

class Solution:
    def solve(self , a, b):
        # write code here
        #矩阵a和矩阵bn*n的矩阵
        n=len(a)
        res=[[0] * n for _ in range(n)]

        for i in range(0,n):
            for j in range(0,n):
                for k in range(0,n):
                    #C的第i行第j列的元素为
                    #A的第i行和B的第j列对应元素乘积的和
                    res[i][j] += a[i][k]*b[k][j]
        return res

该算法的时间复杂度是O(N^3),空间复杂度是O(N^2)。

我们都知道对于二维数组来说,在计算机的内存中实际上是顺序存储的,如下所示:

image-20211104151742743

因为操作系统加载数据到缓存中时,都是把命中数据附近的一批数据一起加载到缓存中,因为操作系统认为如果一个内存位置被引用了,那么程序很可能在不久的未来引用附近的一个内存位置。所以我们通过调整数组的读取顺序来进行优化,使得矩阵A和B顺序读取,然后相继送入CPU中进行计算,最后使得运行时间能够更快。下面我们来看一下具体做法:

class Solution:
    def solve(self , a, b):
        # write code here
        #矩阵a和矩阵bn*n的矩阵
        n=len(a)
        res=[[0] * n for _ in range(n)]

        for i in range(0,n):
            for j in range(0,n):
                #顺序访问矩阵A的元素
                temp=a[i][j]

                for k in range(0,n):
                    #矩阵b的元素也是顺序访问的
                    res[i][k] += temp * b[j][k]

        return res

该算法的时间复杂度是O(N^3),当时该算法利用了缓存优化,顺序读取数组A和数组B中的元素,因此一般会比第一种方法运行更快。该算法的空间复杂度是O(N^2)。