Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of new command: ssubscribe and sunsubscribe #2003

Merged
merged 1 commit into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions src/commands/cmd_pubsub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,44 @@ class CommandPUnSubscribe : public Commander {
}
};

class CommandSSubscribe : public Commander {
public:
Status Execute(Server *srv, Connection *conn, std::string *output) override {
uint16_t slot = 0;
if (srv->GetConfig()->cluster_enabled) {
slot = GetSlotIdFromKey(args_[1]);
for (unsigned int i = 2; i < args_.size(); i++) {
if (GetSlotIdFromKey(args_[i]) != slot) {
return {Status::RedisExecErr, "CROSSSLOT Keys in request don't hash to the same slot"};
}
}
}

for (unsigned int i = 1; i < args_.size(); i++) {
conn->SSubscribeChannel(args_[i], slot);
SubscribeCommandReply(output, "ssubscribe", args_[i], conn->SSubscriptionsCount());
}
return Status::OK();
}
};

class CommandSUnSubscribe : public Commander {
public:
Status Execute(Server *srv, Connection *conn, std::string *output) override {
if (args_.size() == 1) {
conn->SUnsubscribeAll([output](const std::string &sub_name, int num) {
SubscribeCommandReply(output, "sunsubscribe", sub_name, num);
});
} else {
for (size_t i = 1; i < args_.size(); i++) {
conn->SUnsubscribeChannel(args_[i], srv->GetConfig()->cluster_enabled ? GetSlotIdFromKey(args_[i]) : 0);
SubscribeCommandReply(output, "sunsubscribe", args_[i], conn->SSubscriptionsCount());
}
}
return Status::OK();
}
};

class CommandPubSub : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
Expand All @@ -146,14 +184,14 @@ class CommandPubSub : public Commander {
return Status::OK();
}

if ((subcommand_ == "numsub") && args.size() >= 2) {
if ((subcommand_ == "numsub" || subcommand_ == "shardnumsub") && args.size() >= 2) {
if (args.size() > 2) {
channels_ = std::vector<std::string>(args.begin() + 2, args.end());
}
return Status::OK();
}

if ((subcommand_ == "channels") && args.size() <= 3) {
if ((subcommand_ == "channels" || subcommand_ == "shardchannels") && args.size() <= 3) {
if (args.size() == 3) {
pattern_ = args[2];
}
Expand All @@ -169,9 +207,13 @@ class CommandPubSub : public Commander {
return Status::OK();
}

if (subcommand_ == "numsub") {
if (subcommand_ == "numsub" || subcommand_ == "shardnumsub") {
std::vector<ChannelSubscribeNum> channel_subscribe_nums;
srv->ListChannelSubscribeNum(channels_, &channel_subscribe_nums);
if (subcommand_ == "numsub") {
srv->ListChannelSubscribeNum(channels_, &channel_subscribe_nums);
} else {
srv->ListSChannelSubscribeNum(channels_, &channel_subscribe_nums);
}

output->append(redis::MultiLen(channel_subscribe_nums.size() * 2));
for (const auto &chan_subscribe_num : channel_subscribe_nums) {
Expand All @@ -182,9 +224,13 @@ class CommandPubSub : public Commander {
return Status::OK();
}

if (subcommand_ == "channels") {
if (subcommand_ == "channels" || subcommand_ == "shardchannels") {
std::vector<std::string> channels;
srv->GetChannelsByPattern(pattern_, &channels);
if (subcommand_ == "channels") {
srv->GetChannelsByPattern(pattern_, &channels);
} else {
srv->GetSChannelsByPattern(pattern_, &channels);
}
*output = redis::MultiBulkString(channels);
return Status::OK();
}
Expand All @@ -205,6 +251,8 @@ REDIS_REGISTER_COMMANDS(
MakeCmdAttr<CommandUnSubscribe>("unsubscribe", -1, "read-only pub-sub no-multi no-script", 0, 0, 0),
MakeCmdAttr<CommandPSubscribe>("psubscribe", -2, "read-only pub-sub no-multi no-script", 0, 0, 0),
MakeCmdAttr<CommandPUnSubscribe>("punsubscribe", -1, "read-only pub-sub no-multi no-script", 0, 0, 0),
MakeCmdAttr<CommandSSubscribe>("ssubscribe", -2, "read-only pub-sub no-multi no-script", 0, 0, 0),
MakeCmdAttr<CommandSUnSubscribe>("sunsubscribe", -1, "read-only pub-sub no-multi no-script", 0, 0, 0),
MakeCmdAttr<CommandPubSub>("pubsub", -2, "read-only pub-sub no-script", 0, 0, 0), )

} // namespace redis
5 changes: 3 additions & 2 deletions src/commands/cmd_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,7 @@ class CommandAnalyze : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
if (args.size() <= 1) return {Status::RedisExecErr, errInvalidSyntax};
for (int i = 1; i < args.size(); ++i) {
for (unsigned int i = 1; i < args.size(); ++i) {
command_args_.push_back(args[i]);
}
return Status::OK();
Expand All @@ -1178,7 +1178,8 @@ class CommandAnalyze : public Commander {
cmd->SetArgs(command_args_);

int arity = cmd->GetAttributes()->arity;
if ((arity > 0 && command_args_.size() != arity) || (arity < 0 && command_args_.size() < -arity)) {
if ((arity > 0 && static_cast<int>(command_args_.size()) != arity) ||
(arity < 0 && static_cast<int>(command_args_.size()) < -arity)) {
*output = redis::Error("ERR wrong number of arguments");
return {Status::RedisExecErr, errWrongNumOfArguments};
}
Expand Down
39 changes: 39 additions & 0 deletions src/server/redis_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,45 @@ void Connection::PUnsubscribeAll(const UnsubscribeCallback &reply) {

int Connection::PSubscriptionsCount() { return static_cast<int>(subscribe_patterns_.size()); }

void Connection::SSubscribeChannel(const std::string &channel, uint16_t slot) {
for (const auto &chan : subscribe_shard_channels_) {
if (channel == chan) return;
}

subscribe_shard_channels_.emplace_back(channel);
owner_->srv->SSubscribeChannel(channel, this, slot);
}

void Connection::SUnsubscribeChannel(const std::string &channel, uint16_t slot) {
for (auto iter = subscribe_shard_channels_.begin(); iter != subscribe_shard_channels_.end(); iter++) {
if (*iter == channel) {
subscribe_shard_channels_.erase(iter);
owner_->srv->SUnsubscribeChannel(channel, this, slot);
return;
}
}
}

void Connection::SUnsubscribeAll(const UnsubscribeCallback &reply) {
if (subscribe_shard_channels_.empty()) {
if (reply) reply("", 0);
return;
}

int removed = 0;
for (const auto &chan : subscribe_shard_channels_) {
owner_->srv->SUnsubscribeChannel(chan, this,
owner_->srv->GetConfig()->cluster_enabled ? GetSlotIdFromKey(chan) : 0);
removed++;
if (reply) {
reply(chan, static_cast<int>(subscribe_shard_channels_.size() - removed));
}
}
subscribe_shard_channels_.clear();
}

int Connection::SSubscriptionsCount() { return static_cast<int>(subscribe_shard_channels_.size()); }

bool Connection::IsProfilingEnabled(const std::string &cmd) {
auto config = srv_->GetConfig();
if (config->profiling_sample_ratio == 0) return false;
Expand Down
5 changes: 5 additions & 0 deletions src/server/redis_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class Connection : public EvbufCallbackBase<Connection> {
void PUnsubscribeChannel(const std::string &pattern);
void PUnsubscribeAll(const UnsubscribeCallback &reply = nullptr);
int PSubscriptionsCount();
void SSubscribeChannel(const std::string &channel, uint16_t slot);
void SUnsubscribeChannel(const std::string &channel, uint16_t slot);
void SUnsubscribeAll(const UnsubscribeCallback &reply = nullptr);
int SSubscriptionsCount();

uint64_t GetAge() const;
uint64_t GetIdleTime() const;
Expand Down Expand Up @@ -159,6 +163,7 @@ class Connection : public EvbufCallbackBase<Connection> {

std::vector<std::string> subscribe_channels_;
std::vector<std::string> subscribe_patterns_;
std::vector<std::string> subscribe_shard_channels_;

Server *srv_;
bool in_exec_ = false;
Expand Down
61 changes: 61 additions & 0 deletions src/server/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ Server::Server(engine::Storage *storage, Config *config)
// Init cluster
cluster = std::make_unique<Cluster>(this, config_->binds, config_->port);

// init shard pub/sub channels
pubsub_shard_channels_.resize(config->cluster_enabled ? HASH_SLOTS_SIZE : 1);

for (int i = 0; i < config->workers; i++) {
auto worker = std::make_unique<Worker>(this, config);
// multiple workers can't listen to the same unix socket, so
Expand Down Expand Up @@ -497,6 +500,64 @@ void Server::PUnsubscribeChannel(const std::string &pattern, redis::Connection *
}
}

void Server::SSubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot) {
assert((config_->cluster_enabled && slot < HASH_SLOTS_SIZE) || slot == 0);
std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);

auto conn_ctx = ConnContext(conn->Owner(), conn->GetFD());
if (auto iter = pubsub_shard_channels_[slot].find(channel); iter == pubsub_shard_channels_[slot].end()) {
pubsub_shard_channels_[slot].emplace(channel, std::list<ConnContext>{conn_ctx});
} else {
iter->second.emplace_back(conn_ctx);
}
}

void Server::SUnsubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot) {
assert((config_->cluster_enabled && slot < HASH_SLOTS_SIZE) || slot == 0);
std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);

auto iter = pubsub_shard_channels_[slot].find(channel);
if (iter == pubsub_shard_channels_[slot].end()) {
return;
}

for (const auto &conn_ctx : iter->second) {
if (conn->GetFD() == conn_ctx.fd && conn->Owner() == conn_ctx.owner) {
iter->second.remove(conn_ctx);
if (iter->second.empty()) {
pubsub_shard_channels_[slot].erase(iter);
}
break;
}
}
}

void Server::GetSChannelsByPattern(const std::string &pattern, std::vector<std::string> *channels) {
std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);

for (const auto &shard_channels : pubsub_shard_channels_) {
for (const auto &iter : shard_channels) {
if (pattern.empty() || util::StringMatch(pattern, iter.first, 0)) {
channels->emplace_back(iter.first);
}
}
}
}

void Server::ListSChannelSubscribeNum(const std::vector<std::string> &channels,
std::vector<ChannelSubscribeNum> *channel_subscribe_nums) {
std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);

for (const auto &chan : channels) {
uint16_t slot = config_->cluster_enabled ? GetSlotIdFromKey(chan) : 0;
if (auto iter = pubsub_shard_channels_[slot].find(chan); iter != pubsub_shard_channels_[slot].end()) {
channel_subscribe_nums->emplace_back(ChannelSubscribeNum{iter->first, iter->second.size()});
} else {
channel_subscribe_nums->emplace_back(ChannelSubscribeNum{chan, 0});
}
}
}

void Server::BlockOnKey(const std::string &key, redis::Connection *conn) {
std::lock_guard<std::mutex> guard(blocking_keys_mu_);

Expand Down
7 changes: 7 additions & 0 deletions src/server/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ class Server {
void PSubscribeChannel(const std::string &pattern, redis::Connection *conn);
void PUnsubscribeChannel(const std::string &pattern, redis::Connection *conn);
size_t GetPubSubPatternSize() const { return pubsub_patterns_.size(); }
void SSubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot);
void SUnsubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot);
void GetSChannelsByPattern(const std::string &pattern, std::vector<std::string> *channels);
void ListSChannelSubscribeNum(const std::vector<std::string> &channels,
std::vector<ChannelSubscribeNum> *channel_subscribe_nums);

void BlockOnKey(const std::string &key, redis::Connection *conn);
void UnblockOnKey(const std::string &key, redis::Connection *conn);
Expand Down Expand Up @@ -351,6 +356,8 @@ class Server {
std::map<std::string, std::list<ConnContext>> pubsub_channels_;
std::map<std::string, std::list<ConnContext>> pubsub_patterns_;
std::mutex pubsub_channels_mu_;
std::vector<std::map<std::string, std::list<ConnContext>>> pubsub_shard_channels_;
std::mutex pubsub_shard_channels_mu_;
std::map<std::string, std::list<ConnContext>> blocking_keys_;
std::mutex blocking_keys_mu_;

Expand Down
Loading