Skip to content

Commit 883a231

Browse files
refactor: wrap GetUserByEmailOrUsername in transaction
1 parent fa9d426 commit 883a231

File tree

1 file changed

+52
-61
lines changed

1 file changed

+52
-61
lines changed

coderd/userauth.go

Lines changed: 52 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -341,66 +341,57 @@ func (api *API) postChangePasswordWithOneTimePasscode(rw http.ResponseWriter, r
341341
return
342342
}
343343

344-
//nolint:gocritic // In order to change a user's password, we need to get the user first - and can only do that in the system auth context.
345-
user, err := api.Database.GetUserByEmailOrUsername(dbauthz.AsSystemRestricted(ctx), database.GetUserByEmailOrUsernameParams{
346-
Email: req.Email,
347-
})
348-
if err != nil && !errors.Is(err, sql.ErrNoRows) {
349-
logger.Error(ctx, "unable to fetch user by email", slog.Error(err))
350-
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
351-
Message: "Internal error.",
344+
err = api.Database.InTx(func(tx database.Store) error {
345+
//nolint:gocritic // In order to change a user's password, we need to get the user first - and can only do that in the system auth context.
346+
user, err := tx.GetUserByEmailOrUsername(dbauthz.AsSystemRestricted(ctx), database.GetUserByEmailOrUsernameParams{
347+
Email: req.Email,
352348
})
353-
return
354-
}
355-
aReq.Old = user
349+
if err != nil && !errors.Is(err, sql.ErrNoRows) {
350+
logger.Error(ctx, "unable to fetch user by email", slog.F("email", req.Email), slog.Error(err))
351+
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
352+
Message: "Internal error.",
353+
})
354+
return nil
355+
}
356+
aReq.Old = user
356357

357-
equal, err := userpassword.Compare(string(user.HashedOneTimePasscode), req.OneTimePasscode)
358-
if err != nil {
359-
logger.Error(ctx, "unable to compare passwords", slog.Error(err))
360-
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
361-
Message: "Internal error.",
362-
})
363-
return
364-
}
358+
equal, err := userpassword.Compare(string(user.HashedOneTimePasscode), req.OneTimePasscode)
359+
if err != nil {
360+
return xerrors.Errorf("compare one time passcode: %w", err)
361+
}
365362

366-
if !equal {
367-
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
368-
Message: "Incorrect email or one-time-passcode.",
369-
})
370-
return
371-
}
363+
if !equal {
364+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
365+
Message: "Incorrect email or one-time-passcode.",
366+
})
367+
return nil
368+
}
372369

373-
if err := userpassword.Validate(req.Password); err != nil {
374-
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
375-
Message: "Invalid password.",
376-
Validations: []codersdk.ValidationError{
377-
{
378-
Field: "password",
379-
Detail: err.Error(),
370+
if err := userpassword.Validate(req.Password); err != nil {
371+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
372+
Message: "Invalid password.",
373+
Validations: []codersdk.ValidationError{
374+
{
375+
Field: "password",
376+
Detail: err.Error(),
377+
},
380378
},
381-
},
382-
})
383-
return
384-
}
379+
})
380+
return nil
381+
}
385382

386-
if equal, _ = userpassword.Compare(string(user.HashedPassword), req.Password); equal {
387-
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
388-
Message: "New password cannot match old password.",
389-
})
390-
return
391-
}
383+
if equal, _ = userpassword.Compare(string(user.HashedPassword), req.Password); equal {
384+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
385+
Message: "New password cannot match old password.",
386+
})
387+
return nil
388+
}
392389

393-
newHashedPassword, err := userpassword.Hash(req.Password)
394-
if err != nil {
395-
logger.Error(ctx, "unable to hash new user password", slog.Error(err))
396-
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
397-
Message: "Internal error hashing new password.",
398-
Detail: err.Error(),
399-
})
400-
return
401-
}
390+
newHashedPassword, err := userpassword.Hash(req.Password)
391+
if err != nil {
392+
return xerrors.Errorf("hash user password: %w", err)
393+
}
402394

403-
err = api.Database.InTx(func(tx database.Store) error {
404395
//nolint:gocritic // We need the system auth context to be able to update the user's password.
405396
err = tx.UpdateUserHashedPassword(dbauthz.AsSystemRestricted(ctx), database.UpdateUserHashedPasswordParams{
406397
ID: user.ID,
@@ -416,24 +407,24 @@ func (api *API) postChangePasswordWithOneTimePasscode(rw http.ResponseWriter, r
416407
return xerrors.Errorf("delete api keys for user: %w", err)
417408
}
418409

410+
auditUser := user
411+
auditUser.HashedPassword = []byte(newHashedPassword)
412+
auditUser.OneTimePasscodeExpiresAt = sql.NullTime{}
413+
auditUser.HashedOneTimePasscode = nil
414+
aReq.New = auditUser
415+
416+
rw.WriteHeader(http.StatusOK)
417+
419418
return nil
420419
}, nil)
421420
if err != nil {
422421
logger.Error(ctx, "unable to update user's password", slog.Error(err))
423422
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
424-
Message: "Internal error updating user's password.",
423+
Message: "Internal error.",
425424
Detail: err.Error(),
426425
})
427426
return
428427
}
429-
430-
auditUser := user
431-
auditUser.HashedPassword = []byte(newHashedPassword)
432-
auditUser.OneTimePasscodeExpiresAt = sql.NullTime{}
433-
auditUser.HashedOneTimePasscode = nil
434-
aReq.New = auditUser
435-
436-
rw.WriteHeader(http.StatusOK)
437428
}
438429

439430
// Authenticates the user with an email and password.

0 commit comments

Comments
 (0)