p2p/simulations: encapsulate Node.Up field so we avoid data races

The Node.Up field was accessed concurrently without "proper" locking.
There was a lock on Network and that was used sometimes to access
the  field. Other times the locking was missed and we had
a data race.

For example: https://github.com/ethereum/go-ethereum/pull/18464
The case above was solved, but there were still intermittent/hard to
reproduce races. So let's solve the issue permanently.

resolves: ethersphere/go-ethereum#1146
pull/18976/head
Ferenc Szabo 6 years ago
parent 4b2f34c824
commit 13292ee897
  1. 2
      p2p/simulations/events.go
  2. 18
      p2p/simulations/http_test.go
  3. 9
      p2p/simulations/mocker_test.go
  4. 51
      p2p/simulations/network.go
  5. 4
      swarm/network/simulation/node.go
  6. 18
      swarm/network/simulation/node_test.go
  7. 2
      swarm/network/simulation/service.go
  8. 4
      swarm/network/simulation/simulation_test.go
  9. 2
      swarm/network/simulations/overlay_test.go

@ -100,7 +100,7 @@ func ControlEvent(v interface{}) *Event {
func (e *Event) String() string { func (e *Event) String() string {
switch e.Type { switch e.Type {
case EventTypeNode: case EventTypeNode:
return fmt.Sprintf("<node-event> id: %s up: %t", e.Node.ID().TerminalString(), e.Node.Up) return fmt.Sprintf("<node-event> id: %s up: %t", e.Node.ID().TerminalString(), e.Node.Up())
case EventTypeConn: case EventTypeConn:
return fmt.Sprintf("<conn-event> nodes: %s->%s up: %t", e.Conn.One.TerminalString(), e.Conn.Other.TerminalString(), e.Conn.Up) return fmt.Sprintf("<conn-event> nodes: %s->%s up: %t", e.Conn.One.TerminalString(), e.Conn.Other.TerminalString(), e.Conn.Up)
case EventTypeMsg: case EventTypeMsg:

@ -421,14 +421,15 @@ type expectEvents struct {
} }
func (t *expectEvents) nodeEvent(id string, up bool) *Event { func (t *expectEvents) nodeEvent(id string, up bool) *Event {
node := Node{
Config: &adapters.NodeConfig{
ID: enode.HexID(id),
},
up: up,
}
return &Event{ return &Event{
Type: EventTypeNode, Type: EventTypeNode,
Node: &Node{ Node: &node,
Config: &adapters.NodeConfig{
ID: enode.HexID(id),
},
Up: up,
},
} }
} }
@ -480,6 +481,7 @@ loop:
} }
func (t *expectEvents) expect(events ...*Event) { func (t *expectEvents) expect(events ...*Event) {
t.Helper()
timeout := time.After(10 * time.Second) timeout := time.After(10 * time.Second)
i := 0 i := 0
for { for {
@ -501,8 +503,8 @@ func (t *expectEvents) expect(events ...*Event) {
if event.Node.ID() != expected.Node.ID() { if event.Node.ID() != expected.Node.ID() {
t.Fatalf("expected node event %d to have id %q, got %q", i, expected.Node.ID().TerminalString(), event.Node.ID().TerminalString()) t.Fatalf("expected node event %d to have id %q, got %q", i, expected.Node.ID().TerminalString(), event.Node.ID().TerminalString())
} }
if event.Node.Up != expected.Node.Up { if event.Node.Up() != expected.Node.Up() {
t.Fatalf("expected node event %d to have up=%t, got up=%t", i, expected.Node.Up, event.Node.Up) t.Fatalf("expected node event %d to have up=%t, got up=%t", i, expected.Node.Up(), event.Node.Up())
} }
case EventTypeConn: case EventTypeConn:

@ -90,15 +90,12 @@ func TestMocker(t *testing.T) {
for { for {
select { select {
case event := <-events: case event := <-events:
//if the event is a node Up event only if isNodeUp(event) {
if event.Node != nil && event.Node.Up {
//add the correspondent node ID to the map //add the correspondent node ID to the map
nodemap[event.Node.Config.ID] = true nodemap[event.Node.Config.ID] = true
//this means all nodes got a nodeUp event, so we can continue the test //this means all nodes got a nodeUp event, so we can continue the test
if len(nodemap) == nodeCount { if len(nodemap) == nodeCount {
nodesComplete = true nodesComplete = true
//wait for 3s as the mocker will need time to connect the nodes
//time.Sleep( 3 *time.Second)
} }
} else if event.Conn != nil && nodesComplete { } else if event.Conn != nil && nodesComplete {
connCount += 1 connCount += 1
@ -169,3 +166,7 @@ func TestMocker(t *testing.T) {
t.Fatalf("Expected empty list of nodes, got: %d", len(nodesInfo)) t.Fatalf("Expected empty list of nodes, got: %d", len(nodesInfo))
} }
} }
func isNodeUp(event *Event) bool {
return event.Node != nil && event.Node.Up()
}

@ -136,7 +136,7 @@ func (net *Network) Config() *NetworkConfig {
// StartAll starts all nodes in the network // StartAll starts all nodes in the network
func (net *Network) StartAll() error { func (net *Network) StartAll() error {
for _, node := range net.Nodes { for _, node := range net.Nodes {
if node.Up { if node.Up() {
continue continue
} }
if err := net.Start(node.ID()); err != nil { if err := net.Start(node.ID()); err != nil {
@ -149,7 +149,7 @@ func (net *Network) StartAll() error {
// StopAll stops all nodes in the network // StopAll stops all nodes in the network
func (net *Network) StopAll() error { func (net *Network) StopAll() error {
for _, node := range net.Nodes { for _, node := range net.Nodes {
if !node.Up { if !node.Up() {
continue continue
} }
if err := net.Stop(node.ID()); err != nil { if err := net.Stop(node.ID()); err != nil {
@ -174,7 +174,7 @@ func (net *Network) startWithSnapshots(id enode.ID, snapshots map[string][]byte)
net.lock.Unlock() net.lock.Unlock()
return fmt.Errorf("node %v does not exist", id) return fmt.Errorf("node %v does not exist", id)
} }
if node.Up { if node.Up() {
net.lock.Unlock() net.lock.Unlock()
return fmt.Errorf("node %v already up", id) return fmt.Errorf("node %v already up", id)
} }
@ -184,7 +184,7 @@ func (net *Network) startWithSnapshots(id enode.ID, snapshots map[string][]byte)
log.Warn("Node startup failed", "id", id, "err", err) log.Warn("Node startup failed", "id", id, "err", err)
return err return err
} }
node.Up = true node.SetUp(true)
log.Info("Started node", "id", id) log.Info("Started node", "id", id)
ev := NewEvent(node) ev := NewEvent(node)
net.lock.Unlock() net.lock.Unlock()
@ -219,7 +219,7 @@ func (net *Network) watchPeerEvents(id enode.ID, events chan *p2p.PeerEvent, sub
if node == nil { if node == nil {
return return
} }
node.Up = false node.SetUp(false)
ev := NewEvent(node) ev := NewEvent(node)
net.events.Send(ev) net.events.Send(ev)
}() }()
@ -263,17 +263,17 @@ func (net *Network) Stop(id enode.ID) error {
net.lock.Unlock() net.lock.Unlock()
return fmt.Errorf("node %v does not exist", id) return fmt.Errorf("node %v does not exist", id)
} }
if !node.Up { if !node.Up() {
net.lock.Unlock() net.lock.Unlock()
return fmt.Errorf("node %v already down", id) return fmt.Errorf("node %v already down", id)
} }
node.Up = false node.SetUp(false)
net.lock.Unlock() net.lock.Unlock()
err := node.Stop() err := node.Stop()
if err != nil { if err != nil {
net.lock.Lock() net.lock.Lock()
node.Up = true node.SetUp(true)
net.lock.Unlock() net.lock.Unlock()
return err return err
} }
@ -430,7 +430,7 @@ func (net *Network) GetRandomUpNode(excludeIDs ...enode.ID) *Node {
func (net *Network) getUpNodeIDs() (ids []enode.ID) { func (net *Network) getUpNodeIDs() (ids []enode.ID) {
for _, node := range net.Nodes { for _, node := range net.Nodes {
if node.Up { if node.Up() {
ids = append(ids, node.ID()) ids = append(ids, node.ID())
} }
} }
@ -446,7 +446,7 @@ func (net *Network) GetRandomDownNode(excludeIDs ...enode.ID) *Node {
func (net *Network) getDownNodeIDs() (ids []enode.ID) { func (net *Network) getDownNodeIDs() (ids []enode.ID) {
for _, node := range net.GetNodes() { for _, node := range net.GetNodes() {
if !node.Up { if !node.Up() {
ids = append(ids, node.ID()) ids = append(ids, node.ID())
} }
} }
@ -595,8 +595,21 @@ type Node struct {
// Config if the config used to created the node // Config if the config used to created the node
Config *adapters.NodeConfig `json:"config"` Config *adapters.NodeConfig `json:"config"`
// Up tracks whether or not the node is running // up tracks whether or not the node is running
Up bool `json:"up"` up bool `json:"up"`
upMu sync.RWMutex `json:"-"`
}
func (n *Node) Up() bool {
n.upMu.RLock()
defer n.upMu.RUnlock()
return n.up
}
func (n *Node) SetUp(up bool) {
n.upMu.Lock()
defer n.upMu.Unlock()
n.up = up
} }
// ID returns the ID of the node // ID returns the ID of the node
@ -630,7 +643,7 @@ func (n *Node) MarshalJSON() ([]byte, error) {
}{ }{
Info: n.NodeInfo(), Info: n.NodeInfo(),
Config: n.Config, Config: n.Config,
Up: n.Up, Up: n.Up(),
}) })
} }
@ -653,10 +666,10 @@ type Conn struct {
// nodesUp returns whether both nodes are currently up // nodesUp returns whether both nodes are currently up
func (c *Conn) nodesUp() error { func (c *Conn) nodesUp() error {
if !c.one.Up { if !c.one.Up() {
return fmt.Errorf("one %v is not up", c.One) return fmt.Errorf("one %v is not up", c.One)
} }
if !c.other.Up { if !c.other.Up() {
return fmt.Errorf("other %v is not up", c.Other) return fmt.Errorf("other %v is not up", c.Other)
} }
return nil return nil
@ -728,7 +741,7 @@ func (net *Network) snapshot(addServices []string, removeServices []string) (*Sn
} }
for i, node := range net.Nodes { for i, node := range net.Nodes {
snap.Nodes[i] = NodeSnapshot{Node: *node} snap.Nodes[i] = NodeSnapshot{Node: *node}
if !node.Up { if !node.Up() {
continue continue
} }
snapshots, err := node.Snapshots() snapshots, err := node.Snapshots()
@ -783,7 +796,7 @@ func (net *Network) Load(snap *Snapshot) error {
if _, err := net.NewNodeWithConfig(n.Node.Config); err != nil { if _, err := net.NewNodeWithConfig(n.Node.Config); err != nil {
return err return err
} }
if !n.Node.Up { if !n.Node.Up() {
continue continue
} }
if err := net.startWithSnapshots(n.Node.Config.ID, n.Snapshots); err != nil { if err := net.startWithSnapshots(n.Node.Config.ID, n.Snapshots); err != nil {
@ -855,7 +868,7 @@ func (net *Network) Load(snap *Snapshot) error {
// Start connecting. // Start connecting.
for _, conn := range snap.Conns { for _, conn := range snap.Conns {
if !net.GetNode(conn.One).Up || !net.GetNode(conn.Other).Up { if !net.GetNode(conn.One).Up() || !net.GetNode(conn.Other).Up() {
//in this case, at least one of the nodes of a connection is not up, //in this case, at least one of the nodes of a connection is not up,
//so it would result in the snapshot `Load` to fail //so it would result in the snapshot `Load` to fail
continue continue
@ -909,7 +922,7 @@ func (net *Network) executeControlEvent(event *Event) {
} }
func (net *Network) executeNodeEvent(e *Event) error { func (net *Network) executeNodeEvent(e *Event) error {
if !e.Node.Up { if !e.Node.Up() {
return net.Stop(e.Node.ID()) return net.Stop(e.Node.ID())
} }

@ -44,7 +44,7 @@ func (s *Simulation) NodeIDs() (ids []enode.ID) {
func (s *Simulation) UpNodeIDs() (ids []enode.ID) { func (s *Simulation) UpNodeIDs() (ids []enode.ID) {
nodes := s.Net.GetNodes() nodes := s.Net.GetNodes()
for _, node := range nodes { for _, node := range nodes {
if node.Up { if node.Up() {
ids = append(ids, node.ID()) ids = append(ids, node.ID())
} }
} }
@ -55,7 +55,7 @@ func (s *Simulation) UpNodeIDs() (ids []enode.ID) {
func (s *Simulation) DownNodeIDs() (ids []enode.ID) { func (s *Simulation) DownNodeIDs() (ids []enode.ID) {
nodes := s.Net.GetNodes() nodes := s.Net.GetNodes()
for _, node := range nodes { for _, node := range nodes {
if !node.Up { if !node.Up() {
ids = append(ids, node.ID()) ids = append(ids, node.ID())
} }
} }

@ -54,7 +54,7 @@ func TestUpDownNodeIDs(t *testing.T) {
gotIDs = sim.UpNodeIDs() gotIDs = sim.UpNodeIDs()
for _, id := range gotIDs { for _, id := range gotIDs {
if !sim.Net.GetNode(id).Up { if !sim.Net.GetNode(id).Up() {
t.Errorf("node %s should not be down", id) t.Errorf("node %s should not be down", id)
} }
} }
@ -66,7 +66,7 @@ func TestUpDownNodeIDs(t *testing.T) {
gotIDs = sim.DownNodeIDs() gotIDs = sim.DownNodeIDs()
for _, id := range gotIDs { for _, id := range gotIDs {
if sim.Net.GetNode(id).Up { if sim.Net.GetNode(id).Up() {
t.Errorf("node %s should not be up", id) t.Errorf("node %s should not be up", id)
} }
} }
@ -112,7 +112,7 @@ func TestAddNode(t *testing.T) {
t.Fatal("node not found") t.Fatal("node not found")
} }
if !n.Up { if !n.Up() {
t.Error("node not started") t.Error("node not started")
} }
} }
@ -327,7 +327,7 @@ func TestStartStopNode(t *testing.T) {
if n == nil { if n == nil {
t.Fatal("node not found") t.Fatal("node not found")
} }
if !n.Up { if !n.Up() {
t.Error("node not started") t.Error("node not started")
} }
@ -335,7 +335,7 @@ func TestStartStopNode(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if n.Up { if n.Up() {
t.Error("node not stopped") t.Error("node not stopped")
} }
@ -345,7 +345,7 @@ func TestStartStopNode(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !n.Up { if !n.Up() {
t.Error("node not started") t.Error("node not started")
} }
} }
@ -368,7 +368,7 @@ func TestStartStopRandomNode(t *testing.T) {
if n == nil { if n == nil {
t.Fatal("node not found") t.Fatal("node not found")
} }
if n.Up { if n.Up() {
t.Error("node not stopped") t.Error("node not stopped")
} }
@ -408,7 +408,7 @@ func TestStartStopRandomNodes(t *testing.T) {
if n == nil { if n == nil {
t.Fatal("node not found") t.Fatal("node not found")
} }
if n.Up { if n.Up() {
t.Error("node not stopped") t.Error("node not stopped")
} }
} }
@ -425,7 +425,7 @@ func TestStartStopRandomNodes(t *testing.T) {
if n == nil { if n == nil {
t.Fatal("node not found") t.Fatal("node not found")
} }
if !n.Up { if !n.Up() {
t.Error("node not started") t.Error("node not started")
} }
} }

@ -52,7 +52,7 @@ func (s *Simulation) Services(name string) (services map[enode.ID]node.Service)
nodes := s.Net.GetNodes() nodes := s.Net.GetNodes()
services = make(map[enode.ID]node.Service) services = make(map[enode.ID]node.Service)
for _, node := range nodes { for _, node := range nodes {
if !node.Up { if !node.Up() {
continue continue
} }
simNode, ok := node.Node.(*adapters.SimNode) simNode, ok := node.Node.(*adapters.SimNode)

@ -124,7 +124,7 @@ func TestClose(t *testing.T) {
var upNodeCount int var upNodeCount int
for _, n := range sim.Net.GetNodes() { for _, n := range sim.Net.GetNodes() {
if n.Up { if n.Up() {
upNodeCount++ upNodeCount++
} }
} }
@ -140,7 +140,7 @@ func TestClose(t *testing.T) {
upNodeCount = 0 upNodeCount = 0
for _, n := range sim.Net.GetNodes() { for _, n := range sim.Net.GetNodes() {
if n.Up { if n.Up() {
upNodeCount++ upNodeCount++
} }
} }

@ -178,7 +178,7 @@ func watchSimEvents(net *simulations.Network, ctx context.Context, trigger chan
case ev := <-events: case ev := <-events:
//only catch node up events //only catch node up events
if ev.Type == simulations.EventTypeNode { if ev.Type == simulations.EventTypeNode {
if ev.Node.Up { if ev.Node.Up() {
log.Debug("got node up event", "event", ev, "node", ev.Node.Config.ID) log.Debug("got node up event", "event", ev, "node", ev.Node.Config.ID)
select { select {
case trigger <- ev.Node.Config.ID: case trigger <- ev.Node.Config.ID:

Loading…
Cancel
Save