Skip to content

Commit

Permalink
Add support of new command: ssubscribe and sunsubscribe (#2003)
Browse files Browse the repository at this point in the history
  • Loading branch information
raffertyyu authored Jan 12, 2024
1 parent ee41959 commit 505aeb6
Show file tree
Hide file tree
Showing 7 changed files with 333 additions and 8 deletions.
60 changes: 54 additions & 6 deletions src/commands/cmd_pubsub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,44 @@ class CommandPUnSubscribe : public Commander {
}
};

class CommandSSubscribe : public Commander {
public:
Status Execute(Server *srv, Connection *conn, std::string *output) override {
uint16_t slot = 0;
if (srv->GetConfig()->cluster_enabled) {
slot = GetSlotIdFromKey(args_[1]);
for (unsigned int i = 2; i < args_.size(); i++) {
if (GetSlotIdFromKey(args_[i]) != slot) {
return {Status::RedisExecErr, "CROSSSLOT Keys in request don't hash to the same slot"};
}
}
}

for (unsigned int i = 1; i < args_.size(); i++) {
conn->SSubscribeChannel(args_[i], slot);
SubscribeCommandReply(output, "ssubscribe", args_[i], conn->SSubscriptionsCount());
}
return Status::OK();
}
};

class CommandSUnSubscribe : public Commander {
public:
Status Execute(Server *srv, Connection *conn, std::string *output) override {
if (args_.size() == 1) {
conn->SUnsubscribeAll([output](const std::string &sub_name, int num) {
SubscribeCommandReply(output, "sunsubscribe", sub_name, num);
});
} else {
for (size_t i = 1; i < args_.size(); i++) {
conn->SUnsubscribeChannel(args_[i], srv->GetConfig()->cluster_enabled ? GetSlotIdFromKey(args_[i]) : 0);
SubscribeCommandReply(output, "sunsubscribe", args_[i], conn->SSubscriptionsCount());
}
}
return Status::OK();
}
};

class CommandPubSub : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
Expand All @@ -146,14 +184,14 @@ class CommandPubSub : public Commander {
return Status::OK();
}

if ((subcommand_ == "numsub") && args.size() >= 2) {
if ((subcommand_ == "numsub" || subcommand_ == "shardnumsub") && args.size() >= 2) {
if (args.size() > 2) {
channels_ = std::vector<std::string>(args.begin() + 2, args.end());
}
return Status::OK();
}

if ((subcommand_ == "channels") && args.size() <= 3) {
if ((subcommand_ == "channels" || subcommand_ == "shardchannels") && args.size() <= 3) {
if (args.size() == 3) {
pattern_ = args[2];
}
Expand All @@ -169,9 +207,13 @@ class CommandPubSub : public Commander {
return Status::OK();
}

if (subcommand_ == "numsub") {
if (subcommand_ == "numsub" || subcommand_ == "shardnumsub") {
std::vector<ChannelSubscribeNum> channel_subscribe_nums;
srv->ListChannelSubscribeNum(channels_, &channel_subscribe_nums);
if (subcommand_ == "numsub") {
srv->ListChannelSubscribeNum(channels_, &channel_subscribe_nums);
} else {
srv->ListSChannelSubscribeNum(channels_, &channel_subscribe_nums);
}

output->append(redis::MultiLen(channel_subscribe_nums.size() * 2));
for (const auto &chan_subscribe_num : channel_subscribe_nums) {
Expand All @@ -182,9 +224,13 @@ class CommandPubSub : public Commander {
return Status::OK();
}

if (subcommand_ == "channels") {
if (subcommand_ == "channels" || subcommand_ == "shardchannels") {
std::vector<std::string> channels;
srv->GetChannelsByPattern(pattern_, &channels);
if (subcommand_ == "channels") {
srv->GetChannelsByPattern(pattern_, &channels);
} else {
srv->GetSChannelsByPattern(pattern_, &channels);
}
*output = redis::MultiBulkString(channels);
return Status::OK();
}
Expand All @@ -205,6 +251,8 @@ REDIS_REGISTER_COMMANDS(
MakeCmdAttr<CommandUnSubscribe>("unsubscribe", -1, "read-only pub-sub no-multi no-script", 0, 0, 0),
MakeCmdAttr<CommandPSubscribe>("psubscribe", -2, "read-only pub-sub no-multi no-script", 0, 0, 0),
MakeCmdAttr<CommandPUnSubscribe>("punsubscribe", -1, "read-only pub-sub no-multi no-script", 0, 0, 0),
MakeCmdAttr<CommandSSubscribe>("ssubscribe", -2, "read-only pub-sub no-multi no-script", 0, 0, 0),
MakeCmdAttr<CommandSUnSubscribe>("sunsubscribe", -1, "read-only pub-sub no-multi no-script", 0, 0, 0),
MakeCmdAttr<CommandPubSub>("pubsub", -2, "read-only pub-sub no-script", 0, 0, 0), )

} // namespace redis
5 changes: 3 additions & 2 deletions src/commands/cmd_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,7 @@ class CommandAnalyze : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
if (args.size() <= 1) return {Status::RedisExecErr, errInvalidSyntax};
for (int i = 1; i < args.size(); ++i) {
for (unsigned int i = 1; i < args.size(); ++i) {
command_args_.push_back(args[i]);
}
return Status::OK();
Expand All @@ -1178,7 +1178,8 @@ class CommandAnalyze : public Commander {
cmd->SetArgs(command_args_);

int arity = cmd->GetAttributes()->arity;
if ((arity > 0 && command_args_.size() != arity) || (arity < 0 && command_args_.size() < -arity)) {
if ((arity > 0 && static_cast<int>(command_args_.size()) != arity) ||
(arity < 0 && static_cast<int>(command_args_.size()) < -arity)) {
*output = redis::Error("ERR wrong number of arguments");
return {Status::RedisExecErr, errWrongNumOfArguments};
}
Expand Down
39 changes: 39 additions & 0 deletions src/server/redis_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,45 @@ void Connection::PUnsubscribeAll(const UnsubscribeCallback &reply) {

int Connection::PSubscriptionsCount() { return static_cast<int>(subscribe_patterns_.size()); }

void Connection::SSubscribeChannel(const std::string &channel, uint16_t slot) {
for (const auto &chan : subscribe_shard_channels_) {
if (channel == chan) return;
}

subscribe_shard_channels_.emplace_back(channel);
owner_->srv->SSubscribeChannel(channel, this, slot);
}

void Connection::SUnsubscribeChannel(const std::string &channel, uint16_t slot) {
for (auto iter = subscribe_shard_channels_.begin(); iter != subscribe_shard_channels_.end(); iter++) {
if (*iter == channel) {
subscribe_shard_channels_.erase(iter);
owner_->srv->SUnsubscribeChannel(channel, this, slot);
return;
}
}
}

void Connection::SUnsubscribeAll(const UnsubscribeCallback &reply) {
if (subscribe_shard_channels_.empty()) {
if (reply) reply("", 0);
return;
}

int removed = 0;
for (const auto &chan : subscribe_shard_channels_) {
owner_->srv->SUnsubscribeChannel(chan, this,
owner_->srv->GetConfig()->cluster_enabled ? GetSlotIdFromKey(chan) : 0);
removed++;
if (reply) {
reply(chan, static_cast<int>(subscribe_shard_channels_.size() - removed));
}
}
subscribe_shard_channels_.clear();
}

int Connection::SSubscriptionsCount() { return static_cast<int>(subscribe_shard_channels_.size()); }

bool Connection::IsProfilingEnabled(const std::string &cmd) {
auto config = srv_->GetConfig();
if (config->profiling_sample_ratio == 0) return false;
Expand Down
5 changes: 5 additions & 0 deletions src/server/redis_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class Connection : public EvbufCallbackBase<Connection> {
void PUnsubscribeChannel(const std::string &pattern);
void PUnsubscribeAll(const UnsubscribeCallback &reply = nullptr);
int PSubscriptionsCount();
void SSubscribeChannel(const std::string &channel, uint16_t slot);
void SUnsubscribeChannel(const std::string &channel, uint16_t slot);
void SUnsubscribeAll(const UnsubscribeCallback &reply = nullptr);
int SSubscriptionsCount();

uint64_t GetAge() const;
uint64_t GetIdleTime() const;
Expand Down Expand Up @@ -159,6 +163,7 @@ class Connection : public EvbufCallbackBase<Connection> {

std::vector<std::string> subscribe_channels_;
std::vector<std::string> subscribe_patterns_;
std::vector<std::string> subscribe_shard_channels_;

Server *srv_;
bool in_exec_ = false;
Expand Down
61 changes: 61 additions & 0 deletions src/server/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ Server::Server(engine::Storage *storage, Config *config)
// Init cluster
cluster = std::make_unique<Cluster>(this, config_->binds, config_->port);

// init shard pub/sub channels
pubsub_shard_channels_.resize(config->cluster_enabled ? HASH_SLOTS_SIZE : 1);

for (int i = 0; i < config->workers; i++) {
auto worker = std::make_unique<Worker>(this, config);
// multiple workers can't listen to the same unix socket, so
Expand Down Expand Up @@ -497,6 +500,64 @@ void Server::PUnsubscribeChannel(const std::string &pattern, redis::Connection *
}
}

void Server::SSubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot) {
assert((config_->cluster_enabled && slot < HASH_SLOTS_SIZE) || slot == 0);
std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);

auto conn_ctx = ConnContext(conn->Owner(), conn->GetFD());
if (auto iter = pubsub_shard_channels_[slot].find(channel); iter == pubsub_shard_channels_[slot].end()) {
pubsub_shard_channels_[slot].emplace(channel, std::list<ConnContext>{conn_ctx});
} else {
iter->second.emplace_back(conn_ctx);
}
}

void Server::SUnsubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot) {
assert((config_->cluster_enabled && slot < HASH_SLOTS_SIZE) || slot == 0);
std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);

auto iter = pubsub_shard_channels_[slot].find(channel);
if (iter == pubsub_shard_channels_[slot].end()) {
return;
}

for (const auto &conn_ctx : iter->second) {
if (conn->GetFD() == conn_ctx.fd && conn->Owner() == conn_ctx.owner) {
iter->second.remove(conn_ctx);
if (iter->second.empty()) {
pubsub_shard_channels_[slot].erase(iter);
}
break;
}
}
}

void Server::GetSChannelsByPattern(const std::string &pattern, std::vector<std::string> *channels) {
std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);

for (const auto &shard_channels : pubsub_shard_channels_) {
for (const auto &iter : shard_channels) {
if (pattern.empty() || util::StringMatch(pattern, iter.first, 0)) {
channels->emplace_back(iter.first);
}
}
}
}

void Server::ListSChannelSubscribeNum(const std::vector<std::string> &channels,
std::vector<ChannelSubscribeNum> *channel_subscribe_nums) {
std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);

for (const auto &chan : channels) {
uint16_t slot = config_->cluster_enabled ? GetSlotIdFromKey(chan) : 0;
if (auto iter = pubsub_shard_channels_[slot].find(chan); iter != pubsub_shard_channels_[slot].end()) {
channel_subscribe_nums->emplace_back(ChannelSubscribeNum{iter->first, iter->second.size()});
} else {
channel_subscribe_nums->emplace_back(ChannelSubscribeNum{chan, 0});
}
}
}

void Server::BlockOnKey(const std::string &key, redis::Connection *conn) {
std::lock_guard<std::mutex> guard(blocking_keys_mu_);

Expand Down
7 changes: 7 additions & 0 deletions src/server/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ class Server {
void PSubscribeChannel(const std::string &pattern, redis::Connection *conn);
void PUnsubscribeChannel(const std::string &pattern, redis::Connection *conn);
size_t GetPubSubPatternSize() const { return pubsub_patterns_.size(); }
void SSubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot);
void SUnsubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot);
void GetSChannelsByPattern(const std::string &pattern, std::vector<std::string> *channels);
void ListSChannelSubscribeNum(const std::vector<std::string> &channels,
std::vector<ChannelSubscribeNum> *channel_subscribe_nums);

void BlockOnKey(const std::string &key, redis::Connection *conn);
void UnblockOnKey(const std::string &key, redis::Connection *conn);
Expand Down Expand Up @@ -351,6 +356,8 @@ class Server {
std::map<std::string, std::list<ConnContext>> pubsub_channels_;
std::map<std::string, std::list<ConnContext>> pubsub_patterns_;
std::mutex pubsub_channels_mu_;
std::vector<std::map<std::string, std::list<ConnContext>>> pubsub_shard_channels_;
std::mutex pubsub_shard_channels_mu_;
std::map<std::string, std::list<ConnContext>> blocking_keys_;
std::mutex blocking_keys_mu_;

Expand Down
Loading

0 comments on commit 505aeb6

Please sign in to comment.