CF1868C Travel Plan
题目大意
给定一颗 \(n\) 个节点的完全二叉树,每个点有权值 \(a_i \in [1,m]\),定义从 \(i\) 到 \(j\) 的路径的权值 \(s_{i,j}\) 为路径上的最大点权。
求所有树(\(m^n\) 种点权)的 \(\sum_{i=1}^n \sum_{j=i}^n s_{i,j}\) 的和,模 \(998244353\)。
\[n\leq 10^{18},m\leq 10^5
\]
题解
我们可以分别计算每一条长度为\(l\)的路径的贡献
\(\max\leq i\)的方案数为\(i^l\),则\(\max=i\)的方案数为\(i^l-(i-1)^l\),乘上\(i\),非路径上的点随意的方案数\(m^{(n-l)}\)即为长度为\(l\)的路径的贡献,这部分可以用\(O(m\log n)\)的时间复杂度处理出来
然后就是计算长度为\(l\)的路径有多少条,可以枚举二叉树的高度\(d\),计算完全二叉树最底下高度为\(d\)的子树,只计算经过子树顶点的路径,枚举左子树中路径长度和右子树路径长度,左右两边的子树都是满二叉树,只有中间的一个子树可能不是满二叉树,单独计算即可,这部分时间复杂度为\(O(\log^3 n)\)
总时间复杂度为\(O(\sum m\log n+\log^3n)\)
也可以\(dp\),设\(f_k\)为大小为\(k\)的子树的贡献,观察一下发现只会有\(O(\log n)\)级别的状态,也是可以做的
\(\text{code}\)
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;
const ll mod=998244353;
const int L=120;
ll n;int m;
ll ksm(ll a,ll b)
{
if(b==0) return 1;
ll tmp=ksm(a,b>>1);
if(b&1) return tmp*tmp%mod*a%mod;
else return tmp*tmp%mod;
}
ll val[L+10],g[L+10];
void Add(ll &a,ll b){a+=b;if(a>=mod) a-=mod;}
ll f(int i){return i==0?1ll:(1ll<<i-1);}
ll calc(int d)
{
ll res=0;
for(int i=0;i<=d;i++)
for(int j=0;j<=d;j++)
{
ll k1=f(i)%mod,k2=f(j)%mod;
Add(res,k1*k2%mod*val[i+j+1]%mod);
}
return res;
}
ll calc1(int d,ll n)
{
// printf("calc1 %d %lld\n",d,n);
ll res=0;
for(int i=0;i<=d;i++)
for(int j=0;j<=d;j++)
{
ll k1=f(i),k2=f(j);
if(i==d) k1=min(k1,n);
if(j==d) k2=min(k2,max(n-k2,0ll));
k1%=mod,k2%=mod;
Add(res,val[i+j+1]*k1%mod*k2%mod);
}
return res;
}
int main()
{
// freopen("e.in","r",stdin);
int T;
scanf("%d",&T);
for(;T--;)
{
scanf("%lld %d",&n,&m);
for(int i=1;i<=min(n,1ll*L);i++) val[i]=0;
for(int k=1;k<=m;k++)
{
ll k1=1,k2=1;
for(int i=1;i<=min(n,1ll*L);i++)
{
k1=k1*k%mod,k2=k2*(k-1)%mod;
Add(val[i],(k1-k2+mod)%mod*k%mod);
}
}
for(int i=1;i<=min(n,1ll*L);i++) val[i]=val[i]*ksm(m,n-i)%mod;
int d=0;
for(d=0;;d++)
{
if(n>=(1ll<<d)) n-=(1ll<<d);
else break;
}
if(n==0) d--,n=(1ll<<d);
// printf("fuc %d %lld\n",d,n);
// printf("qnm %lld\n",calc1(1,1));
ll full=(1ll<<d);
ll ans=0;
for(int i=0;i<=d;i++)
{
ll tmp=(1ll<<i);
ll k1=n/tmp,k2=(full-n)/tmp,k3=n%tmp;
Add(ans,k1%mod*calc(i)%mod);
if(i>0) Add(ans,k2%mod*calc(i-1)%mod);
if(k3!=0) Add(ans,calc1(i,k3));
}
printf("%lld\n",ans);
}
return 0;
}