技術向上

プログラミングの学び、気になるテクノロジーやビジネストレンドを発信

SQLトランザクション【Go】

Goでのトランザクションの構築方法と、トランザクションの利用ケースについて触れます。

MySQLなどのデータベースにアクセスする際、Goではdatabaseパッケージを使うことができます。
databaseパッケージを用いてトランザクションを構築するには、Txを使います。

※下記の例では、SQLビルダーのSQuirrel名前空間"sq"と指定して利用しています。
※modelの定義を省略しています。
※クリーンアーキテクトによる実装をモデルとしており、SQL操作をするrepositoryとビジネスロジックを記述するservice層に分かれているコードを無理やり1つにまとめています。

// 共通処理を関数化
func ConnectDB() *sqlx.DB {
    // 任意の設定内容を指定
    c := mysql.Config{
            DBName:               os.Getenv("MYSQL_DATABASE"),
            User:                 os.Getenv("MYSQL_USER"),
            Passwd:               os.Getenv("MYSQL_PASSWORD"),
            Addr:                 os.Getenv("MYSQL_ADDRESS"),
            Net:                  "tcp",
            Loc:                  util.TimeZoneJST,
            ParseTime:            true,
            AllowNativePasswords: true,
        }
    dsn = c.FormatDSN()

    db, err := sqlx.Open("mysql", dsn)    // mysqlを開く

    if err != nil {
        log.Fatalf("Could not connect to mysql: %s", err)
    }
    if err := db.Ping(); err != nil {
        log.Fatalf("Could not connect to mysql: %s", err)
    }
}

func Transact(ctx context.Context, db Beginner, txFunc func(Tx) error) (err error) {
    defer func() {
        if p := recover(); p != nil {
            if err := tx.Rollback(); err != nil {
                l.Error("database: failed to rollback", zap.Error(err))
            }
            panic(p)
        }
        if err != nil {    // txFunc(tx)が返すerror
            if err := tx.Rollback(); err != nil {
                l.Error("database: failed to rollback", zap.Error(err))
            }
        } else {
            return tx.Commit()
        }
    }()

    return txFunc(tx)
}


// service ビジネスロジックを記述
type member struct {
    mRepo repository.Member
    db     database.Runner
}

func (s *member) Wrapper(ctx context.Context,whosid int64, whosname string) error {
    return Transact(ctx, s.db, func(tx database.Tx) error {
        return s.mRepo.Update(ctx, tx, &model.Member{
            ID:   whosid,
            Name: whosname,
        })
    })
}

// repository SQL操作を記述
func (r *member) Update(ctx context.Context, who *model.Member) error {
    db := ConnectDB()  // DBと接続
    defer db.Close()// 必ず接続を閉じる
    
    tx, err := db.Begin(ctx)  // トランザクションを開始
    if err != nil {
        return err
    }

    q, attrs, err := sq.    // SQuirrelを利用
        Update("members").
        SetMap(sq.Eq{
            "name":       who.Name,
        }).
        Where(sq.Eq{
            "id": who.ID,
        }).
        ToSql()

    if err != nil {
        return err
    }

    // ここで、sqlに関するlogの出力をした方が良いでしょう

    if _, err := sql.ExecContext(ctx, q, attrs...); err != nil {
        return err
    }
    return nil
}

上記は更新系SQL、しかも単文の例でしたが、たとえばSelectの結果を用いてUpdateやInsertをする場合にもトランザクションを構築します。

その際は、SelectメソッドとUpdateメソッドを内包するラッパーメソッドを作成することになると思います。Selectメソッド、Updateメソッドがsql処理を行うだけのメソッドだと仮定するならば、Selectメソッドの引数はdatabase.Txではなく、database.Queryerを用います。Select SQL自体には明示的にトランザクションを構築する必要がないためです。Select SQL文にはfor updateを用いて、抽出対象行を排他ロックすることでしょう。

SelectやUpdateのSQLを組み合わせる以外にも、例えば会員登録直後に送信する、24時間有効なリンクを貼った確認メールなどでもトランザクションが使えます。テーブルへの登録、更新処理と合わせてメール送信が完了しない場合にはロールバックをして、会員登録を無かったことにするのです。会員登録を済ませるにはどちらも必須条件ですから、トランザクションで囲むことが有効です。


Beginを明示的に指定してトランザクションを開始しない場合、MYSQLでは、Auto Commitの設定値が有効であれば、最初のSQLが発行されたタイミングでトランザクションが開始されます。これはSelect単文、しかもfor updateでない場合でもです。


Testコードを書く場合のテクニック

SQL操作部分のテストコードを書く場合も、いくつかのテストケースを用意するでしょう。
for文を回して、テストケースの数だけまとめて処理を実行することになるかと思いますが、各テストケースの内容が他のテストケースに影響を与えないように、各テストケースの実行完了後には、それが実行される前の状態に戻しておいたほうが確認がしやすいです。

そんな時は、意図的にエラーを起こすことで、トランザクションロールバックを利用します。
先ほどのTransact関数では、引数として渡した関数がerrorを返す場合、ロールバックされる仕様になっています。各testcaseが終了するたびにerrorを返すことで、処理直前の状態に戻すことができます。

func Test_customizeType_Update(t *testing.T) {
    dbConn, teardown := GetTestDBConn()   // テスト用のDB接続を行う任意の関数
    defer teardown()

    // 必要に応じてテーブルに必要なデータを挿入する処理を記述
    
    db := database.NewRunner(dbConn)
    type args struct {
        ctx           context.Context
        member *model.Member
    }
    cases := map[string]struct {
        args    args
        err     error
        success bool
    }{
        "Success": {
            // テストケース
        },

        "Duplicate name": {
            // テストケース
        },
        ...
    }
    for testname, testcase := range cases {
        t.Run(testname, func(t *testing.T) {
        ...
            Transact(testcase.args.ctx, db, func(tx database.Tx) error {
                err := repo.Update(testcase.args.ctx, tx, testcase.args.member)
                if !assert.Equal(t, testcase.err, errors.Topmost(err)) {
                    t.Logf("%#v failed: %#v", testname, err)
                }

                testContains(testcase.args.ctx, repo, t, tx, testcase.args.customizeType, testcase.success)    // Updateした結果がテーブルレコードに含まれるかをチェックする任意の関数

                return fmt.Errorf("rollback")    // errorを必ず発生させて、rollbackさせる
            })
        })
    }
}



参考

メールのトランザクション設計 - Qiita

MySQLのトランザクション制御がキモい話 - なからなLife