Rust sqlxでデータベースに依存した部分のテストを書く

49591 ワード

はじめに

アプリケーションにおいてデータの永続化を実現しようとすると、DBとアクセスする層が必要になることが多いです。適切なインターフェースを定義すれば、DBにアクセスする層をモック化して、その層に依存する部分のテストを書くことができます。しかし時にはDBを直接扱う層のロジックをテストしたいときもあります。

例えばJavaであれば、H2を使ってテスト用のデータベースを立ち上げることができます。しかしRustでsqlxを採用した場合、どのようにすればDBに依存するテストが実現できるのでしょうか。

あまり情報が見つからなかったので、試行錯誤しながら得られた知見をまとめておきたいと思います。

前提

本稿で使用するバージョンは、Rust 1.60.0、sqlx 0.5.13です。

方針

テスト用のPostgreSQLをDockerを使って立ち上げます。テストが繰り返し実行できるように、毎回トランザクションを貼って最後にロールバックするようにします。

ソースコード

次のようなスキーマのPostgreSQLデータベースを対象とします。

migrations/20220306122339_create_tables.sql
CREATE TABLE bookshelf_user (
  id text NOT NULL PRIMARY KEY,
  created_at timestamp NOT NULL default current_timestamp,
  updated_at timestamp NOT NULL default current_timestamp
);

テスト用のデータベースはDockerで立ち上げます。

docker-compose-test.yml
services:
  db:
    image: postgres:latest
    ports:
      - "5432:5432"
    environment:
      - POSTGRES_PASSWORD=password
$ docker-compose -f docker-compose-test.yml up -d
$ sqlx migrate run

永続化対象のUser構造体はこちらです。New Typeパターンを使っているので、多少長めの実装になっています。

use validator::Validate;

use crate::domain::error::DomainError;

#[derive(Debug, Clone, PartialEq, Eq, Validate)]
pub struct UserId {
    #[validate(length(min = 1))]
    value: String,
}

impl UserId {
    pub fn new(id: String) -> Result<Self, DomainError> {
        let object = Self { value: id };
        object.validate()?;
        Ok(object)
    }

    pub fn as_str(&self) -> &str {
        &self.value
    }

    pub fn into_string(self) -> String {
        self.value
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct User {
    pub id: UserId,
}

impl User {
    pub fn new(id: UserId) -> User {
        User { id }
    }
}

次にUserRepositoryのtraitを定義しておきます。今回の話には関係ありませんが、UserRepositoryのモックを作成できるようにするためです。

use async_trait::async_trait;
use mockall::automock;

use crate::domain::{
    entity::user::{User, UserId},
    error::DomainError,
};

#[automock]
#[async_trait]
pub trait UserRepository: Send + Sync + 'static {
    async fn create(&self, user: &User) -> Result<(), DomainError>;
    async fn find_by_id(&self, id: &UserId) -> Result<Option<User>, DomainError>;
}

最後に本体の実装とテストです。PgUserRepositoryがtraitを実装した構造体です。

use async_trait::async_trait;
use sqlx::{PgConnection, PgPool};

use crate::domain::{
    entity::user::{User, UserId},
    error::DomainError,
    repository::user_repository::UserRepository,
};

#[derive(sqlx::FromRow)]
struct UserRow {
    id: String,
}

#[derive(Debug, Clone)]
pub struct PgUserRepository {
    pool: PgPool,
}

impl PgUserRepository {
    pub fn new(pool: PgPool) -> Self {
        Self { pool }
    }
}

#[async_trait]
impl UserRepository for PgUserRepository {
    async fn create(&self, user: &User) -> Result<(), DomainError> {
        let mut conn = self.pool.acquire().await?;
        let result = InternalUserRepository::create(user, &mut conn).await?;
        Ok(result)
    }

    async fn find_by_id(&self, id: &UserId) -> Result<Option<User>, DomainError> {
        let mut conn = self.pool.acquire().await?;
        let user = InternalUserRepository::find_by_id(id, &mut conn).await?;
        Ok(user)
    }
}

pub(in crate::infrastructure) struct InternalUserRepository {}

impl InternalUserRepository {
    pub(in crate::infrastructure) async fn create(
        user: &User,
        conn: &mut PgConnection,
    ) -> Result<(), DomainError> {
        sqlx::query("INSERT INTO bookshelf_user (id) VALUES ($1)")
            .bind(user.id.as_str())
            .execute(conn)
            .await?;
        Ok(())
    }

    async fn find_by_id(id: &UserId, conn: &mut PgConnection) -> Result<Option<User>, DomainError> {
        let row: Option<UserRow> = sqlx::query_as("SELECT * FROM bookshelf_user WHERE id = $1")
            .bind(id.as_str())
            .fetch_optional(conn)
            .await?;

        let id = row.map(|row| UserId::new(row.id)).transpose()?;
        Ok(id.map(|id| User::new(id)))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::time::Duration;

    use sqlx::postgres::PgPoolOptions;

    #[tokio::test]
    async fn test_user_repository() -> anyhow::Result<()> {
        dotenv::dotenv().ok();

        let db_url = fetch_database_url();
        let pool = PgPoolOptions::new()
            .max_connections(5)
            .connect_timeout(Duration::from_secs(1))
            .connect(&db_url)
            .await?;
        let mut tx = pool.begin().await?;

        let id = UserId::new(String::from("foo"))?;
        let user = User::new(id.clone());

        let fetched_user = InternalUserRepository::find_by_id(&id, &mut tx).await?;
        assert!(fetched_user.is_none());

        InternalUserRepository::create(&user, &mut tx).await?;

        let fetched_user = InternalUserRepository::find_by_id(&id, &mut tx).await?;
        assert_eq!(fetched_user, Some(user));

        tx.rollback().await?;
        Ok(())
    }

    fn fetch_database_url() -> String {
        use std::env::VarError;

        match std::env::var("DATABASE_URL") {
            Ok(s) => s,
            Err(VarError::NotPresent) => panic!("Environment variable DATABASE_URL is required."),
            Err(VarError::NotUnicode(_)) => {
                panic!("Environment variable DATABASE_URL is not unicode.")
            }
        }
    }
}

解説

PgUserRepositoryが実際のアプリケーションで使われるリポジトリですが、こちらは直接テストしません。代わりに実際の処理をInternalUserRepositoryに移譲して、こちらをテストするようにします。一種のHunble Objectパターンとみなせるでしょう。

InternalUserRepositoryはメソッドを定義せず、関連関数のみを定義します。関数をまとめておくくらいの意味合いしかないので、moduleに関数をまとめておくくらいでも良いでしょう。

関連関数はSQLに必要な情報と、&mut PgConnectionを引数に受け取ります。これは公式ドキュメントにおいて、connectionとtransactionを両方受け取れるようにする方法として紹介されているものです。