[競プロ]繰り返し2乗法【Java】【Python】


久々の競プロ記事です。

今回は繰り返し2乗法についてです。Javaで書いてみます。また、繰り返し2乗法を使用した問題をJava, Pythonで解いてみます。

繰り返し2乗法とは

べき乗の計算量を減らすテクニックです。以下のように、指数を2のべき乗表記をして累乗計算をします。
N乗の計算がO(N)からO(logN)になります。

計算方法

3^10を求めます。
10 = 2^3 + 2^1と表せるため、
3^10 = 3^(2^3 + 2^1) = 3^(2^3) * 3^(2^1)と表せます。

きれいに書くと以下のとおりです。

3^{10} = 3^{2^3} * 3^{2^1}

「それはわかるけどなんでそう考えると計算量が減るの?」という声が聞こえて来そうです。
もうちょっと丁寧に説明しますね。
2進数のbit演算を使用しています。
こんな感じです↓

計算の過程は以下のとおりです。
tmp・・・一時変数
ans・・・3^10の解が入る変数

# 今参照しているbit 説明
1 最下位bit bitが0のため、ansの更新はしない(ans=1)
tmpにtmpをかけて9とする(3*3=9)
2 下から2番目 bitが1のため、ans*=tmpとする(ans=1*9=9)
tmpにtmpをかけて81とする(9*9=81)
3 下から3番目 bitが0のため、ansの更新はしない(ans=9)
tmpにtmpをかけて6561とする(81*81=6561)
4 下から4番目 bitが1のため、ansの更新をする(ans=9*6561=59049)

こんな感じです。
tmpが重要ですね。最初は3だったのですが、上位bitを参照していくたびに2乗ずつして増えていきます。
2進数の桁が増えるのと同様、tmpも増えていく感じです。

実装

では、実装してみます。

    public static long myPow(long a, long n) {
        // a^nを計算
        long ans = 1l;
        long tmp = a;

        // わかりやすくfor文の中ですべて処理
        for (;;) {
            // すべての桁を見終わったら終了
            if (n < 1l) {
                break;
            }

            // 最下位bitが1かどうかの判定
            if (n % 2l == 1l) {
                ans *= tmp;
            }

            // tmpの更新
            tmp *= tmp;

            // nのbitを一つずらす
            n = n >> 1;
        }
        return ans;
    }

はい。こんな感じです。

myPow(3,10)を呼ぶとしっかり59049が返ってきます。

問題演習

2021/04/17(土)のAtCoder、第二回日本最強プログラマー学生選手権のD問題、Nowhere Pを解いてみます。

問題はリンクのとおりですが、答えは

(P-1)(P-2)^{n-1} mod 1000000007

を計算すれば良いです。
ただ、nの制約は最高で10^9で、単純にやるとTLEになる可能性が高いです。
なので繰り返し2乗法の登場です。

Javaの回答例

import java.util.Scanner;

public class Main {

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        long n = sc.nextLong();
        long p = sc.nextLong();
        long MOD = 1000000007l;

        long ans = (p - 1) * modPow(p - 2, n - 1, MOD) % MOD;

        System.out.println(ans);
    }

    public static long modPow(long a, long n, long mod) {
        // a^nを計算
        long ans = 1l;
        long tmp = a;

        // わかりやすくfor文の中ですべて処理
        for (;;) {
            // すべての桁を見終わったら終了
            if (n < 1l) {
                break;
            }

            // 最下位bitが1かどうかの判定
            if (n % 2l == 1l) {
                ans *= tmp;
                ans %= mod;
            }

            // tmpの更新
            tmp *= tmp;
            tmp %= mod;

            // nのbitを一つずらす
            n = n >> 1;
        }
        return ans;
    }

}

先程のmyPow関数をちょっと変えてmodPowにしました。計算の途中でmodとってるだけですね。
これで問題が解けました。

Python の回答例

ちなみに、Pythonのpow関数は繰り返し2乗法が実装されているので特に気にせず実装が出来ます。
また、 powの第三引数に値を入れるとmodPowの実装になります。

実装例(公式の解説とほぼ同じですが)

N, P = map(int, input().split())
MOD = 1000000007
ans = (P-1) * pow(P-2, N-1, MOD) % MOD
print(ans)

こんな感じです。

逆に競プロ以外で繰り返し2乗法や modPowが必要な場面を教えてほしいくらいですが・・・。

割と簡単なアルゴリズムだったので覚えておきたいですね。

繰り返し2乗法を紹介しました。今回の記事はここまでです。