ABC172 C - Tsundoku

備忘録

問題

atcoder.jp

回答

import sys, os, math, bisect, itertools, collections, heapq, queue
# from scipy.sparse.csgraph import csgraph_from_dense, floyd_warshall
from decimal import Decimal
from collections import defaultdict, deque

# import fractions

sys.setrecursionlimit(10000000)

ii = lambda: int(sys.stdin.buffer.readline().rstrip())
il = lambda: list(map(int, sys.stdin.buffer.readline().split()))
fl = lambda: list(map(float, sys.stdin.buffer.readline().split()))
iln = lambda n: [int(sys.stdin.buffer.readline().rstrip()) for _ in range(n)]

iss = lambda: sys.stdin.buffer.readline().decode().rstrip()
sl = lambda: list(map(str, sys.stdin.buffer.readline().decode().split()))
isn = lambda n: [sys.stdin.buffer.readline().decode().rstrip() for _ in range(n)]

lcm = lambda x, y: (x * y) // math.gcd(x, y)
# lcm = lambda x, y: (x * y) // fractions.gcd(x, y)

MOD = 10 ** 9 + 7
MAX = float('inf')


def main():
    if os.getenv("LOCAL"):
        sys.stdin = open("input.txt", "r")

    N, M, K = il()
    # 累積和
    A = list(itertools.accumulate([0] + il()))
    B = list(itertools.accumulate(([0] + il())))

    ret = 0
    for n in range(N + 1):
        if A[n] > K:
            break

        # 二分探索
        b = bisect.bisect_right(B, K - A[n])
        ret = max(ret, n + b - 1)
    print(ret)


if __name__ == '__main__':
    main()

考え方

予め、机Aに積まれたN冊の本と机Bに積まれたM冊の本に対して、
i冊まで読んだ場合の累積和を求めておく。
次に、机Aの本をn(0 <= n <= N)冊読んだ場合、残りの時間(K - A[n])で
Bに積まれた本を何冊読むことが出来るか、二分探索で求める。

itertoolsを使用した累積和の求め方
すごいぞitertoolsくん - Qiita

bisectを使用した二分探索
Python で二分探索 bisect | 民主主義に乾杯

ちなみに、Mを右端に見立て、
尺取り法でも回答することが出来る。

import sys, os, math, bisect, itertools, collections, heapq, queue
# from scipy.sparse.csgraph import csgraph_from_dense, floyd_warshall
from decimal import Decimal
from collections import defaultdict, deque

# import fractions

sys.setrecursionlimit(10000000)

ii = lambda: int(sys.stdin.buffer.readline().rstrip())
il = lambda: list(map(int, sys.stdin.buffer.readline().split()))
fl = lambda: list(map(float, sys.stdin.buffer.readline().split()))
iln = lambda n: [int(sys.stdin.buffer.readline().rstrip()) for _ in range(n)]

iss = lambda: sys.stdin.buffer.readline().decode().rstrip()
sl = lambda: list(map(str, sys.stdin.buffer.readline().decode().split()))
isn = lambda n: [sys.stdin.buffer.readline().decode().rstrip() for _ in range(n)]

lcm = lambda x, y: (x * y) // math.gcd(x, y)
# lcm = lambda x, y: (x * y) // fractions.gcd(x, y)

MOD = 10 ** 9 + 7
MAX = float('inf')


def main():
    if os.getenv("LOCAL"):
        sys.stdin = open("input.txt", "r")

    N, M, K = il()
    # 累積和
    A = list(itertools.accumulate([0] + il()))
    B = list(itertools.accumulate(([0] + il())))


    ret, m = 0, M
    for n in range(N + 1):
        if A[n] > K:
            break

        # 尺取り法
        while B[m] > K - A[n]:
            m -= 1
        ret = max(ret, n + m)
    print(ret)


if __name__ == '__main__':
    main()