杭电多校第六场-J-Ridiculous Netizens

Problem Description

Mr. Bread has a tree T with n vertices, labeled by 1,2,,n. Each vertex of the tree has a positive integer value wi.

The value of a non-empty tree T is equal to w1×w2××wn. A subtree of T is a connected subgraph of T that is also a tree.

Please write a program to calculate the number of non-empty subtrees of T whose values are not larger than a given number m.

Input

The first line of the input contains an integer T(1T10), denoting the number of test cases.

In each test case, there are two integers n,m(1n2000,1m106) in the first line, denoting the number of vertices and the upper bound.

In the second line, there are n integers w1,w2,,wn(1wim), denoting the value of each vertex.

Each of the following n1 lines contains two integers ui,vi(1ui,vin,uivi), denoting an bidirectional edge between vertices ui and vi.

Output

For each test case, print a single line containing an integer, denoting the number of valid non-empty subtrees. As the answer can be very large, output it modulo 10^9+7.

Sample Input

1
5 6
1 2 1 2 3
1 2
1 3
2 4
2 5

Sample Output

14

 

题意:
一棵无根树,每个点有权值,询问有多少个联通子图的权值的积等于m

思路
https://www.cnblogs.com/hua-dong/p/11320013.html
考虑对某点,联通块要么经过它要么不经过它 ——> 点分治
对于经过该点的用dp求解
在dfs序上dp,类似于树形依赖背包
dp[i][j]表示 dfs序i之后的乘积为j的方案数
可知 dp[i][j]=(dp[i+1][j/a[dfn[i]]]+dp[i+son[i]][j]) //当前点选/不选
但第二维为m不可行
考虑把<sqrt(M)的和大于sqrt(M)的分开保存,那么前者就是正常的背包,表示当前乘积;后者可以看成以后还可以取多少
#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
#define ll long long
using namespace std;
const int N=1e5+10;
const int p=1e9+7;
int T,n,m,cnt,sum,root,tim,ans;
struct orz{
    int v,nex;}e[N*2];
int a[N],last[N],son[N],f[N],dfn[N],dp1[N][1005],dp2[N][1005];
bool vis[N];
void add(int x,int y)
{
    cnt++;
    e[cnt].v=y;
    e[cnt].nex=last[x];
    last[x]=cnt;
}
void getroot(int x,int fa)
{
    son[x]=1; f[x]=0;
    for (int i=last[x];i;i=e[i].nex)
    {
        if (e[i].v==fa || vis[e[i].v]) continue;
        getroot(e[i].v,x);
        son[x]+=son[e[i].v];
        f[x]=max(f[x],son[e[i].v]);
    }
    f[x]=max(f[x],sum-son[x]);
    if (f[x]<f[root]) root=x;
}
void dfs(int x,int fa)
{
    dfn[++tim]=x; son[x]=1;
    for (int i=last[x];i;i=e[i].nex)
    {
        if (e[i].v==fa || vis[e[i].v]) continue;
        dfs(e[i].v,x);
        son[x]+=son[e[i].v];
    }
}
void cal()
{
    int mm=sqrt(m);
    for (int i=1;i<=tim+1;i++)
    {
        memset(dp1[i],0,sizeof(dp1[i]));
        memset(dp2[i],0,sizeof(dp2[i]));
    }
    dp1[tim+1][1]=1;
    for (int i=tim;i>=1;i--)
    {
        int x=a[dfn[i]];
        for (int j=1;j<=min(mm,m/x);j++)
        {
            int k=j*x;
            if (k<=mm) dp1[i][k]=(dp1[i][k]+dp1[i+1][j])%p;
            else dp2[i][m/k]=(dp2[i][m/k]+dp1[i+1][j])%p;
        }
        for (int j=x;j<=mm;j++)
        {
            dp2[i][j/x]=(dp2[i][j/x]+dp2[i+1][j])%p;
        }
        for (int j=1;j<=mm;j++)
        {
            dp1[i][j]=(dp1[i][j]+dp1[i+son[dfn[i]]][j])%p;
            dp2[i][j]=(dp2[i][j]+dp2[i+son[dfn[i]]][j])%p;
        }
    }
    for (int i=1;i<=mm;i++)
    {
        ans=(ans+dp1[1][i])%p;
        ans=(ans+dp2[1][i])%p;
    }
    ans=(ans-1+p)%p;
}
void work(int x)
{
    //cout<<x<<endl;
    vis[x]=1; tim=0;
    dfs(root,0); //for (int i=1;i<=tim;i++) cout<<dfn[i]<<' '; cout<<endl;
    cal();
    for (int i=last[x];i;i=e[i].nex)
    {
        if (vis[e[i].v]) continue;
        sum=son[e[i].v];
        root=0;
        getroot(e[i].v,root);
        work(root);
    }
}
void init()
{
    cnt=0; ans=0;
    for (int i=1;i<=n;i++) last[i]=0,vis[i]=0;
}
int main()
{
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d%d",&n,&m);
        init();
        for (int i=1;i<=n;i++) scanf("%d",&a[i]);
        int x,y;
        for (int i=1;i<n;i++)
        {
            scanf("%d%d",&x,&y);
            add(x,y); add(y,x);
        }
        sum=n; f[0]=inf;
        getroot(1,0);
        work(root);
        printf("%d\n",ans);
    }
    return 0;
}
View Code

 


 

posted @ 2019-08-18 09:47  特特w  阅读(196)  评论(0编辑  收藏  举报