[算法]求一个数的算数平方根

474 阅读2分钟

求一个数的算术平方根

求一个数的算术平方根有两种解法,一种是牛顿迭代法,是一种收敛算法,用于求解方程f(x)=0f(x) = 0的根的数值方法。通过迭代过程逼近方程的解,即找到xx的值,使得f(x)f(x)等于零,当满足给定的宽容度则结束迭代,一种是二分法,适用于不知道根的近似值或根的性质时,可以用于任何连续函数,适用性更广泛,但是收敛速度不如牛顿迭代法快。

题目

给定一个正实数 n ,计算并返回 n算术平方根,即 n\sqrt{n}

牛顿迭代法

牛顿迭代法公式

xn+1=xnf(xn)f(xn)x_{n+1} = x_n - \frac{f(x_n)}{f'(x_n)}

思路

求解一个输的算术平方根,实际上是求解公式x2N=0x^2 - N = 0的值,NN即为我们求解的根。使用牛顿迭代法相比二分法,有更快的收敛速度。

N是我们要求解的根,在迭代开始阶段,也是我们对根的猜测,我们假定一个相对x2x^2较小的数,或者x2x^2的一半,如果你不确定这个值是多少,可以把猜测的值设置为x2x^2,从这个数本身开始迭代,不会影响迭代的结果。

牛顿迭代法不适用于所有的方程,当导数可能为0,或者有多个根时,函数可能不收敛或者收敛到错误的根。但是对于平方根,不会出现这样的问题。

确定求解公式之后可以对其进行推导,建立方程:

f(x)=x2N=0f(x) = x^2 - N = 0

带入牛顿迭代公式推导,计算方程的导数为f(x)=2xf'(x) = 2x,带入迭代公式

xn+1=xxxn2N2x=1/2(xn+Nxn)x_{n+1} = x_{x} - \frac{x_{n}^2-N}{2x} = 1/2(x_{n} + \frac{N}{x_{n}})

则推导结果为 xn+1=1/2(xn+Nxn)x_{n+1} = 1/2 * (x_{n} + \frac{N}{x_{n}})

收敛条件

设定一个容忍度cc作为收敛条件,当两次迭代的差值,小于cc的时候,我们就认为足够接近真是的平方根。

算法实现

def newton_sqrt(n, x0=1, tolerance=1e-10, max_iterations=10000):
    """
    牛顿迭代公式 x_(n+1) = x_n - f(x_n) / f'(x_n)
    推导结果公式 x_next = 0.5 * (x_n + n / x_n)
    Args:
        n: 需要求解的数
        x0: 迭代初值,默认为 1
        tolerance: 精度,两次迭代结果的最大差值,若小于该值则认为已收敛
        max_iterations: 最大迭代次数,防止死循环
    """
    if n < 0:
        # 负数在实数范围内没有平方根
        return None

    x_n = x0
    for _ in range(max_iterations):
        # 根据推导公式计算下一迭代值:x_(n+1) = 0.5 * (x_n + n / x_n)
        x_next = 0.5 * (x_n + n / x_n)
        if abs(x_next - x_n) < tolerance:
            # 差值小于精度,认为已收敛
            return x_next
        x_n = x_next

    # 如果达到最大迭代次数还没有收敛,则返回x_n
    return x_n

go版本

package main

import (
	"fmt"
	"math"
)

// 牛顿迭代法求平方根
// 推导公式:x_(n+1) = 0.5 * (x_n + n / x_n)
// 收敛条件:abs(x_n - x_(n+1)) < tolerance
// Parameters:
// - n: 需要求平方根的数
// - x0: 初始迭代值
// - tolerance: 收敛精度
// - max_iterations: 最大迭代次数
// Returns:
// - 平方根
func newtonSqrt(
	n float64,
	x0 float64,
	tolerance float64,
	maxIterations int,
) float64 {
	if n < 0 {
		return 0
	}
	if tolerance == 0 {
		tolerance = 1e-10
	}
	if maxIterations == 0 {
		maxIterations = 10000
	}

	xN := x0
	for i := 0; i < maxIterations; i++ {
		// 根据推导公式计算下一个迭代值:x_(n+1) = 0.5 * (x_n + n / x_n)
		xNext := 0.5 * (xN + n/xN)
		if math.Abs(xNext-xN) < tolerance {
			return xNext
		}

		xN = xNext
	}

	// 如果达到最大迭代次数还没有收敛,则返回x_n
	return xN
}

func main() {
	fmt.Println(newtonSqrt(21, 11, 1e-10, 10000))
}

rust版本:

/// 二分法求平方根
/// 推导公式 c = 0.5 * (a + b)
/// 
/// Args:
/// *  n: 需要求平方根的数
/// *  x0: 迭代初值,默认为 1
/// *  tolerance: 精度
/// *  max_iterations: 最大迭代次数
/// Returns:
///     None: 负数在实数范围内没有平方根
///     Some: 平方根
fn newton_sqrt(n: f64, x0: f64, tolerance: Option<f64>, max_iterations: Option<i32>) -> Option<f64> {
    if n < 0.0 {
        return None;
    }

    let tolerance = tolerance.unwrap_or(1e-10);
    let max_iterations = max_iterations.unwrap_or(10000);

    let mut x_n = x0;
    for _ in 0..max_iterations {
        // 根据推导公式计算下一迭代值:x_(n+1) = 0.5 * (x_n + n / x_n)
        let x_next = 0.5 * (x_n + n / x_n);

        // 判断是否满足精度要求
        if (x_next - x_n).abs() < tolerance {
            return Some(x_next);
        }

        x_n = x_next;
    }

    Some(x_n)
}

fn main() {
    let n = 21.0;
    let sqrt = newton_sqrt(n, 1.0, None, None);
    println!("sqrt({:?}) = {:?}", n, sqrt);
}

二分法

二分法即为取区间[a,b][a,b]的中间一点 cc 为猜测值,计算是否满足条件,不断缩小这个区间,从来获得更加精确合理的猜测值。

二分法的优点是不需要知道函数的导数,不需要提前准备迭代公式,只需要知道原始公式,即x2=Nx^2=N,二分法具有线型收敛速度,适用于任何连续函数,特别是不知道根的近似值或者性质,同事是全局收敛的,只要根在初始区间之内,就一定能收敛到根。同时二分法计算量较小,只需要根据原始公式对猜测值进行计算即可。

二分法的迭代公式,即为计算区间[a,b][a,b]的中点

m=a+b2m = \frac{a+b}{2}

主要逻辑

判断中点mm的平方和NN的关系:

  • 如果m2m^2接近或者等于NN,即在容忍度cc内,则mm可以作为平方根的近似值

  • 如果m2>Nm^2 > N,则NN的平方根大于mm,将区间右侧设置为mmb=mb = m,继续下一次迭代

  • 如果m2<Nm^2 < N,则NN的平方根小于mm,将区间左侧设置为mma=ma = m,继续下一次迭代

收敛条件

ba<=c|b - a| <= c

算法实现

def binary_sqrt(n, tolerance=1e-10, max_iterations=10000):
    """
    二分法求平方根
    推导公式 c = 0.5 * (a + b)
    收敛条件:abs(c * c - n) < tolerance
    Args:
        n: 需要求解的数
        x0: 迭代初值,默认为 1
        tolerance: 精度,两次迭代结果的最大差值,若小于该值则认为已收敛
        max_iterations: 最大迭代次数,防止死循环
    """
    if n < 0:
        # 负数在实数范围内没有平方根
        return None

    a = 0
    b = n
    for _ in range(max_iterations):
        m = (a + b) / 2
        guess_square = m * m
        if abs(guess_square - n) < tolerance:
            # 差值小于精度,认为已收敛
            return m
        if guess_square > n:
            b = m
        else:
            a = m

    # 如果达到最大迭代次数还没有收敛,则返回None
    return None

go版本:

package main

import (
	"fmt"
	"math"
)

// 二分法求平方根
// 推导公式:m = 0.5 * (a + b)
// 收敛条件:m*m - n < tolerance
// Parameters:
// - n: 需要求平方根的数
// - tolerance: 收敛精度
// - max_iterations: 最大迭代次数
// Returns:
// - 平方根
// - 如果达到最大迭代次数还没有收敛,则返回nil
// - 如果负数在实数范围内没有平方根,则返回nil
func binarySqrt(n float64, tolerance float64, maxIterations int) *float64 {
	if n < 0 {
		return nil
	}

	if tolerance == 0 {
		tolerance = 1e-10
	}
	if maxIterations == 0 {
		maxIterations = 10000
	}

	a := 0.0
	b := n

	for i := 0; i < maxIterations; i++ {
		// 根据推导公式计算下一个迭代值:m = 0.5 * (a + b)
		m := (a + b) / 2
		guessSquare := m * m
		if math.Abs(guessSquare-n) < tolerance {
			return &m
		}
		if guessSquare > n {
			b = m
		} else {
			a = m
		}

	}

	// 如果达到最大迭代次数还没有收敛,则返回nil
	return nil
}

func main() {
    let n = 21.0;
    println!("newton_sqrt({:?}) = {:?}", n, newton_sqrt(n, 1.0, None, None));
}

rust版本:

/// 二分法求平方根
/// Args:
/// *  n: 需要求平方根的数
/// *  tolerance: 精度
/// *  max_iterations: 最大迭代次数
/// Returns:
///     None: 负数在实数范围内没有平方根,或者达到最大迭代次数还没有收敛
///     Some: 平方根
fn binary_sqrt(n: f64, tolerance: Option<f64>, max_iterations: Option<i32>) -> Option<f64> {
    if n < 0.0 {
        return None;
    }

    let tolerance = tolerance.unwrap_or(1e-10);
    let max_iterations = max_iterations.unwrap_or(10000);

    let mut a = 0.0;
    let mut b = n;
    for _ in 0..max_iterations {
        // 根据推导公式计算下一迭代值:m = 0.5 * (a + b)
        let m = 0.5 * (a + b);
        let guess_square = m * m;
        if (guess_square - n).abs() < tolerance {
            return Some(m);
        }
        if guess_square > n {
            b = m;
        } else {
            a = m;
        }
    }
    return None;
}

fn main() {
    let n = 21.0;
    println!("binary_sqrt({:?}) = {:?}", n, binary_sqrt(n, None, None));
}