diff --git a/src/server/acl/acl_family.cc b/src/server/acl/acl_family.cc index 3a58e77fb072..d7b18fe0b359 100644 --- a/src/server/acl/acl_family.cc +++ b/src/server/acl/acl_family.cc @@ -64,8 +64,7 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) { for (const auto& [username, user] : registry) { std::string buffer = "user "; - const std::string_view pass = user.Password(); - const std::string password = pass == "nopass" ? "nopass" : PrettyPrintSha(pass); + const std::string password = PasswordsToString(user.Passwords(), user.HasNopass(), false); const std::string acl_keys = AclKeysToString(user.Keys()); const std::string maybe_space_com = acl_keys.empty() ? "" : " "; @@ -75,7 +74,7 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) { using namespace std::string_view_literals; - absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password, " ", + absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password, acl_keys, maybe_space_com, acl_cat_and_commands); cntx->SendSimpleString(buffer); @@ -196,9 +195,7 @@ std::string AclFamily::RegistryToString() const { std::string result; for (auto& [username, user] : registry) { std::string command = "USER "; - const std::string_view pass = user.Password(); - const std::string password = - pass == "nopass" ? "nopass " : absl::StrCat("#", PrettyPrintSha(pass, true), " "); + const std::string password = PasswordsToString(user.Passwords(), user.HasNopass(), true); const std::string acl_keys = AclKeysToString(user.Keys()); const std::string maybe_space = acl_keys.empty() ? "" : " "; @@ -495,7 +492,10 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) { } auto& user = registry.find(username)->second; std::string status = user.IsActive() ? "on" : "off"; - auto pass = user.Password(); + auto pass = PasswordsToString(user.Passwords(), user.HasNopass(), false); + if (!pass.empty()) { + pass.pop_back(); + } auto* rb = static_cast(cntx->reply_builder()); rb->StartArray(8); @@ -509,7 +509,7 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) { } rb->SendSimpleString("passwords"); - if (pass != "nopass") { + if (pass != "nopass" && !pass.empty()) { rb->SendSimpleString(pass); } else { rb->SendEmptyArray(); @@ -647,7 +647,7 @@ void AclFamily::Init(facade::Listener* main_listener, UserRegistry* registry) { registry_ = registry; config_registry.RegisterMutable("requirepass", [this](const absl::CommandLineFlag& flag) { User::UpdateRequest rqst; - rqst.password = flag.CurrentValue(); + rqst.passwords.push_back({flag.CurrentValue()}); registry_->MaybeAddAndUpdate("default", std::move(rqst)); return true; }); diff --git a/src/server/acl/acl_family_test.cc b/src/server/acl/acl_family_test.cc index 0aac28dd3f1e..cbee4061fe18 100644 --- a/src/server/acl/acl_family_test.cc +++ b/src/server/acl/acl_family_test.cc @@ -47,16 +47,67 @@ TEST_F(AclFamilyTest, AclSetUser) { EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); auto vec = resp.GetVec(); - EXPECT_THAT( - vec, UnorderedElementsAre("user default on nopass ~* +@all", "user vlad off nopass -@all")); + EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* +@all", "user vlad off -@all")); resp = Run({"ACL", "SETUSER", "vlad", "+ACL"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); vec = resp.GetVec(); - EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* +@all", - "user vlad off nopass -@all +acl")); + EXPECT_THAT(vec, + UnorderedElementsAre("user default on nopass ~* +@all", "user vlad off -@all +acl")); + + resp = Run({"ACL", "SETUSER", "vlad", "on", ">pass", ">temp"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"ACL", "LIST"}); + vec = resp.GetVec(); + EXPECT_THAT(vec.size(), 2); + auto contains_vlad = [](const auto& vec) { + const std::string default_user = "user default on nopass ~* +@all"; + const std::string a_permutation = "user vlad on #a6864eb339b0e1f #d74ff0ee8da3b98 -@all +acl"; + const std::string b_permutation = "user vlad on #d74ff0ee8da3b98 #a6864eb339b0e1f -@all +acl"; + std::string_view other; + if (vec[0] == default_user) { + other = vec[1].GetView(); + } else if (vec[1] == default_user) { + other = vec[0].GetView(); + } else { + return false; + } + + return other == a_permutation || other == b_permutation; + }; + + EXPECT_THAT(contains_vlad(vec), true); + + resp = Run({"AUTH", "vlad", "pass"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"AUTH", "vlad", "temp"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"AUTH", "default", R"("")"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"ACL", "SETUSER", "vlad", ">another"}); + EXPECT_THAT(resp, "OK"); + + resp = Run({"ACL", "SETUSER", "vlad", " MaybeParseAclKey(std::string_view command) { return ParseKeyResult{std::string(key), op}; } -std::optional MaybeParsePassword(std::string_view command, bool hashed) { +std::optional MaybeParsePassword(std::string_view command, bool hashed) { + using UpPass = User::UpdatePass; if (command == "nopass") { - return std::string(command); + return UpPass{"", false, true}; + } + + if (command == "resetpass") { + return UpPass{"", false, false, true}; } if (command[0] == '>' || (hashed && command[0] == '#')) { - return std::string(command.substr(1)); + return UpPass{std::string(command.substr(1))}; + } + + if (command[0] == '<') { + return UpPass{std::string(command.substr(1)), true}; } return {}; @@ -261,10 +270,8 @@ std::variant ParseAclSetUser(facade::ArgRange a for (std::string_view arg : args) { if (auto pass = MaybeParsePassword(facade::ToSV(arg), hashed); pass) { - if (req.password) { - return ErrorReply("Only one password is allowed"); - } - req.password = std::move(pass); + req.passwords.push_back(std::move(*pass)); + if (hashed && absl::StartsWith(facade::ToSV(arg), "#")) { req.is_hashed = hashed; } @@ -346,4 +353,16 @@ std::string AclKeysToString(const AclKeys& keys) { return result; } +std::string PasswordsToString(const absl::flat_hash_set& passwords, bool nopass, + bool full_sha) { + if (nopass) { + return "nopass "; + } + std::string result; + for (const auto& pass : passwords) { + absl::StrAppend(&result, "#", PrettyPrintSha(pass, full_sha), " "); + } + + return result; +} } // namespace dfly::acl diff --git a/src/server/acl/helpers.h b/src/server/acl/helpers.h index 0840ab817e70..75cbd4d8b491 100644 --- a/src/server/acl/helpers.h +++ b/src/server/acl/helpers.h @@ -10,6 +10,7 @@ #include #include +#include "absl/container/flat_hash_set.h" #include "facade/facade_types.h" #include "server/acl/acl_log.h" #include "server/acl/user.h" @@ -23,7 +24,7 @@ std::string AclCatAndCommandToString(const User::CategoryChanges& cat, std::string PrettyPrintSha(std::string_view pass, bool all = false); // When hashed is true, we allow passwords that start with both # and > -std::optional MaybeParsePassword(std::string_view command, bool hashed = false); +std::optional MaybeParsePassword(std::string_view command, bool hashed = false); std::optional MaybeParseStatus(std::string_view command); @@ -55,4 +56,8 @@ struct ParseKeyResult { std::optional MaybeParseAclKey(std::string_view command); std::string AclKeysToString(const AclKeys& keys); + +std::string PasswordsToString(const absl::flat_hash_set& passwords, bool nopass, + bool full_sha); + } // namespace dfly::acl diff --git a/src/server/acl/user.cc b/src/server/acl/user.cc index 863d9857f81e..48a3be8b326b 100644 --- a/src/server/acl/user.cc +++ b/src/server/acl/user.cc @@ -8,6 +8,7 @@ #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/escaping.h" #include "core/overloaded.h" #include "server/acl/helpers.h" @@ -30,8 +31,20 @@ User::User() { } void User::Update(UpdateRequest&& req) { - if (req.password) { - SetPasswordHash(*req.password, req.is_hashed); + for (auto& pass : req.passwords) { + if (pass.nopass) { + SetNopass(); + continue; + } + if (pass.unset) { + UnsetPassword(pass.password); + continue; + } + if (pass.reset_password) { + password_hashes_.clear(); + continue; + } + SetPasswordHash(pass.password, req.is_hashed); } auto cat_visitor = [this](UpdateRequest::CategoryValueType cat) { @@ -68,23 +81,23 @@ void User::Update(UpdateRequest&& req) { } void User::SetPasswordHash(std::string_view password, bool is_hashed) { - if (password == "nopass") { - return; - } - + nopass_ = false; if (is_hashed) { - password_hash_ = absl::HexStringToBytes(password); + password_hashes_.insert(absl::HexStringToBytes(password)); return; } - password_hash_ = StringSHA256(password); + password_hashes_.insert(StringSHA256(password)); +} + +void User::UnsetPassword(std::string_view password) { + password_hashes_.erase(StringSHA256(password)); } bool User::HasPassword(std::string_view password) const { - if (!password_hash_) { + if (nopass_) { return true; } - // hash password and compare - return *password_hash_ == StringSHA256(password); + return password_hashes_.contains(StringSHA256(password)); } void User::SetAclCategoriesAndIncrSeq(uint32_t cat) { @@ -174,10 +187,12 @@ bool User::IsActive() const { return is_active_; } -static const std::string_view default_pass = "nopass"; +const absl::flat_hash_set& User::Passwords() const { + return password_hashes_; +} -std::string_view User::Password() const { - return password_hash_ ? *password_hash_ : default_pass; +bool User::HasNopass() const { + return nopass_; } const AclKeys& User::Keys() const { @@ -206,4 +221,9 @@ void User::SetKeyGlobs(std::vector keys) { } } +void User::SetNopass() { + nopass_ = true; + password_hashes_.clear(); +} + } // namespace dfly::acl diff --git a/src/server/acl/user.h b/src/server/acl/user.h index fd3e84a3ff76..3e66491f08c9 100644 --- a/src/server/acl/user.h +++ b/src/server/acl/user.h @@ -14,6 +14,7 @@ #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/hash/hash.h" #include "server/acl/acl_commands_def.h" @@ -30,8 +31,16 @@ class User final { bool reset_keys = false; }; + struct UpdatePass { + std::string password; + // Set to denote remove password + bool unset{false}; + bool nopass{false}; + bool reset_password{false}; + }; + struct UpdateRequest { - std::optional password{}; + std::vector passwords; std::optional is_active{}; @@ -48,6 +57,8 @@ class User final { std::vector keys; bool reset_all_keys{false}; bool allow_all_keys{false}; + // TODO allow reset all + // bool reset_all{false}; }; using CategoryChange = uint32_t; @@ -80,7 +91,9 @@ class User final { bool IsActive() const; - std::string_view Password() const; + const absl::flat_hash_set& Passwords() const; + + bool HasNopass() const; // Selector maps a command string (like HSET, SET etc) to // its respective ID within the commands vector. @@ -111,13 +124,19 @@ class User final { // For passwords void SetPasswordHash(std::string_view password, bool is_hashed); + void UnsetPassword(std::string_view password); // For ACL key globs void SetKeyGlobs(std::vector keys); - // when optional is empty, the special `nopass` password is implied - // password hashed with xx64 - std::optional password_hash_; + // Set NOPASS and remove all passwords + void SetNopass(); + + // Passwords for each user + absl::flat_hash_set password_hashes_; + // if `nopass` is used + bool nopass_ = false; + uint32_t acl_categories_{NONE}; // Each element index in the vector corresponds to a familly of commands // Each bit in the uin64_t field at index id, corresponds to a specific diff --git a/src/server/acl/user_registry.cc b/src/server/acl/user_registry.cc index 54510344e86c..9bd9645bff6c 100644 --- a/src/server/acl/user_registry.cc +++ b/src/server/acl/user_registry.cc @@ -75,7 +75,8 @@ UserRegistry::UserWithWriteLock::UserWithWriteLock(std::unique_lock acl{User::Sign::PLUS, acl::ALL}; auto key = User::UpdateKey{"~*", KeyOp::READ_WRITE, true, false}; - return {{}, true, false, {std::move(acl)}, {std::move(key)}}; + auto pass = std::vector{{"", false, true}}; + return {std::move(pass), true, false, {std::move(acl)}, {std::move(key)}}; } void UserRegistry::Init() { @@ -86,11 +87,14 @@ void UserRegistry::Init() { auto default_user = DefaultUserUpdateRequest(); auto maybe_password = absl::GetFlag(FLAGS_requirepass); if (!maybe_password.empty()) { - default_user.password = std::move(maybe_password); + default_user.passwords.front().password = std::move(maybe_password); + default_user.passwords.front().nopass = false; } else if (const char* env_var = getenv("DFLY_PASSWORD"); env_var) { - default_user.password = env_var; + default_user.passwords.front().password = env_var; + default_user.passwords.front().nopass = false; } else if (const char* env_var = getenv("DFLY_requirepass"); env_var) { - default_user.password = env_var; + default_user.passwords.front().password = env_var; + default_user.passwords.front().nopass = false; } MaybeAddAndUpdate("default", std::move(default_user)); } diff --git a/tests/dragonfly/acl_family_test.py b/tests/dragonfly/acl_family_test.py index f044269e91c7..efc17787205a 100644 --- a/tests/dragonfly/acl_family_test.py +++ b/tests/dragonfly/acl_family_test.py @@ -14,53 +14,53 @@ async def test_acl_setuser(async_client): await async_client.execute_command("ACL SETUSER kostas") result = await async_client.execute_command("ACL list") assert 2 == len(result) - assert "user kostas off nopass -@all" in result + assert "user kostas off -@all" in result await async_client.execute_command("ACL SETUSER kostas ON") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@all" in result + assert "user kostas on -@all" in result await async_client.execute_command("ACL SETUSER kostas +@list +@string +@admin") result = await async_client.execute_command("ACL list") # TODO consider printing to lowercase - assert "user kostas on nopass -@all +@list +@string +@admin" in result + assert "user kostas on -@all +@list +@string +@admin" in result await async_client.execute_command("ACL SETUSER kostas -@list -@admin") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@all +@string -@list -@admin" in result + assert "user kostas on -@all +@string -@list -@admin" in result # mix and match await async_client.execute_command("ACL SETUSER kostas +@list -@string") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@all -@admin +@list -@string" in result + assert "user kostas on -@all -@admin +@list -@string" in result # mix and match interleaved await async_client.execute_command("ACL SETUSER kostas +@set -@set +@set") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@all -@admin +@list -@string +@set" in result + assert "user kostas on -@all -@admin +@list -@string +@set" in result await async_client.execute_command("ACL SETUSER kostas +@all") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@admin +@list -@string +@set +@all" in result + assert "user kostas on -@admin +@list -@string +@set +@all" in result # commands await async_client.execute_command("ACL SETUSER kostas +set +get +hset") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@admin +@list -@string +@set +@all +set +get +hset" in result + assert "user kostas on -@admin +@list -@string +@set +@all +set +get +hset" in result await async_client.execute_command("ACL SETUSER kostas -set -get +hset") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@admin +@list -@string +@set +@all -set -get +hset" in result + assert "user kostas on -@admin +@list -@string +@set +@all -set -get +hset" in result # interleaved await async_client.execute_command("ACL SETUSER kostas -hset +get -get -@all") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@admin +@list -@string +@set -set -hset -get -@all" in result + assert "user kostas on -@admin +@list -@string +@set -set -hset -get -@all" in result # interleaved with categories await async_client.execute_command("ACL SETUSER kostas +@string +get -get +set") result = await async_client.execute_command("ACL list") - assert "user kostas on nopass -@admin +@list +@set -hset -@all +@string -get +set" in result + assert "user kostas on -@admin +@list +@set -hset -@all +@string -get +set" in result @pytest.mark.asyncio @@ -332,7 +332,7 @@ async def test_good_acl_file(df_local_factory, tmp_dir): await client.execute_command("ACL LOAD") result = await client.execute_command("ACL list") assert 2 == len(result) - assert "user MrFoo on ea71c25a7a60224 -@all" in result + assert "user MrFoo on #ea71c25a7a60224 -@all" in result assert "user default on nopass ~* +@all" in result await client.execute_command("ACL DELUSER MrFoo") @@ -342,9 +342,9 @@ async def test_good_acl_file(df_local_factory, tmp_dir): result = await client.execute_command("ACL list") assert 4 == len(result) - assert "user roy on ea71c25a7a60224 -@all +@string +hset" in result - assert "user shahar off ea71c25a7a60224 -@all +@set" in result - assert "user vlad off nopass ~foo ~bar* -@all +@string" in result + assert "user roy on #ea71c25a7a60224 -@all +@string +hset" in result + assert "user shahar off #ea71c25a7a60224 -@all +@set" in result + assert "user vlad off ~foo ~bar* -@all +@string" in result assert "user default on nopass ~* +@all" in result result = await client.execute_command("ACL DELUSER shahar") @@ -356,8 +356,8 @@ async def test_good_acl_file(df_local_factory, tmp_dir): result = await client.execute_command("ACL list") assert 3 == len(result) - assert "user roy on ea71c25a7a60224 -@all +@string +hset" in result - assert "user vlad off nopass ~foo ~bar* -@all +@string" in result + assert "user roy on #ea71c25a7a60224 -@all +@string +hset" in result + assert "user vlad off ~foo ~bar* -@all +@string" in result assert "user default on nopass ~* +@all" in result await client.close()