洛谷p1637 三元上升子序列(树状数组
题目描述
Erwin最近对一种叫"thair"的东西巨感兴趣。。。
在含有n个整数的序列a1,a2......an中,
三个数被称作"thair"当且仅当i<j<k且ai<aj<ak
求一个序列中"thair"的个数。
输入输出格式
输入格式:
开始一个正整数n,
以后n个数a1~an。
输出格式:
"thair"的个数
输入输出样例
输入样例#1:
4 50 18 3 4 6 8 14 15 16 17 21 25 26 Input 4 2 1 3 4 Output 2 Input 5 1 2 2 3 4 Output 7 对样例2的说明: 7个"thair"分别是 1 2 3 1 2 4 1 2 3 1 2 4 1 3 4 2 3 4 2 3 4
输出样例#1:
说明
约定 30%的数据n<=100
60%的数据n<=2000
100%的数据n<=30000
大数据随机生成
0<=a[i]<=maxlongint
那么如果我们考虑在输入时考虑当前的c,那么我们只需找两个小于c并且不同的数
如果位置小于c且值小于c的数没有重复,那么我们可以得到是,以c结尾的三元组数量是
n*(n-1)/2,
有重复元素怎么办呢,因为这样计数,1,2,2,3,4,计算以4结尾的三元组时,会算到2,2,4
那么怎么解决这个问题..
解决1:
换种计数方法,考虑中间元素b,我们只需考虑b之前有多少个严格小于它的元素数量u,之后有多少严格大于它的元素v
于是中间元素b的三元组对答案的贡献就是u*v
于是我们可以算两遍,第一遍算u第二遍算v
附上代码...
1 #include <iostream> 2 #include <cstdio> 3 #include <algorithm> 4 #include <cstring> 5 using namespace std; 6 const int maxn=1e5+7; 7 int N,w; 8 typedef long long ll; 9 ll t[maxn],u[maxn],v[maxn]; 10 struct node{ 11 int id,v;node(){};node(int id,int v):id(id),v(v){}; 12 }; 13 node a[maxn]; 14 int lowbit(int x){ 15 return x&-x; 16 } 17 void add(int n,int x){ 18 while(n<=N){ 19 t[n]+=x; 20 n+=lowbit(n); 21 } 22 } 23 int sum(int n){ 24 int ans=0; 25 while(n){ 26 ans+=t[n]; 27 n-=lowbit(n); 28 } 29 return ans; 30 } 31 bool cmp1(node a,node b){ 32 return a.v<b.v; 33 } 34 bool cmp2(node a,node b){ 35 return a.id<b.id; 36 } 37 int main(){ 38 int n,x;scanf("%d",&n); 39 for(int i=1;i<=n;++i){ 40 scanf("%d",&x); 41 a[i]=node(i,x); 42 } 43 sort(a+1,a+1+n,cmp1); 44 int cnt=1,st=1,pre=a[1].v; 45 for(int i=2;i<=n;++i){ 46 while(i<=n&&a[i].v==pre) i++; 47 for(int j=st;j<i;++j){ 48 a[j].v=cnt; 49 } 50 st=i;pre=a[i].v; 51 cnt++; 52 } 53 for(int j=st;j<=n;++j) a[j].v=cnt; 54 //for(int i=1;i<=n;++i) printf("%d,",a[i].v);printf("\n"); 55 N=cnt; 56 sort(a+1,a+1+n,cmp2); 57 ll ans=0; 58 for(int i=1;i<=n;++i){ 59 u[i]=sum(a[i].v-1); 60 add(a[i].v,1); 61 } 62 memset(t,0,sizeof(t)); 63 for(int i=n;i>=1;--i){ 64 v[i]=sum(N)-sum(a[i].v); 65 ans+=u[i]*v[i]; 66 add(a[i].v,1); 67 } 68 printf("%lld\n",ans); 69 return 0; 70 }
其实也可以这么写,
因为sum(N)=n-i的,因为是倒着插入的,所以当你插入n时,正好已经插入了n-n个元素,
插入n-1时,正好已经插入了一个元素,所以n-i-sum(a[i].v)的意思是,当前插入的所有元素减去小于等于v的元素个数,
那么剩下的一定都大于v,sum(N)=大于v的元素个数+小于等于v的元素个数
1 #include <iostream> 2 #include <cstdio> 3 #include <algorithm> 4 #include <cstring> 5 using namespace std; 6 const int maxn=1e5+7; 7 int N,w; 8 typedef long long ll; 9 ll t[maxn],u[maxn],v[maxn]; 10 struct node{ 11 int id,v;node(){};node(int id,int v):id(id),v(v){}; 12 }; 13 node a[maxn]; 14 int lowbit(int x){ 15 return x&-x; 16 } 17 void add(int n,int x){ 18 while(n<=N){ 19 t[n]+=x; 20 n+=lowbit(n); 21 } 22 } 23 int sum(int n){ 24 int ans=0; 25 while(n){ 26 ans+=t[n]; 27 n-=lowbit(n); 28 } 29 return ans; 30 } 31 bool cmp1(node a,node b){ 32 return a.v<b.v; 33 } 34 bool cmp2(node a,node b){ 35 return a.id<b.id; 36 } 37 int main(){ 38 int n,x;scanf("%d",&n); 39 for(int i=1;i<=n;++i){ 40 scanf("%d",&x); 41 a[i]=node(i,x); 42 } 43 sort(a+1,a+1+n,cmp1); 44 int cnt=1,st=1,pre=a[1].v; 45 for(int i=2;i<=n;++i){ 46 while(i<=n&&a[i].v==pre) i++; 47 for(int j=st;j<i;++j){ 48 a[j].v=cnt; 49 } 50 st=i;pre=a[i].v; 51 cnt++; 52 } 53 for(int j=st;j<=n;++j) a[j].v=cnt; 54 //for(int i=1;i<=n;++i) printf("%d,",a[i].v);printf("\n"); 55 N=cnt; 56 sort(a+1,a+1+n,cmp2); 57 ll ans=0; 58 for(int i=1;i<=n;++i){ 59 u[i]=sum(a[i].v-1); 60 add(a[i].v,1); 61 } 62 memset(t,0,sizeof(t)); 63 for(int i=n;i>=1;--i){ 64 v[i]=n-i-sum(a[i].v); 65 ans+=u[i]*v[i]; 66 add(a[i].v,1); 67 } 68 printf("%lld\n",ans); 69 return 0; 70 }