ZROI#987
差分+简单数学即可.
首先有个性质:
两条链相交等价于其中一条链的\(LCA\)在另一条链上.
于是我们就对每一条链的\(LCA\)都加\(1\).
最后查询每一条链的区间和即可.树剖实现.
但这样我们会算重复,就是说\((a,b)\)两条链相交我们会算\((a,b)\)一次,\((b,a)\)一次.
也就是说我们算出的是有序数对.容斥掉即可.(没有公式,直接减掉一半即可.)
\(Code:\)
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <string>
#include <vector>
#include <queue>
#include <cmath>
#include <ctime>
#include <map>
#include <set>
#define MEM(x,y) memset ( x , y , sizeof ( x ) )
#define rep(i,a,b) for (int i = (a) ; i <= (b) ; ++ i)
#define per(i,a,b) for (int i = (a) ; i >= (b) ; -- i)
#define pii pair < int , int >
#define X first
#define Y second
#define rint read<int>
#define int long long
#define pb push_back
#define ls ( rt << 1 )
#define rs ( rt << 1 | 1 )
#define mid ( ( l + r ) >> 1 )
using std::queue ;
using std::set ;
using std::pair ;
using std::max ;
using std::min ;
using std::priority_queue ;
using std::vector ;
using std::swap ;
using std::sort ;
using std::unique ;
using std::greater ;
template < class T >
inline T read () {
T x = 0 , f = 1 ; char ch = getchar () ;
while ( ch < '0' || ch > '9' ) {
if ( ch == '-' ) f = - 1 ;
ch = getchar () ;
}
while ( ch >= '0' && ch <= '9' ) {
x = ( x << 3 ) + ( x << 1 ) + ( ch - 48 ) ;
ch = getchar () ;
}
return f * x ;
}
const int N = 1e6 + 100 ;
vector < int > G[N] ;
int f[N] , deep[N] , ans , idx[N] , cnt ;
int n , m , p[N][2] , siz[N] , son[N] , top[N] ;
struct seg {
int left , right , data , tag ;
inline int size () { return right - left + 1 ; }
} t[N<<2] ;
inline void dfs (int cur , int anc , int dep) {
f[cur] = anc ; deep[cur] = dep ; siz[cur] = 1 ;
int maxson = - 1 ; for (int k : G[cur]) {
if ( k == anc ) continue ;
dfs ( k , cur , dep + 1 ) ; siz[cur] += siz[k] ;
if ( siz[k] > maxson ) maxson = siz[k] , son[cur] = k ;
}
return ;
}
inline void _dfs (int cur , int topf) {
top[cur] = topf ; idx[cur] = ++ cnt ;
if ( ! son[cur] ) return ; _dfs ( son[cur] , topf ) ;
for (int k : G[cur]) {
if ( k == son[cur] || k == f[cur] ) continue ;
_dfs ( k , k ) ;
}
return ;
}
inline void pushup (int rt) { t[rt].data = t[ls].data + t[rs].data ; return ; }
inline void build (int rt , int l , int r) {
t[rt].left = l ; t[rt].right = r ; t[rt].tag = 0 ;
if ( l == r ) { t[rt].data = 0 ; return ; }
build ( ls , l , mid ) ; build ( rs , mid + 1 , r ) ;
pushup ( rt ) ; return ;
}
inline void pushdown (int rt) {
t[ls].tag += t[rt].tag ; t[rs].tag += t[rt].tag ;
t[ls].data += t[ls].size () * t[rt].tag ;
t[rs].data += t[rs].size () * t[rt].tag ;
t[rt].tag = 0 ; return ;
}
inline void update (int rt , int ll , int rr , int val) {
int l = t[rt].left , r = t[rt].right ;
if ( l == ll && r == rr ) { t[rt].tag += val ; t[rt].data += val ; return ; }
if ( t[rt].tag ) pushdown ( rt ) ;
if ( rr <= mid ) update ( ls , ll , rr , val ) ;
else if ( ll > mid ) update ( rs , ll , rr , val ) ;
else { update ( ls , ll , mid , val ) ; update ( rs , mid + 1 , rr , val ) ; }
pushup ( rt ) ; return ;
}
inline int query (int rt , int ll , int rr) {
int l = t[rt].left , r = t[rt].right ;
if ( ll == l && r == rr ) return t[rt].data ;
if ( t[rt].tag ) pushdown ( rt ) ;
if ( rr <= mid ) return query ( ls , ll , rr ) ;
else if ( ll > mid ) return query ( rs , ll , rr ) ;
else return query ( ls , ll , mid ) + query ( rs , mid + 1 , rr ) ;
}
inline int qrange (int x , int y) {
int res = 0 ;
while ( top[x] != top[y] ) {
if ( deep[top[x]] < deep[top[y]] ) swap ( x , y ) ;
res += query ( 1 , idx[top[x]] , idx[x] ) ; x = f[top[x]] ;
}
if ( deep[x] > deep[y] ) swap ( x , y ) ;
return res + query ( 1 , idx[x] , idx[y] ) ;
}
inline int LCA (int x , int y) {
while ( top[x] != top[y] )
deep[top[x]] < deep[top[y]] ? y = f[top[y]] : x = f[top[x]] ;
return deep[x] < deep[y] ? x : y ;
}
signed main (int argc , char * argv[]) {
n = rint () ; m = rint () ;
rep ( i , 2 , n ) {
int u = rint () , v = rint () ;
G[u].pb ( v ) ; G[v].pb ( u ) ;
}
dfs ( 1 , 0 , 1 ) ; _dfs ( 1 , 1 ) ; build ( 1 , 1 , cnt ) ;
rep ( i , 1 , m ) {
p[i][0] = rint () ; p[i][1] = rint () ;
int t = LCA ( p[i][0] , p[i][1] ) ;
update ( 1 , idx[t] , idx[t] , 1 ) ;
}
rep ( i , 1 , m ) ans += ( qrange ( p[i][0] , p[i][1] ) - 1 ) ;
rep ( i , 1 , n ) {
int tmp = query ( 1 , idx[i] , idx[i] ) ;
ans -= tmp * ( tmp - 1 ) / 2 ;
}
printf ("%lld\n" , ans ) ;
return 0 ;
}
May you return with a young heart after years of fighting.