2018.10.18--多校联测第三场测试总结
T1:容斥+二维偏序,考场上用cdq爆搞出36分。
T2:据说是容斥裸题,然而我至今都没写出来。考场上puts("0")拿了5分。
T3:树上背包优化,估计60分,实际60分。
今天三道题都没怎么见过类似的题…………
T1:
我们令\(x\)为\(a_i<a_j,b_i<b_j\)的二元组\((i,j)\)的数量,\(y\)为\(a_i<a_j,c_i<c_j\)的二元组\((i,j)\)的数量,\(z\)为\((b_i<b_j,c_i<c_j)\)的二元组的数量,那么对于每一个合法三元组会在x,y,z中被算到三次,不合法的会被算到一次。所以答案就是\((x+y+z+C_n^2)/2\)。
代码如下:
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define ll long long
#define low(i) ((i)&(-i))
const int N = 2e6+5;
ll ans;
int n,a[N],b[N],c[N];
unsigned int SA,SB,SC;
unsigned int rd(){
SA^=SA<<16;SA^=SA>>5;SA^=SA<<1;
unsigned int t=SA;SA=SB;SB=SC;SC^=t^SA;return SC;
}
void gen(int *P){
for (int i=1;i<=n;++i) P[i]=i;
for (int i=1;i<=n;++i) swap(P[i],P[1+rd()%n]);
}
struct fake {
int a,b;
bool operator<(const fake &gey)const {
if(a==gey.a)return b<gey.b;
return a<gey.a;
}
}p[N];
struct Tree_array {
int c[N];
int query(int pos) {
int res=0;
for(int i=pos;i;i-=low(i))
res+=c[i];
return res;
}
void add(int pos) {
for(int i=pos;i<=n;i+=low(i))
c[i]++;
}
}T[3];
ll calc(int id) {
ll res=0;sort(p+1,p+n+1);
for(int i=1;i<=n;i++) {
res+=T[id].query(p[i].b);
T[id].add(p[i].b);
}
return res;
}
int main(){
freopen("cdq.in","r",stdin);
freopen("cdq.out","w",stdout);
scanf("%d%u%u%u",&n,&SA,&SB,&SC);
gen(a);gen(b);gen(c);
for(int i=1;i<=n;i++)
p[i].a=a[i],p[i].b=b[i];
ans+=calc(0);
for(int i=1;i<=n;i++)
p[i].a=a[i],p[i].b=c[i];
ans+=calc(1);
for(int i=1;i<=n;i++)
p[i].a=b[i],p[i].b=c[i];
ans+=calc(2);
printf("%lld\n",(ans-1ll*n*(n-1)/2)/2);
return 0;
}
T2:不会……贴个题解就走……毒瘤瘤……
题解:https://files.cnblogs.com/files/zhoushuyu/sol.pdf
T3:对于每个询问,我们在u的子树里和v的子树(根看情况改)里选k个点乘起来就行了。用背包做一下就行了,然后背包消除影响可以暴力\(O(d)\)的做。
代码如下:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=100005,pps=998244353;
int n,q,tot,tim;
int fac[maxn],inv[maxn];
int tmp[505],bag[maxn][505],deg[maxn];
int now[maxn],pre[maxn*2],son[maxn*2];
int dep[maxn],dfn[maxn],pos[maxn],siz[maxn];
int read() {
int x=0,f=1;char ch=getchar();
for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
return x*f;
}
void add(int a,int b) {
pre[++tot]=now[a];
now[a]=tot;son[tot]=b;
}
int quick(int a,int b) {
int sum=1;
while(b) {
if(b&1)sum=1ll*sum*a%pps;
a=1ll*a*a%pps;b>>=1;
}
return sum;
}
void dfs(int fa,int u) {
dep[u]=dep[fa]+1;dfn[++tim]=u;
bag[u][0]=1;siz[u]=1;pos[u]=tim;
for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
if(v!=fa) {
dfs(u,v),siz[u]+=siz[v];
deg[u]++;
for(int i=deg[u];i;i--)
bag[u][i]=(bag[u][i]+1ll*bag[u][i-1]*siz[v]%pps)%pps;
}
}
int find(int u,int fake) {
for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
if(pos[v]>pos[u]&&pos[v]+siz[v]-1>=pos[fake])
return v;
return 0;
}
int solve(int a,int b,int sum,int opt) {
int oxy=0;
if(pos[a]<pos[b]&&pos[a]+siz[a]-1>=pos[b])oxy=find(a,b);
if(pos[b]<pos[a]&&pos[b]+siz[b]-1>=pos[a])oxy=find(b,a),swap(a,b);
int ans=0;memcpy(tmp,bag[a],sizeof(tmp));
if(oxy) {
for(int i=1;i<=sum;i++) {
tmp[i]=(tmp[i]-1ll*tmp[i-1]*siz[oxy]%pps)%pps;
tmp[i]=(tmp[i]+pps)%pps;
}
for(int i=sum;i;i--)
tmp[i]=(1ll*tmp[i-1]*(n-siz[a])%pps+tmp[i])%pps;
}
if(opt) {
for(int i=sum;i;i--)
tmp[i]=(tmp[i-1]+tmp[i])%pps;
}
if(opt)ans=1ll*tmp[sum]%pps;
else for(int i=0;i<=sum;i++)ans=(ans+1ll*tmp[i]*inv[sum-i]%pps)%pps;
int fake=0;memcpy(tmp,bag[b],sizeof(tmp));
if(opt) {
for(int i=sum;i;i--)
tmp[i]=(tmp[i-1]+tmp[i])%pps;
}
if(opt)fake=1ll*tmp[sum]%pps;
else for(int i=0;i<=sum;i++)fake=(fake+1ll*tmp[i]*inv[sum-i]%pps)%pps;
return 1ll*fake*ans%pps*fac[sum]%pps*fac[sum]%pps;
}
void prepare() {
fac[0]=fac[1]=1;
for(int i=2;i<=n;i++)
fac[i]=1ll*fac[i-1]*i%pps;
inv[0]=1;inv[n]=quick(fac[n],pps-2);
for(int i=n-1;i;i--)
inv[i]=1ll*inv[i+1]*(i+1)%pps;
}
int main() {
n=read(),q=read();prepare();
for(int i=1;i<n;i++) {
int a=read(),b=read();
add(a,b);add(b,a);
}dfs(0,1);
for(int i=1;i<=q;i++) {
int u=read(),v=read(),k=read(),opt=read();
int ans=solve(u,v,k,opt);
printf("%d\n",ans);
}
return 0;
}