diff --git a/ethstats/ethstats.go b/ethstats/ethstats.go index f100af4d11..34c9a39f3b 100644 --- a/ethstats/ethstats.go +++ b/ethstats/ethstats.go @@ -28,6 +28,7 @@ import ( "runtime" "strconv" "strings" + "sync" "time" "github.com/ethereum/go-ethereum/common" @@ -93,6 +94,49 @@ type Service struct { pongCh chan struct{} // Pong notifications are fed into this channel histCh chan []uint64 // History request block numbers are fed into this channel + +} + +// connWrapper is a wrapper to prevent concurrent-write or concurrent-read on the +// websocket. +// From Gorilla websocket docs: +// Connections support one concurrent reader and one concurrent writer. +// Applications are responsible for ensuring that no more than one goroutine calls the write methods +// - NextWriter, SetWriteDeadline, WriteMessage, WriteJSON, EnableWriteCompression, SetCompressionLevel +// concurrently and that no more than one goroutine calls the read methods +// - NextReader, SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler +// concurrently. +// The Close and WriteControl methods can be called concurrently with all other methods. +// +// The connWrapper uses a single mutex for both reading and writing. +type connWrapper struct { + conn *websocket.Conn + mu sync.Mutex +} + +func newConnectionWrapper(conn *websocket.Conn) *connWrapper { + return &connWrapper{conn: conn} +} + +// WriteJSON wraps corresponding method on the websocket but is safe for concurrent calling +func (w *connWrapper) WriteJSON(v interface{}) error { + w.mu.Lock() + defer w.mu.Unlock() + return w.conn.WriteJSON(v) +} + +// ReadJSON wraps corresponding method on the websocket but is safe for concurrent calling +func (w *connWrapper) ReadJSON(v interface{}) error { + w.mu.Lock() + defer w.mu.Unlock() + return w.conn.ReadJSON(v) +} + +// Close wraps corresponding method on the websocket but is safe for concurrent calling +func (w *connWrapper) Close() error { + // The Close and WriteControl methods can be called concurrently with all other methods, + // so the mutex is not used here + return w.conn.Close() } // New returns a monitoring service ready for stats reporting. @@ -204,17 +248,19 @@ func (s *Service) loop() { case <-errTimer.C: // Establish a websocket connection to the server on any supported URL var ( - conn *websocket.Conn + conn *connWrapper err error ) dialer := websocket.Dialer{HandshakeTimeout: 5 * time.Second} header := make(http.Header) header.Set("origin", "http://localhost") for _, url := range urls { - conn, _, err = dialer.Dial(url, header) - if err == nil { + c, _, e := dialer.Dial(url, header) + if e == nil { + conn = newConnectionWrapper(c) break } + err = e } if err != nil { log.Warn("Stats server unreachable", "err", err) @@ -282,7 +328,7 @@ func (s *Service) loop() { // from the network socket. If any of them match an active request, it forwards // it, if they themselves are requests it initiates a reply, and lastly it drops // unknown packets. -func (s *Service) readLoop(conn *websocket.Conn) { +func (s *Service) readLoop(conn *connWrapper) { // If the read loop exists, close the connection defer conn.Close() @@ -391,7 +437,7 @@ type authMsg struct { } // login tries to authorize the client at the remote server. -func (s *Service) login(conn *websocket.Conn) error { +func (s *Service) login(conn *connWrapper) error { // Construct and send the login authentication infos := s.server.NodeInfo() @@ -436,7 +482,7 @@ func (s *Service) login(conn *websocket.Conn) error { // report collects all possible data to report and send it to the stats server. // This should only be used on reconnects or rarely to avoid overloading the // server. Use the individual methods for reporting subscribed events. -func (s *Service) report(conn *websocket.Conn) error { +func (s *Service) report(conn *connWrapper) error { if err := s.reportLatency(conn); err != nil { return err } @@ -454,7 +500,7 @@ func (s *Service) report(conn *websocket.Conn) error { // reportLatency sends a ping request to the server, measures the RTT time and // finally sends a latency update. -func (s *Service) reportLatency(conn *websocket.Conn) error { +func (s *Service) reportLatency(conn *connWrapper) error { // Send the current time to the ethstats server start := time.Now() @@ -523,7 +569,7 @@ func (s uncleStats) MarshalJSON() ([]byte, error) { } // reportBlock retrieves the current chain head and reports it to the stats server. -func (s *Service) reportBlock(conn *websocket.Conn, block *types.Block) error { +func (s *Service) reportBlock(conn *connWrapper, block *types.Block) error { // Gather the block details from the header or block chain details := s.assembleBlockStats(block) @@ -598,7 +644,7 @@ func (s *Service) assembleBlockStats(block *types.Block) *blockStats { // reportHistory retrieves the most recent batch of blocks and reports it to the // stats server. -func (s *Service) reportHistory(conn *websocket.Conn, list []uint64) error { +func (s *Service) reportHistory(conn *connWrapper, list []uint64) error { // Figure out the indexes that need reporting indexes := make([]uint64, 0, historyUpdateRange) if len(list) > 0 { @@ -660,7 +706,7 @@ type pendStats struct { // reportPending retrieves the current number of pending transactions and reports // it to the stats server. -func (s *Service) reportPending(conn *websocket.Conn) error { +func (s *Service) reportPending(conn *connWrapper) error { // Retrieve the pending count from the local blockchain pending, _ := s.backend.Stats() // Assemble the transaction stats and send it to the server @@ -691,7 +737,7 @@ type nodeStats struct { // reportStats retrieves various stats about the node at the networking and // mining layer and reports it to the stats server. -func (s *Service) reportStats(conn *websocket.Conn) error { +func (s *Service) reportStats(conn *connWrapper) error { // Gather the syncing and mining infos from the local miner instance var ( mining bool