grpc-goのMiddlewareについて調べてみた


grpc-goで認証したりログを出したりする場合はミドルウェアを使用します。

GitHub - grpc-ecosystem/go-grpc-middleware: Golang gRPC Middlewares: interceptor chaining, auth, logging, retries and more.

ドキュメントを見ながら使おうとしてみたのですが、いまいち勝手が分からなかったので、基本に立ち戻ってgrpc-goのソースを理解するところから始めてみました。

まずは定義の確認

最初に、サーバ起動時に指定する引数の定義を確認します。

grpc/server.go
func NewServer(opt ...ServerOption) *Server
func UnaryInterceptor(i UnaryServerInterceptor) ServerOption

サーバオブジェクトを作るときにServerOptionが指定できます。
UnaryServerInterceptorはインターセプタの関数をServerOptionに変換する処理です。インターセプタを作るにはUnaryServerInterceptorの型に合わせたオブジェクトを作る必要があることが分かります。

次に、UnaryServerInterceptorの定義を見ます。

grpc/interceptor.go
// UnaryServerInterceptor provides a hook to intercept the execution of a unary RPC on the server. info
// contains all the information of this RPC the interceptor can operate on. And handler is the wrapper
// of the service method implementation. It is the responsibility of the interceptor to invoke handler
// to complete the RPC.
type UnaryServerInterceptor func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error)

この定義の型とコメントから、以下のようなインターセプタの役割と出来ることが読み取れます。

  • UnaryServerInterceptorはRPCをインターセプトするフックを提供する。
  • ハンドラを呼び出してRPCを完了させるのはインターセプタの役割。
  • ハンドラを呼び出すのがインターセプタの役割ということは、ハンドラを呼ばずに握りつぶす(実際に処理はせずに正常終了させる)ことやエラーにすることが可能。
  • ハンドラを呼び出すタイミングの前または後に、独自の処理を入れることが可能。(呼ぶ前に認証、読んだあとにログ出力とか)

Stream指定があるRPC用にStreamServerInterceptorもありますが、基本は同じです。

何もしないインターセプタを定義してみる

基本を押さえるため、まずはシンプルな定義で確認します。

まず、UnaryServerInterceptorを返す関数を定義します。

func unaryServerInterceptor() grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        resp, err := handler(ctx, req)
        return resp, err
    }
}

これをgRPCサーバのオブジェクト作成時に引数として指定すれば登録完了です。

func main() {
    lis, err := net.Listen("tcp", port)
    if err != nil {
        log.Fatalf("failed to listen: %v", err)
    }

    // サーバオブジェクト作成
    s := grpc.NewServer(grpc.UnaryInterceptor(
        unaryServerInterceptor()))

    pb.RegisterGreeterServer(s, &server{})

    if err := s.Serve(lis); err != nil {
        log.Fatalf("failed to serve: %v", err)
    }
}

簡易的な認証を追加する

gRPCリクエストヘッダのキーにauthorization、値にパスワードを入れてリクエストする場合の例です。(エラー処理はサボってます)

func unaryServerInterceptor() grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        md, _ := metadata.FromIncomingContext(ctx)

        password := md["authorization"]
        if len(password) < 1 || password[0] != "xxxxx" {
            return nil, errors.New("Incorrect password")
        }

        resp, err := handler(ctx, req)
        return resp, err
    }
}

リクエストヘッダの内容はコンテキストからメタデータとして取得できます。
引数のgrpc.UnaryServerInfoからはgRPCのメソッド名などの情報が取れるので、特定のメソッドの場合だけ認証しないようにすることも可能です。

ソース全体

package main

import (
    "errors"
    "log"
    "net"

    "golang.org/x/net/context"
    "google.golang.org/grpc"
    pb "google.golang.org/grpc/examples/helloworld/helloworld"
    "google.golang.org/grpc/metadata"
)

const (
    port = ":50051"
)

type server struct{}

func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
    return &pb.HelloReply{Message: "Hello " + in.Name}, nil
}

func unaryServerInterceptor() grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        md, _ := metadata.FromIncomingContext(ctx)

        password := md["authorization"]
        if len(password) < 1 || password[0] != "xxxxx" {
            return nil, errors.New("Incorrect password")
        }

        resp, err := handler(ctx, req)
        return resp, err
    }
}

func main() {
    lis, err := net.Listen("tcp", port)
    if err != nil {
        log.Fatalf("failed to listen: %v", err)
    }

    // サーバオブジェクト作成
    s := grpc.NewServer(grpc.UnaryInterceptor(
        unaryServerInterceptor()))

    pb.RegisterGreeterServer(s, &server{})

    if err := s.Serve(lis); err != nil {
        log.Fatalf("failed to serve: %v", err)
    }
}

とりあえず、このぐらいまで押さえておけばMiddlewareが何をしているかが見えてきた気がします。