Skip to content

Commit

Permalink
fix(acl): acl compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
kostasrim committed Jun 6, 2024
1 parent 3924fca commit e2fe306
Show file tree
Hide file tree
Showing 15 changed files with 284 additions and 127 deletions.
1 change: 0 additions & 1 deletion src/facade/command_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <string_view>

#include "facade/facade_types.h"
#include "server/acl/acl_commands_def.h"

namespace facade {

Expand Down
1 change: 0 additions & 1 deletion src/facade/dragonfly_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,6 @@ void Connection::DispatchOperations::operator()(const MonitorMessage& msg) {
void Connection::DispatchOperations::operator()(const AclUpdateMessage& msg) {
if (self->cntx()) {
if (msg.username == self->cntx()->authed_username) {
self->cntx()->acl_categories = msg.categories;
self->cntx()->acl_commands = msg.commands;
self->cntx()->keys = msg.keys;
}
Expand Down
1 change: 0 additions & 1 deletion src/facade/dragonfly_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ class Connection : public util::Connection {
// ACL Update message, contains ACL updates to be applied to the connection.
struct AclUpdateMessage {
std::string username;
uint32_t categories;
std::vector<uint64_t> commands;
dfly::acl::AclKeys keys;
};
Expand Down
38 changes: 37 additions & 1 deletion src/server/acl/acl_commands_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
#pragma once

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "facade/acl_commands_def.h"
#include "server/command_registry.h"
#include "server/conn_context.h"

namespace dfly::acl {

Expand Down Expand Up @@ -84,6 +87,15 @@ inline const std::vector<std::string> REVERSE_CATEGORY_INDEX_TABLE{
"_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED",
"BLOOM", "FT_SEARCH", "THROTTLE", "JSON"};

// bit index to index in the REVERSE_CATEGORY_INDEX_TABLE
using CategoryToIdxStore = absl::flat_hash_map<uint32_t, uint32_t>;

// inline const CategoryToIdxStore& CategoryToIdx(CategoryToIdxStore store = {}) {
inline const CategoryToIdxStore& CategoryToIdx(CategoryToIdxStore store = {}) {
static CategoryToIdxStore cat_idx = std::move(store);
return cat_idx;
}

using RevCommandField = std::vector<std::string>;
using RevCommandsIndexStore = std::vector<RevCommandField>;

Expand All @@ -104,9 +116,33 @@ inline const RevCommandsIndexStore& CommandsRevIndexer(RevCommandsIndexStore sto
return rev_index_store;
}

inline void BuildIndexers(std::vector<std::vector<std::string>> families) {
using CategoryToCommandsIndexStore = absl::flat_hash_map<std::string, std::vector<uint64_t>>;

inline const CategoryToCommandsIndexStore& CategoryToCommandsIndex(
CategoryToCommandsIndexStore store = {}) {
static CategoryToCommandsIndexStore index = std::move(store);
return index;
}

inline void BuildIndexers(RevCommandsIndexStore families, CommandRegistry* cmd_registry) {
acl::NumberOfFamilies(families.size());
acl::CommandsRevIndexer(std::move(families));
CategoryToCommandsIndexStore index;
cmd_registry->Traverse([&](std::string_view name, auto& cid) {
if (index[name].empty()) {
index[name].resize(CommandsRevIndexer().size());
}
auto family = cid.GetFamily();
auto bit_index = cid.GetBitIndex();
index[name][family] = index[name][family] & (1u << bit_index);
});

CategoryToCommandsIndex(std::move(index));
CategoryToIdxStore idx_store;
for (size_t i = 0; i < 32; ++i) {
idx_store[1 << i] = i;
}
CategoryToIdx(std::move(idx_store));
}

} // namespace dfly::acl
43 changes: 24 additions & 19 deletions src/server/acl/acl_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,23 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
std::string buffer = "user ";
const std::string_view pass = user.Password();
const std::string password = pass == "nopass" ? "nopass" : PrettyPrintSha(pass);
const std::string acl_cat = AclCatToString(user.AclCategory());
const std::string acl_commands = AclCommandToString(user.AclCommandsRef());
const std::string maybe_space_com = acl_commands.empty() ? "" : " ";

const std::string acl_cat_and_commands =
AclCatAndCommandToString(user.CatChanges(), user.CmdChanges());

const std::string maybe_space_com = acl_cat_and_commands.empty() ? "" : " ";
const std::string acl_keys = AclKeysToString(user.Keys());
const std::string maybe_space = acl_keys.empty() ? "" : " ";

using namespace std::string_view_literals;

absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password, " ",
acl_cat, maybe_space_com, acl_commands, maybe_space, acl_keys);
acl_cat_and_commands, maybe_space_com, acl_keys);

cntx->SendSimpleString(buffer);
}
}

void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, uint32_t update_cat,
void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user,
const Commands& update_commands,
const AclKeys& update_keys) {
auto update_cb = [&]([[maybe_unused]] size_t id, util::Connection* conn) {
Expand All @@ -90,7 +91,7 @@ void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, u
if (connection->protocol() == facade::Protocol::REDIS && !connection->IsHttp() &&
connection->cntx()) {
connection->SendAclUpdateAsync(
facade::Connection::AclUpdateMessage{user, update_cat, update_commands, update_keys});
facade::Connection::AclUpdateMessage{user, update_commands, update_keys});
}
};

Expand All @@ -113,10 +114,14 @@ void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) {

auto update_case = [username, &reg, cntx, this, exists](User::UpdateRequest&& req) {
auto& user = reg.registry[username];
if (!exists) {
User::UpdateRequest default_req;
default_req.updates = {User::UpdateRequest::CategoryValueType{User::Sign::MINUS, acl::ALL}};
user.Update(std::move(default_req));
}
user.Update(std::move(req));
if (exists) {
StreamUpdatesToAllProactorConnections(std::string(username), user.AclCategory(),
user.AclCommands(), user.Keys());
StreamUpdatesToAllProactorConnections(std::string(username), user.AclCommands(), user.Keys());
}
cntx->SendOk();
};
Expand Down Expand Up @@ -194,16 +199,16 @@ std::string AclFamily::RegistryToString() const {
const std::string_view pass = user.Password();
const std::string password =
pass == "nopass" ? "nopass " : absl::StrCat("#", PrettyPrintSha(pass, true), " ");
const std::string acl_cat = AclCatToString(user.AclCategory());
const std::string acl_commands = AclCommandToString(user.AclCommandsRef());
const std::string maybe_space_com = acl_commands.empty() ? "" : " ";
const std::string acl_cat_and_commands =
AclCatAndCommandToString(user.CatChanges(), user.CmdChanges());
const std::string maybe_space_com = acl_cat_and_commands.empty() ? "" : " ";
const std::string acl_keys = AclKeysToString(user.Keys());
const std::string maybe_space = acl_keys.empty() ? "" : " ";

using namespace std::string_view_literals;

absl::StrAppend(&result, command, username, " ", user.IsActive() ? "ON "sv : "OFF "sv, password,
acl_cat, maybe_space_com, acl_commands, maybe_space, acl_keys, "\n");
acl_cat_and_commands, maybe_space_com, acl_keys, "\n");
}

if (!result.empty()) {
Expand Down Expand Up @@ -446,6 +451,7 @@ void AclFamily::Cat(CmdArgList args, ConnectionContext* cntx) {

const uint32_t cid_mask = CATEGORY_INDEX_TABLE.find(category)->second;
std::vector<std::string_view> results;
// TODO replace this with indexer
auto cb = [cid_mask, &results](auto name, auto& cid) {
if (cid_mask & cid.acl_categories()) {
results.push_back(name);
Expand Down Expand Up @@ -510,10 +516,10 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) {
}
rb->SendSimpleString("commands");

std::string acl = absl::StrCat(AclCatToString(user.AclCategory()), " ",
AclCommandToString(user.AclCommandsRef()));
const std::string acl_cat_and_commands =
AclCatAndCommandToString(user.CatChanges(), user.CmdChanges());

rb->SendSimpleString(acl);
rb->SendSimpleString(acl_cat_and_commands);

rb->SendSimpleString("keys");
std::string keys = AclKeysToString(user.Keys());
Expand Down Expand Up @@ -572,9 +578,8 @@ void AclFamily::DryRun(CmdArgList args, ConnectionContext* cntx) {
}

const auto& user = registry.find(username)->second;
const bool is_allowed = IsUserAllowedToInvokeCommandGeneric(
user.AclCategory(), user.AclCommandsRef(), {{}, true}, {}, *cid)
.first;
const bool is_allowed =
IsUserAllowedToInvokeCommandGeneric(user.AclCommandsRef(), {{}, true}, {}, *cid).first;
if (is_allowed) {
cntx->SendOk();
return;
Expand Down
2 changes: 1 addition & 1 deletion src/server/acl/acl_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class AclFamily final {
// Helper function that updates all open connections and their
// respective ACL fields on all the available proactor threads
using Commands = std::vector<uint64_t>;
void StreamUpdatesToAllProactorConnections(const std::string& user, uint32_t update_cat,
void StreamUpdatesToAllProactorConnections(const std::string& user,
const Commands& update_commands,
const AclKeys& update_keys);

Expand Down
134 changes: 84 additions & 50 deletions src/server/acl/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,64 +12,112 @@
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "core/overloaded.h"
#include "facade/acl_commands_def.h"
#include "server/acl/acl_commands_def.h"
#include "server/acl/user.h"
#include "server/common.h"

namespace dfly::acl {

std::string AclCatToString(uint32_t acl_category) {
std::string tmp;
namespace {

std::string AclCatToString(uint32_t acl_category, User::Sign sign) {
std::string res = sign == User::Sign::PLUS ? "+@" : "-@";
if (acl_category == acl::ALL) {
return "+@ALL";
absl::StrAppend(&res, "ALL");
return res;
}

if (acl_category == acl::NONE) {
return "+@NONE";
const auto& index = CategoryToIdx().at(acl_category);
absl::StrAppend(&res, REVERSE_CATEGORY_INDEX_TABLE[index]);
return res;
}

std::string AclCommandToString(size_t family, uint64_t mask, User::Sign sign) {
// This is constant but can be optimized with an indexer
const auto& rev_index = CommandsRevIndexer();
std::string res;
std::string prefix = (sign == User::Sign::PLUS) ? "+" : "-";
if (mask == ALL_COMMANDS) {
for (const auto& cmd : rev_index[family]) {
absl::StrAppend(&res, prefix, cmd, " ");
}
res.pop_back();
return res;
}

const std::string prefix = "+@";
const std::string postfix = " ";
size_t pos = 0;
while (mask != 0) {
++pos;
mask = mask >> 1;
}
--pos;
absl::StrAppend(&res, prefix, rev_index[family][pos]);
return res;
}

for (uint32_t i = 0; i < 32; ++i) {
uint32_t cat_bit = 1ULL << i;
if (acl_category & cat_bit) {
absl::StrAppend(&tmp, prefix, REVERSE_CATEGORY_INDEX_TABLE[i], postfix);
}
struct CategoryAndMetadata {
User::CategoryChange change;
User::ChangeMetadata metadata;
};

struct CommandAndMetadata {
User::CommandChange change;
User::ChangeMetadata metadata;
};

using MergeResult = std::vector<std::variant<CategoryAndMetadata, CommandAndMetadata>>;

} // namespace

// Merge Category and Command changes and sort them by global order seq_no
MergeResult MergeTables(const User::CategoryChanges& categories,
const User::CommandChanges& commands) {
MergeResult result;
for (auto [cat, meta] : categories) {
result.push_back(CategoryAndMetadata{cat, meta});
}

for (auto [cmd, meta] : commands) {
result.push_back(CommandAndMetadata{cmd, meta});
}

tmp.pop_back();
std::sort(result.begin(), result.end(), [](const auto& l, const auto& r) {
auto fetch = [](const auto& l) { return l.metadata.seq_no; };
return std::visit(fetch, l) < std::visit(fetch, r);
});

return tmp;
return result;
}

std::string AclCommandToString(const std::vector<uint64_t>& acl_category) {
std::string AclCatAndCommandToString(const User::CategoryChanges& cat,
const User::CommandChanges& cmds) {
std::string result;

const std::string prefix = "+";
const std::string postfix = " ";
const auto& rev_index = CommandsRevIndexer();
bool all = true;

size_t family_id = 0;
for (auto family : acl_category) {
for (uint64_t i = 0; i < 64; ++i) {
const uint64_t cmd_bit = 1ULL << i;
if (family & cmd_bit && i < rev_index[family_id].size()) {
absl::StrAppend(&result, prefix, rev_index[family_id][i], postfix);
continue;
}
if (i < rev_index[family_id].size()) {
all = false;
}
}
++family_id;
auto tables = MergeTables(cat, cmds);

auto cat_visitor = [&result](const CategoryAndMetadata& val) {
const auto& [change, meta] = val;
absl::StrAppend(&result, AclCatToString(change.category, meta.sign), " ");
};

auto cmd_visitor = [&result](const CommandAndMetadata& val) {
const auto& [change, meta] = val;
absl::StrAppend(&result, AclCommandToString(change.family, change.bit_index, meta.sign), " ");
};

Overloaded visitor{cat_visitor, cmd_visitor};

for (auto change : tables) {
std::visit(visitor, change);
}

if (!result.empty()) {
result.pop_back();
}
return all ? "+ALL" : result;

return result;
}

std::string PrettyPrintSha(std::string_view pass, bool all) {
Expand Down Expand Up @@ -157,21 +205,8 @@ std::pair<OptCat, bool> MaybeParseAclCategory(std::string_view command) {
return {};
}

bool IsIndexAllCommandsFlag(size_t index) {
return index == std::numeric_limits<size_t>::max();
}

std::pair<OptCommand, bool> MaybeParseAclCommand(std::string_view command,
const CommandRegistry& registry) {
const auto all_commands = std::pair<size_t, uint64_t>{std::numeric_limits<size_t>::max(), 0};
if (command == "+ALL") {
return {all_commands, true};
}

if (command == "-ALL") {
return {all_commands, false};
}

if (absl::StartsWith(command, "+")) {
auto res = registry.Find(command.substr(1));
if (!res) {
Expand Down Expand Up @@ -281,7 +316,7 @@ std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(T args,
using Sign = User::Sign;
using Val = std::pair<Sign, uint32_t>;
auto val = add ? Val{Sign::PLUS, *cat} : Val{Sign::MINUS, *cat};
req.categories.push_back(val);
req.updates.push_back(val);
continue;
}

Expand All @@ -292,10 +327,9 @@ std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(T args,

using Sign = User::Sign;
using Val = User::UpdateRequest::CommandsValueType;
;
auto [index, bit] = *cmd;
auto val = sign ? Val{Sign::PLUS, index, bit} : Val{Sign::MINUS, index, bit};
req.commands.push_back(val);
req.updates.push_back(val);
}

return req;
Expand Down
Loading

0 comments on commit e2fe306

Please sign in to comment.