Skip to content

Commit

Permalink
fix(acl): acl compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
kostasrim committed Jun 7, 2024
1 parent 3924fca commit 51ae74f
Show file tree
Hide file tree
Showing 19 changed files with 344 additions and 223 deletions.
1 change: 0 additions & 1 deletion src/facade/command_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <string_view>

#include "facade/facade_types.h"
#include "server/acl/acl_commands_def.h"

namespace facade {

Expand Down
1 change: 0 additions & 1 deletion src/facade/dragonfly_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,6 @@ void Connection::DispatchOperations::operator()(const MonitorMessage& msg) {
void Connection::DispatchOperations::operator()(const AclUpdateMessage& msg) {
if (self->cntx()) {
if (msg.username == self->cntx()->authed_username) {
self->cntx()->acl_categories = msg.categories;
self->cntx()->acl_commands = msg.commands;
self->cntx()->keys = msg.keys;
}
Expand Down
1 change: 0 additions & 1 deletion src/facade/dragonfly_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ class Connection : public util::Connection {
// ACL Update message, contains ACL updates to be applied to the connection.
struct AclUpdateMessage {
std::string username;
uint32_t categories;
std::vector<uint64_t> commands;
dfly::acl::AclKeys keys;
};
Expand Down
51 changes: 50 additions & 1 deletion src/server/acl/acl_commands_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
#pragma once

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "base/logging.h"
#include "facade/acl_commands_def.h"
#include "server/command_registry.h"
#include "server/conn_context.h"

namespace dfly::acl {

Expand Down Expand Up @@ -84,6 +88,15 @@ inline const std::vector<std::string> REVERSE_CATEGORY_INDEX_TABLE{
"_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED",
"BLOOM", "FT_SEARCH", "THROTTLE", "JSON"};

// bit index to index in the REVERSE_CATEGORY_INDEX_TABLE
using CategoryToIdxStore = absl::flat_hash_map<uint32_t, uint32_t>;

// inline const CategoryToIdxStore& CategoryToIdx(CategoryToIdxStore store = {}) {
inline const CategoryToIdxStore& CategoryToIdx(CategoryToIdxStore store = {}) {
static CategoryToIdxStore cat_idx = std::move(store);
return cat_idx;
}

using RevCommandField = std::vector<std::string>;
using RevCommandsIndexStore = std::vector<RevCommandField>;

Expand All @@ -104,9 +117,45 @@ inline const RevCommandsIndexStore& CommandsRevIndexer(RevCommandsIndexStore sto
return rev_index_store;
}

inline void BuildIndexers(std::vector<std::vector<std::string>> families) {
using CategoryToCommandsIndexStore = absl::flat_hash_map<std::string, std::vector<uint64_t>>;

inline const CategoryToCommandsIndexStore& CategoryToCommandsIndex(
CategoryToCommandsIndexStore store = {}) {
static CategoryToCommandsIndexStore index = std::move(store);
return index;
}

inline void BuildIndexers(RevCommandsIndexStore families, CommandRegistry* cmd_registry) {
acl::NumberOfFamilies(families.size());
acl::CommandsRevIndexer(std::move(families));
CategoryToCommandsIndexStore index;
cmd_registry->Traverse([&](std::string_view name, auto& cid) {
auto cat = cid.acl_categories();
for (size_t i = 0; i < 32; ++i) {
if (cat & (1 << i)) {
std::string_view cat_name = REVERSE_CATEGORY_INDEX_TABLE[i];
bool is_acl = cid.name() == "ACL SETUSER";
if (is_acl) {
LOG(INFO) << "CAUGHT";
LOG(INFO) << "CAUGHT";
LOG(INFO) << "CAUGHT";
}
if (index[cat_name].empty()) {
index[cat_name].resize(CommandsRevIndexer().size());
}
auto family = cid.GetFamily();
auto bit_index = cid.GetBitIndex();
index[cat_name][family] = index[cat_name][family] | bit_index;
}
}
});

CategoryToCommandsIndex(std::move(index));
CategoryToIdxStore idx_store;
for (size_t i = 0; i < 32; ++i) {
idx_store[1 << i] = i;
}
CategoryToIdx(std::move(idx_store));
}

} // namespace dfly::acl
49 changes: 26 additions & 23 deletions src/server/acl/acl_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,23 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
std::string buffer = "user ";
const std::string_view pass = user.Password();
const std::string password = pass == "nopass" ? "nopass" : PrettyPrintSha(pass);
const std::string acl_cat = AclCatToString(user.AclCategory());
const std::string acl_commands = AclCommandToString(user.AclCommandsRef());
const std::string maybe_space_com = acl_commands.empty() ? "" : " ";

const std::string acl_cat_and_commands =
AclCatAndCommandToString(user.CatChanges(), user.CmdChanges());

const std::string acl_keys = AclKeysToString(user.Keys());
const std::string maybe_space = acl_keys.empty() ? "" : " ";
const std::string maybe_space_com = acl_keys.empty() ? "" : " ";

using namespace std::string_view_literals;

absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password, " ",
acl_cat, maybe_space_com, acl_commands, maybe_space, acl_keys);
acl_cat_and_commands, maybe_space_com, acl_keys);

cntx->SendSimpleString(buffer);
}
}

void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, uint32_t update_cat,
void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user,
const Commands& update_commands,
const AclKeys& update_keys) {
auto update_cb = [&]([[maybe_unused]] size_t id, util::Connection* conn) {
Expand All @@ -90,7 +91,7 @@ void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, u
if (connection->protocol() == facade::Protocol::REDIS && !connection->IsHttp() &&
connection->cntx()) {
connection->SendAclUpdateAsync(
facade::Connection::AclUpdateMessage{user, update_cat, update_commands, update_keys});
facade::Connection::AclUpdateMessage{user, update_commands, update_keys});
}
};

Expand All @@ -113,10 +114,14 @@ void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) {

auto update_case = [username, &reg, cntx, this, exists](User::UpdateRequest&& req) {
auto& user = reg.registry[username];
if (!exists) {
User::UpdateRequest default_req;
default_req.updates = {User::UpdateRequest::CategoryValueType{User::Sign::MINUS, acl::ALL}};
user.Update(std::move(default_req));
}
user.Update(std::move(req));
if (exists) {
StreamUpdatesToAllProactorConnections(std::string(username), user.AclCategory(),
user.AclCommands(), user.Keys());
StreamUpdatesToAllProactorConnections(std::string(username), user.AclCommands(), user.Keys());
}
cntx->SendOk();
};
Expand Down Expand Up @@ -194,20 +199,15 @@ std::string AclFamily::RegistryToString() const {
const std::string_view pass = user.Password();
const std::string password =
pass == "nopass" ? "nopass " : absl::StrCat("#", PrettyPrintSha(pass, true), " ");
const std::string acl_cat = AclCatToString(user.AclCategory());
const std::string acl_commands = AclCommandToString(user.AclCommandsRef());
const std::string maybe_space_com = acl_commands.empty() ? "" : " ";
const std::string acl_cat_and_commands =
AclCatAndCommandToString(user.CatChanges(), user.CmdChanges());
const std::string acl_keys = AclKeysToString(user.Keys());
const std::string maybe_space = acl_keys.empty() ? "" : " ";

using namespace std::string_view_literals;

absl::StrAppend(&result, command, username, " ", user.IsActive() ? "ON "sv : "OFF "sv, password,
acl_cat, maybe_space_com, acl_commands, maybe_space, acl_keys, "\n");
}

if (!result.empty()) {
result.pop_back();
acl_cat_and_commands, maybe_space, acl_keys, "\n");
}

return result;
Expand Down Expand Up @@ -298,7 +298,10 @@ GenericError AclFamily::LoadToRegistryFromFile(std::string_view full_path,
}

for (size_t i = 0; i < usernames.size(); ++i) {
User::UpdateRequest default_req;
default_req.updates = {User::UpdateRequest::CategoryValueType{User::Sign::MINUS, acl::ALL}};
auto& user = registry[usernames[i]];
user.Update(std::move(default_req));
user.Update(std::move(requests[i]));
}

Expand Down Expand Up @@ -446,6 +449,7 @@ void AclFamily::Cat(CmdArgList args, ConnectionContext* cntx) {

const uint32_t cid_mask = CATEGORY_INDEX_TABLE.find(category)->second;
std::vector<std::string_view> results;
// TODO replace this with indexer
auto cb = [cid_mask, &results](auto name, auto& cid) {
if (cid_mask & cid.acl_categories()) {
results.push_back(name);
Expand Down Expand Up @@ -510,10 +514,10 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) {
}
rb->SendSimpleString("commands");

std::string acl = absl::StrCat(AclCatToString(user.AclCategory()), " ",
AclCommandToString(user.AclCommandsRef()));
const std::string acl_cat_and_commands =
AclCatAndCommandToString(user.CatChanges(), user.CmdChanges());

rb->SendSimpleString(acl);
rb->SendSimpleString(acl_cat_and_commands);

rb->SendSimpleString("keys");
std::string keys = AclKeysToString(user.Keys());
Expand Down Expand Up @@ -572,9 +576,8 @@ void AclFamily::DryRun(CmdArgList args, ConnectionContext* cntx) {
}

const auto& user = registry.find(username)->second;
const bool is_allowed = IsUserAllowedToInvokeCommandGeneric(
user.AclCategory(), user.AclCommandsRef(), {{}, true}, {}, *cid)
.first;
const bool is_allowed =
IsUserAllowedToInvokeCommandGeneric(user.AclCommandsRef(), {{}, true}, {}, *cid).first;
if (is_allowed) {
cntx->SendOk();
return;
Expand Down
2 changes: 1 addition & 1 deletion src/server/acl/acl_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class AclFamily final {
// Helper function that updates all open connections and their
// respective ACL fields on all the available proactor threads
using Commands = std::vector<uint64_t>;
void StreamUpdatesToAllProactorConnections(const std::string& user, uint32_t update_cat,
void StreamUpdatesToAllProactorConnections(const std::string& user,
const Commands& update_commands,
const AclKeys& update_keys);

Expand Down
38 changes: 19 additions & 19 deletions src/server/acl/acl_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ TEST_F(AclFamilyTest, AclSetUser) {
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "LIST"});
auto vec = resp.GetVec();
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
"user vlad off nopass +@NONE"));
EXPECT_THAT(
vec, UnorderedElementsAre("user default on nopass +@ALL ~*", "user vlad off nopass -@ALL"));

resp = Run({"ACL", "SETUSER", "vlad", "+ACL"});
EXPECT_THAT(resp, "OK");

resp = Run({"ACL", "LIST"});
vec = resp.GetVec();
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
"user vlad off nopass +@NONE +ACL"));
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL ~*",
"user vlad off nopass -@ALL +ACL"));
}

TEST_F(AclFamilyTest, AclDelUser) {
Expand All @@ -82,7 +82,7 @@ TEST_F(AclFamilyTest, AclDelUser) {
EXPECT_THAT(resp, IntArg(0));

resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetString(), "user default on nopass +@ALL +ALL ~*");
EXPECT_THAT(resp.GetString(), "user default on nopass +@ALL ~*");

Run({"ACL", "SETUSER", "michael", "ON"});
Run({"ACL", "SETUSER", "kobe", "ON"});
Expand All @@ -103,9 +103,9 @@ TEST_F(AclFamilyTest, AclList) {

resp = Run({"ACL", "LIST"});
auto vec = resp.GetVec();
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
"user kostas off d74ff0ee8da3b98 +@ADMIN",
"user adi off d74ff0ee8da3b98 +@FAST"));
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL ~*",
"user kostas off d74ff0ee8da3b98 -@ALL +@ADMIN",
"user adi off d74ff0ee8da3b98 -@ALL +@FAST"));
}

TEST_F(AclFamilyTest, AclAuth) {
Expand Down Expand Up @@ -154,16 +154,16 @@ TEST_F(AclFamilyTest, TestAllCategories) {

resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
absl::StrCat("user kostas off nopass ", "+@", cat)));
UnorderedElementsAre("user default on nopass +@ALL ~*",
absl::StrCat("user kostas off nopass -@ALL ", "+@", cat)));

resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-@", cat)});
EXPECT_THAT(resp, "OK");

resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
absl::StrCat("user kostas off nopass ", "+@NONE")));
UnorderedElementsAre("user default on nopass +@ALL ~*",
absl::StrCat("user kostas off nopass -@ALL ", "-@", cat)));

resp = Run({"ACL", "DELUSER", "kostas"});
EXPECT_THAT(resp, IntArg(1));
Expand Down Expand Up @@ -201,16 +201,16 @@ TEST_F(AclFamilyTest, TestAllCommands) {
EXPECT_THAT(resp, "OK");

resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
absl::StrCat("user kostas off nopass +@NONE ",
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL ~*",
absl::StrCat("user kostas off nopass -@ALL ",
"+", command_name)));

resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-", command_name)});

resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
absl::StrCat("user kostas off nopass ", "+@NONE")));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL ~*",
absl::StrCat("user kostas off nopass ",
"-@ALL ", "-", command_name)));

resp = Run({"ACL", "DELUSER", "kostas"});
EXPECT_THAT(resp, IntArg(1));
Expand Down Expand Up @@ -259,7 +259,7 @@ TEST_F(AclFamilyTest, TestGetUser) {
EXPECT_THAT(vec[2], "passwords");
EXPECT_TRUE(vec[3].GetVec().empty());
EXPECT_THAT(vec[4], "commands");
EXPECT_THAT(vec[5], "+@ALL +ALL");
EXPECT_THAT(vec[5], "+@ALL");
EXPECT_THAT(vec[6], "keys");
EXPECT_THAT(vec[7], "~*");

Expand All @@ -271,7 +271,7 @@ TEST_F(AclFamilyTest, TestGetUser) {
EXPECT_THAT(kvec[2], "passwords");
EXPECT_TRUE(kvec[3].GetVec().empty());
EXPECT_THAT(kvec[4], "commands");
EXPECT_THAT(kvec[5], "+@STRING +HSET");
EXPECT_THAT(kvec[5], "-@ALL +@STRING +HSET");
}

TEST_F(AclFamilyTest, TestDryRun) {
Expand Down
Loading

0 comments on commit 51ae74f

Please sign in to comment.