QBXT模拟赛T3
set+树上操作
这道题主要是用set维护每个点的01序列,然后启发式合并
但是我tm(拒绝口吐莲花)调了好几个小时!!!!!!!!!!!!!!!!!!!!!!!!!
主要是set的使用问题:
1、set一定要插入一个极小值,不然可能会越界RE
2、set使用的迭代器还是很好用的,并且迭代器只能++或者--,最好用一下中间变量存一下
3、set的迭代器注意不要用混了!!!
4、set的insert插入一个结构体时注意元素位置不要打错了
代码:
#include<iostream>
#include<cstdio>
#include<queue>
#include<algorithm>
#include<cmath>
#include<set>
#include<cstring>
#define max(a,b) a>b?a:b
#define min(a,b) a<b?a:b
#define maxn 700006
#define mod 1000000007
#define pb push_back
#define all (zmd)
#define pot set<zmd>::iterator
#define rep(i,a,b) for (int i=a;i<=b;++i)
#define erep(i,a) for (int i=head[a];i!=-1;i=e[i].next)
using namespace std;
#define int long long
struct zmd
{
int id,l,r;
bool operator < (const zmd &a) const
{
return r<a.r;
}
};
set<zmd>st[maxn];
struct hzw
{
int to,next;
}e[maxn];
inline int dc(int k)
{
return (1+k)*k/2;
}
int head[maxn],cur,n,real[maxn],ans,size[maxn],cost[maxn];
inline void add(int a,int b)
{
e[cur].to=b;
e[cur].next=head[a];
head[a]=cur++;
}
inline void FFT(int a,int b)
{
for (pot it = st[a].begin();it != st[a].end();++it)
{
if (it == st[a].begin())continue;
if (!it->id) continue;
int ll = it->l,rr = it->r;//dang qian yao cha ru de qu jian
pot it2 = st[b].lower_bound(all{1,ll,rr}); //fa xian yi ge di yi ge da yu de weizhi
int nowl = it2->l,nowr = it2->r; //now
pot up = ++it2;it2--;
pot down = --it2;it2++;
cost[b]-=dc(nowr-nowl+1);
if (rr != nowr && ll != nowl)
{
size[b] += 2;
st[b].erase(it2);
st[b].insert(all{1,ll,rr});
cost[b]+=dc(rr-ll+1);
st[b].insert(all{0,nowl,ll-1});
cost[b]+=dc(ll-nowl);
st[b].insert(all{0,rr+1,nowr});
cost[b]+=dc(nowr-rr);
}
else if (ll == nowl && rr == nowr)
{
if ( down != st[b].begin())
{
ll = down->l;
cost[b]-=dc(down->r-down->l+1);
st[b].erase(down);
}
if (up != st[b].end())
{
rr = up->r;
cost[b]-=dc(up->r-up->l+1);
st[b].erase(up);
}
st[b].erase(it2);
st[b].insert(all{1,ll,rr});
cost[b]+=dc(rr-ll+1);
}
else
{
if (down != st[b].begin() && nowl == ll)
{
ll = (down)->l;
cost[b]-=dc(down->r-down->l+1);
st[b].erase(down);
st[b].erase(it2);
cost[b]+=dc(rr-ll+1);
st[b].insert(all{1,ll,rr});
cost[b]+=dc(nowr-rr);
st[b].insert(all{0,rr+1,nowr});
}
else if (up != st[b].end() && nowr == rr)
{
rr = (up)->r;
cost[b]-=dc(up->r-up->l+1);
st[b].erase(up);
st[b].erase(it2);
cost[b]+=dc(rr-ll+1);
st[b].insert(all{1,ll,rr});
cost[b]+=dc(ll-nowl);
st[b].insert(all{0,nowl,ll-1});
}
else
{
st[b].erase(it2);
if (nowl == ll)
{
st[b].insert(all{1,ll,rr});
st[b].insert(all{0,rr+1,nowr});
cost[b]+=dc(rr-ll+1)+dc(nowr-rr);
}
else
{
st[b].insert(all{1,ll,rr});
st[b].insert(all{0,nowl,ll-1});
cost[b]+=dc(rr-ll+1)+dc(ll-nowl);
}
}
}
}
}
inline void dfs(int s,int fa)
{
int mx = -1,id = s;
erep(i,s)
{
if (e[i].to == fa) continue;
dfs(e[i].to,s);
if (size[real[e[i].to]]>mx)
{
mx = size[real[e[i].to]];
id = real[e[i].to];
}
}
real[s]=id;
erep(i,s)
{
if (e[i].to==fa) continue;
if (real[e[i].to] == id) continue;
FFT(real[e[i].to],id);
st[real[e[i].to]].clear();
}
st[n+1].clear(),st[n+2].clear();
st[n+1].insert(all{0,1,n}),st[n+1].insert(all{-1,-1,-1});
st[n+2].insert(all{1,s,s}),st[n+2].insert(all{-1,-1,-1});
FFT(n+2,n+1);
FFT(n+1,id);
if (s!=1) ans += cost[id];
}
#undef int
int main()
{
#define int long long
freopen("bye.in","r",stdin);
freopen("bye.out","w",stdout);
memset(head,-1,sizeof(head));
cin>>n;
rep(i,1,n+2) st[i].insert(all{-1,-1,-1}),st[i].insert(all{0,1,n}),cost[i] = dc(n),size[i] = 1;
rep(i,1,n-1)
{
int a,b;
scanf("%lld%lld",&a,&b);
add(a,b);
add(b,a);
}
dfs(1,1);
cout<<ans;
}
代码2:
#include<iostream>
#include<cstdio>
#include<queue>
#include<algorithm>
#include<cmath>
#include<set>
#include<cstring>
#define max(a,b) a>b?a:b
#define min(a,b) a<b?a:b
#define maxn 700006
#define mod 1000000007
#define pb push_back
#define all (zmd)
#define pot set<zmd>::iterator
#define rep(i,a,b) for (int i=a;i<=b;++i)
#define erep(i,a) for (int i=head[a];i!=-1;i=e[i].next)
using namespace std;
#define int long long
struct zmd {int id,l,r;bool operator < (const zmd &a) const {return r<a.r;}};
set<zmd>st[maxn];
struct hzw {int to,next;} e[maxn];
inline int dc(int k) {return (1+k)*k/2;}
int head[maxn],cur,n,real[maxn],ans,size[maxn],cost[maxn];
inline void add(int a,int b) {e[cur].to=b;e[cur].next=head[a];head[a]=cur++;}
inline void FFT(int a,int b) {
for (pot it = st[a].begin(); it != st[a].end(); ++it) {
if (it == st[a].begin())continue;
if (!it->id) continue;
int ll = it->l,rr = it->r;//dang qian yao cha ru de qu jian
pot it2 = st[b].lower_bound(all {1,ll,rr}); //fa xian yi ge di yi ge da yu de weizhi
int nowl = it2->l,nowr = it2->r; //now
pot up = ++it2;it2--;pot down = --it2;it2++;cost[b]-=dc(nowr-nowl+1);
if (rr != nowr && ll != nowl) {
size[b] += 2;st[b].erase(it2);st[b].insert(all {1,ll,rr});cost[b]+=dc(rr-ll+1);
st[b].insert(all {0,nowl,ll-1});cost[b]+=dc(ll-nowl);st[b].insert(all {0,rr+1,nowr});cost[b]+=dc(nowr-rr);
} else if (ll == nowl && rr == nowr) {
if ( down != st[b].begin()) {ll = down->l;cost[b]-=dc(down->r-down->l+1);st[b].erase(down);}
if (up != st[b].end()) {rr = up->r;cost[b]-=dc(up->r-up->l+1);st[b].erase(up);}
st[b].erase(it2);st[b].insert(all {1,ll,rr});cost[b]+=dc(rr-ll+1);
} else {
if (down != st[b].begin() && nowl == ll) {
ll = (down)->l;cost[b]-=dc(down->r-down->l+1);st[b].erase(down);st[b].erase(it2);
cost[b]+=dc(rr-ll+1);st[b].insert(all {1,ll,rr});cost[b]+=dc(nowr-rr);st[b].insert(all {0,rr+1,nowr});
} else if (up != st[b].end() && nowr == rr) {
rr = (up)->r;cost[b]-=dc(up->r-up->l+1);st[b].erase(up);st[b].erase(it2);cost[b]+=dc(rr-ll+1);
st[b].insert(all {1,ll,rr});cost[b]+=dc(ll-nowl);st[b].insert(all {0,nowl,ll-1});
} else {
st[b].erase(it2);
if (nowl == ll) {st[b].insert(all {1,ll,rr});st[b].insert(all {0,rr+1,nowr});cost[b]+=dc(rr-ll+1)+dc(nowr-rr);}
else {st[b].insert(all {1,ll,rr});st[b].insert(all {0,nowl,ll-1});cost[b]+=dc(rr-ll+1)+dc(ll-nowl);}
}
}
}
}
inline void dfs(int s,int fa) {
int mx = -1,id = s;
erep(i,s) {if (e[i].to == fa) continue;dfs(e[i].to,s);if (size[real[e[i].to]]>mx) {mx = size[real[e[i].to]];id = real[e[i].to];}}
real[s]=id;
erep(i,s) {if (e[i].to==fa) continue;if (real[e[i].to] == id) continue;FFT(real[e[i].to],id);st[real[e[i].to]].clear();}
st[n+1].clear(),st[n+2].clear();st[n+1].insert(all {0,1,n}),st[n+1].insert(all {-1,-1,-1});st[n+2].insert(all {1,s,s}),st[n+2].insert(all {-1,-1,-1});
FFT(n+2,n+1);FFT(n+1,id);if (s!=1) ans += cost[id];
}
#undef int
int main() {
#define int long long
freopen("bye.in","r",stdin);freopen("bye.out","w",stdout);
memset(head,-1,sizeof(head));cin>>n;
rep(i,1,n+2) st[i].insert(all {-1,-1,-1}),st[i].insert(all {0,1,n}),cost[i] = dc(n),size[i] = 1;
rep(i,1,n-1) {int a,b;scanf("%lld%lld",&a,&b);add(a,b);add(b,a);}
dfs(1,1);cout<<ans;
}