python で numpy の値を swap する方法


前提

get_idxs関数は、ランダムな2つのインデックスを返します
結論が知りたい方はまとめまで飛ばしてください

方法1

C言語的な記述法(たぶん誰もが最初に思いつく)

def swap1(arr):
    idx0, idx1 = get_idxs(arr)
    tmp = arr[idx0]
    arr[idx0] = arr[idx1]
    arr[idx1] = tmp

方法2

説明しずらい

def swap2(arr):
    idx0, idx1 = get_idxs(arr)
    arr[[idx0, idx1]] = arr[[idx1, idx0]]

方法3

タプルで

def swap3(arr):
    idx0, idx1 = get_idxs(arr)
    arr[idx0], arr[idx1] = arr[idx1], arr[idx0]

ベンチマークテスト

せっかくなので計測っちゃいましょうと

ズラリ

## benchmarker:         release 4.0.1 (for python)
## python version:      3.7.1
## python compiler:     GCC 7.3.0
## python platform:     Linux-4.15.0-45-generic-x86_64-with-debian-buster-sid
## python executable:   /opt/anaconda3/bin/python
## cpu model:           Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz  # 3504.115 MHz
## parameters:          loop=1000000, cycle=10, extra=1

## (#1)                                  real    (total    = user    + sys)
(Empty)                                0.0156    0.0100    0.0100    0.0000
swap1                                  2.3436    2.3500    2.3500    0.0000
swap2                                  3.7167    3.7200    3.7200    0.0000
swap3                                  2.3666    2.3800    2.3800    0.0000

## (#2)                                  real    (total    = user    + sys)
(Empty)                                0.0154    0.0100    0.0100    0.0000
swap1                                  2.3326    2.3400    2.3400    0.0000
swap2                                  3.6803    3.6800    3.6800    0.0000
swap3                                  2.3584    2.3700    2.3700    0.0000

## (#3)                                  real    (total    = user    + sys)
(Empty)                                0.0157    0.0100    0.0100    0.0000
swap1                                  2.3131    2.3200    2.3200    0.0000
swap2                                  3.6810    3.6900    3.6900    0.0000
swap3                                  2.3164    2.3200    2.3200    0.0000

## (#4)                                  real    (total    = user    + sys)
(Empty)                                0.0177    0.0200    0.0200    0.0000
swap1                                  2.3225    2.3200    2.3200    0.0000
swap2                                  3.6397    3.6300    3.6300    0.0000
swap3                                  2.3224    2.3200    2.3200    0.0000

## (#5)                                  real    (total    = user    + sys)
(Empty)                                0.0152    0.0200    0.0200    0.0000
swap1                                  2.3206    2.3100    2.3100    0.0000
swap2                                  3.6659    3.6600    3.6600    0.0000
swap3                                  2.3865    2.3900    2.3900    0.0000

## (#6)                                  real    (total    = user    + sys)
(Empty)                                0.0165    0.0100    0.0100    0.0000
swap1                                  2.3145    2.3200    2.3200    0.0000
swap2                                  3.6422    3.6500    3.6500    0.0000
swap3                                  2.3354    2.3400    2.3400    0.0000

## (#7)                                  real    (total    = user    + sys)
(Empty)                                0.0162    0.0200    0.0200    0.0000
swap1                                  2.2964    2.2900    2.2900    0.0000
swap2                                  3.6841    3.6800    3.6800    0.0000
swap3                                  2.3225    2.3200    2.3200    0.0000

## (#8)                                  real    (total    = user    + sys)
(Empty)                                0.0151    0.0100    0.0100    0.0000
swap1                                  2.3031    2.3100    2.3100    0.0000
swap2                                  3.6650    3.6700    3.6700    0.0000
swap3                                  2.3700    2.3800    2.3800    0.0000

## (#9)                                  real    (total    = user    + sys)
(Empty)                                0.0158    0.0100    0.0100    0.0000
swap1                                  2.3015    2.3100    2.3100    0.0000
swap2                                  3.6395    3.6400    3.6400    0.0000
swap3                                  2.3125    2.3200    2.3200    0.0000

## (#10)                                 real    (total    = user    + sys)
(Empty)                                0.0153    0.0200    0.0200    0.0000
swap1                                  2.3200    2.3100    2.3100    0.0000
swap2                                  3.6493    3.6500    3.6500    0.0000
swap3                                  2.3936    2.3800    2.3800    0.0000

## (#11)                                 real    (total    = user    + sys)
(Empty)                                0.0149    0.0200    0.0200    0.0000
swap1                                  2.3522    2.3500    2.3500    0.0000
swap2                                  3.7416    3.7300    3.7300    0.0000
swap3                                  2.3533    2.3500    2.3500    0.0000

## (#12)                                 real    (total    = user    + sys)
(Empty)                                0.0159    0.0200    0.0200    0.0000
swap1                                  2.3102    2.3000    2.3000    0.0000
swap2                                  3.6899    3.6900    3.6900    0.0000
swap3                                  2.3286    2.3200    2.3200    0.0000

## Ignore min & max                       min     cycle       max     cycle
swap1                                  2.2964      (#7)    2.3522     (#11)
swap2                                  3.6395      (#9)    3.7416     (#11)
swap3                                  2.3125      (#9)    2.3936     (#10)

## Average of 10 (=12-2*1)               real    (total    = user    + sys)
swap1                                  2.3182    2.3190    2.3190    0.0000
swap2                                  3.6714    3.6720    3.6720    0.0000
swap3                                  2.3460    2.3490    2.3490    0.0000

## Ranking                               real
swap1                                  2.3182  (100.0) ********************
swap3                                  2.3460  ( 98.8) ********************
swap2                                  3.6714  ( 63.1) *************

## Matrix                                real    [01]    [02]    [03]
[01] swap1                             2.3182   100.0   101.2   158.4
[02] swap3                             2.3460    98.8   100.0   156.5
[03] swap2                             3.6714    63.1    63.9   100.0

意外や意外!
なんと 方法1が一番はやい 結果になりました
次が方法3、最後に方法2と続きます

方法1方法3がいい勝負してますね
方法2は安定して遅いようです

まとめ

文法が簡単かつ実行速度のはやい方法3を使いましょう

ソースコード

bm.py
from numpy import arange
from benchmarker import Benchmarker

import swap


LOOP = 1000000
ARR_LEN = 2000
CYC = 10

with Benchmarker(LOOP, cycle=CYC, extra=1) as bench:
    n = arange(ARR_LEN)

    @bench(None)                ## empty loop
    def _(bm):
        for i in bm:
            pass

    @bench("swap1")
    def _(bm):
        for i in bm:
            swap.swap1(n)

    @bench("swap2")
    def _(bm):
        for i in bm:
            swap.swap2(n)

    @bench("swap3")
    def _(bm):
        for i in bm:
            swap.swap3(n)

swap.py
from numpy.random import randint


def swap1(arr):
    idx0, idx1 = get_idxs(arr)
    tmp = arr[idx0]
    arr[idx0] = arr[idx1]
    arr[idx1] = tmp


def swap2(arr):
    idx0, idx1 = get_idxs(arr)
    arr[[idx0, idx1]] = arr[[idx1, idx0]]


def swap3(arr):
    idx0, idx1 = get_idxs(arr)
    arr[idx0], arr[idx1] = arr[idx1], arr[idx0]


def get_idxs(arr):
    return randint(arr.shape[0]), randint(arr.shape[0])