用矩阵快速幂计算斐波那契数列

261 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

背景介绍

递推式和矩阵乘法

斐波那契数列有递推公式

Fn+2=Fn+1+FnnNF_{n+2}=F_{n+1}+F_{n} \enspace n \in \mathbb{N}

我们可以把这个计算过程抽象成一个矩阵运算的过程。

[Fn+2Fn+1]=[1110][Fn+1Fn]\begin{bmatrix} F_{n+2}\\ F_{n+1} \end{bmatrix} = \begin{bmatrix} 1\enspace 1\\ 1\enspace 0 \end{bmatrix} \cdot \begin{bmatrix} F_{n+1}\\ F_{n} \end{bmatrix}

那么对于第nn项,我们有:

[FnFn1]=[1110]n1[F1F0]\begin{bmatrix} F_{n}\\ F_{n-1} \end{bmatrix} = \begin{bmatrix} 1\enspace 1\\ 1\enspace 0 \end{bmatrix}^{n-1} \cdot \begin{bmatrix} F_{1}\\ F_{0} \end{bmatrix}

快速幂

对于一个指数为正整数的幂运算,我们有:

XT=(X2)T2T2,4,6,XT=X(X2)T2T1,3,5,X^T=(X^2)^{\frac{T}{2}}\enspace T \in {2, 4, 6, \cdots}\\ X^T=X\cdot (X^2)^{\lfloor \frac{T}{2} \rfloor}\enspace T \in {1, 3, 5, \cdots}\\

依次递推,我们可以把幂运算的复杂度,从O(n)O(n)降低到O(log2n)O(log_2n)。 而我们又知道矩阵乘法运算是符合结合律的,所以可以使用快速幂。

代码实现

实现2阶矩阵

这里我们简单用一维列表来表示222\cdot2矩阵,重载加减乘运算符,并用快速幂重载幂运算运算符。

class matrix:
    def __init__(self, list:list):
        self.number = [0, 0, 0, 0]
        self.number[0] = list[0]
        self.number[1] = list[1]
        self.number[2] = list[2]
        self.number[3] = list[3]

    def __add__(self, other):
        return matrix([self.number[0] + other.number[0], self.number[1] + other.number[1], self.number[2] + other.number[2], self.number[3] + other.number[3]])

    def __sub__(self, other):
        return matrix([self.number[0] - other.number[0], self.number[1] - other.number[1], self.number[2] - other.number[2], self.number[3] - other.number[3]])

    def __mul__(self, other):
        '''
        a0 a1    b0 b1
        a2 a3    b2 b3
        a0*b0+a1*b2 a0*b1+a1*b3
        a2*b0+a3*b2 a2*b1+a3*b3
        '''
        list = [0, 0, 0, 0]
        list[0] = self.number[0] * other.number[0] + self.number[1] * other.number[2]
        list[1] = self.number[0] * other.number[1] + self.number[1] * other.number[3]
        list[2] = self.number[2] * other.number[0] + self.number[3] * other.number[2]
        list[3] = self.number[2] * other.number[1] + self.number[3] * other.number[3]
        return matrix(list)

    def __pow__(self, n):
        if n == 0:
            return matrix([1, 0, 0, 1])
        if n == 1:
            return self
        if n % 2 == 0:
            return (self * self) ** (n // 2)
        else:
            return self * (self * self) ** ((n - 1) // 2)

计算斐波那契数列

def getFib(n:int):
    m1 = matrix([1, 1, 1, 0])
    m1 = m1 ** (n-1)
    m1 = m1 * matrix([1, 0, 1, 0])
    return m1.number[0]

简单测试

在这里插入图片描述 由于python中的int是不限长度的,所以可以计算比较高位,例如第10000项。 在这里插入图片描述

完整代码

import decimal
from decimal import Decimal

decimal.getcontext().prec = 32000

class matrix:
    def __init__(self, list:list):
        self.number = [0, 0, 0, 0]
        self.number[0] = list[0]
        self.number[1] = list[1]
        self.number[2] = list[2]
        self.number[3] = list[3]

    def __add__(self, other):
        return matrix([self.number[0] + other.number[0], self.number[1] + other.number[1], self.number[2] + other.number[2], self.number[3] + other.number[3]])

    def __sub__(self, other):
        return matrix([self.number[0] - other.number[0], self.number[1] - other.number[1], self.number[2] - other.number[2], self.number[3] - other.number[3]])

    def __mul__(self, other):
        '''
        a0 a1    b0 b1
        a2 a3    b2 b3
        a0*b0+a1*b2 a0*b1+a1*b3
        a2*b0+a3*b2 a2*b1+a3*b3
        '''
        list = [0, 0, 0, 0]
        list[0] = self.number[0] * other.number[0] + self.number[1] * other.number[2]
        list[1] = self.number[0] * other.number[1] + self.number[1] * other.number[3]
        list[2] = self.number[2] * other.number[0] + self.number[3] * other.number[2]
        list[3] = self.number[2] * other.number[1] + self.number[3] * other.number[3]
        return matrix(list)

    def __pow__(self, n):
        if n == 0:
            return matrix([1, 0, 0, 1])
        if n == 1:
            return self
        if n % 2 == 0:
            return (self * self) ** (n // 2)
        else:
            return self * (self * self) ** ((n - 1) // 2)

def getFib(n:int):
    m1 = matrix([1, 1, 1, 0])
    m1 = m1 ** (n-1)
    m1 = m1 * matrix([1, 0, 1, 0])
    return m1.number[0]

if __name__ == '__main__':
    for i in range(10000, 10001):
        print('F['+str(i)+']=', str(getFib(i)))