Skip to content

Commit

Permalink
fix(dot/network): close notifications streams (#2093)
Browse files Browse the repository at this point in the history
close notifications streams for reading/writing when outbound/inbound
respectively

Closes #2046
  • Loading branch information
kishansagathiya authored Dec 14, 2021
1 parent 8bd05d1 commit de6e7c9
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions dot/network/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,13 @@ func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecode
}

// createNotificationsMessageHandler returns a function that is called by the handler of *inbound* streams.
func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
messageHandler NotificationsMessageHandler,
batchHandler NotificationsMessageBatchHandler) messageHandler {
func (s *Service) createNotificationsMessageHandler(
info *notificationsProtocol,
notificationsMessageHandler NotificationsMessageHandler,
batchHandler NotificationsMessageBatchHandler,
) messageHandler {
return func(stream libp2pnetwork.Stream, m Message) error {
if m == nil || info == nil || info.handshakeValidator == nil || messageHandler == nil {
if m == nil || info == nil || info.handshakeValidator == nil || notificationsMessageHandler == nil {
return nil
}

Expand Down Expand Up @@ -214,6 +216,10 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
}

logger.Tracef("receiver: sent handshake to peer %s using protocol %s", peer, info.protocolID)

if err := stream.CloseWrite(); err != nil {
logger.Tracef("failed to close stream for writing: %s", err)
}
}

return nil
Expand All @@ -227,7 +233,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
return nil
}

propagate, err := messageHandler(peer, msg)
propagate, err := notificationsMessageHandler(peer, msg)
if err != nil {
return err
}
Expand Down Expand Up @@ -380,6 +386,10 @@ func (s *Service) sendHandshake(peer peer.ID, hs Handshake, info *notificationsP
hsData.received = true
}

if err := stream.CloseRead(); err != nil {
logger.Tracef("failed to close stream for reading: %s", err)
}

if err = info.handshakeValidator(peer, resp); err != nil {
logger.Tracef("failed to validate handshake from peer %s using protocol %s: %s", peer, info.protocolID, err)
hsData.validated = false
Expand Down

0 comments on commit de6e7c9

Please sign in to comment.