题面
题解
以下的数都是在 b b b 进制意义下讨论。
默认 n ≥ b n\geq b n≥b,否则 n < b n< b n<b 可以特判答案为 1 1 1。
考虑 DP,设 d r d_r dr 表示所有模 n n n 余 r r r 的正整数中非零位个数的最小值,那么我们要求的即为 d 0 d_0 d0。
我们考虑从 d r d_r dr 转移出去:
-
我们可以考虑把这个模 n n n 余 r r r 的数末尾添上一个 0 0 0,此时余数变为了 b r m o d n br\bmod{n} brmodn,非零位个数不变,故:
d b r m o d n ← d r d_{br\bmod n}\gets d_r dbrmodn←dr -
我们也可以考虑把这个模 n n n 余 r r r 的数末尾添上一个非零数 s s s( 1 ≤ s < b 1\leq s<b 1≤s<b),此时余数变为了 ( b r + s ) m o d n (br+s)\bmod{n} (br+s)modn,非零位个数加 1 1 1,故:
d ( b r + s ) m o d n ← d r + 1 d_{(br+s)\bmod n}\gets d_r+1 d(br+s)modn←dr+1
那么初始状态也就出来了:
d
r
←
1
(
1
≤
r
<
b
)
d_r\gets 1\quad(1\leq r<b)
dr←1(1≤r<b)
暴力建边是
O
(
n
b
)
O(nb)
O(nb) 的,接下来考虑如何优化转移。
我们把 d r d_r dr 看做点 r r r,那么 d r d_r dr 的转移就是在图上跑 01 01 01 最短路。
利用到边权只能为 0 0 0 或 1 1 1 的优秀性质,我们考虑模拟最短路中的 bfs \operatorname{bfs} bfs:大概思路是先一直贪心走 0 0 0 边并对没访问过的点进行更新,再对刚刚所有更新的点走一次 1 1 1 边并对没访问过的点进行更新,再重复上述过程。
注意到一个点的 0 0 0 边只有一条,所以走 0 0 0 边时如果走到的点已经被更新了,那么它往后走 0 0 0 边的点也肯定被更新了,就无需继续更新。所以走 0 0 0 边的部分我们直接暴力更新即可。总时间复杂度 O ( n ) O(n) O(n)。
而对于每一个要向外更新的点,我们需要找到它的 1 1 1 边中所有未被更新的点。注意到每个点向外连的 1 1 1 边是一段区间,所以我们用并查集维护即可。总时间复杂度 O ( n ⋅ α ( n ) ) O(n\cdot\alpha (n)) O(n⋅α(n))。
代码如下:
#include<bits/stdc++.h>
#define N 10000010
#define re register
using namespace std;
namespace modular
{
int mod;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
}using namespace modular;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
int n,num,b,d[N],nxt[N];
int fa[N],tor[N];
bool flag;
int find(int x)
{
return x==fa[x]?x:(fa[x]=find(fa[x]));
}
inline void merge(int x,int y)
{
int a=find(x),b=find(y);
fa[b]=a,tor[a]=max(tor[a],tor[b]);
}
inline void change(int i)
{
if(i-1>=0&&d[i-1]) merge(i-1,i);
if(i+1<n&&d[i+1]) merge(i,i+1);
}
queue<int>q[2];
void update(int u)
{
int now=u;
do
{
d[now]=num;
change(now);
q[flag^1].push(now);
now=nxt[now];
}while(!d[now]);
}
void work(int l,int r)
{
for(re int i=l;i<=r;i++)
{
if(d[i])
{
int rt=find(i);
i=tor[rt];
continue;
}
d[i]=num;
update(i);
}
}
int main()
{
mod=n=read(),b=read();
if(n<b)
{
puts("1");
return 0;
}
for(re int i=0;i<n;i++) fa[i]=tor[i]=i;
for(re int i=0;i<n;i++) nxt[i]=1ll*i*b%n;
num=1,flag=0;
work(1,b-1);
while(!d[0])
{
num++,flag^=1;
while(!q[flag].empty())
{
int u=q[flag].front();
q[flag].pop();
int l=add(nxt[u],1),r=add(nxt[u],b-1);
if(l<=r) work(l,r);
else work(l,n-1),work(0,r);
}
}
printf("%d\n",d[0]);
return 0;
}
/*
79 2
*/