18.10.18队测
T1 : 题意简述 : 三维偏序问题,每一维都是一个排列,求满足\((a_i<a_j,b_i<b_j,c_i<c_j)\)的有序二元组个数。
其中:\(n\leq2e6\).
这题的题目名字就叫cdq。。。显然我们不能cdq,时间复杂度不允许。
由于是排列,没有重复的数,所以可以容斥。先按\((a,b),(a,c),(b,c)\)做二维偏序,设答案为\(x,y,z\),那么不合法的肯定只会在任意一种二维偏序中出现一次,所以总答案为
\[ans=\frac{1}{2}(x+y+z-\binom{n}{2})
\]
#include<bits/stdc++.h>
#pragma GCC optimize(3)
#define ll long long
using namespace std;
void read(int &x){
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-f;
for(;isdigit(ch);ch=getchar())x=(x<<1)+(x<<3)+ch-'0';x*=f;
}
#define maxn 2000050
#define write(x) printf("%lld\n",x)
int n,k,tot;
struct data{int x,y,z;}s[maxn],p[maxn];
struct binary_index_tree{
int tree[maxn];void clear() {memset(tree,0,sizeof tree);}
void modify(int x,int v){for(int i=x;i<=n;i+=i&-i)tree[i]+=v;}
int query(int x,int ans=0){for(int i=x;i;i-=i&-i)ans+=tree[i];return ans;}
}T;
namespace make {
const int N = 2e6+5;
unsigned int SA,SB,SC;int a[N],b[N],c[N];
unsigned int rd(){
SA^=SA<<16;SA^=SA>>5;SA^=SA<<1;
unsigned int t=SA;SA=SB;SB=SC;SC^=t^SA;return SC;
}
void gen(int *P){
for (int i=1;i<=n;++i) P[i]=i;
for (int i=1;i<=n;++i) swap(P[i],P[1+rd()%n]);
}
void get(){
scanf("%d%u%u%u",&n,&SA,&SB,&SC);
gen(a);gen(b);gen(c);
for(int i=1;i<=n;i++) s[i].x=a[i],s[i].y=b[i],s[i].z=c[i];
}
}
int cmp_x(data x,data y) {return x.x<y.x;}
int cmp_y(data x,data y) {return x.y<y.y;}
int main(){
make :: get();sort(s+1,s+n+1,cmp_x);long long ans=0;
for(int i=1;i<=n;i++) ans+=T.query(s[i].y),T.modify(s[i].y,1);T.clear();
for(int i=1;i<=n;i++) ans+=T.query(s[i].z),T.modify(s[i].z,1);T.clear();
sort(s+1,s+n+1,cmp_y);for(int i=1;i<=n;i++) ans+=T.query(s[i].z),T.modify(s[i].z,1);
write((ans-1ll*n*(n-1)/2)/2);
return 0;
}
T2过于毒瘤,不会。。。。。
T3 : 题意简述 : 给出一颗树,有\(q\)个询问,每次询问给出一条链\((u,v)\),要找出\(k\)个点对,使得两两点对的公共路径为\((u,v)\),对于每次询问会要求可否选重复的点,求方案数。
设\(f[u][k]\)为\(u\)的子树中选\(k\)个点,且每个\(u\)的儿子的子树中最多只能有一个点的方案数,考虑已经算完了\(u\)的\(1\)~ \(~n-1\) 的儿子对它的贡献,现在要新增一个点,显然可以得到转移\(f[u][k]+=f[u][k-1]*sz[e[i].to]\).要逆序转移。
对于\(u,v\)没有祖先关系,直接拿\(f\)统计就好了。
否则,设\(u\)为\(v\)的祖先,我们就以\(v\)的方向为根,对与\(u\)类似与\(f\)反向处理出新的\(f\)统计答案,具体的细节看代码吧。。
#include<bits/stdc++.h>
#pragma GCC optimize(3)
#define int long long
using namespace std;
void read(int &x){
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-f;
for(;isdigit(ch);ch=getchar())x=(x<<1)+(x<<3)+ch-'0';x*=f;
}
#define write(x) printf("%lld\n",x)
#define mod 998244353
#define maxn 100050
int n,k,tot,q,head[maxn],sz[maxn],f[maxn][501],dfn[maxn],dfn_cnt,fac[maxn],ifac[maxn],g[501],Fa[maxn];
struct edge{int to,nxt;}e[maxn<<1];
int qpow(int a,int x,int res=1) {for(;x;x>>=1,a=a*a%mod) if(x&1) res=res*a%mod;return res;}
void ins(int u,int v) {e[++tot].to=v,e[tot].nxt=head[u],head[u]=tot;}
void dfs(int x,int fa) {
dfn[x]=++dfn_cnt;sz[x]=1;f[x][0]=1;int top=0;Fa[x]=fa;
for(int i=head[x];i;i=e[i].nxt)
if(e[i].to!=fa) {
dfs(e[i].to,x);sz[x]+=sz[e[i].to];
for(int j=top;~j;j--) (f[x][j+1]+=sz[e[i].to]*f[x][j])%=mod;top++;
}
}
void solve(int u,int v,int k) {
int y=0;memcpy(g,f[u],sizeof g);
for(int i=head[u];i;i=e[i].nxt) if(e[i].to!=Fa[u]) if(dfn[e[i].to]+sz[e[i].to]-1>=dfn[v]&&dfn[e[i].to]<=dfn[v]) {y=e[i].to;break;}//write(y);
for(int i=1;i<=k;i++) g[i]=((g[i]-g[i-1]*sz[y])%mod+mod)%mod;
for(int i=k;i;i--) g[i]=(g[i]+g[i-1]*(n-sz[u]))%mod;
}
signed main() {
read(n),read(q);int x,y;
for(int i=1;i<n;i++) read(x),read(y),ins(x,y),ins(y,x);
fac[0]=ifac[0]=1;for(int i=1;i<=500;i++) fac[i]=fac[i-1]*i%mod;
ifac[500]=qpow(fac[500],mod-2);for(int i=500-1;i;i--) ifac[i]=ifac[i+1]*(i+1)%mod;
dfs(1,0);
while(q--) {
int k,op,u,v;read(u),read(v),read(k),read(op);
if(dfn[u]>dfn[v]) swap(u,v);
if(op==1) {
if(dfn[v]>=dfn[u]+sz[u]) write((f[u][k]+f[u][k-1])*(f[v][k]+f[v][k-1])%mod*fac[k]%mod*fac[k]%mod);//,puts("A");
else {
int ans=(f[v][k]+f[v][k-1])*fac[k]%mod;
solve(u,v,k);ans=ans*(g[k]+g[k-1])%mod*fac[k]%mod;write(ans);//puts("B");
}
} else {
if(dfn[v]>=dfn[u]+sz[u]) {
int ans0=0,ans1=0;
for(int i=0;i<=k;i++) ans0=(ans0+f[u][i]*fac[k]%mod*ifac[k-i]%mod)%mod;
for(int i=0;i<=k;i++) ans1=(ans1+f[v][i]*fac[k]%mod*ifac[k-i]%mod)%mod;
write(ans0*ans1%mod);//puts("C");
} else {
int ans0=0,ans1=0;
for(int i=0;i<=k;i++) ans1=(ans1+f[v][i]*fac[k]%mod*ifac[k-i]%mod)%mod;
solve(u,v,k);for(int i=0;i<=k;i++) ans0=(ans0+g[i]*fac[k]%mod*ifac[k-i]%mod)%mod;
write(ans0*ans1%mod);//puts("D");
}
}
}
return 0;
}