Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Make KeyIndex iterable #3326

Merged
merged 5 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/facade/facade_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ struct ArgRange {
return Range().second;
}

std::string_view operator[](size_t idx) const {
return std::visit([idx](const auto& span) { return facade::ToSV(span[idx]); }, span);
}

std::variant<CmdArgList, ArgSlice, OwnedArgSlice> span;
};
struct ConnectionStats {
Expand Down
22 changes: 5 additions & 17 deletions src/server/acl/validator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,11 @@ namespace dfly::acl {

bool keys_allowed = true;
if (!keys.all_keys && id.first_key_pos() != 0 && (is_read_command || is_write_command)) {
const auto keys_index = DetermineKeys(&id, tail_args).value();
const size_t end = keys_index.end;
if (keys_index.bonus) {
auto target = facade::ToSV(tail_args[*keys_index.bonus]);
if (!iterate_globs(target)) {
keys_allowed = false;
}
}
if (keys_allowed) {
for (size_t i = keys_index.start; i < end; i += keys_index.step) {
auto target = facade::ToSV(tail_args[i]);
if (!iterate_globs(target)) {
keys_allowed = false;
break;
}
}
}
auto keys_index = DetermineKeys(&id, tail_args);
DCHECK(keys_index);

for (std::string_view key : keys_index->Range(tail_args))
keys_allowed &= iterate_globs(key);
}

return {keys_allowed, AclLog::Reason::KEY};
Expand Down
1 change: 1 addition & 0 deletions src/server/cluster/cluster_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ TEST_F(ClusterFamilyTest, ClusterCrossSlot) {

EXPECT_THAT(Run({"MSET", "key", "value", "key2", "value2"}), ErrArg("CROSSSLOT"));
EXPECT_THAT(Run({"MGET", "key", "key2"}), ErrArg("CROSSSLOT"));
EXPECT_THAT(Run({"ZINTERSTORE", "key", "2", "key1", "key2"}), ErrArg("CROSSSLOT"));

EXPECT_EQ(Run({"MSET", "key{tag}", "value", "key2{tag}", "value2"}), "OK");
EXPECT_THAT(Run({"MGET", "key{tag}", "key2{tag}"}), RespArray(ElementsAre("value", "value2")));
Expand Down
28 changes: 5 additions & 23 deletions src/server/main_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ Transaction::MultiMode DeduceExecMode(ExecEvalState state,
StoredCmd cmd = scmd;
cmd.Fill(&arg_vec);
auto keys = DetermineKeys(scmd.Cid(), absl::MakeSpan(arg_vec));
transactional |= (keys && keys.value().num_args() > 0);
transactional |= (keys && keys.value().NumArgs() > 0);
} else {
transactional |= scmd.Cid()->IsTransactional();
}
Expand Down Expand Up @@ -943,10 +943,8 @@ optional<ErrorReply> Service::CheckKeysOwnership(const CommandId* cid, CmdArgLis
optional<cluster::SlotId> keys_slot;
bool cross_slot = false;
// Iterate keys and check to which slot they belong.
for (unsigned i = key_index.start; i < key_index.end; i += key_index.step) {
string_view key = ArgS(args, i);
cluster::SlotId slot = cluster::KeySlot(key);
if (keys_slot && slot != *keys_slot) {
for (string_view key : key_index.Range(args)) {
if (cluster::SlotId slot = cluster::KeySlot(key); keys_slot && slot != *keys_slot) {
cross_slot = true; // keys belong to different slots
break;
} else {
Expand Down Expand Up @@ -991,18 +989,7 @@ optional<ErrorReply> CheckKeysDeclared(const ConnectionState::ScriptInfo& eval_i
// TODO: Switch to transaction internal locked keys once single hop multi transactions are merged
// const auto& locked_keys = trans->GetMultiKeys();
const auto& locked_tags = eval_info.lock_tags;

const auto& key_index = *key_index_res;
for (unsigned i = key_index.start; i < key_index.end; ++i) {
string_view key = ArgS(args, i);
LockTag tag{key};
if (!locked_tags.contains(tag)) {
return ErrorReply(absl::StrCat(kUndeclaredKeyErr, ", key: ", key));
}
}

if (key_index.bonus) {
string_view key = ArgS(args, *key_index.bonus);
for (string_view key : key_index_res->Range(args)) {
if (!locked_tags.contains(LockTag{key})) {
return ErrorReply(absl::StrCat(kUndeclaredKeyErr, ", key: ", key));
}
Expand Down Expand Up @@ -2125,13 +2112,8 @@ template <typename F> void IterateAllKeys(ConnectionState::ExecInfo* exec_info,
if (!key_res.ok())
continue;

auto key_index = key_res.value();

for (unsigned i = key_index.start; i < key_index.end; i += key_index.step)
for (unsigned i : key_res->Range())
f(arg_vec[i]);

if (key_index.bonus)
f(arg_vec[*key_index.bonus]);
}
}

Expand Down
35 changes: 10 additions & 25 deletions src/server/multi_command_squasher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "server/conn_context.h"
#include "server/engine_shard_set.h"
#include "server/transaction.h"
#include "server/tx_base.h"

namespace dfly {

Expand All @@ -22,14 +23,6 @@ using namespace util;

namespace {

template <typename F> void IterateKeys(CmdArgList args, KeyIndex keys, F&& f) {
for (unsigned i = keys.start; i < keys.end; i += keys.step)
f(args[i]);

if (keys.bonus)
f(args[*keys.bonus]);
}

void CheckConnStateClean(const ConnectionState& state) {
DCHECK_EQ(state.exec_info.state, ConnectionState::ExecInfo::EXEC_INACTIVE);
DCHECK(state.exec_info.body.empty());
Expand Down Expand Up @@ -90,29 +83,21 @@ MultiCommandSquasher::SquashResult MultiCommandSquasher::TrySquash(StoredCmd* cm
auto keys = DetermineKeys(cmd->Cid(), args);
if (!keys.ok())
return SquashResult::ERROR;
if (keys->NumArgs() == 0)
return SquashResult::NOT_SQUASHED;

// Check if all commands belong to one shard
bool found_more = false;
cluster::UniqueSlotChecker slot_checker;
ShardId last_sid = kInvalidSid;
IterateKeys(args, *keys, [&last_sid, &found_more, &slot_checker](MutableSlice key) {
if (found_more)
return;

string_view key_sv = facade::ToSV(key);

slot_checker.Add(key_sv);

ShardId sid = Shard(key_sv, shard_set->size());
if (last_sid == kInvalidSid || last_sid == sid) {
for (string_view key : keys->Range(args)) {
slot_checker.Add(key);
ShardId sid = Shard(key, shard_set->size());
if (last_sid == kInvalidSid || last_sid == sid)
last_sid = sid;
return;
}
found_more = true;
});

if (found_more || last_sid == kInvalidSid)
return SquashResult::NOT_SQUASHED;
else
return SquashResult::NOT_SQUASHED; // at least two shards
}

auto& sinfo = PrepareShardInfo(last_sid, slot_checker.GetUniqueSlotId());

Expand Down
95 changes: 39 additions & 56 deletions src/server/transaction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,32 +190,23 @@ void Transaction::InitGlobal() {
}

void Transaction::BuildShardIndex(const KeyIndex& key_index, std::vector<PerShardCache>* out) {
// Because of the way we iterate in InitShardData
DCHECK(!key_index.bonus || key_index.step == 1);

auto& shard_index = *out;
for (unsigned i : key_index.Range()) {
string_view key = ArgS(full_args_, i);
unique_slot_checker_.Add(key);
ShardId sid = Shard(key, shard_data_.size());

auto add = [&shard_index](uint32_t sid, uint32_t b, uint32_t e) {
unsigned step = key_index.bonus ? 1 : key_index.step;
shard_index[sid].key_step = step;
auto& slices = shard_index[sid].slices;
if (!slices.empty() && slices.back().second == b) {
slices.back().second = e;
if (!slices.empty() && slices.back().second == i) {
slices.back().second = i + step;
} else {
slices.emplace_back(b, e);
slices.emplace_back(i, i + step);
}
};

if (key_index.bonus) {
DCHECK(key_index.step == 1);
string_view key = ArgS(full_args_, *key_index.bonus);
unique_slot_checker_.Add(key);
uint32_t sid = Shard(key, shard_data_.size());
add(sid, *key_index.bonus, *key_index.bonus + 1);
}

for (unsigned i = key_index.start; i < key_index.end; i += key_index.step) {
string_view key = ArgS(full_args_, i);
unique_slot_checker_.Add(key);
uint32_t sid = Shard(key, shard_data_.size());
shard_index[sid].key_step = key_index.step;

add(sid, i, i + key_index.step);
}
}

Expand Down Expand Up @@ -246,11 +237,9 @@ void Transaction::InitShardData(absl::Span<const PerShardCache> shard_index, siz
unique_shard_cnt_++;
unique_shard_id_ = i;

for (size_t j = 0; j < src.slices.size(); ++j) {
IndexSlice slice = src.slices[j];
args_slices_.push_back(slice);
for (uint32_t k = slice.first; k < slice.second; k += src.key_step) {
string_view key = ArgS(full_args_, k);
for (const auto& [start, end] : src.slices) {
args_slices_.emplace_back(start, end);
for (string_view key : KeyIndex(start, end, src.key_step).Range(full_args_)) {
kv_fp_.push_back(LockTag(key).Fingerprint());
sd.fp_count++;
}
Expand Down Expand Up @@ -278,10 +267,8 @@ void Transaction::StoreKeysInArgs(const KeyIndex& key_index) {

// even for a single key we may have multiple arguments per key (MSET).
args_slices_.emplace_back(key_index.start, key_index.end);
for (unsigned j = key_index.start; j < key_index.end; j += key_index.step) {
string_view key = ArgS(full_args_, j);
for (string_view key : key_index.Range(full_args_))
kv_fp_.push_back(LockTag(key).Fingerprint());
}
}

void Transaction::InitByKeys(const KeyIndex& key_index) {
Expand All @@ -295,14 +282,14 @@ void Transaction::InitByKeys(const KeyIndex& key_index) {
// Stub transactions always operate only on single shard.
bool is_stub = multi_ && multi_->role == SQUASHED_STUB;

if ((key_index.HasSingleKey() && !IsAtomicMulti()) || is_stub) {
if ((key_index.NumArgs() == 1 && !IsAtomicMulti()) || is_stub) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe rename the method to NumKeys instead of NumArgs?

Copy link
Contributor Author

@dranikpg dranikpg Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to do it first, but it's also used as literally num args. And a single argument can only exist with a single key in that context, so num_args == 1 <=> num_keys == 1

DCHECK(!IsActiveMulti() || multi_->mode == NON_ATOMIC);

// We don't have to split the arguments by shards, so we can copy them directly.
StoreKeysInArgs(key_index);

unique_shard_cnt_ = 1;
string_view akey = ArgS(full_args_, key_index.start);
string_view akey = *key_index.Range(full_args_).begin();
if (is_stub) // stub transactions don't migrate
DCHECK_EQ(unique_shard_id_, Shard(akey, shard_set->size()));
else {
Expand All @@ -328,7 +315,7 @@ void Transaction::InitByKeys(const KeyIndex& key_index) {
BuildShardIndex(key_index, &shard_index);

// Initialize shard data based on distributed arguments.
InitShardData(shard_index, key_index.num_args());
InitShardData(shard_index, key_index.NumArgs());

DCHECK(!multi_ || multi_->mode != LOCK_AHEAD || !multi_->tag_fps.empty());

Expand Down Expand Up @@ -440,7 +427,7 @@ void Transaction::StartMultiLockedAhead(Namespace* ns, DbIndex dbid, CmdArgList
PrepareMultiFps(keys);

InitBase(ns, dbid, keys);
InitByKeys(KeyIndex::Range(0, keys.size()));
InitByKeys(KeyIndex(0, keys.size()));

if (!skip_scheduling)
ScheduleInternal();
Expand Down Expand Up @@ -1432,23 +1419,24 @@ bool Transaction::CanRunInlined() const {
}

OpResult<KeyIndex> DetermineKeys(const CommandId* cid, CmdArgList args) {
KeyIndex key_index;

if (cid->opt_mask() & (CO::GLOBAL_TRANS | CO::NO_KEY_TRANSACTIONAL))
return key_index;
return KeyIndex{};

int num_custom_keys = -1;

if (cid->opt_mask() & CO::VARIADIC_KEYS) {
unsigned start = 0, end = 0, step = 0;
std::optional<unsigned> bonus = std::nullopt;

if (cid->opt_mask() & CO::VARIADIC_KEYS) { // number of keys is not trivially deducable
// ZUNION/INTER <num_keys> <key1> [<key2> ...]
// EVAL <script> <num_keys>
// XREAD ... STREAMS ...
if (args.size() < 2) {
if (args.size() < 2)
return OpStatus::SYNTAX_ERR;
}

string_view name{cid->name()};

// Determine based on STREAMS argument position
if (name == "XREAD" || name == "XREADGROUP") {
for (size_t i = 0; i < args.size(); ++i) {
string_view arg = ArgS(args, i);
Expand All @@ -1457,24 +1445,20 @@ OpResult<KeyIndex> DetermineKeys(const CommandId* cid, CmdArgList args) {
if (left < 2 || left % 2 != 0)
return OpStatus::SYNTAX_ERR;

key_index.start = i + 1;
key_index.end = key_index.start + (left / 2);
key_index.step = 1;

return key_index;
return KeyIndex(i + 1, i + 1 + (left / 2));
}
}
return OpStatus::SYNTAX_ERR;
}

if (absl::EndsWith(name, "STORE"))
key_index.bonus = 0; // Z<xxx>STORE <key> commands
bonus = 0; // Z<xxx>STORE <key> commands

unsigned num_keys_index;
if (absl::StartsWith(name, "EVAL"))
num_keys_index = 1;
else
num_keys_index = key_index.bonus ? *key_index.bonus + 1 : 0;
num_keys_index = bonus ? *bonus + 1 : 0;

string_view num = ArgS(args, num_keys_index);
if (!absl::SimpleAtoi(num, &num_custom_keys) || num_custom_keys < 0)
Expand All @@ -1491,22 +1475,22 @@ OpResult<KeyIndex> DetermineKeys(const CommandId* cid, CmdArgList args) {
}

if (cid->first_key_pos() > 0) {
key_index.start = cid->first_key_pos() - 1;
start = cid->first_key_pos() - 1;
int last = cid->last_key_pos();

if (num_custom_keys >= 0) {
key_index.end = key_index.start + num_custom_keys;
end = start + num_custom_keys;
} else {
key_index.end = last > 0 ? last : (int(args.size()) + last + 1);
end = last > 0 ? last : (int(args.size()) + last + 1);
}
if (cid->opt_mask() & CO::INTERLEAVED_KEYS) {
if (cid->name() == "JSON.MSET") {
key_index.step = 3;
step = 3;
} else {
key_index.step = 2;
step = 2;
}
} else {
key_index.step = 1;
step = 1;
}

if (cid->opt_mask() & CO::STORE_LAST_KEY) {
Expand All @@ -1516,17 +1500,16 @@ OpResult<KeyIndex> DetermineKeys(const CommandId* cid, CmdArgList args) {
// key member radius .. STORE destkey
string_view opt = ArgS(args, args.size() - 2);
if (absl::EqualsIgnoreCase(opt, "STORE") || absl::EqualsIgnoreCase(opt, "STOREDIST")) {
key_index.bonus = args.size() - 1;
bonus = args.size() - 1;
}
}
}

return key_index;
return KeyIndex{start, end, step, bonus};
}

LOG(FATAL) << "TBD: Not supported " << cid->name();

return key_index;
return {};
}

std::vector<Transaction::PerShardCache>& Transaction::TLTmpSpace::GetShardIndex(unsigned size) {
Expand Down
Loading
Loading