Pass!
f ( 1 ) = 0 , f ( 2 ) = n − 1 , f ( t ) = ( n − 2 ) × f ( t − 1 ) + ( t − 1 ) × f ( t − 2 ) f(1) = 0, f(2) = n - 1, f(t) = (n - 2) \times f(t - 1) + (t - 1) \times f(t - 2) f(1)=0,f(2)=n−1,f(t)=(n−2)×f(t−1)+(t−1)×f(t−2),考虑对通项两边同时加一个 x × f ( t − 1 ) x \times f(t - 1) x×f(t−1)。
可以得到 f ( t ) + x × f ( t − 1 ) = ( n − 1 + x ) × ( f ( t − 1 ) + f ( t − 2 ) ) f(t) + x \times f(t - 1) = (n - 1 + x) \times (f(t - 1) + f(t - 2)) f(t)+x×f(t−1)=(n−1+x)×(f(t−1)+f(t−2)),所以可以得到两个 x x x,然后得到 f ( t ) = ( n − 1 ) t + ( n − 1 ) × ( − 1 ) t n f(t) = \frac{(n - 1) ^ t + (n - 1) \times (-1) ^ t}{n} f(t)=n(n−1)t+(n−1)×(−1)t。
接下来特判几个解,即可分奇偶即可进行 B S G S BSGS BSGS求解,整体复杂度 T × σ × m o d T \times \sigma \times \sqrt {mod} T×σ×mod。
( n − 1 ) t + ( n − 1 ) × ( − 1 ) t = n × x t = i × m − j , 可 以 考 虑 取 m 为 偶 数 , i ≥ 1 ( n − 1 ) i × m − j = n × x − ( n − 1 ) × ( − 1 ) j ( n − 1 ) i × m = ( n × x − ( n − 1 ) × ( − 1 ) j ) × ( n − 1 ) j (n - 1) ^ t + (n - 1) \times (-1) ^ t = n \times x\\ t = i \times m - j, 可以考虑取m为偶数, i \geq 1\\ (n - 1) ^{i \times m - j} = n \times x - (n - 1) \times (-1) ^ j\\ (n - 1) ^{i \times m} = (n \times x - (n - 1) \times (-1) ^ j) \times (n - 1) ^ j\\ (n−1)t+(n−1)×(−1)t=n×xt=i×m−j,可以考虑取m为偶数,i≥1(n−1)i×m−j=n×x−(n−1)×(−1)j(n−1)i×m=(n×x−(n−1)×(−1)j)×(n−1)j
取 m > m o d m > \sqrt {mod} m>mod,可以发现 j j j最多有 m m m个取值即可,同时, i i i也最多有 m m m个取值,用 m p a mpa mpa存下右边的 m m m个值,然后枚举左边即可。
#include <bits/stdc++.h>
using namespace std;
const int mod = 998244353, M = 31596;
int n, x;
unordered_map<int, int> mp;
inline int add(int x, int y) {
return x + y < mod ? x + y : x + y - mod;
}
inline int sub(int x, int y) {
return x >= y ? x - y : x - y + mod;
}
int main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
int T;
scanf("%d", &T);
while (T--) {
scanf("%d %d", &n, &x);
if (x == 1) {
puts("0");
continue;
}
int base = 1, p = n - 1;
for (int i = 0, s = 1; i < M; i++, s = 1ll * s * p % mod) {
base = 1ll * base * p % mod;
if (i & 1) {
int cur = 1ll * add(1ll * n * x % mod, p) * s % mod;
mp[cur] = i;
}
else {
int cur = 1ll * sub(1ll * n * x % mod, p) * s % mod;
mp[cur] = i;
}
}
int ans = -1;
for (int i = 1, s = base; i <= M; i++, s = 1ll * s * base % mod) {
if (mp.count(s)) {
ans = i * M - mp[s];
break;
}
}
printf("%d\n", ans);
mp.clear();
}
return 0;
}