CF715C&&CF716E Digit Tree 点分治

题目大意

给出一个树,每条边上写了一个数字,给出一个G,求有多少条路径按顺序读出的数字可以被G整除。保证G与10互质。

题解

双倍经验~

首先一条路径顺着读和逆着读是视为两条不同的路径的,即使值一样。

同时要注意一条路径顺着读和逆着读不一定都满足要求,比如14能整出7而41不能。

于是我们可以把一条路径断开,设前半段的值为x,后半段的值为y。于是得到:

\[x*10^{len[y]}+y\equiv 0(\bmod\ G) \]

\[x\equiv \frac{-y}{10^{len[y]}} (\bmod\ G) \]

然后我们用map记录x的值(即从任意点到root的值),用dis数组记录y的值(即root到任意点的值)。
最后查找与\(\cfrac{-y}{10^{len(y)}}\)值相同的有多少更新答案即可。

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<map>
#define int long long
using namespace std;
const int maxn=1e5+10;
const int maxm=1e7+10;
const int INF=0x3f3f3f3f; 
int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
int n,m;
int head[maxn],cnt;
struct edge{int to,nxt,w;}e[maxn*2];
void add(int x,int y,int w){
    e[++cnt]=(edge){y,head[x],w%m};
    head[x]=cnt;
}
int vis[maxn],ll[maxn];
int size[maxn],son[maxn];
int S;
int mx,root;
map<int,int>mp;
int c[maxn],pow10[maxn];
void getroot(int x,int f){
    size[x]=1;son[x]=0;
    for(int i=head[x];i;i=e[i].nxt){
        int u=e[i].to;
        if(u==f||vis[u])continue;
        getroot(u,x);
        size[x]+=size[u];
        son[x]=max(son[x],size[u]);
    }
    son[x]=max(son[x],S-size[x]);
    if(son[x]<mx){mx=son[x];root=x;}  
}
int tot;
int dis[maxn];
void exgcd(int a,int b,int &d,int &x,int &y){
    if(!b){d=a;x=1;y=0;}
    else{exgcd(b,a%b,d,y,x);y-=x*(a/b);}
}
int inv(int a,int p){
    int x,y,d;
    exgcd(a,p,d,x,y);
    return d==1?(x%p+p)%p:-1;
}
void getdis(int x,int f,int v1,int v2,int dep){
    if(dep){
        mp[v2%m]++;dis[++tot]=v1;ll[tot]=dep;
    }
    for(int i=head[x];i;i=e[i].nxt){
        int u=e[i].to;
        if(u==f||vis[u])continue;
        getdis(u,x,(v1*10%m+e[i].w)%m,(v2+e[i].w*pow10[dep]%m)%m,dep+1);
    }
}
int ans;
int calc(int x,int v,int opt){
    int res=0;tot=0;
    mp.clear();
    getdis(x,0,v,v,opt);
    for(int i=1;i<=tot;i++){
    	int val=(m-dis[i]*inv(pow10[ll[i]],m)%m)%m;
    	if(mp.find(val)!=mp.end())res+=mp[val];
    	if(opt==0)res+=(dis[i]==0);
    }
    if(opt==0)res+=mp[0];
    return res;
}
void dfs(int x){
    ans+=calc(x,0,0);
    vis[x]=1;
    for(int i=head[x];i;i=e[i].nxt){
        int u=e[i].to;
        if(vis[u])continue;
        ans-=calc(u,e[i].w,1);
        root=0;S=size[u];mx=INF;getroot(u,0);
        dfs(root);
    }
}
#undef int 
int main()
#define int long long
{
    n=read();m=read(); 
    for(int i=1;i<n;i++){
        int x=read()+1,y=read()+1,w=read();
        add(x,y,w);add(y,x,w);
    }
    pow10[0]=1LL;
    for(int i=1;i<=n;i++)pow10[i]=1LL*pow10[i-1]*10%m;
    S=n;mx=INF;getroot(1,0);
    dfs(root);
    cout<<ans;
    return 0;
}
posted @ 2018-10-02 11:42  南城ㄱ  阅读(212)  评论(0编辑  收藏  举报