分治法的应用
【算法】
Mul(A[0…n-1], B[0…n-1], n)
//计算两个大整数A[], B[]的乘积
//输入:字符数组(或字串)表示的两个大整数
//输出:以字串形式输出的两个大整数的乘积
if (n == 1)
return A[0] * B[0];
//高位补0,使n成为偶数(二分需要)
if (n%2 == 0)
{
A[0…n] = ‘0’ + A[0…n-1] ;
B[0…n] = ‘0’ + B[0…n-1] ;
n++;
}
//进行二分
a1 = A[0, n/2]; //A的前半部分
a0 = A[n/2, n-1]; //A的后半部分
b1 = B[0, n/2]; //B的前半部分
b0 = B[n/2, n-1]; //B的后半部分
/*那么A = a1*10^n/2 + a0 B = b1*10^n/2 + b0
利用与计算两位数相同的方法可以得到:
c = a*b = (a1*10^n/2 + a0)*(b1*10^n/2 + b0)
= (a1*b1)10^n + (a1*b0 + a0*b1)10^n/2 + (a0*b0)
=c2*10^n + c1*10^n/2 + c0
其中,c2 = a1 * b1, 是它们前半部分的积,c0 = a0 * b0,是它们后半部分的积
c1 = (a1*b0 + a0*b1) = (a1 + a0) * (b1 + b0) – (c2 + c0) ----利用已经算出来的积(c2, c0),减少乘法的数量(2次 à 1 次)
*/
c2 = Mul(a1, b1);
c0 = Mul(a0, b0);
c1 = Mul(a1+a0, b1+b0) – (c2 + c0);
return c2*10^n + c1*10^n/2 + c0;
【效率分析】
该算法会做多少次位乘呢?因为n位数的乘法需要对n/2位数做三次乘法运算,乘法次数M(n)的递推式将会是:
当n>1时,M(n) = 3M(n/2), M(1) = 1
当n = 2^k时, 我们可以利用反向替换法对它求解:
M(2^k) = 3M( 2^(k-1) ) = 3^2 M( 2^(k-2)
= 3^k M(2^(k-k)) = 3^k
因为k = log2n,
M(n) = 3log2n = n^1.585;
【C语言实现】
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <math.h>
//支持的大整数的位数
#define N 1000
/*
@brief reverse a string
*/
void reverseStr(char *A, int n)
{
int i = 0, j = n-1;
char temp;
while (i < j)
{
temp = A[i];
A[i] = A[j];
A[j] = temp;
i++;
j--;
}
}
/*
@brief return the maximum of three numbers
*/
int max3(int a, int b, int c)
{
if (a>b)
{
return a>c?a:c;
}
else
return b>c?b:c;
}
/*
@brief caculate the difference of A, B, restore the result in C.
the three numbers are represented in the form of string
*/
void substract2str(const char *A, const char *B, int n, char *C)
{
int i, borrow = 0;
for (i = 0; i<n; i++)
{
if (A[i] - borrow >= B[i])
{
C[i] = (A[i] - B[i] - borrow) + '0';
borrow = 0;
}
else
{
C[i] = (A[i] +10 - B[i] - borrow) + '0';
borrow = 1;
}
}
}
/*
@brief caculate the sum of A, B, restore the result in C.
the three numbers are represented in the form of string
*/
void add2str(const char *A, const char *B, int n, char *C)
{
int i, sum, carry;
carry = 0;
for (i = 0; i<n; i++)
{
sum = (A[i] - '0') + (B[i] - '0')+carry;
C[i] = sum%10 + '0'; //int --> char
carry = sum/10;
}
if (carry)
C[i] = '1';
}
/*
@brief return the product of two big integers A, B, which are represented in
the form of string, also the product.
*/
char* Mul(char *A, char *B, size_t n)
{
//一位数相乘,直接返回结果
if (n == 1)
{
char *rst = (char *)malloc(3*sizeof(char));
if (rst == NULL)
{
printf("Allocate memory failed!\n");
exit(1);
}
rst[2] = 0;
int temp = (A[0] - '0') * (B[0] - '0');
rst[0] = temp%10 + '0';
if (temp/10)
rst[1] = temp/10 + '0';
else
rst[1] = 0;
return rst;
}
//高位补0,使n为偶数
if (n % 2 != 0)
{
*(A+n) = '0';
*(B+n) = '0';
n++;
}
char *a0 = A; //A 的后半部分
char *a1 = A + n/2; //A 的前半部分
char *b0 = B; //B 的后半部分
char *b1 = B + n/2; //B 的前半部分
size_t i;
//多分配一位,因为下一层次的计算中有可能要补一位
char *tp1 = (char *)calloc(n/2+2, sizeof(char));
char *tp2 = (char *)calloc(n/2+2, sizeof(char));
strncpy(tp1, a1, n/2);
strncpy(tp2, b1, n/2);
char *c2 = Mul(tp1, tp2, n/2);
strncpy(tp1, a0, n/2);
strncpy(tp2, b0, n/2);
char *c0 = Mul(tp1, tp2, n/2);
free(tp1); tp1 = NULL;
free(tp2); tp2 = NULL;
//两个乘积对齐
if (strlen(c2) > strlen(c0))
{
for (i=strlen(c0); i<strlen(c2); i++)
{
c0[i] = '0';
}
}
else if (strlen(c2) < strlen(c0))
{
for (i=strlen(c2); i<strlen(c0); i++)
{
c2[i] = '0';
}
}
char *pa;
//两个n/2位数的和,可能为n/2 + 1位
size_t len = n/2+2;
pa = (char *)calloc(len, sizeof(char));
if (pa == NULL)
{
printf("Allocate memory failed!\n");
exit(1);
}
add2str(a1, a0, n/2, pa);
char *pb = (char *)calloc(len, sizeof(char));
if (pb == NULL)
{
printf("Allocate memory failed!\n");
exit(1);
}
add2str(b1, b0, n/2, pb);
len = strlen(pa)>strlen(pb)?strlen(pa):strlen(pb);
//对齐pa和pb
if (len > strlen(pa))
{
pa[len-1] = '0';
}
else if (len > strlen(pb))
{
pb[len-1] = '0';
}
char *pd;
pd = Mul(pa, pb, len);
len = strlen(pd);
free(pa); pa=NULL;
free(pb); pb=NULL;
//两个n位数的积的位数肯定大于两个n位数的和的位数
char *pc = (char *)calloc((len+1), sizeof(char));
if (pc == NULL)
{
printf("Allocate memory failed!\n");
exit(1);
}
add2str(c2, c0, strlen(c0), pc);
//pd pc 对齐
for (i=strlen(pc); i<len; i++)
pc[i] = '0';
//两个n位数的差最多为n位数
char *c1 = (char *)calloc((len+1), sizeof(char));
if (c1 == NULL)
{
printf("Allocate memory failed!\n");
exit(1);
}
substract2str(pd, pc, len, c1);
free(pd); pd = NULL;
free(pc); pc = NULL;
//两个n位数的乘积最多为2n位数
char *tpc0 = (char *)calloc((2*n+1), sizeof(char));
char *tpc1 = (char *)calloc((2*n+1), sizeof(char));
char *tpc2 = (char *)calloc((2*n+1), sizeof(char));
if (tpc0 == NULL || tpc1 == NULL || tpc2 == NULL)
{
printf("Allocate memory failed!\n");
exit(1);
}
len = 2*n;
//为了计算方便,都补齐到2n位
strcpy(tpc0, c0);
for (i=strlen(c0); i<len; i++)
tpc0[i] = '0';
tpc0[len] = 0;
strncpy(tpc1+n/2, c1, len-n/2);
for (i=0; i<n/2; i++)
tpc1[i] = '0';
for (i=n/2+strlen(c1); i<len; i++)
tpc1[i] = '0';
tpc1[len] = 0;
//在此遇到内存错误,如果用strcpy(tpc+n, c2) -- C语言一定要注意内存的操作,确定不会越界吗?
strncpy(tpc2+n, c2, len-n);
for (i=0; i<n; i++)
tpc2[i] = '0';
for (i=n+strlen(c2); i<len; i++)
tpc2[i] = '0';
tpc2[len] = 0;
free(c0); c0=NULL;
free(c1); c1=NULL;
free(c2); c2=NULL;
char *rst = (char *)calloc((2*n+1), sizeof(char));
char *tmp = (char *)calloc((2*n+1), sizeof(char));
if (rst == NULL || tmp == NULL)
{
printf("Allocate memory failed!\n");
exit(1);
}
add2str(tpc2, tpc0, len, tmp);
add2str(tmp, tpc1, strlen(tmp), rst);
free(tpc0); tpc0 = NULL;
free(tpc1); tpc1 = NULL;
free(tpc2); tpc2 = NULL;
free(tmp); tmp = NULL;
return rst;
}
int main()
{
char *A = (char *)calloc(N+1, sizeof(char));
char *B = (char *)calloc(N+1, sizeof(char));
size_t i;
char *rst;
do
{
printf("input A: ");
scanf("%s", A);
printf("input B: ");
scanf("%s", B);
/*
如果我们是按高位 --> 地位顺序输入, 则A的低地址存的是最高位的数。
比如输入95432,那么内存中会是这样的A[0] A[1] ... A[0] : 9 5 ... 2
而我们的算法是建立在低地址存最低位基础上的,所以需要一个reverse操作。
*/
reverseStr(A, strlen(A));
reverseStr(B, strlen(B));
//对齐A和B --> 通过高位补零
if (strlen(A) > strlen(B))
{
for (i=strlen(B); i<strlen(A); i++)
B[i] = '0';
}
else if (strlen(A) < strlen(B))
{
for (i=strlen(A); i<strlen(B); i++)
A[i] = '0';
}
rst = Mul(A, B, strlen(A));
reverseStr(rst, strlen(rst));
//remove the leading zero
for (i=0; i<strlen(rst); i++)
if (rst[i] != '0')
break;
if (i!=strlen(rst))
puts(rst+i);
else
puts("0");
free(rst);
} while( !((A[0] == '0' && strlen(A) == 1) && (B[0] == '0' && strlen(B) == 1))); //输入零结束
free(A);
free(B);
return 0;
}