CodeForces - 557E(trie
题目:给出一个只有ab组成的串,求其中半回文串的第k大。
思路:暴力枚举出所有半回文串,然后暴力插到trie里查询第k大,原先写trie是抄模板的,这次自己写了一下。
#include <iostream> #include <map> #include <algorithm> #include <cstdio> #include <cstring> #include <cstdlib> #include <vector> #include <queue> #include <stack> #include <functional> #include <set> #include <cmath> #define pb push_back #define fs first #define se second #define sq(x) (x)*(x) #define eps 0.0000000001 #define IINF (1<<30) using namespace std; typedef long long ll; typedef pair<ll,ll> P; const int charsize=2; const int maxv=5005; const int maxpool=25000000; bool ispa[maxv][maxv]; char str[maxv]; int len,k; struct Node{ int num; int sum; Node *ch[charsize]; Node():num(0),sum(0){}; }pool[maxpool]; Node *null=new Node(); int poolh=0; Node *newNode(){ Node *n=&pool[poolh++]; for(int i=0;i<charsize;i++) n->ch[i]=null; n->sum=n->num=0; return n; } Node *root=newNode(); void insert(Node* n,int p,int s){ while(p<len){ if(n->ch[str[p]-'a']==null) n->ch[str[p]-'a']=newNode(); if(ispa[s][p]) n->ch[str[p]-'a']->num++; n=n->ch[str[p]-'a'],p++; } } void cul(Node *n){ if(n==null) return; n->sum=n->num; for(int i=0;i<charsize;i++){ cul(n->ch[i]); n->sum+=n->ch[i]->sum; } } void output(Node *n,int k){ if(n==null||k<=0) return; int sum=0; for(int i=0;i<charsize;i++){ if(sum+n->ch[i]->sum>=k){ printf("%c",(char)('a'+i)); output(n->ch[i],k-sum-n->ch[i]->num); return; } sum+=n->ch[i]->sum; } return; } void getpa(){ for(int i=0;i<len;i++){ for(int j=0;j<=len;j+=2){ if(i-j<0||i+j>=len){ break; } if(str[i+j]==str[i-j]){ ispa[i-j][i+j]=1; }else{ break; } } for(int j=1;j<=len;j+=2){ if(i-j<0||i+j>=len){ break; } if(str[i+j]==str[i-j]){ ispa[i-j][i+j]=1; }else{ break; } } for(int j=0;j<=len;j+=2){ if(i-j<0||i+j+1>=len){break;} if(str[i-j]==str[i+j+1]){ ispa[i-j][i+j+1]=1; }else{ break; } } for(int j=1;j<=len;j+=2){ if(i-j<0||i+j+1>=len){break;} if(str[i-j]==str[i+j+1]){ ispa[i-j][i+j+1]=1; }else{ break; } } } } int main(){ freopen("/home/files/CppFiles/in","r",stdin); /* std::ios::sync_with_stdio(false); std::cin.tie(0);*/ scanf("%s",str); len=strlen(str); cin>>k; getpa(); for(int i=0;i<len;i++) insert(root,i,i); cul(root); output(root,k); return 0; }