Skip to content

Commit

Permalink
chore: Refactor string span management (#3165)
Browse files Browse the repository at this point in the history
Signed-off-by: Vladislav Oleshko <[email protected]>
  • Loading branch information
dranikpg authored Jun 18, 2024
1 parent 6291c04 commit e45c1e9
Show file tree
Hide file tree
Showing 23 changed files with 198 additions and 273 deletions.
2 changes: 1 addition & 1 deletion helio
1 change: 1 addition & 0 deletions src/core/string_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class StringSet : public DenseSet {
}

uint32_t Scan(uint32_t, const std::function<void(sds)>&) const;

iterator Find(std::string_view member) {
return iterator{FindIt(&member, 1)};
}
Expand Down
54 changes: 48 additions & 6 deletions src/facade/facade_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <string_view>
#include <variant>

#include "base/iterator.h"
#include "facade/op_status.h"

namespace facade {
Expand All @@ -37,6 +38,8 @@ enum class Protocol : uint8_t { MEMCACHE = 1, REDIS = 2 };
using MutableSlice = absl::Span<char>;
using CmdArgList = absl::Span<MutableSlice>;
using CmdArgVec = std::vector<MutableSlice>;
using ArgSlice = absl::Span<const std::string_view>;
using OwnedArgSlice = absl::Span<const std::string>;

inline std::string_view ToSV(MutableSlice slice) {
return std::string_view{slice.data(), slice.size()};
Expand All @@ -46,6 +49,50 @@ inline std::string_view ToSV(std::string_view slice) {
return slice;
}

inline std::string_view ToSV(const std::string& slice) {
return slice;
}

inline std::string_view ToSV(std::string&& slice) = delete;

constexpr auto kToSV = [](auto&& v) { return ToSV(std::forward<decltype(v)>(v)); };

inline std::string_view ArgS(CmdArgList args, size_t i) {
auto arg = args[i];
return {arg.data(), arg.size()};
}

inline auto ArgS(CmdArgList args) {
return base::it::Transform(kToSV, base::it::Range{args.begin(), args.end()});
}

struct ArgRange {
ArgRange(ArgRange&&) = default;
ArgRange(const ArgRange&) = default;
ArgRange(ArgRange& range) : ArgRange((const ArgRange&)range) {
}

template <typename T> ArgRange(T&& span) : span(std::forward<T>(span)) {
}

size_t Size() const {
return std::visit([](const auto& span) { return span.size(); }, span);
}

auto Range() const {
return base::it::Wrap(kToSV, span);
}

auto begin() const {
return Range().first;
}

auto end() const {
return Range().second;
}

std::variant<CmdArgList, ArgSlice, OwnedArgSlice> span;
};
struct ConnectionStats {
size_t read_buf_capacity = 0; // total capacity of input buffers
uint64_t dispatch_queue_entries = 0; // total number of dispatch queue entries
Expand Down Expand Up @@ -120,7 +167,7 @@ struct ErrorReply {
}

std::string_view ToSv() const {
return std::visit([](auto& str) { return std::string_view(str); }, message);
return std::visit(kToSV, message);
}

std::variant<std::string, std::string_view> message;
Expand All @@ -132,11 +179,6 @@ inline MutableSlice ToMSS(absl::Span<uint8_t> span) {
return MutableSlice{reinterpret_cast<char*>(span.data()), span.size()};
}

inline std::string_view ArgS(CmdArgList args, size_t i) {
auto arg = args[i];
return std::string_view(arg.data(), arg.size());
}

constexpr inline unsigned long long operator""_MB(unsigned long long x) {
return 1024L * 1024L * x;
}
Expand Down
31 changes: 9 additions & 22 deletions src/facade/reply_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ void SinkReplyBuilder::SendError(ErrorReply error) {
if (error.status)
return SendError(*error.status);

string_view message_sv = visit([](auto&& str) -> string_view { return str; }, error.message);
SendError(message_sv, error.kind);
SendError(error.ToSv(), error.kind);
}

void SinkReplyBuilder::SendError(OpStatus status) {
Expand Down Expand Up @@ -264,14 +263,6 @@ void MCReplyBuilder::SendNotFound() {
SendSimpleString("NOT_FOUND");
}

size_t RedisReplyBuilder::WrappedStrSpan::Size() const {
return visit([](auto arr) { return arr.size(); }, (const StrSpan&)*this);
}

string_view RedisReplyBuilder::WrappedStrSpan::operator[](size_t i) const {
return visit([i](auto arr) { return string_view{arr[i]}; }, (const StrSpan&)*this);
}

char* RedisReplyBuilder::FormatDouble(double val, char* dest, unsigned dest_len) {
StringBuilder sb(dest, dest_len);
CHECK(dfly_conv.ToShortest(val, &sb));
Expand Down Expand Up @@ -504,12 +495,9 @@ void RedisReplyBuilder::SendMGetResponse(MGetResponse resp) {
}

void RedisReplyBuilder::SendSimpleStrArr(StrSpan arr) {
WrappedStrSpan warr{arr};

string res = absl::StrCat("*", warr.Size(), kCRLF);

for (unsigned i = 0; i < warr.Size(); i++)
StrAppend(&res, "+", warr[i], kCRLF);
string res = absl::StrCat("*", arr.Size(), kCRLF);
for (std::string_view str : arr)
StrAppend(&res, "+", str, kCRLF);

SendRaw(res);
}
Expand All @@ -523,16 +511,15 @@ void RedisReplyBuilder::SendEmptyArray() {
}

void RedisReplyBuilder::SendStringArr(StrSpan arr, CollectionType type) {
WrappedStrSpan warr{arr};

if (type == ARRAY && warr.Size() == 0) {
if (type == ARRAY && arr.Size() == 0) {
SendRaw("*0\r\n");
return;
}

auto cb = [&](size_t i) { return warr[i]; };

SendStringArrInternal(warr.Size(), std::move(cb), type);
auto cb = [&](size_t i) {
return visit([i](auto& span) { return facade::ToSV(span[i]); }, arr.span);
};
SendStringArrInternal(arr.Size(), std::move(cb), type);
}

void RedisReplyBuilder::StartArray(unsigned len) {
Expand Down
8 changes: 1 addition & 7 deletions src/facade/reply_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class RedisReplyBuilder : public SinkReplyBuilder {

enum VerbatimFormat { TXT, MARKDOWN };

using StrSpan = std::variant<absl::Span<const std::string>, absl::Span<const std::string_view>>;
using StrSpan = facade::ArgRange;

RedisReplyBuilder(::io::Sink* stream);

Expand Down Expand Up @@ -242,12 +242,6 @@ class RedisReplyBuilder : public SinkReplyBuilder {

static char* FormatDouble(double val, char* dest, unsigned dest_len);

protected:
struct WrappedStrSpan : public StrSpan {
size_t Size() const;
std::string_view operator[](size_t index) const;
};

private:
void SendStringArrInternal(size_t size, absl::FunctionRef<std::string_view(unsigned)> producer,
CollectionType type);
Expand Down
19 changes: 3 additions & 16 deletions src/facade/reply_capture.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ void CapturingReplyBuilder::SendError(std::string_view str, std::string_view typ

void CapturingReplyBuilder::SendError(ErrorReply error) {
SKIP_LESS(ReplyMode::ONLY_ERR);

string message =
visit([](auto&& str) -> string { return string{std::move(str)}; }, error.message);
Capture(Error{std::move(message), error.kind});
Capture(Error{error.ToSv(), error.kind});
}

void CapturingReplyBuilder::SendMGetResponse(MGetResponse resp) {
Expand All @@ -53,25 +50,15 @@ void CapturingReplyBuilder::SendSimpleStrArr(StrSpan arr) {
SKIP_LESS(ReplyMode::FULL);
DCHECK_EQ(current_.index(), 0u);

WrappedStrSpan warr{arr};
vector<string> sarr(warr.Size());
for (unsigned i = 0; i < warr.Size(); i++)
sarr[i] = warr[i];

Capture(StrArrPayload{true, ARRAY, std::move(sarr)});
Capture(StrArrPayload{true, ARRAY, {arr.begin(), arr.end()}});
}

void CapturingReplyBuilder::SendStringArr(StrSpan arr, CollectionType type) {
SKIP_LESS(ReplyMode::FULL);
DCHECK_EQ(current_.index(), 0u);

// TODO: 1. Allocate all strings at once 2. Allow movable types
WrappedStrSpan warr{arr};
vector<string> sarr(warr.Size());
for (unsigned i = 0; i < warr.Size(); i++)
sarr[i] = warr[i];

Capture(StrArrPayload{false, type, std::move(sarr)});
Capture(StrArrPayload{false, type, {arr.begin(), arr.end()}});
}

void CapturingReplyBuilder::SendNull() {
Expand Down
2 changes: 1 addition & 1 deletion src/server/acl/acl_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ GenericError AclFamily::LoadToRegistryFromFile(std::string_view full_path,
std::vector<User::UpdateRequest> requests;

for (auto& cmds : *materialized) {
auto req = ParseAclSetUser<std::vector<std::string_view>&>(cmds, *cmd_registry_, true);
auto req = ParseAclSetUser(cmds, *cmd_registry_, true);
if (std::holds_alternative<ErrorReply>(req)) {
auto error = std::move(std::get<ErrorReply>(req));
LOG(WARNING) << "Error while parsing aclfile: " << error.ToSv();
Expand Down
26 changes: 3 additions & 23 deletions src/server/acl/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,12 @@ MaterializedContents MaterializeFileContents(std::vector<std::string>* usernames

using facade::ErrorReply;

template <typename T>
std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(T args,
std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(facade::ArgRange args,
const CommandRegistry& registry,
bool hashed, bool has_all_keys) {
User::UpdateRequest req;

for (auto& arg : args) {
for (std::string_view arg : args) {
if (auto pass = MaybeParsePassword(facade::ToSV(arg), hashed); pass) {
if (req.password) {
return ErrorReply("Only one password is allowed");
Expand Down Expand Up @@ -291,18 +290,7 @@ std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(T args,
continue;
}

std::string buffer;
std::string_view command;
if constexpr (std::is_same_v<T, facade::CmdArgList>) {
ToUpper(&arg);
command = facade::ToSV(arg);
} else {
// Guaranteed SSO because commands are small
buffer = arg;
absl::Span<char> view{buffer.data(), buffer.size()};
ToUpper(&view);
command = buffer;
}
std::string command = absl::AsciiStrToUpper(arg);

if (auto status = MaybeParseStatus(command); status) {
if (req.is_active) {
Expand Down Expand Up @@ -338,14 +326,6 @@ std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(T args,

using facade::CmdArgList;

template std::variant<User::UpdateRequest, ErrorReply>
ParseAclSetUser<std::vector<std::string_view>&>(std::vector<std::string_view>&,
const CommandRegistry& registry, bool hashed,
bool has_all_keys);

template std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser<CmdArgList>(
CmdArgList args, const CommandRegistry& registry, bool hashed, bool has_all_keys);

std::string AclKeysToString(const AclKeys& keys) {
if (keys.all_keys) {
return "~*";
Expand Down
4 changes: 2 additions & 2 deletions src/server/acl/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ using OptCommand = std::optional<std::pair<size_t, uint64_t>>;
std::pair<OptCommand, bool> MaybeParseAclCommand(std::string_view command,
const CommandRegistry& registry);

template <typename T>
std::variant<User::UpdateRequest, facade::ErrorReply> ParseAclSetUser(
T args, const CommandRegistry& registry, bool hashed = false, bool has_all_keys = false);
facade::ArgRange args, const CommandRegistry& registry, bool hashed = false,
bool has_all_keys = false);

using MaterializedContents = std::optional<std::vector<std::vector<std::string_view>>>;

Expand Down
9 changes: 4 additions & 5 deletions src/server/blocking_controller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ bool BlockingController::DbWatchTable::AddAwakeEvent(string_view key) {
}

// Removes tx from its watch queues if tx appears there.
void BlockingController::FinalizeWatched(const ShardArgs& args, Transaction* tx) {
void BlockingController::FinalizeWatched(Keys keys, Transaction* tx) {
DCHECK(tx);
VLOG(1) << "FinalizeBlocking [" << owner_->shard_id() << "]" << tx->DebugId();

Expand All @@ -135,7 +135,7 @@ void BlockingController::FinalizeWatched(const ShardArgs& args, Transaction* tx)

// Add keys of processed transaction so we could awake the next one in the queue
// in case those keys still exist.
for (string_view key : args) {
for (string_view key : base::it::Wrap(facade::kToSV, keys)) {
bool removed_awakened = wt.UnwatchTx(key, tx);
CHECK(!removed_awakened || removed)
<< tx->DebugId() << " " << key << " " << tx->DEBUG_GetLocalMask(owner_->shard_id());
Expand Down Expand Up @@ -197,16 +197,15 @@ void BlockingController::NotifyPending() {
awakened_indices_.clear();
}

void BlockingController::AddWatched(const ShardArgs& watch_keys, KeyReadyChecker krc,
Transaction* trans) {
void BlockingController::AddWatched(Keys watch_keys, KeyReadyChecker krc, Transaction* trans) {
auto [dbit, added] = watched_dbs_.emplace(trans->GetDbIndex(), nullptr);
if (added) {
dbit->second.reset(new DbWatchTable);
}

DbWatchTable& wt = *dbit->second;

for (auto key : watch_keys) {
for (auto key : base::it::Wrap(facade::kToSV, watch_keys)) {
auto [res, inserted] = wt.queue_map.emplace(key, nullptr);
if (inserted) {
res->second.reset(new WatchQueue);
Expand Down
6 changes: 4 additions & 2 deletions src/server/blocking_controller.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class BlockingController {
explicit BlockingController(EngineShard* owner);
~BlockingController();

using Keys = std::variant<ShardArgs, ArgSlice>;

bool HasAwakedTransaction() const {
return !awakened_transactions_.empty();
}
Expand All @@ -29,7 +31,7 @@ class BlockingController {
return awakened_transactions_;
}

void FinalizeWatched(const ShardArgs& args, Transaction* tx);
void FinalizeWatched(Keys keys, Transaction* tx);

// go over potential wakened keys, verify them and activate watch queues.
void NotifyPending();
Expand All @@ -38,7 +40,7 @@ class BlockingController {
// TODO: consider moving all watched functions to
// EngineShard with separate per db map.
//! AddWatched adds a transaction to the blocking queue.
void AddWatched(const ShardArgs& watch_keys, KeyReadyChecker krc, Transaction* me);
void AddWatched(Keys watch_keys, KeyReadyChecker krc, Transaction* me);

// Called from operations that create keys like lpush, rename etc.
void AwakeWatched(DbIndex db_index, std::string_view db_key);
Expand Down
6 changes: 3 additions & 3 deletions src/server/conn_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,15 @@ vector<unsigned> ChangeSubscriptions(bool pattern, CmdArgList args, bool to_add,
ChannelStoreUpdater csu{pattern, to_add, conn, uint32_t(tid)};

// Gather all the channels we need to subscribe to / remove.
for (size_t i = 0; i < args.size(); ++i) {
string_view channel = ArgS(args, i);
size_t i = 0;
for (string_view channel : ArgS(args)) {
if (to_add && local_store.emplace(channel).second)
csu.Record(channel);
else if (!to_add && local_store.erase(channel) > 0)
csu.Record(channel);

if (to_reply)
result[i] = sinfo.SubscriptionCount();
result[i++] = sinfo.SubscriptionCount();
}

csu.Apply();
Expand Down
Loading

0 comments on commit e45c1e9

Please sign in to comment.