ATC001 B - Union Find

備忘録

問題

atcoder.jp

回答

import sys

sys.setrecursionlimit(10000000)
import os
import math
import bisect
import collections
import itertools
import heapq
import re
import queue

# import fractions

ii = lambda: int(sys.stdin.buffer.readline().rstrip())
il = lambda: list(map(int, sys.stdin.buffer.readline().split()))
fl = lambda: list(map(float, sys.stdin.buffer.readline().split()))
iln = lambda n: [int(sys.stdin.buffer.readline().rstrip()) for _ in range(n)]

iss = lambda: sys.stdin.buffer.readline().decode().rstrip()
sl = lambda: list(map(str, sys.stdin.buffer.readline().decode().split()))
isn = lambda n: [sys.stdin.buffer.readline().decode().rstrip() for _ in range(n)]

lcm = lambda x, y: (x * y) // math.gcd(x, y)
# lcm = lambda x, y: (x * y) // fractions.gcd(x, y)

MOD = 10 ** 9 + 7
MAX = float('inf')


def ufinit(n):
    global uf
    uf = list(range(n))


def root(n):
    if uf[n] == n:
        return n
    else:
        uf[n] = root(uf[n])
        return uf[n]


def same(x, y):
    return root(x) == root(y)


def unite(x, y):
    x = root(x)
    y = root(y)
    if x == y:
        return
    else:
        uf[x] = y


def main():
    if os.getenv("LOCAL"):
        sys.stdin = open("input.txt", "r")

    N, Q = il()
    ufinit(N)

    for _ in range(Q):
        q, a, b = il()

        if q == 1:
            if same(a, b):
                print('Yes')
            else:
                print('No')
        else:
            unite(a, b)


if __name__ == '__main__':
    main()

考え方

UnionFind入門。
考え方と実装は下記を参考。

素集合データ構造 - Wikipedia Union-Find木の解説と例題 - Qiita

ノードの個数がNの場合、
要素をN個持った配列を用意し、
indexをノードの番号、valueをルートノードのindexとして扱う。

初期状態ではindex = valueとしておき、
全てのノードがルートかつ孤立している状態とする。

# ノードがN個の場合
# UnionFindの初期化
uf = list(range(N)) #[0, 1, 2, 3, 4, 5, 6, 7]

Union(unite関数)の実装は互いのノードのルートを探索し、
uf[xのルートノード] = yのルートノードとすることで結合できる。

Find(same関数)の実装は、互いのルートノードが同じか否かを判定する。
ルートノードの探索は、index = valueとなっている場合はルート。
それ以外の場合には、再帰的にvalue(自身がつながっているノードのindex)から探索を行なう。