Skip to content

Commit

Permalink
Fix canceling subscription in Watch (#127)
Browse files Browse the repository at this point in the history
* Fix canceling subscription in Watch

* feedback
  • Loading branch information
DanG100 authored Aug 29, 2023
1 parent 09b506c commit f1bc953
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
9 changes: 9 additions & 0 deletions ygnmi/ygnmi.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,19 +313,23 @@ func (w *Watcher[T]) Await() (*Value[T], error) {
// Calling Await on the returned Watcher waits for the subscription to complete.
// It returns the last observed value and a boolean that indicates whether that value satisfies the predicate.
func Watch[T any](ctx context.Context, c *Client, q SingletonQuery[T], pred func(*Value[T]) error, opts ...Option) *Watcher[T] {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
w := &Watcher[T]{
errCh: make(chan error, 1),
}

resolvedOpts := resolveOpts(opts)
sub, err := subscribe[T](ctx, c, q, gpb.SubscriptionList_STREAM, resolvedOpts)
if err != nil {
cancel()
w.errCh <- err
return w
}

dataCh, errCh := receiveStream[T](ctx, sub, q)
go func() {
defer cancel()
// Create an intially empty GoStruct, into which all received datapoints will be unmarshalled.
gs := q.goStruct()
for {
Expand Down Expand Up @@ -465,23 +469,28 @@ func GetAll[T any](ctx context.Context, c *Client, q WildcardQuery[T], opts ...O
// Calling Await on the returned Watcher waits for the subscription to complete.
// It returns the last observed value and a boolean that indicates whether that value satisfies the predicate.
func WatchAll[T any](ctx context.Context, c *Client, q WildcardQuery[T], pred func(*Value[T]) error, opts ...Option) *Watcher[T] {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
w := &Watcher[T]{
errCh: make(chan error, 1),
}
path, err := resolvePath(q.PathStruct())
if err != nil {
cancel()
w.errCh <- err
return w
}
resolvedOpts := resolveOpts(opts)
sub, err := subscribe[T](ctx, c, q, gpb.SubscriptionList_STREAM, resolvedOpts)
if err != nil {
cancel()
w.errCh <- err
return w
}

dataCh, errCh := receiveStream[T](ctx, sub, q)
go func() {
defer cancel()
// Create a map intially empty GoStruct, into which all received datapoints will be unmarshalled based on their path prefixes.
structs := map[string]ygot.ValidatedGoStruct{}
for {
Expand Down
53 changes: 53 additions & 0 deletions ygnmi/ygnmi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package ygnmi_test
import (
"context"
"fmt"
"net"
"strings"
"testing"
"time"
Expand All @@ -32,7 +33,9 @@ import (
"github.com/openconfig/ygnmi/ygnmi"
"github.com/openconfig/ygot/util"
"github.com/openconfig/ygot/ygot"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/local"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
Expand Down Expand Up @@ -4443,3 +4446,53 @@ func verifySubscriptionModesSent(t *testing.T, fakeGNMI *testutil.FakeGNMI, want
t.Errorf("Subscription modes (-want, +got):\n%s", diff)
}
}

type gnmiS struct {
gpb.UnimplementedGNMIServer
errCh chan error
}

func (g *gnmiS) Subscribe(srv gpb.GNMI_SubscribeServer) error {
if _, err := srv.Recv(); err != nil {
return err
}
if err := srv.Send(&gpb.SubscribeResponse{Response: &gpb.SubscribeResponse_SyncResponse{}}); err != nil {
return err
}
// This send must fail because the client will have cancelled the subscription context.
time.Sleep(time.Second)
err := srv.Send(&gpb.SubscribeResponse{Response: &gpb.SubscribeResponse_SyncResponse{}})
g.errCh <- err
return nil
}

func TestWatchCancel(t *testing.T) {
srv := &gnmiS{
errCh: make(chan error, 1),
}
s := grpc.NewServer(grpc.Creds(local.NewCredentials()))
gpb.RegisterGNMIServer(s, srv)
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
go func() {
//nolint:errcheck // Don't care about this error.
s.Serve(l)
}()
conn, err := grpc.Dial(l.Addr().String(), grpc.WithTransportCredentials(local.NewCredentials()))
if err != nil {
t.Fatal(err)
}
c, _ := ygnmi.NewClient(gpb.NewGNMIClient(conn))

w := ygnmi.Watch(context.Background(), c, exampleocpath.Root().RemoteContainer().ALeaf().State(), func(v *ygnmi.Value[string]) error {
return nil
})
if _, err := w.Await(); err != nil {
t.Fatal(err)
}
if err := <-srv.errCh; err == nil {
t.Fatalf("Watch() unexpected error: got %v, want context.Cancel", err)
}
}

0 comments on commit f1bc953

Please sign in to comment.