, n <= 1e9, m <= 1e3
由于 n 太大,无法枚举,我们需要一种与 m 有关的算法,比较容易想到由 m 的答案 推到 m + 1 的答案
然后就可以 O(m^2) DP 了
#include<bits/stdc++.h>
#define N 1050
using namespace std;
const int Mod = 1000000007;
typedef long long ll;
ll add(ll a, ll b){ return (a + b) % Mod;}
ll mul(ll a, ll b){ return a * b % Mod;}
ll power(ll a, ll b){ ll ans = 1;
for(;b;b>>=1){if(b&1) ans = mul(ans, a); a = mul(a, a);}
return ans;
}
ll n, m, c[N][N], f[N], pw[N];
int main(){
scanf("%lld%lld", &n, &m);
if(m == 1){ printf("%lld", (n * (n + 1) / 2) % Mod); return 0;}
c[0][0] = 1;
for(int i = 1; i <= m; i++){
c[i][0] = 1; for(int j = 1; j <= i; j++) c[i][j] = add(c[i-1][j-1], c[i-1][j]);
}
f[0] = mul(add(power(m, n + 1), Mod - 1), power(m - 1, Mod - 2));
ll A = power(m, n);
pw[0] = 1; for(int i = 1; i <= m; i++) pw[i] = mul(pw[i-1], n);
ll inv = power(Mod + 1 - m, Mod - 2);
for(int i = 1; i <= m; i++){
ll sum = 0;
for(int j = 0; j < i; j++){
sum = add(sum, mul(c[i][j], add(f[j], Mod - mul(pw[j], A))));
} sum = add(sum, Mod - mul(pw[i], A));
sum = mul(sum, mul(m, inv)); f[i] = sum;
} cout << f[m]; return 0;
}