Goのミドルウェアのテスト


概要

GoでWebアプリケーションを作るにあたり、gRPCとgrpc-gatewayを利用して作っています。
ここで何か全APIに共通の処理を書きたい場合、grpc-gatewayにミドルウェアを作成し、そこで処理をしてしまうことが多いです。今回はそのミドルウェアのテストを書くやり方をまとめます。
なおミドルウェアはgrpc-gatewayやgRPCに依存しているものではなく、net/httpを使っているミドルウェアであれば同様にテストが書けるはずです。

業務のロジックが含まれていて割愛している所も多く、また同様の事をしている例も多々あるかと思いますが、
実際に使われているものに近いミドルウェアとそのテストとして、何かしら参考になれば幸いです。

テストするミドルウェア

以下はアプリバージョンを渡してもらい、最低アプリバージョン以下だとエラーを返すというミドルウェアです。
実際のアプリケーションでは強制アップデートをかけるために利用しています。

appVersion.go
package gateway

import (
    "fmt"
    "net/http"
    "strconv"

    "github.com/andfactory/xxx-webapp/domain/model"

    "github.com/andfactory/xxx-webapp/domain/errors/code"

    "github.com/andfactory/xxx-webapp/domain/errors"
    "github.com/andfactory/xxx-webapp/library/env"
)

const (
    slackTitleAppVersionInvalid = "appVersion-invalid"
    headerKeyAppVersion         = "App-Version"
)

var minimumAppVersionIos int
var minimumAppVersionAndroid int

func init() {
    minimumAppVersionIos = env.GetMinimumAppVersionIos()
    minimumAppVersionAndroid = env.GetMinimumAppVersionAndroid()
}

// getAppVersionHeader クライアントのアプリバージョンチェックを実施するミドルウェアを取得する
func getAppVersionHeader(h http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

        //不要なログ出力を避けるため、healthCheckとドキュメントルートではこのチェックをおこなわない
        if r.RequestURI == "/health_check" || r.RequestURI == "/" {
            h.ServeHTTP(w, r)
            return
        }

        deviceTypeStr := r.Header.Get(headerKeyDeviceType)
        deviceType, err := model.ConvertStringToDeviceType(deviceTypeStr)
        if err != nil {
            err := errors.WrapApplicationError(err, code.InvalidDevice, fmt.Sprintf("invalid device type: '%v'", deviceTypeStr))
            setUnencryptedErrorResponse(w, slackTitleAppVersionInvalid, http.StatusBadRequest, err)
            return
        }

        appVersionStr := r.Header.Get(headerKeyAppVersion)
        appVersion, err := strconv.Atoi(appVersionStr)
        if err != nil {
            err := errors.WrapApplicationError(err, code.InvalidAppVersion, fmt.Sprintf("invalid application version: '%v'", appVersionStr))
            setUnencryptedErrorResponse(w, slackTitleAppVersionInvalid, http.StatusBadRequest, err)
            return
        }

        var minimumAppVersion int
        switch deviceType {
        case model.DeviceTypeIOS:
            minimumAppVersion = minimumAppVersionIos
        case model.DeviceTypeAndroid:
            minimumAppVersion = minimumAppVersionAndroid
        }

        if appVersion < minimumAppVersion {
            err := errors.NewApplicationError(code.NeedUpdateApplication, fmt.Sprintf("%s Application version too low. got %d want %d", deviceType, appVersion, minimumAppVersion))
            setUnencryptedErrorResponse(w, slackTitleAppVersionInvalid, http.StatusBadRequest, err)
            return
        }
        h.ServeHTTP(w, r)
    })
}

テストコード

上記のミドルウェアに対しては、以下のようにテストを書くことができます。

appVersion_test.go
package gateway_test

import (
    "bytes"
    "encoding/json"
    "io/ioutil"
    "net/http"
    "net/http/httptest"
    "testing"

    "github.com/andfactory/xxx-webapp/adapter/grpc/presenter"
    "github.com/andfactory/xxx-webapp/domain/errors/code"
    "github.com/andfactory/xxx-webapp/infra/grpc/gateway"
)

//TestAppVersionSkip 特定のpassで処理をスキップする部分のテスト
func TestAppVersionSkip(t *testing.T) {
    ts := httptest.NewServer(gateway.GetAppVersionHeader(GetTestHandler()))
    defer ts.Close()

    tests := []struct {
        name         string
        pass         string
        isError      bool
        expectedCode code.ErrorCode
    }{
        {
            name:    "ルート",
            pass:    "/",
            isError: false,
        },
        {
            name:    "ヘルスチェック",
            pass:    "/health_check",
            isError: false,
        },
        {
            name:    "通常",
            pass:    "/test",
            isError: true,
        },
    }
    for _, tt := range tests {

        t.Run(tt.name, func(t *testing.T) {

            var u bytes.Buffer
            u.WriteString(string(ts.URL))
            u.WriteString(tt.pass)

            req, _ := http.NewRequest("GET", u.String(), nil)
            req.Header.Set(gateway.GetHeaderKeyDeviceType(), "invalidDeviceType")
            req.Header.Set(gateway.GetHeaderKeyAppVersion(), "0")
            res, err := gateway.Client.Do(req)
            if err != nil {
                t.Fatalf("request faiulure %v", err)
            }
            if res != nil {
                defer res.Body.Close()
            }

            if tt.isError {
                var d presenter.ErrorResponse
                if err := json.NewDecoder(res.Body).Decode(&d); err != nil {
                    t.Fatalf("request faiulure %v", err)
                }
                if d.Body.ErrorCode != code.InvalidDevice {
                    t.Fatalf("return want to be %v but returned %v", tt.expectedCode, d.Body.ErrorCode)
                }
            } else {

                b, err := ioutil.ReadAll(res.Body)
                if err != nil {
                    t.Fatalf("request faiulure %v", err)
                }
                if string(b) != "OK" {
                    t.Fatalf("return want to be OK but returned %v", string(b))
                }
            }
        })
    }
}

//TestAppVersion appVersionでチェックする処理全般のテスト
func TestAppVersion(t *testing.T) {
    ts := httptest.NewServer(gateway.GetAppVersionHeader(GetTestHandler()))
    defer ts.Close()

    var u bytes.Buffer
    u.WriteString(string(ts.URL))
    u.WriteString("/test")

    gateway.SetMinimumAppVersionIos(50)
    gateway.SetMinimumAppVersionAndroid(150)

    tests := []struct {
        name         string
        deviceType   string
        appVersion   string
        isError      bool
        expectedCode code.ErrorCode
    }{
        {
            name:         "不正なデバイス",
            deviceType:   "",
            appVersion:   "50",
            isError:      true,
            expectedCode: code.InvalidDevice,
        },
        {
            name:         "不正なデバイス",
            deviceType:   "iOS",
            appVersion:   "50",
            isError:      true,
            expectedCode: code.InvalidDevice,
        },
        {
            name:         "不正なデバイス",
            deviceType:   "3",
            appVersion:   "50",
            isError:      true,
            expectedCode: code.InvalidDevice,
        },
        {
            name:         "iOS不正なバージョン",
            deviceType:   "1",
            appVersion:   "",
            isError:      true,
            expectedCode: code.InvalidAppVersion,
        },
        {
            name:         "iOS不正なバージョン",
            deviceType:   "1",
            appVersion:   "1.1.1",
            isError:      true,
            expectedCode: code.InvalidAppVersion,
        },
        {
            name:         "iOS強制アップデート",
            deviceType:   "1",
            appVersion:   "49",
            isError:      true,
            expectedCode: code.NeedUpdateApplication,
        },
        {
            name:       "iOSミニマム",
            deviceType: "1",
            appVersion: "50",
            isError:    false,
        },
        {
            name:       "iOSミニマムより大きい",
            deviceType: "1",
            appVersion: "51",
            isError:    false,
        },
        {
            name:         "android不正なバージョン",
            deviceType:   "2",
            appVersion:   "",
            isError:      true,
            expectedCode: code.InvalidAppVersion,
        },
        {
            name:         "android不正なバージョン",
            deviceType:   "2",
            appVersion:   "1.1.1",
            isError:      true,
            expectedCode: code.InvalidAppVersion,
        },
        {
            name:         "android強制アップデート",
            deviceType:   "2",
            appVersion:   "149",
            isError:      true,
            expectedCode: code.NeedUpdateApplication,
        },
        {
            name:       "androidミニマム",
            deviceType: "2",
            appVersion: "150",
            isError:    false,
        },
        {
            name:       "androidミニマムより大きい",
            deviceType: "2",
            appVersion: "151",
            isError:    false,
        },
    }

    for _, tt := range tests {

        t.Run(tt.name, func(t *testing.T) {

            req, _ := http.NewRequest("GET", u.String(), nil)
            req.Header.Set(gateway.GetHeaderKeyDeviceType(), tt.deviceType)
            req.Header.Set(gateway.GetHeaderKeyAppVersion(), tt.appVersion)
            res, err := gateway.Client.Do(req)
            if err != nil {
                t.Fatalf("request faiulure %v", err)
            }
            if res != nil {
                defer res.Body.Close()
            }

            if tt.isError {
                var d presenter.ErrorResponse
                if err := json.NewDecoder(res.Body).Decode(&d); err != nil {
                    t.Fatalf("request faiulure %v", err)
                }
                if d.Body.ErrorCode != tt.expectedCode {
                    t.Fatalf("return want to be %v but returned %v", tt.expectedCode, d.Body.ErrorCode)
                }
            } else {

                b, err := ioutil.ReadAll(res.Body)
                if err != nil {
                    t.Fatalf("request faiulure %v", err.Error())
                }
                if string(b) != "OK" {
                    t.Fatalf("return want to be OK but returned %v", string(b))
                }
            }
        })
    }
}

func GetTestHandler() http.HandlerFunc {
    fn := func(rw http.ResponseWriter, req *http.Request) {
        rw.Write([]byte("OK"))
        return
    }
    return http.HandlerFunc(fn)
}

privateな情報にテストからアクセスできるようにexport_test.goを作成します。

export_test.go
package gateway

import (
    "net/http"
)

var Client = new(http.Client)
var GetAppVersionHeader = getAppVersionHeader

func SetApplicationAppVersionIos(i int) {
    applicationAppVersionIos = i
}

func SetApplicationAppVersionAndroid(i int) {
    applicationAppVersionAndroid = i
}

func GetHeaderKeyDeviceType() string {
    return headerKeyDeviceType
}

func GetHeaderKeyAppVersion() string {
    return headerKeyAppVersion
}

解説

ミドルウェアのテストをするには、テストしたいミドルウェアのみを実行するサーバを作れば実現できます。

以下のようにエラーがなかった時用のハンドラを用意し、

func GetTestHandler() http.HandlerFunc {
    fn := func(rw http.ResponseWriter, req *http.Request) {
        rw.Write([]byte("OK"))
        return
    }
    return http.HandlerFunc(fn)
}

テストしたミドルウェアを通してサーバを立ててあげます。

    ts := httptest.NewServer(gateway.GetAppVersionHeader(GetTestHandler()))
    defer ts.Close()

urlの設定は以下のようにすれば実現できます

            var u bytes.Buffer
            u.WriteString(string(ts.URL))
            u.WriteString(tt.pass)

            req, _ := http.NewRequest("GET", u.String(), nil)

gRPCとgrpc-getewayを使うときは共通のパラメータを送るときはhttpHeaderに設定し、gRPC飲めたデータとして処理しています。headerへの設定は以下のようにします。

    req.Header.Set(gateway.GetHeaderKeyDeviceType(), tt.deviceType)
    req.Header.Set(gateway.GetHeaderKeyAppVersion(), tt.appVersion)

これで、APIにアクセスします。なおクライアントはexport_test.goで作成して使いまわしています。appVersion_test.goで作っても良いのですが、他のミドルウェアのテストでも活用したいのでこのようになってます。

            res, err := gateway.Client.Do(req)
            if err != nil {
                t.Fatalf("request faiulure %v", err)
            }
            if res != nil {
                defer res.Body.Close()
            }

あとはレスポンスの内容をチェックしてあげればOKです。
エラーの場合は特定の型のレスポンスを返すようにしてあるので、それをパースしてコードが意図したものになっていればOK。エラーでない場合はOKが返ってくれば正常です。

            if tt.isError {
                var d presenter.ErrorResponse
                if err := json.NewDecoder(res.Body).Decode(&d); err != nil {
                    t.Fatalf("request faiulure %v", err)
                }
                if d.Body.ErrorCode != tt.expectedCode {
                    t.Fatalf("return want to be %v but returned %v", tt.expectedCode, d.Body.ErrorCode)
                }
            } else {

                b, err := ioutil.ReadAll(res.Body)
                if err != nil {
                    t.Fatalf("request faiulure %v", err.Error())
                }
                if string(b) != "OK" {
                    t.Fatalf("return want to be OK but returned %v", string(b))
                }
            }

参考

export_test.goを作って非公開の変数や関数を扱うやり方は以下で詳しく解説されてます。
非公開(unexported)な機能を使ったテスト

以下の記事でも同様の事が書かれています。
Unit Testing Golang HTTP Middleware