题目大意
对称群记为 S n S_n Sn,记 T = { p m ∣ p ∈ S n } T=\{p^m\mid p\in S_n\} T={pm∣p∈Sn},求 ∣ T ∣ |T| ∣T∣ 。 n , m ≤ 5 × 1 0 4 n,m\le 5\times 10^4 n,m≤5×104 。
思路
一个置换可以拆成若干个互不相交的循环。
一个长度为 k k k 的循环的 m m m 次方是 gcd ( k , m ) \gcd(k,m) gcd(k,m) 个长度为 k gcd ( k , m ) \frac{k}{\gcd(k,m)} gcd(k,m)k 的循环。
对任意 q ∈ T q\in T q∈T, 若 q q q 中长度为 x x x 的循环有 y y y 个,则 ∃ D , y = ∑ d ∈ D d \exists D,y=\sum_{d\in D} d ∃D,y=∑d∈Dd 且 ∀ d ∈ D , gcd ( x d , m ) = d \forall d\in D, \gcd(xd,m)=d ∀d∈D,gcd(xd,m)=d 。
则 min { d ∣ gcd ( x d , m ) = d } ∣ y \min\{d\mid \gcd(xd,m)=d\}\mid y min{d∣gcd(xd,m)=d}∣y ,即 gcd ( x + ∞ , m ) ∣ y \gcd(x^{+\infty},m)\mid y gcd(x+∞,m)∣y 。
记 t x = gcd ( x + ∞ , m ) t_x=\gcd(x^{+\infty},m) tx=gcd(x+∞,m) ,则 t x ∣ y t_x\mid y tx∣y。
答案为
n ! [ z n ] ∏ x = 1 + ∞ ∑ j = 0 + ∞ ( z x x ) t x j ( t x j ) ! = n ! [ z n ] ∏ x = 1 n G t x ( z x x ) n![z^n]\prod_{x=1}^{+\infty}\sum_{j=0}^{+\infty}\frac{(\frac{z^x}{x})^{t_xj}}{(t_xj)!}=n![z^n]\prod_{x=1}^{n}G_{t_x}(\frac{z^x}{x}) n![zn]x=1∏+∞j=0∑+∞(txj)!(xzx)txj=n![zn]x=1∏nGtx(xzx)
(记 G i ( z ) = ∑ j = 0 + ∞ z i j ( i j ) ! G_i(z)=\sum_{j=0}^{+\infty}\frac{z^{ij}}{(ij)!} Gi(z)=∑j=0+∞(ij)!zij)
乘法使用取 ln 再取 exp 的方法求。
t x t_x tx 相同的一起处理。
代码
#include <bits/stdc++.h>
#define rep(i, l, r) for (int i = l; i <= r; ++i)
#define per(i, r, l) for (int i = r; i >= l; --i)
using namespace std;
const int mod = 998244353;
typedef vector<int> vi;
void print(vi a) {
for (auto x : a) {
printf("%d ", x);
}
printf("\n");
}
int gcd(int x, int y) { return y ? gcd(y, x % y) : x; }
int add(int x, int y) { return (x + y) % mod; }
int sub(int x, int y) { return (x - y + mod) % mod; }
int mul(int x, int y) { return 1ll * x * y % mod; }
int pw(int x, int y) {
int ret = 1;
while (y) {
if (y & 1) ret = mul(ret, x);
x = mul(x, x);
y >>= 1;
}
return ret;
}
int inv(int x) { return pw(x, mod - 2); }
vi fac, invfac;
void get_fac(int n) {
fac.resize(n + 1);
invfac.resize(n + 1);
fac[0] = 1;
rep(i, 1, n) fac[i] = mul(fac[i - 1], i);
invfac[n] = inv(fac[n]);
per(i, n - 1, 0) invfac[i] = mul(invfac[i + 1], i + 1);
}
vi rev;
void get_rev(int n) {
static int lim = -1;
if (n == lim) return;
lim = n;
rev.resize(n);
int bit = 0;
while ((1 << bit) < n) ++bit;
rev[0] = 0;
rep(i, 1, n - 1) { rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); }
}
void DFT(vi &a, int n, int dir = 1) { // n = 2^k
get_rev(n);
rep(i, 0, n - 1) {
if (i < rev[i]) swap(a[i], a[rev[i]]);
}
for (int len = 1; (len << 1) <= n; len <<= 1) {
int wn = pw(3, (mod - 1) / (len << 1));
if (dir == -1) wn = inv(wn);
for (int i = 0; i < n; i += len << 1) {
int w = 1;
rep(j, i, i + len - 1) {
int tmp = mul(a[j + len], w);
a[j + len] = sub(a[j], tmp);
a[j] = add(a[j], tmp);
w = mul(w, wn);
}
}
}
if (dir == -1) {
int invn = inv(n);
rep(i, 0, n - 1) { a[i] = mul(a[i], invn); }
}
}
vi operator+(vi a, vi b) {
if (a.size() < b.size()) a.resize(b.size());
rep(i, 0, (int)b.size() - 1) { a[i] = add(a[i], b[i]); }
return a;
}
vi operator-(vi a, vi b) {
if (a.size() < b.size()) a.resize(b.size());
rep(i, 0, (int)b.size() - 1) { a[i] = sub(a[i], b[i]); }
return a;
}
vi operator*(vi a, vi b) {
int n = a.size() + b.size() - 1;
int lim = 1;
while (lim < n) lim <<= 1;
a.resize(lim);
b.resize(lim);
DFT(a, lim);
DFT(b, lim);
rep(i, 0, lim - 1) { a[i] = mul(a[i], b[i]); }
DFT(a, lim, -1);
return a;
}
vi inv(vi a) {
int lim = 1;
while (lim < a.size()) lim <<= 1;
vi b(1, inv(a[0]));
for (int len = 2; len <= lim; len <<= 1) {
vi x(a);
x.resize(len);
x.resize(len << 1);
b.resize(len << 1);
DFT(x, len << 1);
DFT(b, len << 1);
rep(i, 0, (len << 1) - 1) {
b[i] = (2ll - 1ll * x[i] * b[i] % mod + mod) * b[i] % mod;
}
DFT(b, len << 1, -1);
b.resize(len);
}
return b;
}
vi operator/(vi a, vi b) { // highest not zero
if (a.size() < b.size()) return vi(1, 0);
int l = a.size() - b.size() + 1;
reverse(a.begin(), a.end());
reverse(b.begin(), b.end());
a.resize(l);
b.resize(l);
vi c = a * inv(b);
c.resize(l);
reverse(c.begin(), c.end());
return c;
}
vi operator%(vi a, vi b) {
vi r = a - a / b * b;
r.resize(max((int)b.size() - 1, 1));
return r;
}
vi dw(vi a) {
int n = a.size();
rep(i, 0, n - 2) a[i] = mul(a[i + 1], i + 1);
a.resize(n - 1);
return a;
}
vi up(vi a) {
int n = a.size();
a.resize(n + 1);
per(i, n, 1) a[i] = mul(a[i - 1], inv(i));
a[0] = 0;
return a;
}
vi ln(vi a) { return up(dw(a) * inv(a)); }
vi exp(vi a) {
int lim = 1;
while (lim < a.size()) lim <<= 1;
vi b(1, 1);
for (int len = 2; len <= lim * 2; len <<= 1) { // I don't know why lim*2
b = b * (a + vi(1, 1) - ln(b));
b.resize(len);
}
return b;
}
// vi pw(vi a, int k, vi b) {
// vi ret(1, 1);
// while (k) {
// if (k & 1) ret = ret * a % b;
// a = a * a % b;
// k >>= 1;
// }
// // for (auto v : a) printf("%d ", v);
// return ret;
// }
// int linear(vi g, vi a, int n) {
// // for (auto v : g) printf("%d ", v);
// int k = g.size() - 1;
// vi t{0, 1};
// vi r = pw(t, n, g);
// int ret = 0;
// rep(i, 0, k - 1) { ret = (ret + mul(r[i], a[i])) % mod; }
// return ret;
// }
// void test_linear() {
// int n, k;
// scanf("%d%d", &n, &k);
// int x;
// vi f(1, 0);
// rep(i, 1, k) {
// scanf("%d", &x);
// f.push_back((x + mod) % mod);
// }
// vi a;
// rep(i, 0, k - 1) {
// scanf("%d", &x);
// a.push_back((x + mod) % mod);
// }
// vi g;
// per(i, k, 1) { g.push_back(sub(0, f[i])); }
// g.push_back(1);
// int ret = linear(g, a, n);
// printf("%d", ret);
// }
vi G(int tx, int len) {
vi ret = vi(len + 1, 0);
rep(i, 0, len) { ret[i] = invfac[tx * i]; }
// print(ret);
return ret;
}
void work() {
int n, m;
scanf("%d%d", &n, &m);
get_fac(n);
vi t = vi();
int tmp = m;
rep(i, 2, tmp) {
int now = 1;
while (tmp % i == 0) {
tmp /= i;
now *= i;
}
if (now > 1) t.push_back(now);
}
vi f = vi(n + 1, 0);
rep(x, 1, n) {
int tx = 1;
for (auto y : t) {
if (gcd(x, y) > 1) tx *= y;
}
// printf("%d: %d\n", x, tx);
int len = n / x / tx;
vi now = ln(G(tx, len));
// z -> z^{x*tx}/x
rep(i, 0, len) {
int cur = mul(now[i], pw(x, mod - 1 - i * tx));
// printf("%d,", cur);
f[i * x * tx] = add(f[i * x * tx], cur);
}
// printf("\n");
}
f = exp(f);
printf("%d\n", mul(f[n], fac[n]));
}
int main() {
// test_linear();
work();
return 0;
}