|
|
|
@ -19,6 +19,7 @@ package tracing |
|
|
|
|
import ( |
|
|
|
|
"errors" |
|
|
|
|
"math/big" |
|
|
|
|
"reflect" |
|
|
|
|
"testing" |
|
|
|
|
|
|
|
|
|
"github.com/ethereum/go-ethereum/common" |
|
|
|
@ -140,3 +141,97 @@ func TestJournalNestedCalls(t *testing.T) { |
|
|
|
|
t.Fatalf("unexpected balance: %v", tr.bal) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestAllHooksCalled(t *testing.T) { |
|
|
|
|
tracer := newTracerAllHooks() |
|
|
|
|
hooks := tracer.hooks() |
|
|
|
|
|
|
|
|
|
wrapped, err := WrapWithJournal(hooks) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Fatalf("failed to wrap hooks with journal: %v", err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Get the underlying value of the wrapped hooks
|
|
|
|
|
wrappedValue := reflect.ValueOf(wrapped).Elem() |
|
|
|
|
wrappedType := wrappedValue.Type() |
|
|
|
|
|
|
|
|
|
// Iterate over all fields of the wrapped hooks
|
|
|
|
|
for i := 0; i < wrappedType.NumField(); i++ { |
|
|
|
|
field := wrappedType.Field(i) |
|
|
|
|
|
|
|
|
|
// Skip fields that are not function types
|
|
|
|
|
if field.Type.Kind() != reflect.Func { |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
// Skip non-hooks, i.e. Copy
|
|
|
|
|
if field.Name == "Copy" { |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Get the method
|
|
|
|
|
method := wrappedValue.Field(i) |
|
|
|
|
|
|
|
|
|
// Call the method with zero values
|
|
|
|
|
params := make([]reflect.Value, method.Type().NumIn()) |
|
|
|
|
for j := 0; j < method.Type().NumIn(); j++ { |
|
|
|
|
params[j] = reflect.Zero(method.Type().In(j)) |
|
|
|
|
} |
|
|
|
|
method.Call(params) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Check if all hooks were called
|
|
|
|
|
if tracer.numCalled() != tracer.hooksCount() { |
|
|
|
|
t.Errorf("Not all hooks were called. Expected %d, got %d", tracer.hooksCount(), tracer.numCalled()) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for hookName, called := range tracer.hooksCalled { |
|
|
|
|
if !called { |
|
|
|
|
t.Errorf("Hook %s was not called", hookName) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type tracerAllHooks struct { |
|
|
|
|
hooksCalled map[string]bool |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func newTracerAllHooks() *tracerAllHooks { |
|
|
|
|
t := &tracerAllHooks{hooksCalled: make(map[string]bool)} |
|
|
|
|
// Initialize all hooks to false. We will use this to
|
|
|
|
|
// get total count of hooks.
|
|
|
|
|
hooksType := reflect.TypeOf((*Hooks)(nil)).Elem() |
|
|
|
|
for i := 0; i < hooksType.NumField(); i++ { |
|
|
|
|
t.hooksCalled[hooksType.Field(i).Name] = false |
|
|
|
|
} |
|
|
|
|
return t |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (t *tracerAllHooks) hooksCount() int { |
|
|
|
|
return len(t.hooksCalled) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (t *tracerAllHooks) numCalled() int { |
|
|
|
|
count := 0 |
|
|
|
|
for _, called := range t.hooksCalled { |
|
|
|
|
if called { |
|
|
|
|
count++ |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return count |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (t *tracerAllHooks) hooks() *Hooks { |
|
|
|
|
h := &Hooks{} |
|
|
|
|
// Create a function for each hook that sets the
|
|
|
|
|
// corresponding hooksCalled field to true.
|
|
|
|
|
hooksValue := reflect.ValueOf(h).Elem() |
|
|
|
|
for i := 0; i < hooksValue.NumField(); i++ { |
|
|
|
|
field := hooksValue.Type().Field(i) |
|
|
|
|
hookMethod := reflect.MakeFunc(field.Type, func(args []reflect.Value) []reflect.Value { |
|
|
|
|
t.hooksCalled[field.Name] = true |
|
|
|
|
return nil |
|
|
|
|
}) |
|
|
|
|
hooksValue.Field(i).Set(hookMethod) |
|
|
|
|
} |
|
|
|
|
return h |
|
|
|
|
} |
|
|
|
|