2018.10.18--多校联测第三场测试总结

T1:容斥+二维偏序,考场上用cdq爆搞出36分。
T2:据说是容斥裸题,然而我至今都没写出来。考场上puts("0")拿了5分。
T3:树上背包优化,估计60分,实际60分。
今天三道题都没怎么见过类似的题…………
T1:
我们令\(x\)\(a_i<a_j,b_i<b_j\)的二元组\((i,j)\)的数量,\(y\)\(a_i<a_j,c_i<c_j\)的二元组\((i,j)\)的数量,\(z\)\((b_i<b_j,c_i<c_j)\)的二元组的数量,那么对于每一个合法三元组会在x,y,z中被算到三次,不合法的会被算到一次。所以答案就是\((x+y+z+C_n^2)/2\)
代码如下:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define ll long long
#define low(i) ((i)&(-i))

const int N = 2e6+5;

ll ans;
int n,a[N],b[N],c[N];
unsigned int SA,SB,SC;

unsigned int rd(){
	SA^=SA<<16;SA^=SA>>5;SA^=SA<<1;
	unsigned int t=SA;SA=SB;SB=SC;SC^=t^SA;return SC;
}

void gen(int *P){
	for (int i=1;i<=n;++i) P[i]=i;
	for (int i=1;i<=n;++i) swap(P[i],P[1+rd()%n]);
}

struct fake {
	int a,b;
	
	bool operator<(const fake &gey)const {
		if(a==gey.a)return b<gey.b;
		return a<gey.a;
	}
}p[N];

struct Tree_array {
	int c[N];
	
	int query(int pos) {
		int res=0;
		for(int i=pos;i;i-=low(i))
			res+=c[i];
		return res;
	}
	
	void add(int pos) {
		for(int i=pos;i<=n;i+=low(i))
			c[i]++;
	}
}T[3];

ll calc(int id) {
	ll res=0;sort(p+1,p+n+1);
	for(int i=1;i<=n;i++) {
		res+=T[id].query(p[i].b);
		T[id].add(p[i].b);
	}
	return res;
}

int main(){
	freopen("cdq.in","r",stdin);
	freopen("cdq.out","w",stdout);
	scanf("%d%u%u%u",&n,&SA,&SB,&SC);
	gen(a);gen(b);gen(c);
	for(int i=1;i<=n;i++)
		p[i].a=a[i],p[i].b=b[i];
	ans+=calc(0);
	for(int i=1;i<=n;i++)
		p[i].a=a[i],p[i].b=c[i];
	ans+=calc(1);
	for(int i=1;i<=n;i++)
		p[i].a=b[i],p[i].b=c[i];
	ans+=calc(2);
	printf("%lld\n",(ans-1ll*n*(n-1)/2)/2);
	return 0;
}

T2:不会……贴个题解就走……毒瘤瘤……
题解:https://files.cnblogs.com/files/zhoushuyu/sol.pdf
T3:对于每个询问,我们在u的子树里和v的子树(根看情况改)里选k个点乘起来就行了。用背包做一下就行了,然后背包消除影响可以暴力\(O(d)\)的做。
代码如下:

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
 
const int maxn=100005,pps=998244353;
 
int n,q,tot,tim;
int fac[maxn],inv[maxn];
int tmp[505],bag[maxn][505],deg[maxn];
int now[maxn],pre[maxn*2],son[maxn*2];
int dep[maxn],dfn[maxn],pos[maxn],siz[maxn];
 
int read() {
    int x=0,f=1;char ch=getchar();
    for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
    for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
    return x*f;
}
 
void add(int a,int b) {
    pre[++tot]=now[a];
    now[a]=tot;son[tot]=b;
}
 
int quick(int a,int b) {
    int sum=1;
    while(b) {
        if(b&1)sum=1ll*sum*a%pps;
        a=1ll*a*a%pps;b>>=1;
    }
    return sum;
}
 
void dfs(int fa,int u) {
    dep[u]=dep[fa]+1;dfn[++tim]=u;
    bag[u][0]=1;siz[u]=1;pos[u]=tim;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(v!=fa) {
            dfs(u,v),siz[u]+=siz[v];
            deg[u]++;
            for(int i=deg[u];i;i--)
                bag[u][i]=(bag[u][i]+1ll*bag[u][i-1]*siz[v]%pps)%pps;
        }
}
 
int find(int u,int fake) {
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(pos[v]>pos[u]&&pos[v]+siz[v]-1>=pos[fake])
            return v;
    return 0;
}
 
int solve(int a,int b,int sum,int opt) {
    int oxy=0;
    if(pos[a]<pos[b]&&pos[a]+siz[a]-1>=pos[b])oxy=find(a,b);
    if(pos[b]<pos[a]&&pos[b]+siz[b]-1>=pos[a])oxy=find(b,a),swap(a,b);
    int ans=0;memcpy(tmp,bag[a],sizeof(tmp));
    if(oxy) {
        for(int i=1;i<=sum;i++) {
            tmp[i]=(tmp[i]-1ll*tmp[i-1]*siz[oxy]%pps)%pps;
            tmp[i]=(tmp[i]+pps)%pps;
        }
        for(int i=sum;i;i--)
            tmp[i]=(1ll*tmp[i-1]*(n-siz[a])%pps+tmp[i])%pps;
    }
    if(opt) {
        for(int i=sum;i;i--)
            tmp[i]=(tmp[i-1]+tmp[i])%pps;
    }
    if(opt)ans=1ll*tmp[sum]%pps;
    else for(int i=0;i<=sum;i++)ans=(ans+1ll*tmp[i]*inv[sum-i]%pps)%pps;
    int fake=0;memcpy(tmp,bag[b],sizeof(tmp));
    if(opt) {
        for(int i=sum;i;i--)
            tmp[i]=(tmp[i-1]+tmp[i])%pps;
    }
    if(opt)fake=1ll*tmp[sum]%pps;
    else for(int i=0;i<=sum;i++)fake=(fake+1ll*tmp[i]*inv[sum-i]%pps)%pps;
    return 1ll*fake*ans%pps*fac[sum]%pps*fac[sum]%pps;
}
 
void prepare() {
    fac[0]=fac[1]=1;
    for(int i=2;i<=n;i++)
        fac[i]=1ll*fac[i-1]*i%pps;
    inv[0]=1;inv[n]=quick(fac[n],pps-2);
    for(int i=n-1;i;i--)
        inv[i]=1ll*inv[i+1]*(i+1)%pps;
}
 
int main() {
    n=read(),q=read();prepare();
    for(int i=1;i<n;i++) {
        int a=read(),b=read();
        add(a,b);add(b,a);
    }dfs(0,1);
    for(int i=1;i<=q;i++) {
        int u=read(),v=read(),k=read(),opt=read();
        int ans=solve(u,v,k,opt);
        printf("%d\n",ans);
    }
    return 0;
}

posted on 2018-10-19 10:14  HYSBZ_mzf  阅读(192)  评论(0编辑  收藏  举报

导航