Rustの型推論のおかげで逆にデバッグに苦労した。


RustでAtCoderに挑み始めて2週間弱の初心者です。
Rustは型に厳密なおかげでバグにも気づきやすくて便利だなーと思っていたら、なかなか解けないバグに遭遇したので紹介します。

問題

AtCoderの過去問を解いていたら、こんなのに出会いました。1

ABC159 D - Banned K

解き方は会っているはずなのに、なぜか提出すると幾つかのケースでWAになってしまう。TLEじゃなくてWA、ということでかなり悩みました。ついでにJavaScriptで書いてみたら簡単にパスするし、ロジックは正しい筈…。

かなり悩んだのですが、こういう時は他人のコードを参考にするに限ります。少しずつACのコードとの差分を縮めていったところ、ついに通るコードとダメなコードの差分を発見しました。2
さて、下記のコードの何が問題だったでしょう?

ダメなコード

use std::collections::HashMap;
fn main() {
    proconio::input! {
        n:usize, A:[usize;n]
    }
    let mut B = HashMap::new();
    for i in 0..n {
        *B.entry(A[i]).or_insert(0) += 1;
    }
    let mut sum = 0;
    for (_, &v) in B.iter() {
        sum += v * (v - 1) / 2; // <----注目
    }
    for i in 0..n {
        let val = B.get(&A[i]).unwrap();
        let ans = sum + 1 - val;
        println!("{}", ans);
    }
}

OKなコード

use std::collections::HashMap;
fn main() {
    proconio::input! {
        n:usize, A:[usize;n]
    }
    let mut B = HashMap::new();
    for i in 0..n {
        *B.entry(A[i]).or_insert(0) += 1;
    }
    let mut sum = 0;
    for (_, &v) in B.iter() {
        sum += combination(v); // <----注目
    }
    for i in 0..n {
        let val = B.get(&A[i]).unwrap();
        let ans = sum + 1 - val;
        println!("{}", ans);
    }
}

fn combination(v: usize) -> usize {
    v * (v - 1) / 2
}

見てわかるように、2つ目のfor文の中で、${}_n\mathrm{C}_2$を計算する際、インラインで処理しているか関数を呼び出しているかという、ただそれだけの違いです。
ここから小一時間悩んで、漸く原因に思い当たりました。

原因判明

一言でいうと、

  • vとsumがi32で扱われるためオーバーフローが発生する

というのが原因です。3
一方、OKな例では、combination関数の引数と戻り値が usize であるため、コンパイル時に行われる型推論によって、vとsumも usizeになります。これは実質的にu64です。

実際、それぞれの最大値を考えてみると下記のようになります。

  • i32の場合: $2^{32-1}-1 \risingdotseq 2 \times 10^9$
  • u64の場合: $2^{64}-1 \risingdotseq 16 \times 10^{18}$

一方この問題の制約では、$N$は$2\times 10^5$(20万)以下となっています。したがって$A_i$がすべて同じ値の時に$v=N$となり、$v^2$は最大で$N^2=4 \times 10^{10}$になるので、確かにi32だとオーバーフローしてしまいます。

というわけで、型推論の手がかりさえ与えてあげれば、関数呼び出しを行う必要はありません。
下記のコードでもACになります。この場合、sumの型がi32ではなくusizeと推論され、そこからvもusizeと推論されるため、オーバーフローは発生しません。

use std::collections::HashMap;
fn main() {
    proconio::input! {
        n:usize,  A:[usize;n]
    }
    let mut B = HashMap::new();
    for i in 0..n {
        *B.entry(A[i]).or_insert(0) += 1;
    }
    let mut sum = 0usize;   // <--型を指定
    for (_, &v) in B.iter() {
        sum += v * (v - 1) / 2;
    }
    for i in 0..n {
        let val = B.get(&A[i]).unwrap();
        let ans = sum + 1 - val;
        println!("{}", ans);
    }
}

まとめ

Rustの型推論は強力ですが、手がかりの無い数値がi32になることを忘れると痛い目に遭う、というお話でした。4


  1. 実はこの問題に最初に遭遇したのはこちらじゃなくて、A - Zero-Sum Rangesでした。こちらではRustでの回答者自体が他にいなかったため、ACコードを参考にすることもできず、とりあえずJavaScriptでACすることだけ確認できたので、原因不明のままいったん諦めました……。 

  2. 見やすくするために実際に提出したコードよりちょっと整理してます。 

  3. 最初は容量を宣言していないHashMapあたりが原因かと思っていたのですが、問題はそこじゃありませんでした。HashMapの定義時に型を宣言していないために気付きにくかったという点では、無関係とも言えませんが…。 

  4. まあ競プロじゃなければ入力値もわかるため、もう少し早めに気付いた気もしますが……。