diff --git a/rpc/client_test.go b/rpc/client_test.go index 5d301a07a7..224eb0c5c8 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -18,6 +18,7 @@ package rpc import ( "context" + "encoding/json" "fmt" "math/rand" "net" @@ -376,6 +377,93 @@ func TestClientCloseUnsubscribeRace(t *testing.T) { } } +// unsubscribeRecorder collects the subscription IDs of *_unsubscribe calls. +type unsubscribeRecorder struct { + ServerCodec + unsubscribes map[string]bool +} + +func (r *unsubscribeRecorder) readBatch() ([]*jsonrpcMessage, bool, error) { + if r.unsubscribes == nil { + r.unsubscribes = make(map[string]bool) + } + + msgs, batch, err := r.ServerCodec.readBatch() + for _, msg := range msgs { + if msg.isUnsubscribe() { + var params []string + if err := json.Unmarshal(msg.Params, ¶ms); err != nil { + panic("unsubscribe decode error: " + err.Error()) + } + r.unsubscribes[params[0]] = true + } + } + return msgs, batch, err +} + +// This checks that Client calls the _unsubscribe method on the server when Unsubscribe is +// called on a subscription. +func TestClientSubscriptionUnsubscribeServer(t *testing.T) { + t.Parallel() + + // Create the server. + srv := NewServer() + srv.RegisterName("nftest", new(notificationTestService)) + p1, p2 := net.Pipe() + recorder := &unsubscribeRecorder{ServerCodec: NewCodec(p1)} + go srv.ServeCodec(recorder, OptionMethodInvocation|OptionSubscriptions) + defer srv.Stop() + + // Create the client on the other end of the pipe. + client, _ := newClient(context.Background(), func(context.Context) (ServerCodec, error) { + return NewCodec(p2), nil + }) + defer client.Close() + + // Create the subscription. + ch := make(chan int) + sub, err := client.Subscribe(context.Background(), "nftest", ch, "someSubscription", 1, 1) + if err != nil { + t.Fatal(err) + } + + // Unsubscribe and check that unsubscribe was called. + sub.Unsubscribe() + if !recorder.unsubscribes[sub.subid] { + t.Fatal("client did not call unsubscribe method") + } + if _, open := <-sub.Err(); open { + t.Fatal("subscription error channel not closed after unsubscribe") + } +} + +// This checks that the subscribed channel can be closed after Unsubscribe. +// It is the reproducer for https://github.com/ethereum/go-ethereum/issues/22322 +func TestClientSubscriptionChannelClose(t *testing.T) { + t.Parallel() + + var ( + srv = NewServer() + httpsrv = httptest.NewServer(srv.WebsocketHandler(nil)) + wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") + ) + defer srv.Stop() + defer httpsrv.Close() + + srv.RegisterName("nftest", new(notificationTestService)) + client, _ := Dial(wsURL) + + for i := 0; i < 100; i++ { + ch := make(chan int, 100) + sub, err := client.Subscribe(context.Background(), "nftest", ch, "someSubscription", maxClientSubscriptionBuffer-1, 1) + if err != nil { + t.Fatal(err) + } + sub.Unsubscribe() + close(ch) + } +} + // This test checks that Client doesn't lock up when a single subscriber // doesn't read subscription events. func TestClientNotificationStorm(t *testing.T) { diff --git a/rpc/handler.go b/rpc/handler.go index 23023eaca1..85f774a1b7 100644 --- a/rpc/handler.go +++ b/rpc/handler.go @@ -189,7 +189,7 @@ func (h *handler) cancelAllRequests(err error, inflightReq *requestOp) { } for id, sub := range h.clientSubs { delete(h.clientSubs, id) - sub.quitWithError(false, err) + sub.close(err) } } @@ -281,7 +281,7 @@ func (h *handler) handleResponse(msg *jsonrpcMessage) { return } if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil { - go op.sub.start() + go op.sub.run() h.clientSubs[op.sub.subid] = op.sub } } diff --git a/rpc/subscription.go b/rpc/subscription.go index 233215d792..942e764e5d 100644 --- a/rpc/subscription.go +++ b/rpc/subscription.go @@ -208,23 +208,37 @@ type ClientSubscription struct { channel reflect.Value namespace string subid string - in chan json.RawMessage - quitOnce sync.Once // ensures quit is closed once - quit chan struct{} // quit is closed when the subscription exits - errOnce sync.Once // ensures err is closed once - err chan error + // The in channel receives notification values from client dispatcher. + in chan json.RawMessage + + // The error channel receives the error from the forwarding loop. + // It is closed by Unsubscribe. + err chan error + errOnce sync.Once + + // Closing of the subscription is requested by sending on 'quit'. This is handled by + // the forwarding loop, which closes 'forwardDone' when it has stopped sending to + // sub.channel. Finally, 'unsubDone' is closed after unsubscribing on the server side. + quit chan error + forwardDone chan struct{} + unsubDone chan struct{} } +// This is the sentinel value sent on sub.quit when Unsubscribe is called. +var errUnsubscribed = errors.New("unsubscribed") + func newClientSubscription(c *Client, namespace string, channel reflect.Value) *ClientSubscription { sub := &ClientSubscription{ - client: c, - namespace: namespace, - etype: channel.Type().Elem(), - channel: channel, - quit: make(chan struct{}), - err: make(chan error, 1), - in: make(chan json.RawMessage), + client: c, + namespace: namespace, + etype: channel.Type().Elem(), + channel: channel, + in: make(chan json.RawMessage), + quit: make(chan error), + forwardDone: make(chan struct{}), + unsubDone: make(chan struct{}), + err: make(chan error, 1), } return sub } @@ -232,9 +246,9 @@ func newClientSubscription(c *Client, namespace string, channel reflect.Value) * // Err returns the subscription error channel. The intended use of Err is to schedule // resubscription when the client connection is closed unexpectedly. // -// The error channel receives a value when the subscription has ended due -// to an error. The received error is nil if Close has been called -// on the underlying client and no other error has occurred. +// The error channel receives a value when the subscription has ended due to an error. The +// received error is nil if Close has been called on the underlying client and no other +// error has occurred. // // The error channel is closed when Unsubscribe is called on the subscription. func (sub *ClientSubscription) Err() <-chan error { @@ -244,41 +258,63 @@ func (sub *ClientSubscription) Err() <-chan error { // Unsubscribe unsubscribes the notification and closes the error channel. // It can safely be called more than once. func (sub *ClientSubscription) Unsubscribe() { - sub.quitWithError(true, nil) - sub.errOnce.Do(func() { close(sub.err) }) -} - -func (sub *ClientSubscription) quitWithError(unsubscribeServer bool, err error) { - sub.quitOnce.Do(func() { - // The dispatch loop won't be able to execute the unsubscribe call - // if it is blocked on deliver. Close sub.quit first because it - // unblocks deliver. - close(sub.quit) - if unsubscribeServer { - sub.requestUnsubscribe() - } - if err != nil { - if err == ErrClientQuit { - err = nil // Adhere to subscription semantics. - } - sub.err <- err + sub.errOnce.Do(func() { + select { + case sub.quit <- errUnsubscribed: + <-sub.unsubDone + case <-sub.unsubDone: } + close(sub.err) }) } +// deliver is called by the client's message dispatcher to send a notification value. func (sub *ClientSubscription) deliver(result json.RawMessage) (ok bool) { select { case sub.in <- result: return true - case <-sub.quit: + case <-sub.forwardDone: return false } } -func (sub *ClientSubscription) start() { - sub.quitWithError(sub.forward()) +// close is called by the client's message dispatcher when the connection is closed. +func (sub *ClientSubscription) close(err error) { + select { + case sub.quit <- err: + case <-sub.forwardDone: + } +} + +// run is the forwarding loop of the subscription. It runs in its own goroutine and +// is launched by the client's handler after the subscription has been created. +func (sub *ClientSubscription) run() { + defer close(sub.unsubDone) + + unsubscribe, err := sub.forward() + + // The client's dispatch loop won't be able to execute the unsubscribe call if it is + // blocked in sub.deliver() or sub.close(). Closing forwardDone unblocks them. + close(sub.forwardDone) + + // Call the unsubscribe method on the server. + if unsubscribe { + sub.requestUnsubscribe() + } + + // Send the error. + if err != nil { + if err == ErrClientQuit { + // ErrClientQuit gets here when Client.Close is called. This is reported as a + // nil error because it's not an error, but we can't close sub.err here. + err = nil + } + sub.err <- err + } } +// forward is the forwarding loop. It takes in RPC notifications and sends them +// on the subscription channel. func (sub *ClientSubscription) forward() (unsubscribeServer bool, err error) { cases := []reflect.SelectCase{ {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.quit)}, @@ -286,7 +322,7 @@ func (sub *ClientSubscription) forward() (unsubscribeServer bool, err error) { {Dir: reflect.SelectSend, Chan: sub.channel}, } buffer := list.New() - defer buffer.Init() + for { var chosen int var recv reflect.Value @@ -301,7 +337,15 @@ func (sub *ClientSubscription) forward() (unsubscribeServer bool, err error) { switch chosen { case 0: // <-sub.quit - return false, nil + if !recv.IsNil() { + err = recv.Interface().(error) + } + if err == errUnsubscribed { + // Exiting because Unsubscribe was called, unsubscribe on server. + return true, nil + } + return false, err + case 1: // <-sub.in val, err := sub.unmarshal(recv.Interface().(json.RawMessage)) if err != nil { @@ -311,6 +355,7 @@ func (sub *ClientSubscription) forward() (unsubscribeServer bool, err error) { return true, ErrSubscriptionQueueOverflow } buffer.PushBack(val) + case 2: // sub.channel<- cases[2].Send = reflect.Value{} // Don't hold onto the value. buffer.Remove(buffer.Front()) diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 37ed19476f..4976853baf 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -129,11 +129,15 @@ func TestClientWebsocketPing(t *testing.T) { if err != nil { t.Fatalf("client dial error: %v", err) } + defer client.Close() + resultChan := make(chan int) sub, err := client.EthSubscribe(ctx, resultChan, "foo") if err != nil { t.Fatalf("client subscribe error: %v", err) } + // Note: Unsubscribe is not called on this subscription because the mockup + // server can't handle the request. // Wait for the context's deadline to be reached before proceeding. // This is important for reproducing https://github.com/ethereum/go-ethereum/issues/19798