diff --git a/backend/consts/user.go b/backend/consts/user.go index a2ab249..36089c9 100644 --- a/backend/consts/user.go +++ b/backend/consts/user.go @@ -28,3 +28,10 @@ type OAuthKind string const ( OAuthKindSignUpOrIn OAuthKind = "signup_or_in" ) + +type InviteCodeStatus string + +const ( + InviteCodeStatusPending InviteCodeStatus = "pending" + InviteCodeStatusUsed InviteCodeStatus = "used" +) diff --git a/backend/db/invitecode.go b/backend/db/invitecode.go index 3235da8..abeed70 100644 --- a/backend/db/invitecode.go +++ b/backend/db/invitecode.go @@ -9,6 +9,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/db/invitecode" "github.com/google/uuid" ) @@ -22,10 +23,14 @@ type InviteCode struct { AdminID uuid.UUID `json:"admin_id,omitempty"` // Code holds the value of the "code" field. Code string `json:"code,omitempty"` + // Status holds the value of the "status" field. + Status consts.InviteCodeStatus `json:"status,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. - UpdatedAt time.Time `json:"updated_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` + // ExpiredAt holds the value of the "expired_at" field. + ExpiredAt time.Time `json:"expired_at,omitempty"` selectValues sql.SelectValues } @@ -34,9 +39,9 @@ func (*InviteCode) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case invitecode.FieldCode: + case invitecode.FieldCode, invitecode.FieldStatus: values[i] = new(sql.NullString) - case invitecode.FieldCreatedAt, invitecode.FieldUpdatedAt: + case invitecode.FieldCreatedAt, invitecode.FieldUpdatedAt, invitecode.FieldExpiredAt: values[i] = new(sql.NullTime) case invitecode.FieldID, invitecode.FieldAdminID: values[i] = new(uuid.UUID) @@ -73,6 +78,12 @@ func (ic *InviteCode) assignValues(columns []string, values []any) error { } else if value.Valid { ic.Code = value.String } + case invitecode.FieldStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status", values[i]) + } else if value.Valid { + ic.Status = consts.InviteCodeStatus(value.String) + } case invitecode.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -85,6 +96,12 @@ func (ic *InviteCode) assignValues(columns []string, values []any) error { } else if value.Valid { ic.UpdatedAt = value.Time } + case invitecode.FieldExpiredAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expired_at", values[i]) + } else if value.Valid { + ic.ExpiredAt = value.Time + } default: ic.selectValues.Set(columns[i], values[i]) } @@ -127,11 +144,17 @@ func (ic *InviteCode) String() string { builder.WriteString("code=") builder.WriteString(ic.Code) builder.WriteString(", ") + builder.WriteString("status=") + builder.WriteString(fmt.Sprintf("%v", ic.Status)) + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(ic.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") builder.WriteString("updated_at=") builder.WriteString(ic.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("expired_at=") + builder.WriteString(ic.ExpiredAt.Format(time.ANSIC)) builder.WriteByte(')') return builder.String() } diff --git a/backend/db/invitecode/invitecode.go b/backend/db/invitecode/invitecode.go index c376405..5f13725 100644 --- a/backend/db/invitecode/invitecode.go +++ b/backend/db/invitecode/invitecode.go @@ -6,6 +6,7 @@ import ( "time" "entgo.io/ent/dialect/sql" + "github.com/chaitin/MonkeyCode/backend/consts" ) const ( @@ -17,10 +18,14 @@ const ( FieldAdminID = "admin_id" // FieldCode holds the string denoting the code field in the database. FieldCode = "code" + // FieldStatus holds the string denoting the status field in the database. + FieldStatus = "status" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // FieldUpdatedAt holds the string denoting the updated_at field in the database. FieldUpdatedAt = "updated_at" + // FieldExpiredAt holds the string denoting the expired_at field in the database. + FieldExpiredAt = "expired_at" // Table holds the table name of the invitecode in the database. Table = "invite_codes" ) @@ -30,8 +35,10 @@ var Columns = []string{ FieldID, FieldAdminID, FieldCode, + FieldStatus, FieldCreatedAt, FieldUpdatedAt, + FieldExpiredAt, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -45,6 +52,8 @@ func ValidColumn(column string) bool { } var ( + // DefaultStatus holds the default value on creation for the "status" field. + DefaultStatus consts.InviteCodeStatus // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. @@ -71,6 +80,11 @@ func ByCode(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCode, opts...).ToFunc() } +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() @@ -80,3 +94,8 @@ func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() } + +// ByExpiredAt orders the results by the expired_at field. +func ByExpiredAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiredAt, opts...).ToFunc() +} diff --git a/backend/db/invitecode/where.go b/backend/db/invitecode/where.go index 0503279..ad05473 100644 --- a/backend/db/invitecode/where.go +++ b/backend/db/invitecode/where.go @@ -6,6 +6,7 @@ import ( "time" "entgo.io/ent/dialect/sql" + "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/db/predicate" "github.com/google/uuid" ) @@ -65,6 +66,12 @@ func Code(v string) predicate.InviteCode { return predicate.InviteCode(sql.FieldEQ(FieldCode, v)) } +// Status applies equality check predicate on the "status" field. It's identical to StatusEQ. +func Status(v consts.InviteCodeStatus) predicate.InviteCode { + vc := string(v) + return predicate.InviteCode(sql.FieldEQ(FieldStatus, vc)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.InviteCode { return predicate.InviteCode(sql.FieldEQ(FieldCreatedAt, v)) @@ -75,6 +82,11 @@ func UpdatedAt(v time.Time) predicate.InviteCode { return predicate.InviteCode(sql.FieldEQ(FieldUpdatedAt, v)) } +// ExpiredAt applies equality check predicate on the "expired_at" field. It's identical to ExpiredAtEQ. +func ExpiredAt(v time.Time) predicate.InviteCode { + return predicate.InviteCode(sql.FieldEQ(FieldExpiredAt, v)) +} + // AdminIDEQ applies the EQ predicate on the "admin_id" field. func AdminIDEQ(v uuid.UUID) predicate.InviteCode { return predicate.InviteCode(sql.FieldEQ(FieldAdminID, v)) @@ -180,6 +192,90 @@ func CodeContainsFold(v string) predicate.InviteCode { return predicate.InviteCode(sql.FieldContainsFold(FieldCode, v)) } +// StatusEQ applies the EQ predicate on the "status" field. +func StatusEQ(v consts.InviteCodeStatus) predicate.InviteCode { + vc := string(v) + return predicate.InviteCode(sql.FieldEQ(FieldStatus, vc)) +} + +// StatusNEQ applies the NEQ predicate on the "status" field. +func StatusNEQ(v consts.InviteCodeStatus) predicate.InviteCode { + vc := string(v) + return predicate.InviteCode(sql.FieldNEQ(FieldStatus, vc)) +} + +// StatusIn applies the In predicate on the "status" field. +func StatusIn(vs ...consts.InviteCodeStatus) predicate.InviteCode { + v := make([]any, len(vs)) + for i := range v { + v[i] = string(vs[i]) + } + return predicate.InviteCode(sql.FieldIn(FieldStatus, v...)) +} + +// StatusNotIn applies the NotIn predicate on the "status" field. +func StatusNotIn(vs ...consts.InviteCodeStatus) predicate.InviteCode { + v := make([]any, len(vs)) + for i := range v { + v[i] = string(vs[i]) + } + return predicate.InviteCode(sql.FieldNotIn(FieldStatus, v...)) +} + +// StatusGT applies the GT predicate on the "status" field. +func StatusGT(v consts.InviteCodeStatus) predicate.InviteCode { + vc := string(v) + return predicate.InviteCode(sql.FieldGT(FieldStatus, vc)) +} + +// StatusGTE applies the GTE predicate on the "status" field. +func StatusGTE(v consts.InviteCodeStatus) predicate.InviteCode { + vc := string(v) + return predicate.InviteCode(sql.FieldGTE(FieldStatus, vc)) +} + +// StatusLT applies the LT predicate on the "status" field. +func StatusLT(v consts.InviteCodeStatus) predicate.InviteCode { + vc := string(v) + return predicate.InviteCode(sql.FieldLT(FieldStatus, vc)) +} + +// StatusLTE applies the LTE predicate on the "status" field. +func StatusLTE(v consts.InviteCodeStatus) predicate.InviteCode { + vc := string(v) + return predicate.InviteCode(sql.FieldLTE(FieldStatus, vc)) +} + +// StatusContains applies the Contains predicate on the "status" field. +func StatusContains(v consts.InviteCodeStatus) predicate.InviteCode { + vc := string(v) + return predicate.InviteCode(sql.FieldContains(FieldStatus, vc)) +} + +// StatusHasPrefix applies the HasPrefix predicate on the "status" field. +func StatusHasPrefix(v consts.InviteCodeStatus) predicate.InviteCode { + vc := string(v) + return predicate.InviteCode(sql.FieldHasPrefix(FieldStatus, vc)) +} + +// StatusHasSuffix applies the HasSuffix predicate on the "status" field. +func StatusHasSuffix(v consts.InviteCodeStatus) predicate.InviteCode { + vc := string(v) + return predicate.InviteCode(sql.FieldHasSuffix(FieldStatus, vc)) +} + +// StatusEqualFold applies the EqualFold predicate on the "status" field. +func StatusEqualFold(v consts.InviteCodeStatus) predicate.InviteCode { + vc := string(v) + return predicate.InviteCode(sql.FieldEqualFold(FieldStatus, vc)) +} + +// StatusContainsFold applies the ContainsFold predicate on the "status" field. +func StatusContainsFold(v consts.InviteCodeStatus) predicate.InviteCode { + vc := string(v) + return predicate.InviteCode(sql.FieldContainsFold(FieldStatus, vc)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.InviteCode { return predicate.InviteCode(sql.FieldEQ(FieldCreatedAt, v)) @@ -260,6 +356,46 @@ func UpdatedAtLTE(v time.Time) predicate.InviteCode { return predicate.InviteCode(sql.FieldLTE(FieldUpdatedAt, v)) } +// ExpiredAtEQ applies the EQ predicate on the "expired_at" field. +func ExpiredAtEQ(v time.Time) predicate.InviteCode { + return predicate.InviteCode(sql.FieldEQ(FieldExpiredAt, v)) +} + +// ExpiredAtNEQ applies the NEQ predicate on the "expired_at" field. +func ExpiredAtNEQ(v time.Time) predicate.InviteCode { + return predicate.InviteCode(sql.FieldNEQ(FieldExpiredAt, v)) +} + +// ExpiredAtIn applies the In predicate on the "expired_at" field. +func ExpiredAtIn(vs ...time.Time) predicate.InviteCode { + return predicate.InviteCode(sql.FieldIn(FieldExpiredAt, vs...)) +} + +// ExpiredAtNotIn applies the NotIn predicate on the "expired_at" field. +func ExpiredAtNotIn(vs ...time.Time) predicate.InviteCode { + return predicate.InviteCode(sql.FieldNotIn(FieldExpiredAt, vs...)) +} + +// ExpiredAtGT applies the GT predicate on the "expired_at" field. +func ExpiredAtGT(v time.Time) predicate.InviteCode { + return predicate.InviteCode(sql.FieldGT(FieldExpiredAt, v)) +} + +// ExpiredAtGTE applies the GTE predicate on the "expired_at" field. +func ExpiredAtGTE(v time.Time) predicate.InviteCode { + return predicate.InviteCode(sql.FieldGTE(FieldExpiredAt, v)) +} + +// ExpiredAtLT applies the LT predicate on the "expired_at" field. +func ExpiredAtLT(v time.Time) predicate.InviteCode { + return predicate.InviteCode(sql.FieldLT(FieldExpiredAt, v)) +} + +// ExpiredAtLTE applies the LTE predicate on the "expired_at" field. +func ExpiredAtLTE(v time.Time) predicate.InviteCode { + return predicate.InviteCode(sql.FieldLTE(FieldExpiredAt, v)) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.InviteCode) predicate.InviteCode { return predicate.InviteCode(sql.AndPredicates(predicates...)) diff --git a/backend/db/invitecode_create.go b/backend/db/invitecode_create.go index 1c249d3..758b10f 100644 --- a/backend/db/invitecode_create.go +++ b/backend/db/invitecode_create.go @@ -12,6 +12,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/db/invitecode" "github.com/google/uuid" ) @@ -36,6 +37,20 @@ func (icc *InviteCodeCreate) SetCode(s string) *InviteCodeCreate { return icc } +// SetStatus sets the "status" field. +func (icc *InviteCodeCreate) SetStatus(ccs consts.InviteCodeStatus) *InviteCodeCreate { + icc.mutation.SetStatus(ccs) + return icc +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (icc *InviteCodeCreate) SetNillableStatus(ccs *consts.InviteCodeStatus) *InviteCodeCreate { + if ccs != nil { + icc.SetStatus(*ccs) + } + return icc +} + // SetCreatedAt sets the "created_at" field. func (icc *InviteCodeCreate) SetCreatedAt(t time.Time) *InviteCodeCreate { icc.mutation.SetCreatedAt(t) @@ -64,6 +79,12 @@ func (icc *InviteCodeCreate) SetNillableUpdatedAt(t *time.Time) *InviteCodeCreat return icc } +// SetExpiredAt sets the "expired_at" field. +func (icc *InviteCodeCreate) SetExpiredAt(t time.Time) *InviteCodeCreate { + icc.mutation.SetExpiredAt(t) + return icc +} + // SetID sets the "id" field. func (icc *InviteCodeCreate) SetID(u uuid.UUID) *InviteCodeCreate { icc.mutation.SetID(u) @@ -105,6 +126,10 @@ func (icc *InviteCodeCreate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (icc *InviteCodeCreate) defaults() { + if _, ok := icc.mutation.Status(); !ok { + v := invitecode.DefaultStatus + icc.mutation.SetStatus(v) + } if _, ok := icc.mutation.CreatedAt(); !ok { v := invitecode.DefaultCreatedAt() icc.mutation.SetCreatedAt(v) @@ -123,12 +148,18 @@ func (icc *InviteCodeCreate) check() error { if _, ok := icc.mutation.Code(); !ok { return &ValidationError{Name: "code", err: errors.New(`db: missing required field "InviteCode.code"`)} } + if _, ok := icc.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`db: missing required field "InviteCode.status"`)} + } if _, ok := icc.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`db: missing required field "InviteCode.created_at"`)} } if _, ok := icc.mutation.UpdatedAt(); !ok { return &ValidationError{Name: "updated_at", err: errors.New(`db: missing required field "InviteCode.updated_at"`)} } + if _, ok := icc.mutation.ExpiredAt(); !ok { + return &ValidationError{Name: "expired_at", err: errors.New(`db: missing required field "InviteCode.expired_at"`)} + } return nil } @@ -173,6 +204,10 @@ func (icc *InviteCodeCreate) createSpec() (*InviteCode, *sqlgraph.CreateSpec) { _spec.SetField(invitecode.FieldCode, field.TypeString, value) _node.Code = value } + if value, ok := icc.mutation.Status(); ok { + _spec.SetField(invitecode.FieldStatus, field.TypeString, value) + _node.Status = value + } if value, ok := icc.mutation.CreatedAt(); ok { _spec.SetField(invitecode.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -181,6 +216,10 @@ func (icc *InviteCodeCreate) createSpec() (*InviteCode, *sqlgraph.CreateSpec) { _spec.SetField(invitecode.FieldUpdatedAt, field.TypeTime, value) _node.UpdatedAt = value } + if value, ok := icc.mutation.ExpiredAt(); ok { + _spec.SetField(invitecode.FieldExpiredAt, field.TypeTime, value) + _node.ExpiredAt = value + } return _node, _spec } @@ -257,6 +296,18 @@ func (u *InviteCodeUpsert) UpdateCode() *InviteCodeUpsert { return u } +// SetStatus sets the "status" field. +func (u *InviteCodeUpsert) SetStatus(v consts.InviteCodeStatus) *InviteCodeUpsert { + u.Set(invitecode.FieldStatus, v) + return u +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *InviteCodeUpsert) UpdateStatus() *InviteCodeUpsert { + u.SetExcluded(invitecode.FieldStatus) + return u +} + // SetCreatedAt sets the "created_at" field. func (u *InviteCodeUpsert) SetCreatedAt(v time.Time) *InviteCodeUpsert { u.Set(invitecode.FieldCreatedAt, v) @@ -281,6 +332,18 @@ func (u *InviteCodeUpsert) UpdateUpdatedAt() *InviteCodeUpsert { return u } +// SetExpiredAt sets the "expired_at" field. +func (u *InviteCodeUpsert) SetExpiredAt(v time.Time) *InviteCodeUpsert { + u.Set(invitecode.FieldExpiredAt, v) + return u +} + +// UpdateExpiredAt sets the "expired_at" field to the value that was provided on create. +func (u *InviteCodeUpsert) UpdateExpiredAt() *InviteCodeUpsert { + u.SetExcluded(invitecode.FieldExpiredAt) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create except the ID field. // Using this option is equivalent to using: // @@ -357,6 +420,20 @@ func (u *InviteCodeUpsertOne) UpdateCode() *InviteCodeUpsertOne { }) } +// SetStatus sets the "status" field. +func (u *InviteCodeUpsertOne) SetStatus(v consts.InviteCodeStatus) *InviteCodeUpsertOne { + return u.Update(func(s *InviteCodeUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *InviteCodeUpsertOne) UpdateStatus() *InviteCodeUpsertOne { + return u.Update(func(s *InviteCodeUpsert) { + s.UpdateStatus() + }) +} + // SetCreatedAt sets the "created_at" field. func (u *InviteCodeUpsertOne) SetCreatedAt(v time.Time) *InviteCodeUpsertOne { return u.Update(func(s *InviteCodeUpsert) { @@ -385,6 +462,20 @@ func (u *InviteCodeUpsertOne) UpdateUpdatedAt() *InviteCodeUpsertOne { }) } +// SetExpiredAt sets the "expired_at" field. +func (u *InviteCodeUpsertOne) SetExpiredAt(v time.Time) *InviteCodeUpsertOne { + return u.Update(func(s *InviteCodeUpsert) { + s.SetExpiredAt(v) + }) +} + +// UpdateExpiredAt sets the "expired_at" field to the value that was provided on create. +func (u *InviteCodeUpsertOne) UpdateExpiredAt() *InviteCodeUpsertOne { + return u.Update(func(s *InviteCodeUpsert) { + s.UpdateExpiredAt() + }) +} + // Exec executes the query. func (u *InviteCodeUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -628,6 +719,20 @@ func (u *InviteCodeUpsertBulk) UpdateCode() *InviteCodeUpsertBulk { }) } +// SetStatus sets the "status" field. +func (u *InviteCodeUpsertBulk) SetStatus(v consts.InviteCodeStatus) *InviteCodeUpsertBulk { + return u.Update(func(s *InviteCodeUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *InviteCodeUpsertBulk) UpdateStatus() *InviteCodeUpsertBulk { + return u.Update(func(s *InviteCodeUpsert) { + s.UpdateStatus() + }) +} + // SetCreatedAt sets the "created_at" field. func (u *InviteCodeUpsertBulk) SetCreatedAt(v time.Time) *InviteCodeUpsertBulk { return u.Update(func(s *InviteCodeUpsert) { @@ -656,6 +761,20 @@ func (u *InviteCodeUpsertBulk) UpdateUpdatedAt() *InviteCodeUpsertBulk { }) } +// SetExpiredAt sets the "expired_at" field. +func (u *InviteCodeUpsertBulk) SetExpiredAt(v time.Time) *InviteCodeUpsertBulk { + return u.Update(func(s *InviteCodeUpsert) { + s.SetExpiredAt(v) + }) +} + +// UpdateExpiredAt sets the "expired_at" field to the value that was provided on create. +func (u *InviteCodeUpsertBulk) UpdateExpiredAt() *InviteCodeUpsertBulk { + return u.Update(func(s *InviteCodeUpsert) { + s.UpdateExpiredAt() + }) +} + // Exec executes the query. func (u *InviteCodeUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/db/invitecode_update.go b/backend/db/invitecode_update.go index 6862245..0ab209f 100644 --- a/backend/db/invitecode_update.go +++ b/backend/db/invitecode_update.go @@ -11,6 +11,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/db/invitecode" "github.com/chaitin/MonkeyCode/backend/db/predicate" "github.com/google/uuid" @@ -58,6 +59,20 @@ func (icu *InviteCodeUpdate) SetNillableCode(s *string) *InviteCodeUpdate { return icu } +// SetStatus sets the "status" field. +func (icu *InviteCodeUpdate) SetStatus(ccs consts.InviteCodeStatus) *InviteCodeUpdate { + icu.mutation.SetStatus(ccs) + return icu +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (icu *InviteCodeUpdate) SetNillableStatus(ccs *consts.InviteCodeStatus) *InviteCodeUpdate { + if ccs != nil { + icu.SetStatus(*ccs) + } + return icu +} + // SetCreatedAt sets the "created_at" field. func (icu *InviteCodeUpdate) SetCreatedAt(t time.Time) *InviteCodeUpdate { icu.mutation.SetCreatedAt(t) @@ -78,6 +93,20 @@ func (icu *InviteCodeUpdate) SetUpdatedAt(t time.Time) *InviteCodeUpdate { return icu } +// SetExpiredAt sets the "expired_at" field. +func (icu *InviteCodeUpdate) SetExpiredAt(t time.Time) *InviteCodeUpdate { + icu.mutation.SetExpiredAt(t) + return icu +} + +// SetNillableExpiredAt sets the "expired_at" field if the given value is not nil. +func (icu *InviteCodeUpdate) SetNillableExpiredAt(t *time.Time) *InviteCodeUpdate { + if t != nil { + icu.SetExpiredAt(*t) + } + return icu +} + // Mutation returns the InviteCodeMutation object of the builder. func (icu *InviteCodeUpdate) Mutation() *InviteCodeMutation { return icu.mutation @@ -140,12 +169,18 @@ func (icu *InviteCodeUpdate) sqlSave(ctx context.Context) (n int, err error) { if value, ok := icu.mutation.Code(); ok { _spec.SetField(invitecode.FieldCode, field.TypeString, value) } + if value, ok := icu.mutation.Status(); ok { + _spec.SetField(invitecode.FieldStatus, field.TypeString, value) + } if value, ok := icu.mutation.CreatedAt(); ok { _spec.SetField(invitecode.FieldCreatedAt, field.TypeTime, value) } if value, ok := icu.mutation.UpdatedAt(); ok { _spec.SetField(invitecode.FieldUpdatedAt, field.TypeTime, value) } + if value, ok := icu.mutation.ExpiredAt(); ok { + _spec.SetField(invitecode.FieldExpiredAt, field.TypeTime, value) + } _spec.AddModifiers(icu.modifiers...) if n, err = sqlgraph.UpdateNodes(ctx, icu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { @@ -196,6 +231,20 @@ func (icuo *InviteCodeUpdateOne) SetNillableCode(s *string) *InviteCodeUpdateOne return icuo } +// SetStatus sets the "status" field. +func (icuo *InviteCodeUpdateOne) SetStatus(ccs consts.InviteCodeStatus) *InviteCodeUpdateOne { + icuo.mutation.SetStatus(ccs) + return icuo +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (icuo *InviteCodeUpdateOne) SetNillableStatus(ccs *consts.InviteCodeStatus) *InviteCodeUpdateOne { + if ccs != nil { + icuo.SetStatus(*ccs) + } + return icuo +} + // SetCreatedAt sets the "created_at" field. func (icuo *InviteCodeUpdateOne) SetCreatedAt(t time.Time) *InviteCodeUpdateOne { icuo.mutation.SetCreatedAt(t) @@ -216,6 +265,20 @@ func (icuo *InviteCodeUpdateOne) SetUpdatedAt(t time.Time) *InviteCodeUpdateOne return icuo } +// SetExpiredAt sets the "expired_at" field. +func (icuo *InviteCodeUpdateOne) SetExpiredAt(t time.Time) *InviteCodeUpdateOne { + icuo.mutation.SetExpiredAt(t) + return icuo +} + +// SetNillableExpiredAt sets the "expired_at" field if the given value is not nil. +func (icuo *InviteCodeUpdateOne) SetNillableExpiredAt(t *time.Time) *InviteCodeUpdateOne { + if t != nil { + icuo.SetExpiredAt(*t) + } + return icuo +} + // Mutation returns the InviteCodeMutation object of the builder. func (icuo *InviteCodeUpdateOne) Mutation() *InviteCodeMutation { return icuo.mutation @@ -308,12 +371,18 @@ func (icuo *InviteCodeUpdateOne) sqlSave(ctx context.Context) (_node *InviteCode if value, ok := icuo.mutation.Code(); ok { _spec.SetField(invitecode.FieldCode, field.TypeString, value) } + if value, ok := icuo.mutation.Status(); ok { + _spec.SetField(invitecode.FieldStatus, field.TypeString, value) + } if value, ok := icuo.mutation.CreatedAt(); ok { _spec.SetField(invitecode.FieldCreatedAt, field.TypeTime, value) } if value, ok := icuo.mutation.UpdatedAt(); ok { _spec.SetField(invitecode.FieldUpdatedAt, field.TypeTime, value) } + if value, ok := icuo.mutation.ExpiredAt(); ok { + _spec.SetField(invitecode.FieldExpiredAt, field.TypeTime, value) + } _spec.AddModifiers(icuo.modifiers...) _node = &InviteCode{config: icuo.config} _spec.Assign = _node.assignValues diff --git a/backend/db/migrate/schema.go b/backend/db/migrate/schema.go index 0c15719..33faaab 100644 --- a/backend/db/migrate/schema.go +++ b/backend/db/migrate/schema.go @@ -159,8 +159,10 @@ var ( {Name: "id", Type: field.TypeUUID}, {Name: "admin_id", Type: field.TypeUUID}, {Name: "code", Type: field.TypeString, Unique: true}, + {Name: "status", Type: field.TypeString, Default: "pending"}, {Name: "created_at", Type: field.TypeTime}, {Name: "updated_at", Type: field.TypeTime}, + {Name: "expired_at", Type: field.TypeTime}, } // InviteCodesTable holds the schema information for the "invite_codes" table. InviteCodesTable = &schema.Table{ diff --git a/backend/db/mutation.go b/backend/db/mutation.go index 5d0b18e..0140a20 100644 --- a/backend/db/mutation.go +++ b/backend/db/mutation.go @@ -5785,8 +5785,10 @@ type InviteCodeMutation struct { id *uuid.UUID admin_id *uuid.UUID code *string + status *consts.InviteCodeStatus created_at *time.Time updated_at *time.Time + expired_at *time.Time clearedFields map[string]struct{} done bool oldValue func(context.Context) (*InviteCode, error) @@ -5969,6 +5971,42 @@ func (m *InviteCodeMutation) ResetCode() { m.code = nil } +// SetStatus sets the "status" field. +func (m *InviteCodeMutation) SetStatus(ccs consts.InviteCodeStatus) { + m.status = &ccs +} + +// Status returns the value of the "status" field in the mutation. +func (m *InviteCodeMutation) Status() (r consts.InviteCodeStatus, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the InviteCode entity. +// If the InviteCode object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *InviteCodeMutation) OldStatus(ctx context.Context) (v consts.InviteCodeStatus, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *InviteCodeMutation) ResetStatus() { + m.status = nil +} + // SetCreatedAt sets the "created_at" field. func (m *InviteCodeMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -6041,6 +6079,42 @@ func (m *InviteCodeMutation) ResetUpdatedAt() { m.updated_at = nil } +// SetExpiredAt sets the "expired_at" field. +func (m *InviteCodeMutation) SetExpiredAt(t time.Time) { + m.expired_at = &t +} + +// ExpiredAt returns the value of the "expired_at" field in the mutation. +func (m *InviteCodeMutation) ExpiredAt() (r time.Time, exists bool) { + v := m.expired_at + if v == nil { + return + } + return *v, true +} + +// OldExpiredAt returns the old "expired_at" field's value of the InviteCode entity. +// If the InviteCode object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *InviteCodeMutation) OldExpiredAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiredAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiredAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiredAt: %w", err) + } + return oldValue.ExpiredAt, nil +} + +// ResetExpiredAt resets all changes to the "expired_at" field. +func (m *InviteCodeMutation) ResetExpiredAt() { + m.expired_at = nil +} + // Where appends a list predicates to the InviteCodeMutation builder. func (m *InviteCodeMutation) Where(ps ...predicate.InviteCode) { m.predicates = append(m.predicates, ps...) @@ -6075,19 +6149,25 @@ func (m *InviteCodeMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *InviteCodeMutation) Fields() []string { - fields := make([]string, 0, 4) + fields := make([]string, 0, 6) if m.admin_id != nil { fields = append(fields, invitecode.FieldAdminID) } if m.code != nil { fields = append(fields, invitecode.FieldCode) } + if m.status != nil { + fields = append(fields, invitecode.FieldStatus) + } if m.created_at != nil { fields = append(fields, invitecode.FieldCreatedAt) } if m.updated_at != nil { fields = append(fields, invitecode.FieldUpdatedAt) } + if m.expired_at != nil { + fields = append(fields, invitecode.FieldExpiredAt) + } return fields } @@ -6100,10 +6180,14 @@ func (m *InviteCodeMutation) Field(name string) (ent.Value, bool) { return m.AdminID() case invitecode.FieldCode: return m.Code() + case invitecode.FieldStatus: + return m.Status() case invitecode.FieldCreatedAt: return m.CreatedAt() case invitecode.FieldUpdatedAt: return m.UpdatedAt() + case invitecode.FieldExpiredAt: + return m.ExpiredAt() } return nil, false } @@ -6117,10 +6201,14 @@ func (m *InviteCodeMutation) OldField(ctx context.Context, name string) (ent.Val return m.OldAdminID(ctx) case invitecode.FieldCode: return m.OldCode(ctx) + case invitecode.FieldStatus: + return m.OldStatus(ctx) case invitecode.FieldCreatedAt: return m.OldCreatedAt(ctx) case invitecode.FieldUpdatedAt: return m.OldUpdatedAt(ctx) + case invitecode.FieldExpiredAt: + return m.OldExpiredAt(ctx) } return nil, fmt.Errorf("unknown InviteCode field %s", name) } @@ -6144,6 +6232,13 @@ func (m *InviteCodeMutation) SetField(name string, value ent.Value) error { } m.SetCode(v) return nil + case invitecode.FieldStatus: + v, ok := value.(consts.InviteCodeStatus) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil case invitecode.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -6158,6 +6253,13 @@ func (m *InviteCodeMutation) SetField(name string, value ent.Value) error { } m.SetUpdatedAt(v) return nil + case invitecode.FieldExpiredAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiredAt(v) + return nil } return fmt.Errorf("unknown InviteCode field %s", name) } @@ -6213,12 +6315,18 @@ func (m *InviteCodeMutation) ResetField(name string) error { case invitecode.FieldCode: m.ResetCode() return nil + case invitecode.FieldStatus: + m.ResetStatus() + return nil case invitecode.FieldCreatedAt: m.ResetCreatedAt() return nil case invitecode.FieldUpdatedAt: m.ResetUpdatedAt() return nil + case invitecode.FieldExpiredAt: + m.ResetExpiredAt() + return nil } return fmt.Errorf("unknown InviteCode field %s", name) } diff --git a/backend/db/runtime/runtime.go b/backend/db/runtime/runtime.go index 246e327..c1ba6bb 100644 --- a/backend/db/runtime/runtime.go +++ b/backend/db/runtime/runtime.go @@ -140,12 +140,16 @@ func init() { extension.DefaultCreatedAt = extensionDescCreatedAt.Default.(func() time.Time) invitecodeFields := schema.InviteCode{}.Fields() _ = invitecodeFields + // invitecodeDescStatus is the schema descriptor for status field. + invitecodeDescStatus := invitecodeFields[3].Descriptor() + // invitecode.DefaultStatus holds the default value on creation for the status field. + invitecode.DefaultStatus = consts.InviteCodeStatus(invitecodeDescStatus.Default.(string)) // invitecodeDescCreatedAt is the schema descriptor for created_at field. - invitecodeDescCreatedAt := invitecodeFields[3].Descriptor() + invitecodeDescCreatedAt := invitecodeFields[4].Descriptor() // invitecode.DefaultCreatedAt holds the default value on creation for the created_at field. invitecode.DefaultCreatedAt = invitecodeDescCreatedAt.Default.(func() time.Time) // invitecodeDescUpdatedAt is the schema descriptor for updated_at field. - invitecodeDescUpdatedAt := invitecodeFields[4].Descriptor() + invitecodeDescUpdatedAt := invitecodeFields[5].Descriptor() // invitecode.DefaultUpdatedAt holds the default value on creation for the updated_at field. invitecode.DefaultUpdatedAt = invitecodeDescUpdatedAt.Default.(func() time.Time) // invitecode.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. diff --git a/backend/ent/schema/invitecode.go b/backend/ent/schema/invitecode.go index c266e71..dc141cf 100644 --- a/backend/ent/schema/invitecode.go +++ b/backend/ent/schema/invitecode.go @@ -7,6 +7,7 @@ import ( "entgo.io/ent/dialect/entsql" "entgo.io/ent/schema" "entgo.io/ent/schema/field" + "github.com/chaitin/MonkeyCode/backend/consts" "github.com/google/uuid" ) @@ -29,8 +30,10 @@ func (InviteCode) Fields() []ent.Field { field.UUID("id", uuid.UUID{}), field.UUID("admin_id", uuid.UUID{}), field.String("code").Unique(), + field.String("status").GoType(consts.InviteCodeStatus("")).Default(string(consts.InviteCodeStatusPending)), field.Time("created_at").Default(time.Now), field.Time("updated_at").Default(time.Now).UpdateDefault(time.Now), + field.Time("expired_at"), } } diff --git a/backend/internal/user/repo/user.go b/backend/internal/user/repo/user.go index 34d46d4..2706e98 100644 --- a/backend/internal/user/repo/user.go +++ b/backend/internal/user/repo/user.go @@ -3,6 +3,7 @@ package repo import ( "context" "errors" + "time" "github.com/google/uuid" @@ -65,7 +66,32 @@ func (r *UserRepo) GetByName(ctx context.Context, username string) (*db.User, er } func (r *UserRepo) ValidateInviteCode(ctx context.Context, code string) (*db.InviteCode, error) { - return r.db.InviteCode.Query().Where(invitecode.Code(code)).Only(ctx) + var res *db.InviteCode + err := entx.WithTx(ctx, r.db, func(tx *db.Tx) error { + ic, err := tx.InviteCode.Query().Where(invitecode.Code(code)).Only(ctx) + if err != nil { + return err + } + + if ic.ExpiredAt.Before(time.Now()) { + return errors.New("invite code has expired") + } + if ic.Status == consts.InviteCodeStatusUsed { + return errors.New("invite code has been used") + } + + ic, err = tx.InviteCode.UpdateOneID(ic.ID). + SetStatus(consts.InviteCodeStatusUsed). + Save(ctx) + + if err != nil { + return err + } + + res = ic + return nil + }) + return res, err } func (r *UserRepo) CreateUser(ctx context.Context, user *db.User) (*db.User, error) { @@ -97,6 +123,8 @@ func (r *UserRepo) CreateInviteCode(ctx context.Context, userID string, code str return r.db.InviteCode.Create(). SetAdminID(adminID). SetCode(code). + SetStatus(consts.InviteCodeStatusPending). + SetExpiredAt(time.Now().Add(15 * time.Minute)). Save(ctx) } diff --git a/backend/migration/000005_alter_invite_codes_table.down.sql b/backend/migration/000005_alter_invite_codes_table.down.sql new file mode 100644 index 0000000..e69de29 diff --git a/backend/migration/000005_alter_invite_codes_table.up.sql b/backend/migration/000005_alter_invite_codes_table.up.sql new file mode 100644 index 0000000..2ca4e86 --- /dev/null +++ b/backend/migration/000005_alter_invite_codes_table.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE invite_codes ADD COLUMN status VARCHAR(255); +UPDATE invite_codes SET status = 'used' WHERE status IS NULL; +ALTER TABLE invite_codes ADD COLUMN expired_at TIMESTAMPTZ; \ No newline at end of file