Database transactions are great to have and use, the default method of using them is the Unit of work method by wrapping them around a specific business logic. Say a user signs up, you create a transaction that creates the user and other related objects a user relies on - maybe a default preferences or similar like adding them to a workspace or team.

The above example is quite simplistic but there are certain times where you want correctness and database writes to be actually atomic, you want all steps previously inserted or removed/updated at the db level to revert if something fails and only commit when all goes through.

Implementation

First thing is to create a middleware called withTransaction

type txKey struct{}

func WithBunTx(ctx context.Context, tx bun.Tx) context.Context {
	return context.WithValue(ctx, txKey{}, tx)
}

func TxFromContext(ctx context.Context) bun.Tx {
	tx, _ := ctx.Value(txKey{}).(bun.Tx)
	return tx
}

func withTransaction(db *bun.DB) func(next http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			ctx := r.Context()

			tx, err := db.BeginTx(ctx, nil)
			if err != nil {
				_ = render.Render(w, r, newAPIStatus(http.StatusInternalServerError, "failed to start processing request"))
				return
			}

			ctx = WithBunTx(ctx, tx)
			r = r.WithContext(ctx)

			next.ServeHTTP(w, r)
		})
	}
}

Said middleware can be used as below:

	router.Use(withTransaction(db))

For simplicity and consistency sake, I will reuse the pattern for my http handlers but this can extend to any http handler or function

func WrapHTTPHandler(
	logger *zap.Logger,
	db *bun.DB,
	handler HTTPHandler,
	cfg *config.Config,
	spanName string) http.HandlerFunc {

	return func(w http.ResponseWriter, r *http.Request) {

		ctx, span, rid := getTracer(r.Context(), r, spanName, cfg.Otel.IsEnabled)
		defer span.End()

		logger := logger.With(zap.String("request_id", rid))

		tx := TxFromContext(ctx)

		if doesBusinessExistInContext(ctx) {
			workspace := getBusinessFromContext(ctx).ID.String()
			logger = logger.With(zap.String("workspace_id", workspace))
			span.SetAttributes(
				attribute.String("workspace_id", workspace))
		}

		if doesUserExistInContext(ctx) {
			userID := getUserFromContext(ctx).ID.String()
			logger = logger.With(zap.String("user_id", userID))
			span.SetAttributes(
				attribute.String("user_id", userID))
		}

		resp, status := handler(ctx, span, logger, r)
		switch status {
		case StatusFailed:
      // if the handler returns failed status, rollback all operations
			if err := tx.Rollback(); err != nil {
				logger.Error("failed to rollback transaction", zap.Error(err))
			}

			span.SetStatus(codes.Error, "")

		case StatusSuccess:

      // commit all operations
			if err := tx.Commit(); err != nil {
				logger.Error("failed to commit transaction", zap.Error(err))
				_ = render.Render(w, r, newAPIStatus(http.StatusInternalServerError, "an error occurred while completing your response"))
				return
			}

			span.SetStatus(codes.Ok, "")

		default:
			if err := tx.Rollback(); err != nil {
				logger.Error("failed to rollback transaction", zap.Error(err))
			}

			_ = render.Render(w, r, newAPIStatus(http.StatusInternalServerError, "an error occurred"))
			return
		}

		err := render.Render(w, r, resp)
		if err != nil {
			logger.Error("could not write http response", zap.Error(err))
		}
	}
}

This way, every single http request starts a transaction, rollbacks or commits when the request is done depending on the status.

Notes

Again, you probably do not need this unless you need/have a use-case for this level of atomicity!