Codeforces Round #510 (Div. 2) D. Petya and Array 座圧とセグメントツリー


思ったより実装に時間かかった。

題意

  • $n \leq 10^5$個の各要素 $- 10^9 \leq a_i \leq 10^9 $からなる配列$a$が与えられます。
  • この連続している部分配列の和が、$- 10^{12} \leq t \leq 10^{12} $未満のものは全部でいくつあるか?(配列は重複あり、ソートなし)

こう考えた

問題は言い換えると、区間$[l, r)$の部分和が$t$未満の部分和となる、$l,r$の数を求めたい。と言えます。

累積和を取りセグメントツリーに入れます。つまり、$a_0$からある$a_i$までの和をindexとして、出現回数を値にするのセグメントツリー$st$を持ちます。ただし、値は負を取りうることと、値の範囲がとても大きいため、座圧を行います。例えば、
例題を例にとります。$[5,-1,3,4,-1]$であれば、累積和は$[5,4,7,11,10]$となるので、座圧のためのテーブル$[4,5,7,10,11]$を持ち、$1,0,2,4,3$に変換します。
この時、stは$[1,1,1,1,1]$というテーブルになります。

ここで、$l=0$と固定したとします。この時、$t=4$であると、累積和のテーブルの中でt未満となる数を列挙すればよいので、上記で計算した(累積和のカウントを座圧した数値をセグメントツリーにしている)stにクエリすればよいです。クエリの対象は、$t=4$未満ですが、座圧後の値にする必要があります。$t=0$とは、座圧するとindexが$0$です。$st.query[0,0) = 0$が答えになります。

つぎに、$l=1$と固定したとしましょう。まず、この処理に入る前に、$i=0$に対応する累積和の要素を消して、セグメントツリーの該当要素を$-1$します。
この時、再度、累積和を計算したくなりますが、これを繰り返すと$O(N^2)$の時間がかかります。このため、逆に、$t$を累積和に合わせて変更します。さて、最初に計算した累積和$[5,4,7,11,10]$ですが、$l=1$から累積和を行うと$[-1, 2, 6, 5]$です。さて、これを見ると、最初の累積和の2要素目以降から1要素目の$5$を引いたものがわかります。累積和の性質を考えると1要素目が抜けたため、明らかです。
ということは、$l=1$のとき最初に求めた各要素からa_0の要素の値(この場合は5)を引いて、t=4未満の要素であればいいので、言い換えれば、最初に求めた各要素からa_0の要素の値がt=9未満であればよいです。
ここで、座圧テーブル$[4,5,7,10,11]$で$9$未満のindexを考えると$3$です。この際、lower bound(Pythonならbisect left)で考えます。このまま、[l, index)のクエリを行うと、まるで、$l=0$の時のクエリと同じように、$l=1$の時の組み合わせが求められます。

図に示すと次のようになります。

実装

def do():
    st = segmentTreeSum()
    n, t = map(int, input().split())
    dat = list(map(int, input().split()))
    dattotal = []
    total = 0
    segtreeList = [0] * (200000 + 10)
    zatsu = set()
    for x in dat:
        total += x
        dattotal.append(total)
        zatsu.add(total)
    zatsu = list(zatsu)
    zatsu.sort()
    zatsuTable = dict()
    zatsuTableRev = dict()
    for ind, val in enumerate(zatsu):
        zatsuTable[val] = ind
        zatsuTableRev[ind] = val
    buf = []
    for x in dattotal:
        buf.append(zatsuTable[x])
        segtreeList[zatsuTable[x]] += 1
    st.load(segtreeList)

    from bisect import bisect_left, bisect_right

    offset = 0
    res = 0
    for i in range(n): # x = total from 0 to curren
        x = dattotal[i]
        curvalind = zatsuTable[x]
        targetval = t + offset# target val
        targetind = bisect_left(zatsu, targetval)
        cnt = st.query(0, targetind )
        res += cnt
        st.addValue(curvalind, -1)
        offset += dat[i] # for next offset
    print(res)
do()