2017 Multi-University Training Contest - Team 3 Kanade's trio(字典树+组合数学)
题解:
官方题解太简略了orz
具体实现的方式其实有很多
问题就在于确定A[j]以后,如何找符合条件的A[i]
这里其实就是要提前预处理好
我是倒序插入点的,所以要沿着A[k]爬树,找符合的A[i]
如果发现A[i]与A[k]的第p位不同,比如A[k]位1,A[i]为0,那么所有的在i右边的第p位为0的数就都可以充当A[j]
所以实际上就需要求出有多少点对(i, j),满足这个条件。
不妨用可持久化的思想考虑这个过程
倒序插入A[i]时,我们就能统计出来A[i]的第p位为0(或者为1)时,所有在i右边的第p位为0(或者为1)的数有多少个
但是,问题在于我们需要删除结点
这个过程就要倒着想
如果删除A[i]
1、对于删除的那条字典树的链,链上每个点减少的贡献为 “那个结点的子树大小”
2、对于非链上的点,如果这个点和A[i]相应的第p位相同,那么它减少的贡献也是“这个结点的子树大小”
但注意,1情况对应的子树大小实际上是要减1的,因为被删除了一个结点。
我们用一个数组记录第p位为0或1时删除了几次,就可以处理第二种情况
但是第一种情况是比较特殊的,所以我们对每个结点都记录一下它上次被删除是哪一次
这样就可以做了
#include <iostream> #include <cstdio> #include <cstring> #include <vector> #include <queue> #include <cstdlib> using namespace std; const int maxn = 5e5 + 200; typedef long long LL; struct Node{ Node* ch[2]; LL num, ans, Mv; }pool[maxn*31], *null; int tot, a[maxn], tt = 0; LL Minus[32][2], Plus[32][2]; inline Node* newnode(){ Node* x = &pool[tot++]; x->ch[0] = x->ch[1] = null; x->num = x->ans = x->Mv = 0; return x; } void pre(){ null = newnode(); null->ch[0] = null->ch[1] = null; null->num = 0; } inline void Insert(Node* root, int x){ Node* u = root; for(int i = 30; i >= 0; i--){ int c = (x&(1<<i)) ? 1 : 0; if(u->ch[c] == null){ u->ch[c] = newnode(); } u->num++; u->ch[c]->ans += Plus[i][c]; Plus[i][c]++; u = u->ch[c]; } u->num++; } inline void Erase(Node* root, int x){ Node* u = root; for(int i = 30; i >= 0; i--){ int c = (x&(1<<i)) ? 1 : 0; u->num--; Minus[i][c]++; u->ch[c]->ans -= (Minus[i][c] - u->ch[c]->Mv - 1)*u->ch[c]->num; u->ch[c]->ans -= u->ch[c]->num - 1; u->ch[c]->Mv = Minus[i][c]; u = u->ch[c]; } u->num--; } inline LL Find(Node* root, int x){ LL ans = 0; Node* u = root; for(int i = 30; i >= 0; i--){ int c = (x&(1<<i)) ? 1 : 0; LL v = u->ch[c^1]->Mv; LL rnum = u->ch[c^1]->ans - (Minus[i][c^1]-v)*u->ch[c^1]->num; ans += rnum; u = u->ch[c]; } return ans; } int main() { int T; cin>>T; for(; T; T--){ int n; cin>>n; LL ans = 0; tot = 0; pre(); Node* root = newnode(); memset(Minus, 0, sizeof(Minus)); memset(Plus, 0, sizeof(Plus)); for(int i = 1; i <= n; i++) scanf("%d", &a[i]); for(int i = n-1; i >= 1; i--) Insert(root, a[i]); for(int i = n; i >= 3; i--){ ans += Find(root, a[i]); Erase(root, a[i-1]); } cout<<ans<<endl; } }