题目分析:O(nsqrt(n))复杂度的算法,跑的比较慢。。
本题一个比较特殊的地方就是比较是sqrt ( n )的,我们尝试每次比较的时候看第k个不同的数在前面出现的位置kth_diff[ k ](从当前位置往左数),这样可以列出一个dp方程:
dp[ i ] = min { dp[ kth_diff[ j + 1 ] ] + min ( i - kth_diff[ j + 1 ] , k * k ) | j <= sqrt ( n ) },需要保证当前最大的j+1的kth_diff为0,因为不存在。
重点就是怎么得到这每个数的kth_diff[ k ],首先我们假设我们已经得到了位置i的kth_diff数组,最大不同数为top,则如果第i+1个数为从未出现的,则所有的kth_diff[ k + 1 ] = kth_diff[ k ],然后kth_diff[ 1 ] = i + 1。
如果第i + 1数为出现过的,则找到kth_diff[ j ] > i + 1的j,有:kth_diff[ k + 1 ] = kth_diff[ k ] ( k + 1 <= j ),因为这个范围内的所有的i的kth_diff[ k ]就等于i + 1的kth_diff[ k + 1 ],然后kth_diff[ 1 ] = i + 1。
代码如下:
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;
typedef long long LL ;
#define REP( i , a , b ) for ( int i = a ; i < b ; ++ i )
#define REV( i , a , b ) for ( int i = a ; i >= b ; -- i )
#define FOR( i , a , b ) for ( int i = a ; i <= b ; ++ i )
#define CLR( a , x ) memset ( a , x , sizeof a )
#define CPY( a , x ) memcpy ( a , x , sizeof a )
#define ls ( o << 1 )
#define rs ( o << 1 | 1 )
#define lson ls , l , m
#define rson rs , m + 1 , r
#define root 1 , 1 , n
#define rt o , l , r
#define mid ( ( l + r ) >> 1 )
const int MAXN = 50005 ;
const int INF = 0x3f3f3f3f ;
int a[MAXN] , cnt ;
int kth_diff[MAXN] ;
int pre[MAXN] ;
int dp[MAXN] ;
int num[MAXN] ;
int n ;
int unique ( int n ) {
int cnt = 1 ;
sort ( a + 1 , a + n + 1 ) ;
FOR ( i , 2 , n ) if ( a[i] != a[cnt] ) a[++ cnt] = a[i] ;
return cnt ;
}
int hash ( int x ) {
int l = 1 , r = cnt ;
while ( l < r ) {
int m = ( l + r ) >> 1 ;
if ( a[m] >= x ) r = m ;
else l = m + 1 ;
}
return l ;
}
void solve () {
int top = 0 ;
int sqr = sqrt ( 1.0 * n ) + 1 ;
CLR ( dp , INF ) ;
dp[0] = 0 ;
CLR ( pre , 0 ) ;
FOR ( i , 1 , n ) {
scanf ( "%d" , &num[i] ) ;
a[i] = num[i] ;
}
cnt = unique ( n ) ;
FOR ( i , 1 , n ) num[i] = hash ( num[i] ) ;
FOR ( i , 1 , n ) {
if ( !pre[num[i]] ) top = min ( top + 1 , sqr ) ;
REV ( k , top - 1 , 1 ) if ( kth_diff[k] > pre[num[i]] ) kth_diff[k + 1] = kth_diff[k] ;
kth_diff[1] = pre[num[i]] = i ;
kth_diff[top + 1] = 0 ;
FOR ( k , 1 , top ) {
if ( k * k > i ) break ;
dp[i] = min ( dp[i] , dp[kth_diff[k + 1]] + min ( i - kth_diff[k + 1] , k * k ) ) ;
}
}
printf ( "%d\n" , dp[n] ) ;
}
int main () {
while ( ~scanf ( "%d" , &n ) ) solve () ;
return 0 ;
}