From 51ae74f1829002a64f6b0f61088ba3d745a34e5e Mon Sep 17 00:00:00 2001 From: kostas Date: Thu, 6 Jun 2024 20:19:03 +0300 Subject: [PATCH] fix(acl): acl compatibility --- src/facade/command_id.h | 1 - src/facade/dragonfly_connection.cc | 1 - src/facade/dragonfly_connection.h | 1 - src/server/acl/acl_commands_def.h | 51 +++++++++- src/server/acl/acl_family.cc | 49 +++++----- src/server/acl/acl_family.h | 2 +- src/server/acl/acl_family_test.cc | 38 ++++---- src/server/acl/helpers.cc | 134 +++++++++++++++++---------- src/server/acl/helpers.h | 5 +- src/server/acl/user.cc | 82 +++++++++++++--- src/server/acl/user.h | 65 +++++++++++-- src/server/acl/user_registry.cc | 19 +--- src/server/acl/user_registry.h | 2 - src/server/acl/user_registry_test.cc | 49 ---------- src/server/acl/validator.cc | 12 +-- src/server/acl/validator.h | 4 +- src/server/command_registry.cc | 2 +- src/server/main_service.cc | 2 +- tests/dragonfly/acl_family_test.py | 48 +++++----- 19 files changed, 344 insertions(+), 223 deletions(-) delete mode 100644 src/server/acl/user_registry_test.cc diff --git a/src/facade/command_id.h b/src/facade/command_id.h index 178a43262e8f..edbc180aab1c 100644 --- a/src/facade/command_id.h +++ b/src/facade/command_id.h @@ -7,7 +7,6 @@ #include #include "facade/facade_types.h" -#include "server/acl/acl_commands_def.h" namespace facade { diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index c16f0db9f519..e3499846f12c 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -430,7 +430,6 @@ void Connection::DispatchOperations::operator()(const MonitorMessage& msg) { void Connection::DispatchOperations::operator()(const AclUpdateMessage& msg) { if (self->cntx()) { if (msg.username == self->cntx()->authed_username) { - self->cntx()->acl_categories = msg.categories; self->cntx()->acl_commands = msg.commands; self->cntx()->keys = msg.keys; } diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index 929ad5f8b5f0..e87f4e50dfa7 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -113,7 +113,6 @@ class Connection : public util::Connection { // ACL Update message, contains ACL updates to be applied to the connection. struct AclUpdateMessage { std::string username; - uint32_t categories; std::vector commands; dfly::acl::AclKeys keys; }; diff --git a/src/server/acl/acl_commands_def.h b/src/server/acl/acl_commands_def.h index 86324ab490bd..2cf3628448aa 100644 --- a/src/server/acl/acl_commands_def.h +++ b/src/server/acl/acl_commands_def.h @@ -5,7 +5,11 @@ #pragma once #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "base/logging.h" #include "facade/acl_commands_def.h" +#include "server/command_registry.h" +#include "server/conn_context.h" namespace dfly::acl { @@ -84,6 +88,15 @@ inline const std::vector REVERSE_CATEGORY_INDEX_TABLE{ "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "BLOOM", "FT_SEARCH", "THROTTLE", "JSON"}; +// bit index to index in the REVERSE_CATEGORY_INDEX_TABLE +using CategoryToIdxStore = absl::flat_hash_map; + +// inline const CategoryToIdxStore& CategoryToIdx(CategoryToIdxStore store = {}) { +inline const CategoryToIdxStore& CategoryToIdx(CategoryToIdxStore store = {}) { + static CategoryToIdxStore cat_idx = std::move(store); + return cat_idx; +} + using RevCommandField = std::vector; using RevCommandsIndexStore = std::vector; @@ -104,9 +117,45 @@ inline const RevCommandsIndexStore& CommandsRevIndexer(RevCommandsIndexStore sto return rev_index_store; } -inline void BuildIndexers(std::vector> families) { +using CategoryToCommandsIndexStore = absl::flat_hash_map>; + +inline const CategoryToCommandsIndexStore& CategoryToCommandsIndex( + CategoryToCommandsIndexStore store = {}) { + static CategoryToCommandsIndexStore index = std::move(store); + return index; +} + +inline void BuildIndexers(RevCommandsIndexStore families, CommandRegistry* cmd_registry) { acl::NumberOfFamilies(families.size()); acl::CommandsRevIndexer(std::move(families)); + CategoryToCommandsIndexStore index; + cmd_registry->Traverse([&](std::string_view name, auto& cid) { + auto cat = cid.acl_categories(); + for (size_t i = 0; i < 32; ++i) { + if (cat & (1 << i)) { + std::string_view cat_name = REVERSE_CATEGORY_INDEX_TABLE[i]; + bool is_acl = cid.name() == "ACL SETUSER"; + if (is_acl) { + LOG(INFO) << "CAUGHT"; + LOG(INFO) << "CAUGHT"; + LOG(INFO) << "CAUGHT"; + } + if (index[cat_name].empty()) { + index[cat_name].resize(CommandsRevIndexer().size()); + } + auto family = cid.GetFamily(); + auto bit_index = cid.GetBitIndex(); + index[cat_name][family] = index[cat_name][family] | bit_index; + } + } + }); + + CategoryToCommandsIndex(std::move(index)); + CategoryToIdxStore idx_store; + for (size_t i = 0; i < 32; ++i) { + idx_store[1 << i] = i; + } + CategoryToIdx(std::move(idx_store)); } } // namespace dfly::acl diff --git a/src/server/acl/acl_family.cc b/src/server/acl/acl_family.cc index 385363878baa..1feab44be067 100644 --- a/src/server/acl/acl_family.cc +++ b/src/server/acl/acl_family.cc @@ -66,22 +66,23 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) { std::string buffer = "user "; const std::string_view pass = user.Password(); const std::string password = pass == "nopass" ? "nopass" : PrettyPrintSha(pass); - const std::string acl_cat = AclCatToString(user.AclCategory()); - const std::string acl_commands = AclCommandToString(user.AclCommandsRef()); - const std::string maybe_space_com = acl_commands.empty() ? "" : " "; + + const std::string acl_cat_and_commands = + AclCatAndCommandToString(user.CatChanges(), user.CmdChanges()); + const std::string acl_keys = AclKeysToString(user.Keys()); - const std::string maybe_space = acl_keys.empty() ? "" : " "; + const std::string maybe_space_com = acl_keys.empty() ? "" : " "; using namespace std::string_view_literals; absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password, " ", - acl_cat, maybe_space_com, acl_commands, maybe_space, acl_keys); + acl_cat_and_commands, maybe_space_com, acl_keys); cntx->SendSimpleString(buffer); } } -void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, uint32_t update_cat, +void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, const Commands& update_commands, const AclKeys& update_keys) { auto update_cb = [&]([[maybe_unused]] size_t id, util::Connection* conn) { @@ -90,7 +91,7 @@ void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, u if (connection->protocol() == facade::Protocol::REDIS && !connection->IsHttp() && connection->cntx()) { connection->SendAclUpdateAsync( - facade::Connection::AclUpdateMessage{user, update_cat, update_commands, update_keys}); + facade::Connection::AclUpdateMessage{user, update_commands, update_keys}); } }; @@ -113,10 +114,14 @@ void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) { auto update_case = [username, ®, cntx, this, exists](User::UpdateRequest&& req) { auto& user = reg.registry[username]; + if (!exists) { + User::UpdateRequest default_req; + default_req.updates = {User::UpdateRequest::CategoryValueType{User::Sign::MINUS, acl::ALL}}; + user.Update(std::move(default_req)); + } user.Update(std::move(req)); if (exists) { - StreamUpdatesToAllProactorConnections(std::string(username), user.AclCategory(), - user.AclCommands(), user.Keys()); + StreamUpdatesToAllProactorConnections(std::string(username), user.AclCommands(), user.Keys()); } cntx->SendOk(); }; @@ -194,20 +199,15 @@ std::string AclFamily::RegistryToString() const { const std::string_view pass = user.Password(); const std::string password = pass == "nopass" ? "nopass " : absl::StrCat("#", PrettyPrintSha(pass, true), " "); - const std::string acl_cat = AclCatToString(user.AclCategory()); - const std::string acl_commands = AclCommandToString(user.AclCommandsRef()); - const std::string maybe_space_com = acl_commands.empty() ? "" : " "; + const std::string acl_cat_and_commands = + AclCatAndCommandToString(user.CatChanges(), user.CmdChanges()); const std::string acl_keys = AclKeysToString(user.Keys()); const std::string maybe_space = acl_keys.empty() ? "" : " "; using namespace std::string_view_literals; absl::StrAppend(&result, command, username, " ", user.IsActive() ? "ON "sv : "OFF "sv, password, - acl_cat, maybe_space_com, acl_commands, maybe_space, acl_keys, "\n"); - } - - if (!result.empty()) { - result.pop_back(); + acl_cat_and_commands, maybe_space, acl_keys, "\n"); } return result; @@ -298,7 +298,10 @@ GenericError AclFamily::LoadToRegistryFromFile(std::string_view full_path, } for (size_t i = 0; i < usernames.size(); ++i) { + User::UpdateRequest default_req; + default_req.updates = {User::UpdateRequest::CategoryValueType{User::Sign::MINUS, acl::ALL}}; auto& user = registry[usernames[i]]; + user.Update(std::move(default_req)); user.Update(std::move(requests[i])); } @@ -446,6 +449,7 @@ void AclFamily::Cat(CmdArgList args, ConnectionContext* cntx) { const uint32_t cid_mask = CATEGORY_INDEX_TABLE.find(category)->second; std::vector results; + // TODO replace this with indexer auto cb = [cid_mask, &results](auto name, auto& cid) { if (cid_mask & cid.acl_categories()) { results.push_back(name); @@ -510,10 +514,10 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) { } rb->SendSimpleString("commands"); - std::string acl = absl::StrCat(AclCatToString(user.AclCategory()), " ", - AclCommandToString(user.AclCommandsRef())); + const std::string acl_cat_and_commands = + AclCatAndCommandToString(user.CatChanges(), user.CmdChanges()); - rb->SendSimpleString(acl); + rb->SendSimpleString(acl_cat_and_commands); rb->SendSimpleString("keys"); std::string keys = AclKeysToString(user.Keys()); @@ -572,9 +576,8 @@ void AclFamily::DryRun(CmdArgList args, ConnectionContext* cntx) { } const auto& user = registry.find(username)->second; - const bool is_allowed = IsUserAllowedToInvokeCommandGeneric( - user.AclCategory(), user.AclCommandsRef(), {{}, true}, {}, *cid) - .first; + const bool is_allowed = + IsUserAllowedToInvokeCommandGeneric(user.AclCommandsRef(), {{}, true}, {}, *cid).first; if (is_allowed) { cntx->SendOk(); return; diff --git a/src/server/acl/acl_family.h b/src/server/acl/acl_family.h index 3c63719d5578..61439f36d62f 100644 --- a/src/server/acl/acl_family.h +++ b/src/server/acl/acl_family.h @@ -49,7 +49,7 @@ class AclFamily final { // Helper function that updates all open connections and their // respective ACL fields on all the available proactor threads using Commands = std::vector; - void StreamUpdatesToAllProactorConnections(const std::string& user, uint32_t update_cat, + void StreamUpdatesToAllProactorConnections(const std::string& user, const Commands& update_commands, const AclKeys& update_keys); diff --git a/src/server/acl/acl_family_test.cc b/src/server/acl/acl_family_test.cc index 27356b9d4e3c..c8354e440f71 100644 --- a/src/server/acl/acl_family_test.cc +++ b/src/server/acl/acl_family_test.cc @@ -46,16 +46,16 @@ TEST_F(AclFamilyTest, AclSetUser) { EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); auto vec = resp.GetVec(); - EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL ~*", - "user vlad off nopass +@NONE")); + EXPECT_THAT( + vec, UnorderedElementsAre("user default on nopass +@ALL ~*", "user vlad off nopass -@ALL")); resp = Run({"ACL", "SETUSER", "vlad", "+ACL"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); vec = resp.GetVec(); - EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL ~*", - "user vlad off nopass +@NONE +ACL")); + EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL ~*", + "user vlad off nopass -@ALL +ACL")); } TEST_F(AclFamilyTest, AclDelUser) { @@ -82,7 +82,7 @@ TEST_F(AclFamilyTest, AclDelUser) { EXPECT_THAT(resp, IntArg(0)); resp = Run({"ACL", "LIST"}); - EXPECT_THAT(resp.GetString(), "user default on nopass +@ALL +ALL ~*"); + EXPECT_THAT(resp.GetString(), "user default on nopass +@ALL ~*"); Run({"ACL", "SETUSER", "michael", "ON"}); Run({"ACL", "SETUSER", "kobe", "ON"}); @@ -103,9 +103,9 @@ TEST_F(AclFamilyTest, AclList) { resp = Run({"ACL", "LIST"}); auto vec = resp.GetVec(); - EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL ~*", - "user kostas off d74ff0ee8da3b98 +@ADMIN", - "user adi off d74ff0ee8da3b98 +@FAST")); + EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL ~*", + "user kostas off d74ff0ee8da3b98 -@ALL +@ADMIN", + "user adi off d74ff0ee8da3b98 -@ALL +@FAST")); } TEST_F(AclFamilyTest, AclAuth) { @@ -154,16 +154,16 @@ TEST_F(AclFamilyTest, TestAllCategories) { resp = Run({"ACL", "LIST"}); EXPECT_THAT(resp.GetVec(), - UnorderedElementsAre("user default on nopass +@ALL +ALL ~*", - absl::StrCat("user kostas off nopass ", "+@", cat))); + UnorderedElementsAre("user default on nopass +@ALL ~*", + absl::StrCat("user kostas off nopass -@ALL ", "+@", cat))); resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-@", cat)}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); EXPECT_THAT(resp.GetVec(), - UnorderedElementsAre("user default on nopass +@ALL +ALL ~*", - absl::StrCat("user kostas off nopass ", "+@NONE"))); + UnorderedElementsAre("user default on nopass +@ALL ~*", + absl::StrCat("user kostas off nopass -@ALL ", "-@", cat))); resp = Run({"ACL", "DELUSER", "kostas"}); EXPECT_THAT(resp, IntArg(1)); @@ -201,16 +201,16 @@ TEST_F(AclFamilyTest, TestAllCommands) { EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); - EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL +ALL ~*", - absl::StrCat("user kostas off nopass +@NONE ", + EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL ~*", + absl::StrCat("user kostas off nopass -@ALL ", "+", command_name))); resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-", command_name)}); resp = Run({"ACL", "LIST"}); - EXPECT_THAT(resp.GetVec(), - UnorderedElementsAre("user default on nopass +@ALL +ALL ~*", - absl::StrCat("user kostas off nopass ", "+@NONE"))); + EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL ~*", + absl::StrCat("user kostas off nopass ", + "-@ALL ", "-", command_name))); resp = Run({"ACL", "DELUSER", "kostas"}); EXPECT_THAT(resp, IntArg(1)); @@ -259,7 +259,7 @@ TEST_F(AclFamilyTest, TestGetUser) { EXPECT_THAT(vec[2], "passwords"); EXPECT_TRUE(vec[3].GetVec().empty()); EXPECT_THAT(vec[4], "commands"); - EXPECT_THAT(vec[5], "+@ALL +ALL"); + EXPECT_THAT(vec[5], "+@ALL"); EXPECT_THAT(vec[6], "keys"); EXPECT_THAT(vec[7], "~*"); @@ -271,7 +271,7 @@ TEST_F(AclFamilyTest, TestGetUser) { EXPECT_THAT(kvec[2], "passwords"); EXPECT_TRUE(kvec[3].GetVec().empty()); EXPECT_THAT(kvec[4], "commands"); - EXPECT_THAT(kvec[5], "+@STRING +HSET"); + EXPECT_THAT(kvec[5], "-@ALL +@STRING +HSET"); } TEST_F(AclFamilyTest, TestDryRun) { diff --git a/src/server/acl/helpers.cc b/src/server/acl/helpers.cc index b8e4fabaeeb3..2165534d6168 100644 --- a/src/server/acl/helpers.cc +++ b/src/server/acl/helpers.cc @@ -12,64 +12,112 @@ #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" +#include "core/overloaded.h" +#include "facade/acl_commands_def.h" #include "server/acl/acl_commands_def.h" +#include "server/acl/user.h" #include "server/common.h" namespace dfly::acl { -std::string AclCatToString(uint32_t acl_category) { - std::string tmp; +namespace { +std::string AclCatToString(uint32_t acl_category, User::Sign sign) { + std::string res = sign == User::Sign::PLUS ? "+@" : "-@"; if (acl_category == acl::ALL) { - return "+@ALL"; + absl::StrAppend(&res, "ALL"); + return res; } - if (acl_category == acl::NONE) { - return "+@NONE"; + const auto& index = CategoryToIdx().at(acl_category); + absl::StrAppend(&res, REVERSE_CATEGORY_INDEX_TABLE[index]); + return res; +} + +std::string AclCommandToString(size_t family, uint64_t mask, User::Sign sign) { + // This is constant but can be optimized with an indexer + const auto& rev_index = CommandsRevIndexer(); + std::string res; + std::string prefix = (sign == User::Sign::PLUS) ? "+" : "-"; + if (mask == ALL_COMMANDS) { + for (const auto& cmd : rev_index[family]) { + absl::StrAppend(&res, prefix, cmd, " "); + } + res.pop_back(); + return res; } - const std::string prefix = "+@"; - const std::string postfix = " "; + size_t pos = 0; + while (mask != 0) { + ++pos; + mask = mask >> 1; + } + --pos; + absl::StrAppend(&res, prefix, rev_index[family][pos]); + return res; +} - for (uint32_t i = 0; i < 32; ++i) { - uint32_t cat_bit = 1ULL << i; - if (acl_category & cat_bit) { - absl::StrAppend(&tmp, prefix, REVERSE_CATEGORY_INDEX_TABLE[i], postfix); - } +struct CategoryAndMetadata { + User::CategoryChange change; + User::ChangeMetadata metadata; +}; + +struct CommandAndMetadata { + User::CommandChange change; + User::ChangeMetadata metadata; +}; + +using MergeResult = std::vector>; + +} // namespace + +// Merge Category and Command changes and sort them by global order seq_no +MergeResult MergeTables(const User::CategoryChanges& categories, + const User::CommandChanges& commands) { + MergeResult result; + for (auto [cat, meta] : categories) { + result.push_back(CategoryAndMetadata{cat, meta}); + } + + for (auto [cmd, meta] : commands) { + result.push_back(CommandAndMetadata{cmd, meta}); } - tmp.pop_back(); + std::sort(result.begin(), result.end(), [](const auto& l, const auto& r) { + auto fetch = [](const auto& l) { return l.metadata.seq_no; }; + return std::visit(fetch, l) < std::visit(fetch, r); + }); - return tmp; + return result; } -std::string AclCommandToString(const std::vector& acl_category) { +std::string AclCatAndCommandToString(const User::CategoryChanges& cat, + const User::CommandChanges& cmds) { std::string result; - const std::string prefix = "+"; - const std::string postfix = " "; - const auto& rev_index = CommandsRevIndexer(); - bool all = true; - - size_t family_id = 0; - for (auto family : acl_category) { - for (uint64_t i = 0; i < 64; ++i) { - const uint64_t cmd_bit = 1ULL << i; - if (family & cmd_bit && i < rev_index[family_id].size()) { - absl::StrAppend(&result, prefix, rev_index[family_id][i], postfix); - continue; - } - if (i < rev_index[family_id].size()) { - all = false; - } - } - ++family_id; + auto tables = MergeTables(cat, cmds); + + auto cat_visitor = [&result](const CategoryAndMetadata& val) { + const auto& [change, meta] = val; + absl::StrAppend(&result, AclCatToString(change.category, meta.sign), " "); + }; + + auto cmd_visitor = [&result](const CommandAndMetadata& val) { + const auto& [change, meta] = val; + absl::StrAppend(&result, AclCommandToString(change.family, change.bit_index, meta.sign), " "); + }; + + Overloaded visitor{cat_visitor, cmd_visitor}; + + for (auto change : tables) { + std::visit(visitor, change); } if (!result.empty()) { result.pop_back(); } - return all ? "+ALL" : result; + + return result; } std::string PrettyPrintSha(std::string_view pass, bool all) { @@ -157,21 +205,8 @@ std::pair MaybeParseAclCategory(std::string_view command) { return {}; } -bool IsIndexAllCommandsFlag(size_t index) { - return index == std::numeric_limits::max(); -} - std::pair MaybeParseAclCommand(std::string_view command, const CommandRegistry& registry) { - const auto all_commands = std::pair{std::numeric_limits::max(), 0}; - if (command == "+ALL") { - return {all_commands, true}; - } - - if (command == "-ALL") { - return {all_commands, false}; - } - if (absl::StartsWith(command, "+")) { auto res = registry.Find(command.substr(1)); if (!res) { @@ -281,7 +316,7 @@ std::variant ParseAclSetUser(T args, using Sign = User::Sign; using Val = std::pair; auto val = add ? Val{Sign::PLUS, *cat} : Val{Sign::MINUS, *cat}; - req.categories.push_back(val); + req.updates.push_back(val); continue; } @@ -292,10 +327,9 @@ std::variant ParseAclSetUser(T args, using Sign = User::Sign; using Val = User::UpdateRequest::CommandsValueType; - ; auto [index, bit] = *cmd; auto val = sign ? Val{Sign::PLUS, index, bit} : Val{Sign::MINUS, index, bit}; - req.commands.push_back(val); + req.updates.push_back(val); } return req; diff --git a/src/server/acl/helpers.h b/src/server/acl/helpers.h index a21fb77a6963..85483f40aa84 100644 --- a/src/server/acl/helpers.h +++ b/src/server/acl/helpers.h @@ -17,9 +17,8 @@ namespace dfly::acl { -std::string AclCatToString(uint32_t acl_category); - -std::string AclCommandToString(const std::vector& acl_category); +std::string AclCatAndCommandToString(const User::CategoryChanges& cat, + const User::CommandChanges& cmds); std::string PrettyPrintSha(std::string_view pass, bool all = false); diff --git a/src/server/acl/user.cc b/src/server/acl/user.cc index aa227618959c..863d9857f81e 100644 --- a/src/server/acl/user.cc +++ b/src/server/acl/user.cc @@ -9,6 +9,7 @@ #include #include "absl/strings/escaping.h" +#include "core/overloaded.h" #include "server/acl/helpers.h" namespace dfly::acl { @@ -33,20 +34,28 @@ void User::Update(UpdateRequest&& req) { SetPasswordHash(*req.password, req.is_hashed); } - for (auto [sign, category] : req.categories) { + auto cat_visitor = [this](UpdateRequest::CategoryValueType cat) { + auto [sign, category] = cat; if (sign == Sign::PLUS) { - SetAclCategories(category); - continue; + SetAclCategoriesAndIncrSeq(category); + return; } - UnsetAclCategories(category); - } + UnsetAclCategoriesAndIncrSeq(category); + }; - for (auto [sign, index, bit_index] : req.commands) { + auto cmd_visitor = [this](UpdateRequest::CommandsValueType cmd) { + auto [sign, index, bit_index] = cmd; if (sign == Sign::PLUS) { - SetAclCommands(index, bit_index); - continue; + SetAclCommandsAndIncrSeq(index, bit_index); + return; } - UnsetAclCommands(index, bit_index); + UnsetAclCommandsAndIncrSeq(index, bit_index); + }; + + Overloaded visitor{cat_visitor, cmd_visitor}; + + for (auto req : req.updates) { + std::visit(visitor, req); } if (!req.keys.empty()) { @@ -78,17 +87,42 @@ bool User::HasPassword(std::string_view password) const { return *password_hash_ == StringSHA256(password); } -void User::SetAclCategories(uint32_t cat) { +void User::SetAclCategoriesAndIncrSeq(uint32_t cat) { acl_categories_ |= cat; + if (cat == acl::ALL) { + SetAclCommands(std::numeric_limits::max(), 0); + } else { + auto id = CategoryToIdx().at(cat); + std::string_view name = REVERSE_CATEGORY_INDEX_TABLE[id]; + const auto& commands_group = CategoryToCommandsIndex().at(name); + for (size_t fam_id = 0; fam_id < commands_group.size(); ++fam_id) { + SetAclCommands(fam_id, commands_group[fam_id]); + } + } + + CategoryChange change{cat}; + cat_changes_[change] = ChangeMetadata{Sign::PLUS, seq_++}; } -void User::UnsetAclCategories(uint32_t cat) { - SetAclCategories(cat); +void User::UnsetAclCategoriesAndIncrSeq(uint32_t cat) { acl_categories_ ^= cat; + if (cat == acl::ALL) { + UnsetAclCommands(std::numeric_limits::max(), 0); + } else { + auto id = CategoryToIdx().at(cat); + std::string_view name = REVERSE_CATEGORY_INDEX_TABLE[id]; + const auto& commands_group = CategoryToCommandsIndex().at(name); + for (size_t fam_id = 0; fam_id < commands_group.size(); ++fam_id) { + UnsetAclCommands(fam_id, commands_group[fam_id]); + } + } + + CategoryChange change{cat}; + cat_changes_[change] = ChangeMetadata{Sign::MINUS, seq_++}; } void User::SetAclCommands(size_t index, uint64_t bit_index) { - if (IsIndexAllCommandsFlag(index)) { + if (index == std::numeric_limits::max()) { for (auto& family : commands_) { family = ALL_COMMANDS; } @@ -97,8 +131,14 @@ void User::SetAclCommands(size_t index, uint64_t bit_index) { commands_[index] |= bit_index; } +void User::SetAclCommandsAndIncrSeq(size_t index, uint64_t bit_index) { + SetAclCommands(index, bit_index); + CommandChange change{index, bit_index}; + cmd_changes_[change] = ChangeMetadata{Sign::PLUS, seq_++}; +} + void User::UnsetAclCommands(size_t index, uint64_t bit_index) { - if (IsIndexAllCommandsFlag(index)) { + if (index == std::numeric_limits::max()) { for (auto& family : commands_) { family = NONE_COMMANDS; } @@ -108,6 +148,12 @@ void User::UnsetAclCommands(size_t index, uint64_t bit_index) { commands_[index] ^= bit_index; } +void User::UnsetAclCommandsAndIncrSeq(size_t index, uint64_t bit_index) { + UnsetAclCommands(index, bit_index); + CommandChange change{index, bit_index}; + cmd_changes_[change] = ChangeMetadata{Sign::MINUS, seq_++}; +} + uint32_t User::AclCategory() const { return acl_categories_; } @@ -138,6 +184,14 @@ const AclKeys& User::Keys() const { return keys_; } +const User::CategoryChanges& User::CatChanges() const { + return cat_changes_; +} + +const User::CommandChanges& User::CmdChanges() const { + return cmd_changes_; +} + void User::SetKeyGlobs(std::vector keys) { for (auto& key : keys) { if (key.all_keys) { diff --git a/src/server/acl/user.h b/src/server/acl/user.h index 2ab15798e689..e29fd9630adc 100644 --- a/src/server/acl/user.h +++ b/src/server/acl/user.h @@ -33,16 +33,16 @@ class User final { struct UpdateRequest { std::optional password{}; - std::vector> categories; - std::optional is_active{}; bool is_hashed{false}; + // Categories and commands + using CategoryValueType = std::pair; // If index s numberic_limits::max() then it's a +all flag using CommandsValueType = std::tuple; - using CommandsUpdateType = std::vector; - CommandsUpdateType commands; + using UpdateType = std::vector>; + UpdateType updates; // keys std::vector keys; @@ -50,6 +50,38 @@ class User final { bool allow_all_keys{false}; }; + struct CategoryChange { + uint32_t category; + + // Customization point to make it hashable with absl containers + template friend H AbslHashValue(H h, const CategoryChange& c) { + return H::combine(std::move(h), c.category); + } + + friend bool operator==(CategoryChange c1, CategoryChange c2) { + return c1.category == c2.category; + } + }; + + struct CommandChange { + size_t family = 0; + uint64_t bit_index = 0; + + // Customization point to make it hashable with absl containers + template friend H AbslHashValue(H h, const CommandChange& c) { + return H::combine(std::move(h), c.family + c.bit_index); + } + + friend bool operator==(CommandChange c1, CommandChange c2) { + return (c1.family == c2.family) && (c1.bit_index == c2.bit_index); + } + }; + + struct ChangeMetadata { + Sign sign; + size_t seq_no; + }; + /* Used for default user * password = nopass * acl_categories = +@all @@ -80,15 +112,24 @@ class User final { const AclKeys& Keys() const; + using CategoryChanges = absl::flat_hash_map; + using CommandChanges = absl::flat_hash_map; + + const CategoryChanges& CatChanges() const; + + const CommandChanges& CmdChanges() const; + private: - // For ACL categories - void SetAclCategories(uint32_t cat); - void UnsetAclCategories(uint32_t cat); + void SetAclCategoriesAndIncrSeq(uint32_t cat); + void UnsetAclCategoriesAndIncrSeq(uint32_t cat); // For ACL commands void SetAclCommands(size_t index, uint64_t bit_index); void UnsetAclCommands(size_t index, uint64_t bit_index); + void SetAclCommandsAndIncrSeq(size_t index, uint64_t bit_index); + void UnsetAclCommandsAndIncrSeq(size_t index, uint64_t bit_index); + // For is_active flag void SetIsActive(bool is_active); @@ -108,6 +149,16 @@ class User final { // on how this mapping is built during the startup/registration of commands std::vector commands_; + // We also need to track all the explicit changes (ACL SETUSER) of acl's in-order. + // To speed up insertion we use the flat_hash_map and a seq_ variable which is a + // strictly monotonically increasing number that is used for ordering. Both of these + // indexers are merged and then sorted by the seq_ number when for example we print + // the ACL rules of each user via ACL LIST. + CategoryChanges cat_changes_; + CommandChanges cmd_changes_; + // Global modification order for changes in rules for acl commands and categories + size_t seq_ = 0; + // Glob patterns for the keys that a user is allowed to read/write AclKeys keys_; diff --git a/src/server/acl/user_registry.cc b/src/server/acl/user_registry.cc index 4a7e38e18e23..42d4525505be 100644 --- a/src/server/acl/user_registry.cc +++ b/src/server/acl/user_registry.cc @@ -8,6 +8,7 @@ #include #include "base/flags.h" +#include "facade/acl_commands_def.h" #include "facade/facade_types.h" #include "server/acl/acl_commands_def.h" @@ -71,29 +72,13 @@ UserRegistry::UserWithWriteLock::UserWithWriteLock(std::unique_lock lock(mu_); - const bool exists = registry_.contains(username); - auto& user = registry_[username]; - user.Update(std::move(req)); - return {std::move(lock), user, exists}; -} - User::UpdateRequest UserRegistry::DefaultUserUpdateRequest() const { - User::UpdateRequest::CommandsUpdateType tmp(NumberOfFamilies()); - size_t id = 0; - for (auto& elem : tmp) { - elem = {User::Sign::PLUS, id++, acl::ALL_COMMANDS}; - } std::pair acl{User::Sign::PLUS, acl::ALL}; auto key = User::UpdateKey{"~*", KeyOp::READ_WRITE, true, false}; - return {{}, {acl}, true, false, std::move(tmp), {std::move(key)}}; + return {{}, true, false, {std::move(acl)}, {std::move(key)}}; } void UserRegistry::Init() { - // Add default user - User::UpdateRequest::CommandsUpdateType tmp(NumberOfFamilies()); // if there exists an acl file to load from, requirepass // will not overwrite the default's user password loaded from // that file. Loading the default's user password from a file diff --git a/src/server/acl/user_registry.h b/src/server/acl/user_registry.h index a9ce519c275d..336a561d52f8 100644 --- a/src/server/acl/user_registry.h +++ b/src/server/acl/user_registry.h @@ -77,8 +77,6 @@ class UserRegistry { std::unique_lock registry_lk_; }; - UserWithWriteLock MaybeAddAndUpdateWithLock(std::string_view username, User::UpdateRequest req); - User::UpdateRequest DefaultUserUpdateRequest() const; private: diff --git a/src/server/acl/user_registry_test.cc b/src/server/acl/user_registry_test.cc deleted file mode 100644 index 0c065a6c7b80..000000000000 --- a/src/server/acl/user_registry_test.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2022, DragonflyDB authors. All rights reserved. -// See LICENSE for licensing terms. -// - -#include "server/acl/user_registry.h" - -#include -#include - -#include "base/gtest.h" -#include "base/logging.h" -#include "server/acl/acl_commands_def.h" -#include "server/acl/user.h" - -using namespace testing; - -namespace dfly::acl { - -class UserRegistryTest : public Test {}; - -TEST_F(UserRegistryTest, BasicOp) { - UserRegistry registry; - const std::string username = "kostas"; - const std::string pass = "mypass"; - - User::UpdateRequest req{pass, {}, true, false, {}, {}, false, false}; - registry.MaybeAddAndUpdate(username, std::move(req)); - CHECK_EQ(registry.AuthUser(username, pass), true); - CHECK_EQ(registry.IsUserActive(username), true); - - CHECK_EQ(registry.GetCredentials(username).acl_categories, NONE); - - using Sign = User::Sign; - std::vector> cat = {{Sign::PLUS, LIST}, {Sign::PLUS, SET}}; - req = User::UpdateRequest{{}, std::move(cat), true, false, {}, {}, false, false}; - registry.MaybeAddAndUpdate(username, std::move(req)); - auto acl_categories = registry.GetCredentials(username).acl_categories; - uint32_t expected_result = NONE | LIST | SET; - CHECK_EQ(acl_categories, expected_result); - - cat.push_back({Sign::MINUS, LIST}); - req = User::UpdateRequest{{}, std::move(cat), true, false, {}, {}, false, false}; - registry.MaybeAddAndUpdate(username, std::move(req)); - acl_categories = registry.GetCredentials(username).acl_categories; - expected_result = NONE | SET; - CHECK_EQ(acl_categories, expected_result); -} - -} // namespace dfly::acl diff --git a/src/server/acl/validator.cc b/src/server/acl/validator.cc index c75c87285e18..9482230e5f28 100644 --- a/src/server/acl/validator.cc +++ b/src/server/acl/validator.cc @@ -23,8 +23,8 @@ namespace dfly::acl { return true; } - const auto [is_authed, reason] = IsUserAllowedToInvokeCommandGeneric( - cntx.acl_categories, cntx.acl_commands, cntx.keys, tail_args, id); + const auto [is_authed, reason] = + IsUserAllowedToInvokeCommandGeneric(cntx.acl_commands, cntx.keys, tail_args, id); if (!is_authed) { auto& log = ServerState::tlocal()->acl_log; @@ -41,15 +41,13 @@ namespace dfly::acl { #endif [[nodiscard]] std::pair IsUserAllowedToInvokeCommandGeneric( - uint32_t acl_cat, const std::vector& acl_commands, const AclKeys& keys, - CmdArgList tail_args, const CommandId& id) { - const auto cat_credentials = id.acl_categories(); + const std::vector& acl_commands, const AclKeys& keys, CmdArgList tail_args, + const CommandId& id) { const size_t index = id.GetFamily(); const uint64_t command_mask = id.GetBitIndex(); DCHECK_LT(index, acl_commands.size()); - const bool command = - (acl_cat & cat_credentials) != 0 || (acl_commands[index] & command_mask) != 0; + const bool command = (acl_commands[index] & command_mask) != 0; if (!command) { return {false, AclLog::Reason::COMMAND}; diff --git a/src/server/acl/validator.h b/src/server/acl/validator.h index a1a12b5ef64a..89d6f77f492b 100644 --- a/src/server/acl/validator.h +++ b/src/server/acl/validator.h @@ -13,8 +13,8 @@ namespace dfly::acl { std::pair IsUserAllowedToInvokeCommandGeneric( - uint32_t acl_cat, const std::vector& acl_commands, const AclKeys& keys, - CmdArgList tail_args, const CommandId& id); + const std::vector& acl_commands, const AclKeys& keys, CmdArgList tail_args, + const CommandId& id); bool IsUserAllowedToInvokeCommand(const ConnectionContext& cntx, const CommandId& id, CmdArgList tail_args); diff --git a/src/server/command_registry.cc b/src/server/command_registry.cc index 1ad459ca90d7..1f426f824e22 100644 --- a/src/server/command_registry.cc +++ b/src/server/command_registry.cc @@ -132,7 +132,7 @@ CommandRegistry& CommandRegistry::operator<<(CommandId cmd) { } cmd.SetFamily(family_of_commands_.size() - 1); - if (!is_sub_command) { + if (!is_sub_command || absl::StartsWith(cmd.name(), "ACL")) { cmd.SetBitIndex(1ULL << bit_index_); family_of_commands_.back().push_back(std::string(k)); ++bit_index_; diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 276fc5515a6f..ba095a080dfd 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -2644,7 +2644,7 @@ void Service::RegisterCommands() { cluster_family_.Register(®istry_); acl_family_.Register(®istry_); - acl::BuildIndexers(registry_.GetFamilies()); + acl::BuildIndexers(registry_.GetFamilies(), ®istry_); // Only after all the commands are registered registry_.Init(pp_.size()); diff --git a/tests/dragonfly/acl_family_test.py b/tests/dragonfly/acl_family_test.py index 60070756b7a0..08c14911dffc 100644 --- a/tests/dragonfly/acl_family_test.py +++ b/tests/dragonfly/acl_family_test.py @@ -8,65 +8,67 @@ import os from . import dfly_args +import logging + @pytest.mark.asyncio async def test_acl_setuser(async_client): await async_client.execute_command("ACL SETUSER kostas") result = await async_client.execute_command("ACL LIST") assert 2 == len(result) - assert "user kostas off nopass +@NONE" in result + assert "user kostas off nopass -@ALL" in result await async_client.execute_command("ACL SETUSER kostas ON") result = await async_client.execute_command("ACL LIST") - assert "user kostas on nopass +@NONE" in result + assert "user kostas on nopass -@ALL" in result await async_client.execute_command("ACL SETUSER kostas +@list +@string +@admin") result = await async_client.execute_command("ACL LIST") # TODO consider printing to lowercase - assert "user kostas on nopass +@LIST +@STRING +@ADMIN" in result + assert "user kostas on nopass -@ALL +@LIST +@STRING +@ADMIN" in result await async_client.execute_command("ACL SETUSER kostas -@list -@admin") result = await async_client.execute_command("ACL LIST") - assert "user kostas on nopass +@STRING" in result + assert "user kostas on nopass -@ALL +@STRING -@LIST -@ADMIN" in result # mix and match await async_client.execute_command("ACL SETUSER kostas +@list -@string") result = await async_client.execute_command("ACL LIST") - assert "user kostas on nopass +@LIST" in result + assert "user kostas on nopass -@ALL -@ADMIN +@LIST -@STRING" in result # mix and match interleaved await async_client.execute_command("ACL SETUSER kostas +@set -@set +@set") result = await async_client.execute_command("ACL LIST") - assert "user kostas on nopass +@SET +@LIST" in result + assert "user kostas on nopass -@ALL -@ADMIN +@LIST -@STRING +@SET" in result await async_client.execute_command("ACL SETUSER kostas +@all") result = await async_client.execute_command("ACL LIST") - assert "user kostas on nopass +@ALL" in result + assert "user kostas on nopass -@ADMIN +@LIST -@STRING +@SET +@ALL" in result # commands await async_client.execute_command("ACL SETUSER kostas +set +get +hset") result = await async_client.execute_command("ACL LIST") - assert "user kostas on nopass +@ALL +SET +GET +HSET" in result + assert "user kostas on nopass -@ADMIN +@LIST -@STRING +@SET +@ALL +SET +GET +HSET" in result await async_client.execute_command("ACL SETUSER kostas -set -get +hset") result = await async_client.execute_command("ACL LIST") - assert "user kostas on nopass +@ALL +HSET" in result + assert "user kostas on nopass -@ADMIN +@LIST -@STRING +@SET +@ALL -SET -GET +HSET" in result # interleaved await async_client.execute_command("ACL SETUSER kostas -hset +get -get -@all") result = await async_client.execute_command("ACL LIST") - assert "user kostas on nopass +@NONE" in result + assert "user kostas on nopass -@ADMIN +@LIST -@STRING +@SET -SET -HSET -GET -@ALL" in result # interleaved with categories await async_client.execute_command("ACL SETUSER kostas +@string +get -get +set") result = await async_client.execute_command("ACL LIST") - assert "user kostas on nopass +@STRING +SET" in result + assert "user kostas on nopass -@ADMIN +@LIST +@SET -HSET -@ALL +@STRING -GET +SET" in result @pytest.mark.asyncio async def test_acl_categories(async_client): await async_client.execute_command( - "ACL SETUSER vlad ON >mypass +@string +@list +@connection ~*" + "ACL SETUSER vlad ON >mypass -@ALL +@string +@list +@connection ~*" ) result = await async_client.execute_command("AUTH vlad mypass") @@ -80,7 +82,7 @@ async def test_acl_categories(async_client): # This should fail, vlad does not have @admin with pytest.raises(redis.exceptions.ResponseError): - await async_client.execute_command("ACL SETUSER vlad ON >mypass") + result = await async_client.execute_command("ACL SETUSER vlad ON >mypass") # This should fail, vlad does not have @sortedset with pytest.raises(redis.exceptions.ResponseError): @@ -116,7 +118,7 @@ async def test_acl_categories(async_client): @pytest.mark.asyncio async def test_acl_commands(async_client): - await async_client.execute_command("ACL SETUSER random ON >mypass +@NONE +set +get ~*") + await async_client.execute_command("ACL SETUSER random ON >mypass -@ALL +set +get ~*") result = await async_client.execute_command("AUTH random mypass") assert result == "OK" @@ -332,8 +334,8 @@ async def test_good_acl_file(df_local_factory, tmp_dir): await client.execute_command("ACL LOAD") result = await client.execute_command("ACL LIST") assert 2 == len(result) - assert "user MrFoo on ea71c25a7a60224 +@NONE" in result - assert "user default on nopass +@ALL +ALL ~*" in result + assert "user MrFoo on ea71c25a7a60224 -@ALL" in result + assert "user default on nopass +@ALL ~*" in result await client.execute_command("ACL DELUSER MrFoo") await client.execute_command("ACL SETUSER roy ON >mypass +@STRING +HSET") @@ -342,10 +344,10 @@ async def test_good_acl_file(df_local_factory, tmp_dir): result = await client.execute_command("ACL LIST") assert 4 == len(result) - assert "user roy on ea71c25a7a60224 +@STRING +HSET" in result - assert "user shahar off ea71c25a7a60224 +@SET" in result - assert "user vlad off nopass +@STRING ~foo ~bar*" in result - assert "user default on nopass +@ALL +ALL ~*" in result + assert "user roy on ea71c25a7a60224 -@ALL +@STRING +HSET" in result + assert "user shahar off ea71c25a7a60224 -@ALL +@SET" in result + assert "user vlad off nopass -@ALL +@STRING ~foo ~bar*" in result + assert "user default on nopass +@ALL ~*" in result result = await client.execute_command("ACL DELUSER shahar") assert result == 1 @@ -356,9 +358,9 @@ async def test_good_acl_file(df_local_factory, tmp_dir): result = await client.execute_command("ACL LIST") assert 3 == len(result) - assert "user roy on ea71c25a7a60224 +@STRING +HSET" in result - assert "user vlad off nopass +@STRING ~foo ~bar*" in result - assert "user default on nopass +@ALL +ALL ~*" in result + assert "user roy on ea71c25a7a60224 -@ALL +@STRING +HSET" in result + assert "user vlad off nopass -@ALL +@STRING ~foo ~bar*" in result + assert "user default on nopass +@ALL ~*" in result await client.close()