// Copyright 2023 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify // it under the terms of the GNU Lesser General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // The go-ethereum library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . package sync import ( "reflect" "testing" "github.com/ethereum/go-ethereum/beacon/light" "github.com/ethereum/go-ethereum/beacon/light/request" "github.com/ethereum/go-ethereum/beacon/types" ) type requestWithID struct { sid request.ServerAndID request request.Request } type TestScheduler struct { t *testing.T module request.Module events []request.Event servers []request.Server allowance map[request.Server]int sent map[int][]requestWithID testIndex int expFail map[request.Server]int // expected Server.Fail calls during next Run lastId request.ID } func NewTestScheduler(t *testing.T, module request.Module) *TestScheduler { return &TestScheduler{ t: t, module: module, allowance: make(map[request.Server]int), expFail: make(map[request.Server]int), sent: make(map[int][]requestWithID), } } func (ts *TestScheduler) Run(testIndex int, exp ...any) { expReqs := make([]requestWithID, len(exp)/2) id := ts.lastId for i := range expReqs { id++ expReqs[i] = requestWithID{ sid: request.ServerAndID{Server: exp[i*2].(request.Server), ID: id}, request: exp[i*2+1].(request.Request), } } if len(expReqs) == 0 { expReqs = nil } ts.testIndex = testIndex ts.module.Process(ts, ts.events) ts.events = nil for server, count := range ts.expFail { delete(ts.expFail, server) if count == 0 { continue } ts.t.Errorf("Missing %d Server.Fail(s) from server %s in test case #%d", count, server.Name(), testIndex) } if !reflect.DeepEqual(ts.sent[testIndex], expReqs) { ts.t.Errorf("Wrong sent requests in test case #%d (expected %v, got %v)", testIndex, expReqs, ts.sent[testIndex]) } } func (ts *TestScheduler) CanSendTo() (cs []request.Server) { for _, server := range ts.servers { if ts.allowance[server] > 0 { cs = append(cs, server) } } return } func (ts *TestScheduler) Send(server request.Server, req request.Request) request.ID { ts.lastId++ ts.sent[ts.testIndex] = append(ts.sent[ts.testIndex], requestWithID{ sid: request.ServerAndID{Server: server, ID: ts.lastId}, request: req, }) ts.allowance[server]-- return ts.lastId } func (ts *TestScheduler) Fail(server request.Server, desc string) { if ts.expFail[server] == 0 { ts.t.Errorf("Unexpected Fail from server %s in test case #%d: %s", server.Name(), ts.testIndex, desc) return } ts.expFail[server]-- } func (ts *TestScheduler) Request(testIndex, reqIndex int) requestWithID { if len(ts.sent[testIndex]) < reqIndex { ts.t.Errorf("Missing request from test case %d index %d", testIndex, reqIndex) return requestWithID{} } return ts.sent[testIndex][reqIndex-1] } func (ts *TestScheduler) ServerEvent(evType *request.EventType, server request.Server, data any) { ts.events = append(ts.events, request.Event{ Type: evType, Server: server, Data: data, }) } func (ts *TestScheduler) RequestEvent(evType *request.EventType, req requestWithID, resp request.Response) { if req.request == nil { return } ts.events = append(ts.events, request.Event{ Type: evType, Server: req.sid.Server, Data: request.RequestResponse{ ID: req.sid.ID, Request: req.request, Response: resp, }, }) } func (ts *TestScheduler) AddServer(server request.Server, allowance int) { ts.servers = append(ts.servers, server) ts.allowance[server] = allowance ts.ServerEvent(request.EvRegistered, server, nil) } func (ts *TestScheduler) RemoveServer(server request.Server) { ts.servers = append(ts.servers, server) for i, s := range ts.servers { if s == server { copy(ts.servers[i:len(ts.servers)-1], ts.servers[i+1:]) ts.servers = ts.servers[:len(ts.servers)-1] break } } delete(ts.allowance, server) ts.ServerEvent(request.EvUnregistered, server, nil) } func (ts *TestScheduler) AddAllowance(server request.Server, allowance int) { ts.allowance[server] += allowance } func (ts *TestScheduler) ExpFail(server request.Server) { ts.expFail[server]++ } type TestCommitteeChain struct { fsp, nsp uint64 init bool } func (t *TestCommitteeChain) CheckpointInit(bootstrap types.BootstrapData) error { t.fsp, t.nsp, t.init = bootstrap.Header.SyncPeriod(), bootstrap.Header.SyncPeriod()+2, true return nil } func (t *TestCommitteeChain) InsertUpdate(update *types.LightClientUpdate, nextCommittee *types.SerializedSyncCommittee) error { period := update.AttestedHeader.Header.SyncPeriod() if period < t.fsp || period > t.nsp || !t.init { return light.ErrInvalidPeriod } if period == t.nsp { t.nsp++ } return nil } func (t *TestCommitteeChain) NextSyncPeriod() (uint64, bool) { return t.nsp, t.init } func (tc *TestCommitteeChain) ExpInit(t *testing.T, ExpInit bool) { if tc.init != ExpInit { t.Errorf("Incorrect init flag (expected %v, got %v)", ExpInit, tc.init) } } func (t *TestCommitteeChain) SetNextSyncPeriod(nsp uint64) { t.init, t.nsp = true, nsp } func (tc *TestCommitteeChain) ExpNextSyncPeriod(t *testing.T, expNsp uint64) { tc.ExpInit(t, true) if tc.nsp != expNsp { t.Errorf("Incorrect NextSyncPeriod (expected %d, got %d)", expNsp, tc.nsp) } } type TestHeadTracker struct { phead types.HeadInfo validated []types.OptimisticUpdate finality types.FinalityUpdate } func (ht *TestHeadTracker) ValidateOptimistic(update types.OptimisticUpdate) (bool, error) { ht.validated = append(ht.validated, update) return true, nil } func (ht *TestHeadTracker) ValidateFinality(update types.FinalityUpdate) (bool, error) { ht.finality = update return true, nil } func (ht *TestHeadTracker) ValidatedFinality() (types.FinalityUpdate, bool) { return ht.finality, ht.finality.Attested.Header != (types.Header{}) } func (ht *TestHeadTracker) ExpValidated(t *testing.T, tci int, expHeads []types.OptimisticUpdate) { for i, expHead := range expHeads { if i >= len(ht.validated) { t.Errorf("Missing validated head in test case #%d index #%d (expected {slot %d blockRoot %x}, got none)", tci, i, expHead.Attested.Header.Slot, expHead.Attested.Header.Hash()) continue } if !reflect.DeepEqual(ht.validated[i], expHead) { vhead := ht.validated[i].Attested.Header t.Errorf("Wrong validated head in test case #%d index #%d (expected {slot %d blockRoot %x}, got {slot %d blockRoot %x})", tci, i, expHead.Attested.Header.Slot, expHead.Attested.Header.Hash(), vhead.Slot, vhead.Hash()) } } for i := len(expHeads); i < len(ht.validated); i++ { vhead := ht.validated[i].Attested.Header t.Errorf("Unexpected validated head in test case #%d index #%d (expected none, got {slot %d blockRoot %x})", tci, i, vhead.Slot, vhead.Hash()) } ht.validated = nil } func (ht *TestHeadTracker) SetPrefetchHead(head types.HeadInfo) { ht.phead = head } func (ht *TestHeadTracker) ExpPrefetch(t *testing.T, tci int, exp types.HeadInfo) { if ht.phead != exp { t.Errorf("Wrong prefetch head in test case #%d (expected {slot %d blockRoot %x}, got {slot %d blockRoot %x})", tci, exp.Slot, exp.BlockRoot, ht.phead.Slot, ht.phead.BlockRoot) } }