Skip to content

Commit

Permalink
Key off incoming host instead of user entry.
Browse files Browse the repository at this point in the history
  • Loading branch information
joel-rieke committed Mar 7, 2024
1 parent a6f6d95 commit a6e07b1
Showing 1 changed file with 18 additions and 22 deletions.
40 changes: 18 additions & 22 deletions sql/mysql_db/mysql_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,18 +243,26 @@ type LockUserMap struct {
sync.Map
}

func (m *LockUserMap) SetUser(readUserEntry *User, value time.Time) {
m.Set(fmt.Sprintf("%s-%s", readUserEntry.User, readUserEntry.Host), value)
func (m *LockUserMap) lockUser(user string, host string, value time.Time) {
m.Set(fmt.Sprintf("%s-%s", user, host), value)
}

func (m *LockUserMap) Set(key string, value time.Time) {
m.Store(key, value)
}

func (m *LockUserMap) GetUser(readUserEntry *User) (time.Time, bool) {
func (m *LockUserMap) GetUser(readUserEntry *User, host string) (time.Time, bool) {
return m.Get(fmt.Sprintf("%s-%s", readUserEntry.User, readUserEntry.Host))
}

func (db *MySQLDb) AddUser(readUserEntry *User, host string) {
if readUserEntry == nil {
return
} else {
lockUserMap.lockUser(readUserEntry.User, host, time.Now())
}
}

func (m *LockUserMap) Get(key string) (time.Time, bool) {
val, ok := m.Load(key)
if !ok {
Expand All @@ -263,24 +271,12 @@ func (m *LockUserMap) Get(key string) (time.Time, bool) {
return val.(time.Time), true
}

func (m *LockUserMap) RemoveUser(readUserEntry *User) {
m.Delete(fmt.Sprintf("%s-%s", readUserEntry.User, readUserEntry.Host))
}

func (m *LockUserMap) Remove(key string) {
m.Delete(key)
func (m *LockUserMap) RemoveUser(readUserEntry *User, host string) {
m.Delete(fmt.Sprintf("%s-%s", readUserEntry.User, host))
}

var lockUserMap = &LockUserMap{}

func (db *MySQLDb) LockUser(readUserEntry *User) {
if readUserEntry == nil {
return
} else {
lockUserMap.SetUser(readUserEntry, time.Now())
}
}

// GetUser returns a user matching the given user and host if it exists. Due to the slight difference between users and
// roles, roleSearch changes whether the search matches against user or role rules.
func (db *MySQLDb) GetUser(user string, host string, roleSearch bool, skipCidrChecks bool) *User {
Expand All @@ -307,10 +303,10 @@ func (db *MySQLDb) GetUser(user string, host string, roleSearch bool, skipCidrCh
return readUserEntry
}

if lockTime, isLocked := lockUserMap.GetUser(readUserEntry); isLocked {
if lockTime, isLocked := lockUserMap.GetUser(readUserEntry, host); isLocked {
if time.Since(lockTime) > time.Hour {
readUserEntry.Locked = false
lockUserMap.RemoveUser(readUserEntry)
lockUserMap.RemoveUser(readUserEntry, host)
} else {
readUserEntry.Locked = true
return readUserEntry
Expand Down Expand Up @@ -344,10 +340,10 @@ func (db *MySQLDb) GetUser(user string, host string, roleSearch bool, skipCidrCh
return readUserEntry
}

if lockTime, isLocked := lockUserMap.GetUser(readUserEntry); isLocked {
if lockTime, isLocked := lockUserMap.GetUser(readUserEntry, host); isLocked {
if time.Since(lockTime) > time.Hour {
readUserEntry.Locked = false
lockUserMap.RemoveUser(readUserEntry)
lockUserMap.RemoveUser(readUserEntry, host)
} else {
readUserEntry.Locked = true
return readUserEntry
Expand Down Expand Up @@ -530,7 +526,7 @@ func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, a
}
if len(userEntry.Password) > 0 {
if !validateMysqlNativePassword(authResponse, salt, userEntry.Password) {
db.LockUser(userEntry)
db.AddUser(userEntry, host)
return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
}
} else if len(authResponse) > 0 { // password is nil or empty, therefore no password is set
Expand Down

0 comments on commit a6e07b1

Please sign in to comment.