CF696B Puzzles 题解

题目链接:CF 或者 洛谷

注意一些计算即可

  1. \(期望=概率\times 出现次数\)

  2. \(dfs\) 一个点,会 \(dfs\) 完以这个点为根的子树,再 \(dfs\) 下一个兄弟节点为根的子树。

  3. 对于两个兄弟节点而言,它们的相对 \(dfs\) 序仅有两种,谁前谁后。

对于本题而言,优先关注时间戳的计算:\(时间戳=dfs中出现在自己之前的点数\)。现在直接去算过于困难,我们使用贡献法去思考,哪些信息会影响到它的时间戳。

如图所示,红色为当前计算点,显然的是,\(fa\) 的期望值假如已经知道了,那么待算点的期望值是由 \(fa\) 的期望值 \(+其余兄弟节点的贡献+1\) 计算得到。期望在本题当中也可以理解为平均出现次数。

考虑树形 \(dp[i]\),表示 \(i\) 节点的期望,那么与 \(dp[i]=dp[fa]+兄弟子树贡献+1\),最后加 \(1\) 是因为当前子树的平均期望时间戳,等于在它之前 \(dfs\) 中出现的点期望个数 \(+1\),即为当前时间戳。考虑兄弟节点贡献,基于第二点,如果我们选择 \(dfs\) 了一个兄弟节点,它将会产生 \(size[son]\) 的贡献,因为会 \(dfs\) 完这棵子树才能选择 \(dfs\) 下一个点,这就是它的出现次数。基于第三点,我们知道,只有 \(dfs\) 序在当前节点之前的才会有贡献,在它之后 \(dfs\) 的并无贡献,在当前节点之前 \(dfs\) 的概率显然为 \(\frac{1}{2}\),所以每个兄弟节点的贡献为 \(\frac{1}{2} \times size[son]\),它们的总贡献即为:\(\frac{1}{2}\sum size[son]\),而 \(\sum size[son]=size[fa]-size[curr]-1\),即以父节点为根的子树大小去掉当前需要计算的子树大小,再去掉父节点这个根的单点,即为所求。

即,紫色减去红色即为绿色部分。所以我们预处理出子树大小,直接做 \(dp\) 即可:

\[dp[curr]=dp[fa]+\frac{size[fa]-size[curr]-1}{2}+1 \]

注意到本题特殊限制:

\(fa[i]<i\),这也意味着,回忆我们 \(dfs\) 处理子树大小,显然是从叶子节点往上统计,所以我们也可以直接倒序遍历,不需要 \(dfs\) 也是满足无后效性,以叶子节点往上统计。同理,\(dp\) 更新,正序即可。

C++参照代码
#include <bits/stdc++.h>

// #pragma GCC optimize(2)
// #pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native")

// #define isPbdsFile

#ifdef isPbdsFile

#include <bits/extc++.h>

#else

#include <ext/pb_ds/priority_queue.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/trie_policy.hpp>
#include <ext/pb_ds/tag_and_trait.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/list_update_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/exception.hpp>
#include <ext/rope>

#endif

using namespace std;
using namespace __gnu_cxx;
using namespace __gnu_pbds;
typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef tuple<int, int, int> tii;
typedef tuple<ll, ll, ll> tll;
typedef unsigned int ui;
typedef unsigned long long ull;
typedef __int128 i128;
#define hash1 unordered_map
#define hash2 gp_hash_table
#define hash3 cc_hash_table
#define stdHeap std::priority_queue
#define pbdsHeap __gnu_pbds::priority_queue
#define sortArr(a, n) sort(a+1,a+n+1)
#define all(v) v.begin(),v.end()
#define yes cout<<"YES"
#define no cout<<"NO"
#define Spider ios_base::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);
#define MyFile freopen("..\\input.txt", "r", stdin),freopen("..\\output.txt", "w", stdout);
#define forn(i, a, b) for(int i = a; i <= b; i++)
#define forv(i, a, b) for(int i=a;i>=b;i--)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
#define endl '\n'
//用于Miller-Rabin
[[maybe_unused]] static int Prime_Number[13] = {0, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};

template <typename T>
int disc(T* a, int n)
{
    return unique(a + 1, a + n + 1) - (a + 1);
}

template <typename T>
T lowBit(T x)
{
    return x & -x;
}

template <typename T>
T Rand(T l, T r)
{
    static mt19937 Rand(time(nullptr));
    uniform_int_distribution<T> dis(l, r);
    return dis(Rand);
}

template <typename T1, typename T2>
T1 modt(T1 a, T2 b)
{
    return (a % b + b) % b;
}

template <typename T1, typename T2, typename T3>
T1 qPow(T1 a, T2 b, T3 c)
{
    a %= c;
    T1 ans = 1;
    for (; b; b >>= 1, (a *= a) %= c)if (b & 1)(ans *= a) %= c;
    return modt(ans, c);
}

template <typename T>
void read(T& x)
{
    x = 0;
    T sign = 1;
    char ch = getchar();
    while (!isdigit(ch))
    {
        if (ch == '-')sign = -1;
        ch = getchar();
    }
    while (isdigit(ch))
    {
        x = (x << 3) + (x << 1) + (ch ^ 48);
        ch = getchar();
    }
    x *= sign;
}

template <typename T, typename... U>
void read(T& x, U&... y)
{
    read(x);
    read(y...);
}

template <typename T>
void write(T x)
{
    if (typeid(x) == typeid(char))return;
    if (x < 0)x = -x, putchar('-');
    if (x > 9)write(x / 10);
    putchar(x % 10 ^ 48);
}

template <typename C, typename T, typename... U>
void write(C c, T x, U... y)
{
    write(x), putchar(c);
    write(c, y...);
}


template <typename T11, typename T22, typename T33>
struct T3
{
    T11 one;
    T22 tow;
    T33 three;

    bool operator<(const T3 other) const
    {
        if (one == other.one)
        {
            if (tow == other.tow)return three < other.three;
            return tow < other.tow;
        }
        return one < other.one;
    }

    T3() { one = tow = three = 0; }

    T3(T11 one, T22 tow, T33 three) : one(one), tow(tow), three(three)
    {
    }
};

template <typename T1, typename T2>
void uMax(T1& x, T2 y)
{
    if (x < y)x = y;
}

template <typename T1, typename T2>
void uMin(T1& x, T2 y)
{
    if (x > y)x = y;
}

constexpr int N = 1e5 + 10;
int fa[N], siz[N];
float dp[N];
int n;

inline void solve()
{
    cout << fixed << setprecision(1);
    cin >> n, siz[0] = siz[1] = 1;
    forn(i, 2, n)cin >> fa[i], siz[i] = 1;
    forv(i, n, 1)siz[fa[i]] += siz[i];
    forn(i, 1, n)cout << (dp[i] = dp[fa[i]] + (siz[fa[i]] - siz[i] - 1) * 0.5 + 1) << ' ';
}

signed int main()
{
    // MyFile
    Spider
    //------------------------------------------------------
    // clock_t start = clock();
    int test = 1;
    //    read(test);
    // cin >> test;
    forn(i, 1, test)solve();
    //    while (cin >> n, n)solve();
    //    while (cin >> test)solve();
    // clock_t end = clock();
    // cerr << "time = " << double(end - start) / CLOCKS_PER_SEC << "s" << endl;
}

Go 参照代码
package main

import (
	"bufio"
	"fmt"
	"io"
	"os"
	"runtime/debug"
)

func run(_r io.Reader, _w io.Writer) {
	in := bufio.NewScanner(_r)
	in.Split(bufio.ScanWords)
	out := bufio.NewWriter(_w)
	defer func(out *bufio.Writer) {
		_ = out.Flush()
	}(out)
	read := func() (x int) {
		in.Scan()
		tmp := in.Bytes()
		if tmp[0] == '-' {
			for _, b := range tmp[1:] {
				x = x*10 + int(b&15)
			}
			return -x
		} else {
			for _, b := range in.Bytes() {
				x = x*10 + int(b&15)
			}
		}
		return
	}
	//-------------------------------------------------------------
	var n = read()
	var fa = make([]int, n+1)
	var siz = make([]int, n+1)
	siz[1] = 1
	for i := 2; i <= n; i++ {
		fa[i] = read()
		siz[i] = 1
	}
	var dp = make([]float64, n+1)
	for i := n; i >= 1; i-- {
		siz[fa[i]] += siz[i]
	}
	dp[1] = 1
	for i := 2; i <= n; i++ {
		dp[i] = dp[fa[i]] + float64(siz[fa[i]]-siz[i]-1)/2 + 1
	}
	for i := 1; i <= n; i++ {
		_, _ = fmt.Fprintf(out, "%.1f ", dp[i])
	}
}

func main() {
	debug.SetGCPercent(-1)
	run(os.Stdin, os.Stdout)
}

C# 参照代码
using CompLib.Util;
using static System.Console;

namespace CompLib.Util
{
    using System;
    using System.Linq;

    internal class Scanner
    {
        private string[] _line;
        private int _index;
        private const char Separator = ' ';

        public Scanner()
        {
            _line = System.Array.Empty<string>();
            _index = 0;
        }

        private string Next()
        {
            if (_index < _line.Length) return _line[_index++];
            string s;
            do
            {
                s = Console.ReadLine() ?? string.Empty;
            } while (s.Length == 0);

            _line = s.Split(Separator);
            _index = 0;

            return _line[_index++];
        }

        public string ReadLine()
        {
            _index = _line.Length;
            return Console.ReadLine() ?? string.Empty;
        }

        public int NextInt() => int.Parse(Next());
        public long NextLong() => long.Parse(Next());
        public double NextDouble() => double.Parse(Next());
        public decimal NextDecimal() => decimal.Parse(Next());
        public char NextChar() => Next()[0];
        public char[] NextCharArray() => Next().ToCharArray();

        private IEnumerable<string> Array()
        {
            var s = Console.ReadLine() ?? string.Empty;
            _line = s.Length == 0 ? System.Array.Empty<string>() : s.Split(Separator);
            _index = _line.Length;
            return _line;
        }

        public int[] IntArray() => Array().AsParallel().Select(int.Parse).ToArray();
        public long[] LongArray() => Array().AsParallel().Select(long.Parse).ToArray();
        public double[] DoubleArray() => Array().AsParallel().Select(double.Parse).ToArray();
        public decimal[] DecimalArray() => Array().AsParallel().Select(decimal.Parse).ToArray();
    }
}

internal abstract class Program
{
    private static void Solve()
    {
        var sc = new Scanner();
        var t = 1;
        // t = sc.NextInt();
        for (var i = 0; i < t; i++) Query(sc);
        Out.Flush();
    }

    private const int Inf = 1000_000_007;

    private static void Query(Scanner sc)
    {
        var n = sc.NextInt();
        var fa = new int[n + 1];
        for (var i = 2; i <= n; i++) fa[i] = sc.NextInt();
        var siz = new int[n + 1];
        Array.Fill(siz, 1);
        var dp = new double[n + 1];
        dp[1] = 1;
        for (var i = n; i >= 1; i--) siz[fa[i]] += siz[i];
        for (var i = 2; i <= n; i++) dp[i] = dp[fa[i]] + (siz[fa[i]] - siz[i] - 1) * 0.5 + 1;
        for (var i = 1; i <= n; i++) Write($"{dp[i]:F1} ");
    }

    public static void Main() => Solve();
    // public static void Main() => new Thread(Solve, 1 << 27).Start();
}

\[时间复杂度为:\ O(n) \]

posted @ 2024-02-28 12:42  Athanasy  阅读(21)  评论(0编辑  收藏  举报