Inverse of sum
Time Limit: 6000/3000 MS (Java/Others) Memory Limit: 524288/524288 K (Java/Others)
Total Submission(s): 657 Accepted Submission(s): 223
Problem Description
There are n nonnegative integers a1…n which are less than p. HazelFan wants to know how many pairs i,j(1≤i<j≤n) are there, satisfying 1ai+aj≡1ai+1aj when we calculate module p, which means the inverse element of their sum equals the sum of their inverse elements. Notice that zero element has no inverse element.
Input
The first line contains a positive integer T(1≤T≤5), denoting the number of test cases.
For each test case:
The first line contains two positive integers n,p(1≤n≤105,2≤p≤1018), and it is guaranteed that p is a prime number.
The second line contains n nonnegative integers a1…n(0≤ai<p).
Output
For each test case:
A single line contains a nonnegative integer, denoting the answer.
Sample Input
2
5 7
1 2 3 4 5
6 7
1 2 3 4 5 6
Sample Output
4
6
题意
一个长度为n的序列,问其中有几对1ai+aj≡1ai+1aj\frac{1}{a_{i}+a_{j}}\equiv \frac{1}{a_{i}}+\frac{1}{a_{j}}ai+aj1≡ai1+aj1(mod p)
思路
1ai+aj≡1ai+1aj\frac{1}{a_{i}+a_{j}}\equiv \frac{1}{a_{i}}+\frac{1}{a_{j}}ai+aj1≡ai1+aj1*(mod p)两边同乘ai+aja_{i}+a_{j}ai+aj得
1≡1+ajai+1+aiaj1 \equiv 1+\frac{a_{j}}{a_{i}}+1+\frac{a_{i}}{a_{j}}1≡1+aiaj+1+ajai(mod p)再同乘aiaja_{i}a_{j}aiaj得
ai2+aiaj+aj2≡0a_{i}^{2}+a_{i}a_{j}+a_{j}^{2}\equiv 0ai2+aiaj+aj2≡0(mod p)*再同乘ai−aja_{i}-a_{j}ai−aj
ai3−aj3≡0a_{i}^{3}-a_{j}^{3}\equiv 0ai3−aj3≡0(mod p)
然后我们枚举ai3a_{i}^{3}ai3%p用map来储存
值得注意的是当ai==aja_{i}==a_{j}ai==aj时ai2+aiaj+aj2≡0a_{i}^{2}+a_{i}a_{j}+a_{j}^{2}\equiv 0ai2+aiaj+aj2≡0就变为3ai2≡03a_{i}^{2}\equiv 03ai2≡0,而此时若ai2+aiaj+aj2≡0a_{i}^{2}+a_{i}a_{j}+a_{j}^{2}\equiv 0ai2+aiaj+aj2≡0不成立而还是乘上了一个ai−aja_{i}-a_{j}ai−aj,那么我们把这个时候的ai3a_{i}^{3}ai3%p算出来也无用,所以要去掉当ai2+aiaj+aj2≡0a_{i}^{2}+a_{i}a_{j}+a_{j}^{2}\equiv 0ai2+aiaj+aj2≡0不成立的时候的aia_{i}ai即3ai2≡03a_{i}^{2}\equiv 03ai2≡0不成立的情况
#include <stdio.h>
#include <string.h>
#include <iostream>
#include <algorithm>
#include <map>
using namespace std;
map<long long,int>num;
map<long long,int>cnt;
long long multi(long long a,long long b,long long p){
long long ans=0;
while(b){
if(b%2==1)
ans=(ans+a)%p;
b/=2;
a=(a+a)%p;
}
return ans;
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
long long n,p;
scanf("%lld%lld",&n,&p);
cnt.clear();
num.clear();
long long ans=0;
for(int i=0;i<n;i++)
{
long long a;
scanf("%lld",&a);
if(a==0)
continue;
if(multi(multi(a,a,p),3,p)!=0)
ans-=cnt[a];
cnt[a]++;
long long val=multi(multi(a,a,p),a,p);
ans+=num[val]++;
}
printf("%lld\n",ans);
}
return 0;
}