给定a1, a2, a3, a4, a5,求符合条件的解(x1, x2, x3, x4, x5)的个数,使得a1*x1^3 + a2*x2^3 + a3*x3^3 + a4*x4^3 + a5*x5^3 = 0。其中-50 <= x1, x2, x3, x4, x5 <= 50且x1-x5都不为0。
x1-x5各有100个可能值,相互的解法是使用5重循环逐一判断,但这样用的时间太长了。因此可以把它们分为两部分(x1, x2)和(x3, x4, x5),求出所有的sum = a1*x1^3 + a2*x2^3,将非负的sum值和负的sum值映射到非负数和负数两个哈希表上,把sum相同的压缩到一个结点,只记录相同的sum的个数。然后求出所有的sum = a3*x3^3 + a4*x4^3 + a5*x5^3,然后分别到非负哈希表和负数哈希表上查找对应的(x1, x2)组合的个数。把所有的结果累加起来就得到最终的答案了。
#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 10005;
const int H = 9997;
struct Node
{
long sum;
int cnt;
int next;
};
Node nodeP[N], nodeN[N];
int curP, curN;
int hashTableP[H], hashTableN[H];
int a1, a2, a3, a4, a5;
int cube[101];
long long ans;
void initHash()
{
curP = curN = 0;
ans = 0;
for (int i = 0; i < H; ++i) hashTableP[i] = hashTableN[i] = -1;
for (int i = 0; i <= 100; ++i)
{
int t = i - 50;
cube[i] = t * t * t;
}
}
void insertHash(int i, int j)
{
long sum = cube[i] * a1 + cube[j] * a2;
int h = sum % H;
int next;
if (h >= 0)
{
next = hashTableP[h];
while (next != -1)
{
if (nodeP[next].sum == sum)
{
++nodeP[next].cnt;
return;
}
next = nodeP[next].next;
}
nodeP[curP].cnt = 1;
nodeP[curP].sum = sum;
nodeP[curP].next = hashTableP[h];
hashTableP[h] = curP;
++curP;
}
else
{
h = -h;
next = hashTableN[h];
while (next != -1)
{
if (nodeN[next].sum == sum)
{
++nodeN[next].cnt;
return;
}
next = nodeN[next].next;
}
nodeN[curN].cnt = 1;
nodeN[curN].sum = sum;
nodeN[curN].next = hashTableN[h];
hashTableN[h] = curN;
++curN;
}
}
void getAns(int i, int j, int k)
{
long sum = cube[i] * a3 + cube[j] * a4 + cube[k] * a5;
int h = sum % H;
int next;
if (h > 0)
{
next = hashTableN[h];
while (next != -1)
{
if (nodeN[next].sum + sum == 0)
{
ans += nodeN[next].cnt;
return;
}
next = nodeN[next].next;
}
}
else
{
h = -h;
next = hashTableP[h];
while (next != -1)
{
if (nodeP[next].sum + sum == 0)
{
ans += nodeP[next].cnt;
return;
}
next = nodeP[next].next;
}
}
}
int main()
{
initHash();
scanf("%d%d%d%d%d", &a1, &a2, &a3, &a4, &a5);
for (int i = 0; i <= 100; ++i)
{
if (i == 50) continue;
for (int j = 0; j <= 100; ++j)
{
if (j == 50) continue;
insertHash(i, j);
}
}
for (int i = 0; i <= 100; ++i)
{
if (i == 50) continue;
for (int j = 0; j <= 100; ++j)
{
if (j == 50) continue;
for (int k = 0; k <= 100; ++k)
{
if (k == 50) continue;
getAns(i, j, k);
}
}
}
printf("%lld\n", ans);
return 0;
}