ABC182 C - To 3

備忘録

問題

atcoder.jp

回答

bit全探索を用いた回答(実行時間: 871ms)

import sys
import os
import math
import bisect
import itertools
import collections
import heapq
import queue
import array
import time


# 時々使う
# import numpy as np
# from decimal import Decimal, ROUND_HALF_UP
# from scipy.sparse.csgraph import csgraph_from_dense, floyd_warshall
# from collections import defaultdict, deque

# 再帰の制限設定
sys.setrecursionlimit(10000000)


def ii(): return int(sys.stdin.buffer.readline().rstrip())
def il(): return list(map(int, sys.stdin.buffer.readline().split()))
def fl(): return list(map(float, sys.stdin.buffer.readline().split()))
def iln(n): return [int(sys.stdin.buffer.readline().rstrip())
                    for _ in range(n)]


def iss(): return sys.stdin.buffer.readline().decode().rstrip()
def sl(): return list(map(str, sys.stdin.buffer.readline().decode().split()))
def isn(n): return [sys.stdin.buffer.readline().decode().rstrip()
                    for _ in range(n)]


def lcm(x, y): return (x * y) // math.gcd(x, y)


# MOD = 10 ** 9 + 7
MOD = 998244353
INF = float('inf')


def main():
    if os.getenv("LOCAL"):
        sys.stdin = open("input.txt", "r")

    S = iss()
    N = len(S)

    ret = INF
    for i in range(1 << N):
        cnt, sm = 0, 0
        for j in range(N):
            if (i >> j) & 1:
                sm += int(S[j])
            else:
                cnt += 1
        else:
            if sm % 3 == 0 and cnt < N and sm > 0:
                ret = min(ret, cnt)

    print(-1 if ret == INF else ret)


if __name__ == '__main__':
    main()

ケース分けを行う回答(実行時間: 34ms)

import sys
import os
import math
import bisect
import itertools
import collections
import heapq
import queue
import array
import time


# 時々使う
# import numpy as np
# from decimal import Decimal, ROUND_HALF_UP
# from scipy.sparse.csgraph import csgraph_from_dense, floyd_warshall
# from collections import defaultdict, deque

# 再帰の制限設定
sys.setrecursionlimit(10000000)


def ii(): return int(sys.stdin.buffer.readline().rstrip())
def il(): return list(map(int, sys.stdin.buffer.readline().split()))
def fl(): return list(map(float, sys.stdin.buffer.readline().split()))
def iln(n): return [int(sys.stdin.buffer.readline().rstrip())
                    for _ in range(n)]


def iss(): return sys.stdin.buffer.readline().decode().rstrip()
def sl(): return list(map(str, sys.stdin.buffer.readline().decode().split()))
def isn(n): return [sys.stdin.buffer.readline().decode().rstrip()
                    for _ in range(n)]


def lcm(x, y): return (x * y) // math.gcd(x, y)


# MOD = 10 ** 9 + 7
MOD = 998244353
INF = float('inf')


def digit_sum(n):
    ans = 0
    while n > 0:
        ans += n % 10
        n //= 10
    return ans


def main():
    if os.getenv("LOCAL"):
        sys.stdin = open("input.txt", "r")

    N = ii()
    L = len(str(N))
    DSUM = digit_sum(N)

    if DSUM % 3 == 0:
        print(0)
    elif DSUM % 3 == 1:
        while N > 0:
            m = (N % 10) % 3
            N //= 10
            if m == 1 and L > 1:
                print(1)
                break
        else:
            if L > 2:
                print(2)
            else:
                print(-1)
    elif DSUM % 3 == 2:
        while N > 0:
            m = (N % 10) % 3
            N //= 10
            if m == 2 and L > 1:
                print(1)
                break
        else:
            if L > 2:
                print(2)
            else:
                print(-1)


if __name__ == '__main__':
    main()

考え方

3の倍数か否かの判定は、各桁の和で判断することができます。
各桁の和が3の倍数の場合には、元の値は3の倍数です。

3の倍数でない場合には3で割ったときの余りが1または2となります。
3の倍数ではない元の値を3の倍数にするためには、
各桁で余りが発生する値を取り除くことによって3の倍数にすることが出来ます。

回答方法を2種類記載しています。
1つ目はbit全探索を用いて、各桁を和に使用する場合と使用しない場合で判定する方法です。
こちらは元の値が最大18桁なので、
最大で2 ** 18回判定が必要になり、実行時間は800ms程度です。

2つ目は公式の解説通り、愚直にケース分けを行う回答です。
処理は若干分かり辛いですが、実行時間は30ms程度です。