[hdu4812]D Tree(点分治)
题意:问有多少条路径,符合路径上所有节点的权值乘积模1000003等于k。
解题关键:预处理阶乘逆元,然后通过hash和树形dp$O(1)$的判定乘积存在问题,注意此道题是如何处理路径保证不重复的,具有启发意义。
代码:2340ms,这段代码最重要的可取点就是如何通过操作省去memset的过程
复杂度:$O(n\log n)$
1 #pragma comment(linker,"/STACK:102400000,102400000") 2 #include<cstdio> 3 #include<cstring> 4 #include<algorithm> 5 #include<cstdlib> 6 #include<iostream> 7 #include<cmath> 8 #define maxn 100040 9 #define maxm 1000500 10 using namespace std; 11 typedef long long ll; 12 const ll mod=1000003; 13 const ll inf=1ll<<60; 14 ll n,k,ans,size,s[maxn],f[maxn],path[maxn],cr,val[maxn],inv[maxm],ansl,ansr; 15 ll head[maxn],cnt,root,pathid[maxn],flag[maxm],mp[maxm],ca; 16 bool vis[maxn]; 17 struct edge{ 18 ll to,nxt; 19 }e[maxn<<1]; 20 21 void add_edge(ll u,ll v){ 22 e[cnt].to=v; 23 e[cnt].nxt=head[u]; 24 head[u]=cnt++; 25 } 26 27 inline ll read(){ 28 char k=0;char ls;ls=getchar();for(;ls<'0'||ls>'9';k=ls,ls=getchar()); 29 ll x=0;for(;ls>='0'&&ls<='9';ls=getchar())x=(x<<3)+(x<<1)+ls-'0'; 30 if(k=='-')x=0-x;return x; 31 } 32 33 void get_root(ll u,ll fa){//get_root会用到size 34 s[u]=1;f[u]=0;//f是dp数组 35 for(ll i=head[u];i!=-1;i=e[i].nxt){ 36 ll v=e[i].to; 37 if(v==fa||vis[v]) continue; 38 get_root(v,u); 39 s[u]+=s[v]; 40 f[u]=max(f[u],s[v]); 41 } 42 f[u]=max(f[u],size-s[u]); 43 root=f[root]>f[u]?u:root; 44 } 45 46 void get_path_size(ll u,ll fa,ll dis){//同时获取size和depth,size是深度,depth是dis的意思 47 path[cr]=dis%mod*val[u]%mod; 48 pathid[cr]=u; 49 cr++; 50 s[u]=1; 51 ll tm=path[cr-1]%mod; 52 for(ll i=head[u];i!=-1;i=e[i].nxt){ 53 ll v=e[i].to; 54 if(v==fa||vis[v]) continue; 55 get_path_size(v,u,tm); 56 s[u]+=s[v]; 57 } 58 } 59 60 void getans(ll a,ll b){ 61 if(a>b) swap(a,b); 62 if(ansl>a) ansl=a,ansr=b; 63 else if(ansl==a&&ansr>b) ansr=b; 64 } 65 66 void work(ll u,ll fa){ 67 vis[u]=true; 68 for(ll i=head[u];i!=-1;i=e[i].nxt){ 69 ll v=e[i].to; 70 if(v==fa||vis[v]) continue; 71 cr=0; 72 get_path_size(v,u,1); 73 for(ll j=0;j<cr;j++){ 74 if(path[j]*val[u]%mod==k) getans(pathid[j],u); 75 ll tm=k*inv[path[j]*val[u]%mod]%mod; 76 if(flag[tm]!=ca) continue; 77 getans(mp[tm],pathid[j]); 78 } 79 for(int j=0;j<cr;j++){ 80 ll tm=path[j]; 81 if(flag[tm]!=ca||mp[tm]>pathid[j]) flag[tm]=ca,mp[tm]=pathid[j]; 82 } 83 } 84 ca++; 85 for(ll i=head[u];i!=-1;i=e[i].nxt){ 86 ll v=e[i].to; 87 if(vis[v]||v==fa) continue; 88 size=s[v],root=0; 89 get_root(v,u); 90 work(root,u); 91 } 92 // vis[u]=false; 93 } 94 95 void init(){ 96 memset(vis,0,sizeof vis); 97 memset(head,-1,sizeof head); 98 //memset(flag,0,sizeof flag); 99 ans=cnt=0; 100 //ca=1; 101 } 102 103 ll mod_pow(ll x,ll n,ll p){ 104 ll res=1; 105 while(n){ 106 if(n&1) res=res*x%p; 107 x=x*x%p; 108 n>>=1; 109 } 110 return res; 111 } 112 113 int main(){ 114 ll a,b; 115 f[0]=inf;inv[0]=1; 116 for(ll i=1;i<1000003;i++) inv[i]=mod_pow(i,mod-2,mod); 117 ca=0; 118 while(scanf("%lld%lld",&n,&k)!=EOF){ 119 ca++; 120 init(); 121 for(int i=1;i<=n;i++) val[i]=read()%mod; 122 for(int i=0;i<n-1;i++){ 123 a=read(),b=read(); 124 add_edge(a,b); 125 add_edge(b,a); 126 } 127 size=n,root=0; 128 get_root(1,-1); 129 ansl=ansr=inf; 130 work(root,-1); 131 if(ansl==inf) printf("No solution\n"); 132 else printf("%lld %lld\n",ansl,ansr); 133 } 134 return 0; 135 } 136