[Public NOIP Round #3]数圈圈
二维猫树分治版题
考虑用一条切割线划分矩形,并统计经过该线的圈。
假设线是竖着切的,那么只需分别统计左右两边 匚
的数量即可。
记 \(L_{i,j},R,U,D\) 分别表示左/右/上/下与 \((i,j)\) 相同的最大距离。
对于左边,考虑上下端点 \(u,v (u<v)\) ,有
\[\sum_{i=\max(L_{u,mid},L_{v,mid})}^{mid} [D_{u,i} \ge v]
\]
当 \(L_{u,mid} \ge L_{v,mid}\) 时:
\[\sum_{i=L_{u,mid}}^{mid} [D_{u,i} \ge v]
\]
只需对 \(u\) 的贡献做后缀和后 \(\mathcal O(1)\) 查询对 \(v\) 的贡献。
当 \(L_{u,mid} < L_{v,mid}\) 时可转换为:
\[\sum_{i=L_{v,mid}}^{mid} [U_{v,i} \le u]
\]
同样可以处理。
切割线与短边平行时单个矩形的复杂度为 \(x^2+xy\) , 总时间复杂度为 \(\mathcal O(nm \log nm)\)
实现较为复杂,但没有太多细节,注意一下边界就可以了。
#include <cstdio>
#include <iostream>
using namespace std;
#define ll long long
const int MAXN = 2000;
int n , m , s[ MAXN + 5 ][ MAXN + 5 ];
int U[ MAXN + 5 ][ MAXN + 5 ] , D[ MAXN + 5 ][ MAXN + 5 ] , L[ MAXN + 5 ][ MAXN + 5 ] , R[ MAXN + 5 ][ MAXN + 5 ];
ll ans;
int buc[ MAXN + 5 ] , f[ MAXN + 5 ][ MAXN + 5 ] , g[ MAXN + 5 ][ MAXN + 5 ];
void Solve( int x1 , int y1 , int x2 , int y2 ) {
int lenx = x2 - x1 + 1 , leny = y2 - y1 + 1;
if( lenx < 1 || leny < 1 ) return;
if( lenx <= leny ) { //竖着切
int md = y1 + y2 >> 1;
for( int u = x1 ; u <= x2 ; u ++ ) {
for( int i = max( L[ u ][ md ] , y1 ) ; i <= md ; i ++ ) buc[ min( D[ u ][ i ] , x2 ) ] ++;
for( int v = x2 ; v > u ; v -- ) {
buc[ v ] += buc[ v + 1 ];
if( L[ u ][ md ] >= L[ v ][ md ] ) f[ u ][ v ] = buc[ v ];
}
for( int i = x1 ; i <= x2 ; i ++ ) buc[ i ] = 0;
}
for( int v = x2 ; v >= x1 ; v -- ) {
for( int i = max( L[ v ][ md ] , y1 ) ; i <= md ; i ++ ) buc[ max( U[ v ][ i ] , x1 ) ] ++;
for( int u = x1 ; u < v ; u ++ ) {
buc[ u ] += buc[ u - 1 ];
if( L[ v ][ md ] > L[ u ][ md ] ) f[ u ][ v ] = buc[ u ];
}
for( int i = x1 ; i <= x2 ; i ++ ) buc[ i ] = 0;
}
for( int u = x1 ; u <= x2 ; u ++ ) {
for( int i = md ; i <= min( R[ u ][ md ] , y2 ) ; i ++ ) buc[ min( D[ u ][ i ] , x2 ) ] ++;
for( int v = x2 ; v > u ; v -- ) {
buc[ v ] += buc[ v + 1 ];
if( R[ u ][ md ] <= R[ v ][ md ] ) g[ u ][ v ] = buc[ v ];
}
for( int i = x1 ; i <= x2 ; i ++ ) buc[ i ] = 0;
}
for( int v = x2 ; v >= x1 ; v -- ) {
for( int i = md ; i <= min( R[ v ][ md ] , y2 ) ; i ++ ) buc[ max( U[ v ][ i ] , x1 ) ] ++;
for( int u = x1 ; u < v ; u ++ ) {
buc[ u ] += buc[ u - 1 ];
if( R[ v ][ md ] < R[ u ][ md ] ) g[ u ][ v ] = buc[ u ];
}
for( int i = x1 ; i <= x2 ; i ++ ) buc[ i ] = 0;
}
for( int u = x1 ; u <= x2 ; u ++ )
for( int v = u + 1 ; v <= x2 ; v ++ )
if( f[ u ][ v ] && g[ u ][ v ] ) ans += 1ll * f[ u ][ v ] * g[ u ][ v ] - ( D[ u ][ md ] >= v );
Solve( x1 , y1 , x2 , md - 1 ); Solve( x1 , md + 1 , x2 , y2 );
}
else {
int md = x1 + x2 >> 1;
for( int u = y1 ; u <= y2 ; u ++ ) {
for( int i = max( U[ md ][ u ] , x1 ) ; i <= md ; i ++ ) buc[ min( R[ i ][ u ] , y2 ) ] ++;
for( int v = y2 ; v > u ; v -- ) {
buc[ v ] += buc[ v + 1 ];
if( U[ md ][ u ] >= U[ md ][ v ] ) f[ u ][ v ] = buc[ v ];
}
for( int i = y1 ; i <= y2 ; i ++ ) buc[ i ] = 0;
}
for( int v = y2 ; v >= y1 ; v -- ) {
for( int i = max( U[ md ][ v ] , x1 ) ; i <= md ; i ++ ) buc[ max( L[ i ][ v ] , y1 ) ] ++;
for( int u = y1 ; u < v ; u ++ ) {
buc[ u ] += buc[ u - 1 ];
if( U[ md ][ v ] > U[ md ][ u ] ) f[ u ][ v ] = buc[ u ];
}
for( int i = y1 ; i <= y2 ; i ++ ) buc[ i ] = 0;
}
for( int u = y1 ; u <= y2 ; u ++ ) {
for( int i = md ; i <= min( D[ md ][ u ] , x2 ) ; i ++ ) buc[ min( R[ i ][ u ] , y2 ) ] ++;
for( int v = y2 ; v > u ; v -- ) {
buc[ v ] += buc[ v + 1 ];
if( D[ md ][ u ] <= D[ md ][ v ] ) g[ u ][ v ] = buc[ v ];
}
for( int i = y1 ; i <= y2 ; i ++ ) buc[ i ] = 0;
}
for( int v = y2 ; v >= y1 ; v -- ) {
for( int i = md ; i <= min( D[ md ][ v ] , x2 ) ; i ++ ) buc[ max( L[ i ][ v ] , y1 ) ] ++;
for( int u = y1 ; u < v ; u ++ ) {
buc[ u ] += buc[ u - 1 ];
if( D[ md ][ v ] < D[ md ][ u ] ) g[ u ][ v ] = buc[ u ];
}
for( int i = y1 ; i <= y2 ; i ++ ) buc[ i ] = 0;
}
int cur = 0;
for( int u = y1 ; u <= y2 ; u ++ )
for( int v = u + 1 ; v <= y2 ; v ++ )
if( f[ u ][ v ] && g[ u ][ v ] ) ans += 1ll * f[ u ][ v ] * g[ u ][ v ] - ( R[ md ][ u ] >= v );
Solve( x1 , y1 , md - 1 , y2 ); Solve( md + 1 , y1 , x2 , y2 );
}
}
int main( ) {
// freopen("circle.in","r",stdin);
// freopen("circle.out","w",stdout);
scanf("%d %d",&n,&m);
for( int i = 1 ; i <= n ; i ++ )
for( int j = 1 ; j <= m ; j ++ ) {
char c = getchar(); while( c < 'a' || c > 'z' ) c = getchar();
s[ i ][ j ] = c;
}
for( int i = 1 ; i <= n ; i ++ )
for( int j = 1 ; j <= m ; j ++ )
U[ i ][ j ] = s[ i - 1 ][ j ] == s[ i ][ j ] ? U[ i - 1 ][ j ] : i;
for( int i = n ; i >= 1 ; i -- )
for( int j = 1 ; j <= m ; j ++ )
D[ i ][ j ] = s[ i + 1 ][ j ] == s[ i ][ j ] ? D[ i + 1 ][ j ] : i;
for( int j = 1 ; j <= m ; j ++ )
for( int i = 1 ; i <= n ; i ++ )
L[ i ][ j ] = s[ i ][ j - 1 ] == s[ i ][ j ] ? L[ i ][ j - 1 ] : j;
for( int j = m ; j >= 1 ; j -- )
for( int i = 1 ; i <= n ; i ++ )
R[ i ][ j ] = s[ i ][ j + 1 ] == s[ i ][ j ] ? R[ i ][ j + 1 ] : j;
Solve( 1 , 1 , n , m );
printf("%lld\n", ans );
return 0;
}