「Gym102979E」Expected Distance
题目
点这里看题目。
分析
对答案变形:
因此我们需要关心的是最终路径上非 \(u\) 且非 \(v\) 的点的点权和。
假设 \(u<v\),我们先画一条路径来分析一下:
考虑路径出现的概率,对于任何一条边 \((p,q),p<q\),设 \(s_n=\sum_{i=1}^na_i\),则该边出现的概率为 \(\frac{a_p}{s_{q-1}}\)。每条边的概率之积则为该路径出现的概率。而我们可以重新排列,使得各个点贡献独立。比如上面这条路径,\(u,v\) 分别贡献 \(s^{-1}_{u-1},s_{v-1}^{-1}\),\(lca\) 会贡献 \(a^2_{lca}\),而路径上的点 \(w\in \{a,b,c,d,e\}\) 则会贡献 \(\frac{a_w}{s_{w-1}}\)。
所以,当我们确定了 \(u,v,lca\) 之后,剩余的点可出现可不出现,系数易于计算。因此自然的思路是尝试枚举 \(P\setminus\{u,v,lca\}\),反向构造出路径来。
在此之前,我们需要明确当路径以 \(lca\) 为根的时候,路径在标号上满足堆性质。因此假如我们知道了 \(u\) 到 \(lca\) 上的点和 \(v\) 到 \(lca\) 上的点,则两条链上点的顺序是确定的。那么对于 \(w\in P\setminus \{u,v,lca\}\),如果有 \(w<u\),则在反向构建的时候,\(w\) 既可以被放入 \(u\) 链,也可以被放入 \(v\) 链,贡献 2 的系数;而如果 \(u<w<v\),则 \(w\) 不能被放入 \(u\) 链,只能被放入 \(v\) 链,贡献 1 的系数。
此外,注意到概率是乘法的关系,而最终的点权需要求和,自然的想法便是将点权放到指数上,使用概率生成函数。
因此,对于路径 \((u,v)\),我们可以构建 \(F_k(x)\) 为 \(lca=k\) 时该路径的点权的概率生成函数,根据上面的分析不难得出:
根据结论,计算出 \(\sum_{1\le k\le u}F_k'(1)\) 就是答案,使用若干个前缀数组记录一下即可。
计算量确实比较大,但是带入值的时候,由于 \(c_j\ge 1\),细节不多。
小结:
-
注意对问题下手的角度,这里是对链本身考虑,而另一种更简单的方法是将贡献拆分成 \(dep_{u}+dep_{v}-2dep_{lca}\) 的类似的形式,于是就转化到了祖先与子孙上。
虽然这个做法比较复杂,但是我觉得挺漂亮的; -
注意细节,算错好多次导致浪费了不少时间,比如最开始的时候就没有意识到 2 的方案系数。这种题可以通过手玩关注到不少细节。
另一方面,这其实也说明计数的思路不太完整,忽略了路径的点在分配时候带来的方案系数。
这个其实随便画一画 \(n=4\) 都可以发现好嘛。
代码
#include <cstdio>
#include <iostream>
#define rep( i, a, b ) for( int i = (a) ; i <= (b) ; i ++ )
#define per( i, a, b ) for( int i = (a) ; i >= (b) ; i -- )
const int mod = 1e9 + 7;
const int MAXN = 3e5 + 5;
template<typename _T>
void read( _T &x )
{
x = 0; char s = getchar(); int f = 1;
while( ! ( '0' <= s && s <= '9' ) ) { f = 1; if( s == '-' ) f = -1; s = getchar(); }
while( '0' <= s && s <= '9' ) { x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar(); }
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ) putchar( '-' ), x = -x;
if( 9 < x ) write( x / 10 );
putchar( x % 10 + '0' );
}
int P[MAXN], Pder[MAXN];
int H[MAXN], Hder[MAXN];
int T[MAXN], TDer[MAXN];
int R[MAXN], RDer[MAXN];
int pre[MAXN], preDer1[MAXN], preDer2[MAXN];
int a[MAXN], s[MAXN], c[MAXN], invS[MAXN];
int N, Q;
inline int Qkpow( int, int );
inline int Mul( int x, int v ) { return 1ll * x * v % mod; }
inline int Inv( const int a ) { return Qkpow( a, mod - 2 ); }
inline int Sub( int x, int v ) { return ( x -= v ) < 0 ? x + mod : x; }
inline int Add( int x, int v ) { return ( x += v ) >= mod ? x - mod : x; }
inline int Sqr( const int x ) { return Mul( x, x ); }
inline int Qkpow( int base, int indx )
{
int ret = 1;
while( indx )
{
if( indx & 1 ) ret = Mul( ret, base );
base = Mul( base, base ), indx >>= 1;
}
return ret;
}
int main()
{
read( N ), read( Q );
rep( i, 1, N - 1 ) read( a[i] );
rep( i, 1, N ) read( c[i] );
rep( i, 1, N ) invS[i] = Inv( s[i] = Add( s[i - 1], a[i] ) );
R[0] = H[0] = 1;
rep( i, 1, N )
{
T[i] = Add( 1, Mul( a[i], invS[i - 1] ) );
TDer[i] = Mul( Mul( a[i], c[i] ), invS[i - 1] );
R[i] = Mul( R[i - 1], T[i] );
RDer[i] = Add( Mul( RDer[i - 1], T[i] ), Mul( R[i - 1], TDer[i] ) );
P[i] = Add( 1, Mul( Mul( 2, a[i] ), invS[i - 1] ) );
Pder[i] = Mul( Mul( Mul( 2, a[i] ), c[i] ), invS[i - 1] );
H[i] = Mul( H[i - 1], P[i] );
Hder[i] = Add( Mul( Hder[i - 1], P[i] ), Mul( H[i - 1], Pder[i] ) );
int inv = Inv( H[i] );
pre[i] = Add( pre[i - 1], Mul( inv, Sqr( a[i] ) ) );
preDer1[i] = Add( preDer1[i - 1], Mul( Sqr( inv ), Mul( Hder[i], Sqr( a[i] ) ) ) );
preDer2[i] = Add( preDer2[i - 1], Mul( inv, Mul( Sqr( a[i] ), c[i] ) ) );
}
while( Q -- )
{
int u, v;
read( u ), read( v );
if( u > v ) std :: swap( u, v );
if( u == v ) { puts( "0" ); continue; }
int ans = 0, res = 0, G;
ans = Mul( Mul( a[u], invS[v - 1] ), Mul( Sqr( Inv( R[u] ) ),
Sub( Mul( RDer[v - 1], R[u] ), Mul( R[v - 1], RDer[u] ) ) ) );
G = Mul( Sqr( Inv( R[u] ) ), Sub( Add( Mul( Mul( Hder[u - 1], R[v - 1] ), R[u] ),
Mul( Mul( H[u - 1], RDer[v - 1] ), R[u] ) ),
Mul( Mul( H[u - 1], R[v - 1] ), RDer[u] ) ) );
res = Add( res, Mul( G, pre[u - 1] ) );
G = Mul( Mul( H[u - 1], R[v - 1] ), Inv( R[u] ) );
res = Add( res, Mul( G, Sub( preDer2[u - 1], preDer1[u - 1] ) ) );
ans = Add( ans, Mul( res, Mul( invS[u - 1], invS[v - 1] ) ) );
write( Add( Mul( ans, 2 ), Add( c[u], c[v] ) ) ), putchar( '\n' );
}
return 0;
}