题解 分组

传送门

考试最后5min想到了k=1时的骗分方案,可惜没时间写了……
这题代码写得有点恶心了,数都很小,完全没必要开4个hush表

k=1时要求同一个组中不能有冲突元素,这里如果枚举组中元素判冲突就是\(n^2\)
发现两元素冲突当且仅当它们加和为某个正整数\(a\)的平方,
考虑不枚举元素,而是通过枚举这个\(a\)判冲突
那时间复杂度能开个根号
这里可以推广,如果有题要求在一个集合中找两个数,使它们运算后的值为某些特定数,
则如果特定数的数量较少,可以考虑「对于集合中的每个数,枚举所有特定数,并检查逆运算结果是否在集合中」的方式避免\(n^2\)的判断

k=2时允许每个组内再乱序划分,
原来是个二分图判定问题啊,有空看看二分图都有点什么性质(咕)
当年关押罪犯的坑在这里填了:
「敌对并查集」其实就是维护点之间的敌对关系,如果这些关系刚好能塞进一张二分图里,则这些点能被分进两个集合,否则不能
然后如果有重复元素,注意特判一下\(2x=a^2\)\(2x=a_1^2\)\(x+y=a_2^2\)的情况就好

严重写麻烦了的 Code:

#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 140010
#define ll long long 
#define ld long double
#define usd unsigned
#define ull unsigned long long
//#define int long long 

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
char buf[1<<21], *p1=buf, *p2=buf;
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, k, m=1;
int a[N], sta[N], top;

namespace task1{
	struct hash_map{
		static const int SIZE=13131;
		int head[SIZE], size;
		//int cnt[SIZE];
		hash_map():size(0){memset(head, 0, sizeof(head));}
		struct node{int dat, next;}e[N];
		inline bool operator [] (int q) {
			int t=q%SIZE;
			for (int i=head[t]; i; i=e[i].next) 
				if (e[i].dat==q) return 1;
			return 0;
		}
		inline void add(int q) {
			int t=q%SIZE;
			//++cnt[t];
			node* k=&e[++size]; k->dat=q; k->next=head[t]; head[t]=size;
		}
		inline void clear() {size=0; memset(head, 0, sizeof(head));}
		//void out() {int k=0; for (int i=1; i<=SIZE; ++i) cout<<cnt[i]<<endl, k=max(k, cnt[i]); cout<<"max: "<<k<<endl;}
	}mp;
	
	void solve() {
		for (int i=n,t; i; --i) {
			for (int j=512; j; --j) 
				if (t=j*j, t>a[i]) {
					if (mp[t-a[i]]) {
						++m;
						sta[++top]=i;
						mp.clear();
						break;
					}
				}
				else break;
			mp.add(a[i]);
		}
		printf("%d\n", m);
		while (top) printf("%d ", sta[top--]);
		printf("\n");
	}
}

namespace task2{
	int fa[N<<1];
	inline int find(int p) {return fa[p]==p?p:fa[p]=find(fa[p]);}
	
	struct hash_map1{
		static const int SIZE=13131;
		int head[SIZE], size;
		//int cnt[SIZE];
		hash_map1():size(0){memset(head, 0, sizeof(head));}
		struct node{int dat, rank, next;}e[N];
		inline bool operator [] (int q) {
			int t=q%SIZE;
			for (int i=head[t]; i; i=e[i].next) 
				if (e[i].dat==q) return 1;
			return 0;
		}
		inline void add(int r, int q) {
			int t=q%SIZE;
			//++cnt[t];
			node* k=&e[++size]; k->dat=q; k->rank=r; k->next=head[t]; head[t]=size;
		}
		inline void clear() {size=0; memset(head, 0, sizeof(head));}
		//void out() {int k=0; for (int i=1; i<=SIZE; ++i) k=max(k, cnt[i]); cout<<"max: "<<k<<endl;}
	}mp;
	
	struct hash_map2{
		static const int SIZE=557;
		int head[SIZE], size;
		//int cnt[SIZE];
		hash_map2():size(0){memset(head, 0, sizeof(head));}
		struct node{int dat, cnt, next;}e[N];
		inline bool operator [] (int q) {
			int t=q%SIZE;
			for (int i=head[t]; i; i=e[i].next) 
				if (e[i].dat==q) 
					if (e[i].cnt>=2) return 1;
					else return 0;
			return 0;
		}
		inline void add(int q) {
			int t=q%SIZE;
			//++cnt[t];
			for (int i=head[t]; i; i=e[i].next) 
				if (e[i].dat==q) {
					++e[i].cnt;
					return ;
				}
			node* k=&e[++size]; k->dat=q; k->cnt=1; k->next=head[t]; head[t]=size;
		}
		inline void clear() {size=0; memset(head, 0, sizeof(head));}
		//void out() {int k=0; for (int i=1; i<=SIZE; ++i) k=max(k, cnt[i]); cout<<"max: "<<k<<endl;}
	}mp2;
	
	struct hash_map3{
		static const int SIZE=557;
		int head[SIZE], size;
		//int cnt[SIZE];
		hash_map3():size(0){memset(head, 0, sizeof(head));}
		struct node{int dat, next;}e[N];
		inline bool operator [] (int q) {
			int t=q%SIZE;
			for (int i=head[t]; i; i=e[i].next) 
				if (e[i].dat==q) return 1;
			return 0;
		}
		inline void add(int q) {
			int t=q%SIZE;
			//++cnt[t];
			node* k=&e[++size]; k->dat=q; k->next=head[t]; head[t]=size;
		}
		//void out() {int k=0; for (int i=1; i<=SIZE; ++i) k=max(k, cnt[i]); cout<<"max: "<<k<<endl;}
	}mp3;
	
	void solve() {
		for (int i=1,lim=n*2; i<=lim; ++i) fa[i]=i;
		for (int i=1; i<=512; ++i) mp3.add(i*i);
		for (int i=n,t,t2,f1,f2; i; --i) {
			//cout<<i<<endl;
			if (mp3[a[i]*2]&&mp2[a[i]]) {
				++m;
				sta[++top]=i;
				mp.clear();
				mp2.clear();
				fa[i]=i; fa[i+n]=i+n;
				mp.add(i, a[i]);
				mp2.add(a[i]);
				continue;
			}
			for (int j=512; j; --j) 
				if (t=j*j, t>a[i]) {
					int q=t-a[i];
					//if (q==a[i]) continue;
					int p=(q)%mp.SIZE;
					for (int l=mp.head[p]; l; l=mp.e[l].next) 
						if (mp.e[l].dat==q) {
							f1=find(i); f2=find(mp.e[l].rank);
							//cout<<"rank: "<<mp.e[l].rank<<endl;
							//cout<<i<<endl;
							//cout<<f1<<' '<<f2<<endl;
							fa[find(i+n)]=f2;
							fa[find(mp.e[l].rank+n)]=f1;
							//cout<<"judge: "<<f1<<' '<<f2<<endl;
							if (find(i)==find(i+n)) {
								//cout<<"pos1"<<endl;
								++m;
								sta[++top]=i;
								mp.clear();
								mp2.clear();
								fa[i]=i; fa[i+n]=i+n;
								goto jump;
							}
						}
				}
				else break;
			jump: 
			mp.add(i, a[i]);
			if (mp3[a[i]*2]) mp2.add(a[i]); //cout<<"add2 "<<a[i]<<endl;
		}
		printf("%d\n", m);
		while (top) printf("%d ", sta[top--]);
		printf("\n");
		//cout<<mp3[131072*2]<<' '<<mp2[131072]<<endl;
		//cout<<a[2657]<<endl;
		//mp.out(); mp2.out(); mp3.out();
	}
}

signed main()
{
	#ifdef DEBUG
	freopen("1.in", "r", stdin);
	#endif
	
	n=read(); k=read();
	for (int i=1; i<=n; ++i) a[i]=read();
	if (n==1) {printf("1\n\n"); return 0;}
	if (k&1) task1::solve();
	else task2::solve();

	return 0;
}
posted @ 2021-06-22 16:20  Administrator-09  阅读(19)  评论(0编辑  收藏  举报