diff --git a/backend/internal/user/repo/user.go b/backend/internal/user/repo/user.go index 908c0db..29d831d 100644 --- a/backend/internal/user/repo/user.go +++ b/backend/internal/user/repo/user.go @@ -347,9 +347,23 @@ func (r *UserRepo) OAuthLogin(ctx context.Context, platform consts.UserPlatform, if err != nil { return nil, errcode.ErrNotInvited.Wrap(err) } + if ui.AvatarURL != req.AvatarURL { + if err = entx.WithTx(ctx, r.db, func(tx *db.Tx) error { + return r.updateAvatar(ctx, tx, ui, req.AvatarURL) + }); err != nil { + return nil, err + } + } return ui.Edges.User, nil } +func (r *UserRepo) updateAvatar(ctx context.Context, tx *db.Tx, ui *db.UserIdentity, avatar string) error { + if err := tx.UserIdentity.UpdateOneID(ui.ID).SetAvatarURL(avatar).Exec(ctx); err != nil { + return err + } + return tx.User.UpdateOneID(ui.UserID).SetAvatarURL(avatar).Exec(ctx) +} + func (r *UserRepo) SignUpOrIn(ctx context.Context, platform consts.UserPlatform, req *domain.OAuthUserInfo) (*db.User, error) { var u *db.User err := entx.WithTx(ctx, r.db, func(tx *db.Tx) error { @@ -359,6 +373,11 @@ func (r *UserRepo) SignUpOrIn(ctx context.Context, platform consts.UserPlatform, First(ctx) if err == nil { u = ui.Edges.User + if ui.AvatarURL != req.AvatarURL { + if err = r.updateAvatar(ctx, tx, ui, req.AvatarURL); err != nil { + return err + } + } return nil } if !db.IsNotFound(err) {