ABC098 D - Xor Sum 2

備忘録

問題

atcoder.jp

回答

import sys
import os
import math
import bisect
import collections
import itertools
import heapq
import re
import queue
from decimal import Decimal

# import fractions

sys.setrecursionlimit(10000000)

ii = lambda: int(sys.stdin.buffer.readline().rstrip())
il = lambda: list(map(int, sys.stdin.buffer.readline().split()))
fl = lambda: list(map(float, sys.stdin.buffer.readline().split()))
iln = lambda n: [int(sys.stdin.buffer.readline().rstrip()) for _ in range(n)]

iss = lambda: sys.stdin.buffer.readline().decode().rstrip()
sl = lambda: list(map(str, sys.stdin.buffer.readline().decode().split()))
isn = lambda n: [sys.stdin.buffer.readline().decode().rstrip() for _ in range(n)]

lcm = lambda x, y: (x * y) // math.gcd(x, y)
# lcm = lambda x, y: (x * y) // fractions.gcd(x, y)

MOD = 10 ** 9 + 7
MAX = float('inf')


def main():
    if os.getenv("LOCAL"):
        sys.stdin = open("input.txt", "r")

    N = ii()
    A = il()

    ret, r, sm = 0, 0, 0
    for l in range(N):
        while r < N and A[r] ^ sm == sm + A[r]:
            sm += A[r]
            r += 1

        ret += r - l
        if l == r:
            r += 1
        else:
            sm -= A[l]
    print(ret)


if __name__ == '__main__':
    main()

考え方

ちょっと条件が面倒な尺取り法。
左端から右端までの合計値をsmとしたとき、
sm + 次の値 = sm xor 次の値であることを確認する。

左端(l)と右端(r)を0からスタートし、
左右の範囲の合計とA[r+1]の合算と、
左右の範囲の合計とA[r+1]xorが一致した場合のみ右端(r)をインクリメントすることができる。

右端(r)がインクリメントされることによって、A[l:r]の合計が増加していくが、
条件式により、単純な和(A[l:r] + A[r+1])とxorの和(A[l:r] xor A[r+1])は等しい範囲のみ選択されている。
そのため、右端(r)のインクリメントされる条件はA[r] ^ sm == sm + A[r]で満たすことができる。

# 例
# 左右の範囲内で
# 単純な和とxor和が一致していることが示されている場合
2 + 5 = 7
2 ^ 5 = 7
# 単純な和とのxorを行った結果と、
# 全てのxor和の結果は必ず一致する
7 + 8 = 15
2 ^ 5 ^ 8 = 15

右端(r)をインクリメントする条件さえ明確にすることが出来れば、
あとは普通に尺取り法に乗っ取った実装をするだけ。