题目:计算下面两个64位数的乘积。
So: what's the product of the following two 64-digit numbers?
3141592653589793238462643383279502884197169399375105820974944592
2718281828459045235360287471352662497757247093699959574966967627
思路分析
举例:5678 可分为 56,78 两半;1234 可分为 12,34两半;
5678 = 10^2 * 56 + 78; 即 a=56, b=78;
1234 = 10^2 * 12 + 34; 即 c=12, d=34
步骤1:计算 a*c = 56 * 12 = 672;
步骤2:计算 b*d = 78 * 34 = 2652;
步骤3:计算 (a+b)(c+d) = 134 * 46 = 6164;
步骤4:计算 (a+b)(c+d) -ac - bd = 6164 - 672 - 2652 = 2840 = ad + bc;
步骤5:计算 10^4 * ac + 10^2 * [ (a+b)(c+d) -ac - bd] + bd = 70066552.
n 为数字长度。
递归做法 Karatsuba Multiplication
递归的基准条件:
当 x 或 y 为个位数时,返回 x*y;
否则,递归执行数据拆分,直到将数据拆为个位数,返回乘积。
Python
a = 3141592653589793238462643383279502884197169399375105820974944592
b = 2718281828459045235360287471352662497757247093699959574966967627
def mul(x, y):
if len(str(x)) == 1 or len(str(y)) == 1:
return x * y
n = max(len(str(x)), len(str(y)))
nhalf = n // 2
a = x // (10 ** (nhalf))
b = x % (10 ** (nhalf))
c = y // (10 ** (nhalf))
d = y % (10 ** (nhalf))
# 递归,分别计算出 ac,bd,abcd 的乘积
ac = kmult(a, c)
bd = kmult(b, d)
abcd = kmult(a + b, c + d)
# 根据 ac,bd,abcd 的值,得出 result 值,返回给ac or bd or abcd. 直到递归结束,返回最终的result
result = (10 ** (2 * nhalf)) * ac + (10 ** nhalf) * (abcd - ac - bd) + bd
return result
c = mul(a, b)
print(int(c))
Java
int数据类型32位,最小值是 -2,147,483,648(-2^31);最大值是 2,147,483,647(2^31 - 1);
long数据类型64位,最小值是 -9,223,372,036,854,775,808(-2^63);最大值是 9,223,372,036,854,775,807(2^63 -1);
BigInteger存放大数,整数
BigDecimal小数
pow 次方
BigInteger ten = BigInteger.valueOf(10);
BigInteger pow = ten.pow(2); // 计算 10 的 2次方
mod 取余
BigInteger divide = ten.divide(BigInteger.valueOf(3)); // 10 / 3 = 3
BigInteger mod = ten.mod(BigInteger.valueOf(3)); // 10 % 3 = 1
System.out.println("divide=" + divide + ",,, mod=" +mod ); // 3 , 1
使用BigInteger计算大数乘法
最直接方法:
public BigInteger mul3(BigInteger x, BigInteger y) {
return x.multiply(y);
}
算法方法:
public class Multiply{
BigInteger ten = BigInteger.valueOf(10);
public BigInteger mul3(BigInteger x, BigInteger y) {
int xl = x.toString().length();
int yl = y.toString().length();
System.out.println("xl=" + xl + ", yl=" + yl);
if (xl == 1 || yl == 1) { // 递归结束条件,有个位数,返回两数相乘
return x.multiply(y);
}
int l = Math.max(xl, yl); // 取两个数中最大的数字长度
BigInteger pow = ten.pow(l / 2); // 10 ^ (l/2)
BigInteger a = x.divide(pow); // 取 x 前半部分
BigInteger b = x.mod(pow); // 取 x 后半部分
BigInteger c = y.divide(pow); // 取 y 前半部分
BigInteger d = y.mod(pow); // 取 y 后半部分
BigInteger ac = mul3(a, c); // 递归得到ac值,ac不断递归得到result值,返回给ac对象
BigInteger bd = mul3(b, d); // 递归得到bd值,bd不断递归得到result值,返回给bd对象
BigInteger abcd = mul3(a.add(b), c.add(d)); // 不断递归得到result值,返回给abcd对象
BigInteger mid = abcd.subtract(ac).subtract(bd); // abcd - ac - bd
System.out.println("ac=" + ac + ", bd=" + bd + ", abcd= " + abcd + ", mid=" + mid);
BigInteger result = ac.multiply(pow.pow(2)).add(mid.multiply(pow)).add(bd);// 上图伪代码中的公式
return result;
}
}
public class Algorithm{
public static void main(String[] args) {
String key1 = "3141592653589793238462643383279502884197169399375105820974944592";
String key2 = "2718281828459045235360287471352662497757247093699959574966967627";
BigInteger bi1 = new BigInteger(key1);
BigInteger bi2 = new BigInteger(key2);
Multiply multiply = new Multiply();
BigInteger result = multiply.mul3(bi1, bi2);
System.out.println(">> the end result=" + result);
}
}
使用Long
使用 long 数据类型 的前提条件是,最终乘积结果 小于 9,223,372,036,854,775,807 因此下列方法不适用于这道作业题,只是给出一个代码作为日后自己参考
/**
* < 9,223,372,036,854,775,807
*
* @param x
* @param y
* @return
*/
public long mul2(long x, long y) {
// 递归到某一位为个位数时,返回它们的乘积 (a*c , b*d , ab*cd)
if (x <10 || y <10) {
return x * y;
}
// 将 x 拆分成两半 a,b ; 将 y 拆分成两半 c,d ; x = 10^(n/2) * a + b;
int n = Math.max(String.valueOf(x).length(), String.valueOf(y).length()); // 数字长度, 如 123456 则 n=6
int half = n / 2; // 如 123456 则 half=3
int pow = (int) Math.pow(10, half); // 如 10的三次方,1000
long a = (long) (x / pow); // 前半,取整 ,如 123456 / 1000 = 123
long b = (long) (x % pow); // 后半, 取余 ,如 123456 % 1000 = 456
long c = (long) (y / pow);
long d = (long) (y % pow);
long ac = mul2(a, c);
long bd = mul2(b, d);
long abcd = mul2(a + b, c + d);
long mid = abcd - ac - bd;
System.out.println("ac=" + ac + ", bd=" + bd + ", abcd= " + abcd + ", mid=" + mid);
long result = (long) ((Math.pow(10, half * 2)) * ac + (Math.pow(10, half)) * mid + bd);
return result;
}