第二回全国統一プログラミング王決定戦予選 B - Counting of Trees

備忘録

問題

atcoder.jp

回答

import sys
import os
import math
import bisect
import itertools
import collections
import heapq
import queue
import array
import time
import datetime

# 時々使う
# import numpy as np
# from decimal import Decimal, ROUND_HALF_UP
# from scipy.sparse.csgraph import csgraph_from_dense, floyd_warshall

# 再帰の制限設定
sys.setrecursionlimit(10000000)


def ii(): return int(sys.stdin.buffer.readline().rstrip())
def il(): return list(map(int, sys.stdin.buffer.readline().split()))
def it(): return tuple(map(int, sys.stdin.buffer.readline().split()))
def fl(): return list(map(float, sys.stdin.buffer.readline().split()))
def iln(n): return [int(sys.stdin.buffer.readline().rstrip())
                    for _ in range(n)]


def iss(): return sys.stdin.buffer.readline().decode().rstrip()
def sl(): return list(map(str, sys.stdin.buffer.readline().decode().split()))
def isn(n): return [sys.stdin.buffer.readline().decode().rstrip()
                    for _ in range(n)]


def lcm(x, y): return (x * y) // math.gcd(x, y)


# MOD = 10 ** 9 + 7
MOD = 998244353
INF = float('inf')


def main():
    if os.getenv("LOCAL"):
        sys.stdin = open("input.txt", "r")

    N = ii()
    D = il()
    cnt = collections.Counter(D)

    if D[0] != 0 or cnt[0] > 1:
        print(0)
        exit()

    ret = 1
    for i in range(1, N-1):
        ret *= (cnt[i-1] ** cnt[i]) % MOD
        ret %= MOD
    print(ret)


if __name__ == '__main__':
    main()

考え方

まず、条件を満たすことができる木は、

  • D1がゼロ(頂点1が頂点1との距離がゼロであること)
  • ある頂点i (1 < i <= N)の距離がゼロではないこと(頂点1以外の頂点が根ではないこと)

が必須となります。
そのため条件を満たしていない場合には、必ず答えはゼロとなります。

あとは、頂点1を根として、子を連結させた場合に取りうる通り数を求めます。

頂点から距離1の位置に頂点が2つ、距離2の位置に頂点が3つ存在する場合、
距離2の頂点は距離1の頂点に存在する頂点のどちらかを親に必要があります。
そのため、距離2の各頂点はそれぞれ2通りの連結方法が存在するため、
距離2の頂点が取りうる通り数は2 ** 3通りです。

上記のように、距離iの頂点は距離i-1の頂点の数によって、
連結の通り数を求めることが可能です。