Javaで関数型っぽくクイックソートを実装する


はじめに

Javaでlambda式やStreamを使ってクイックソートを実装してみました.
この記事の対象読者は次のような人です.

  • Scalaを習得している人
  • Javaの関数型っぽい機能に入門したい人

lambda式

lambda式は(引数)->戻り値の形式で書くことができ,applyメソッドで関数を呼び出すことができます.

import java.util.function.Function;

public class Main {
    public static void main(String[] args) {
        Function<String, String> addHoge = (final String str) -> str + "hoge";
        System.out.println(addHoge.apply("fuga"));
    }
}

また,式を複数書きたい場合はブロックを利用することが可能です.

import java.util.function.Function;

public class Main {
    public static void main(String[] args) {
        Function<String, String> addHoge = (final String str) -> {
            String hoge = "hoge";
            return str + hoge;
        };
        System.out.println(addHoge.apply("fuga"));
    }
}

Stream

Streamとはコレクションの要素に対して関数型の操作をサポートするクラスです.コレクションクラスのstream()メソッドを呼び出すことにより,Streamを利用することができます.

stream()を使って2の倍数を表示するプログラムを書くと,次のようになります.

Java
Arrays.asList(3, 21, 34, 0).stream().filter(n -> n % 2 == 0).forEach(System.out::println);

Scalaだと次のようになります.Javaと比較するとstream()を書く必要がない分,コードがシンプルになります.

Scala
List(3, 21,34, 0).filter(_ % 2 == 0).foreach(println)

ソースコード(Scala)

Javaでクイックソートを実装する前に,Scalaで実装したコードを示します.
クイックソートのpivotは配列の先頭に指定しています.

object Main extends App {
    println(quickSort(List(3, 21, 34, 0, -30, 55, 10)))

    def quickSort(nums: List[Int]): List[Int] = {
      nums.headOption.fold(List[Int]()){ pivot =>
        val left = nums.filter(_ < pivot)
        val right = nums.filter(pivot < _)
        quickSort(left) ++ List(pivot) ++ quickSort(right)
      }
    }
}

ソースコード(Java)

次にJavaでクイックソートを実装したものを示します.

Scalaのコードと見比べると,Javaのコードの多さに圧倒されますが,ロジック自体はほぼ同じような実装になっていると思います.

ロジックに関して,唯一異なる点をあげるとしたら,foldメソッドです.Scalaのコードにはfoldメソッドを使っている箇所が存在しますが,JavaのOptional型(ScalaでいうOption型)にはfoldメソッドは存在しません.その代わりにmapメソッドとOrElseGetメソッド(ScalaでいうgetOrElseメソッド)を組み合わせて使用しています.

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import static java.util.stream.Collectors.toList;
import java.util.stream.Stream;

public class Main {
    public static void main(String[] args) {
        final List<Integer> nums = Arrays.asList(3, 21, 34, 0, -30, 55, 10);
        final List<Integer> sorted = quickSort(nums);
        System.out.println(sorted.toString());
    }

    private static List<Integer> quickSort(final List<Integer> nums) {
        return nums.stream().findFirst().map((final Integer pivot) -> {
            final List<Integer> left = nums.stream().filter(n -> n < pivot).collect(toList());
            final List<Integer> middle = Collections.singletonList(pivot);
            final List<Integer> right = nums.stream().filter(n -> pivot < n).collect(toList());
            return Stream.of(quickSort(left), middle, quickSort(right)).flatMap(Collection::stream).collect(toList());
        }).orElseGet(Collections::emptyList);
    }
}