diff --git a/core/tracing/journal_test.go b/core/tracing/journal_test.go index 4ae26df7e6..129209b2a3 100644 --- a/core/tracing/journal_test.go +++ b/core/tracing/journal_test.go @@ -19,6 +19,7 @@ package tracing import ( "errors" "math/big" + "reflect" "testing" "github.com/ethereum/go-ethereum/common" @@ -140,3 +141,97 @@ func TestJournalNestedCalls(t *testing.T) { t.Fatalf("unexpected balance: %v", tr.bal) } } + +func TestAllHooksCalled(t *testing.T) { + tracer := newTracerAllHooks() + hooks := tracer.hooks() + + wrapped, err := WrapWithJournal(hooks) + if err != nil { + t.Fatalf("failed to wrap hooks with journal: %v", err) + } + + // Get the underlying value of the wrapped hooks + wrappedValue := reflect.ValueOf(wrapped).Elem() + wrappedType := wrappedValue.Type() + + // Iterate over all fields of the wrapped hooks + for i := 0; i < wrappedType.NumField(); i++ { + field := wrappedType.Field(i) + + // Skip fields that are not function types + if field.Type.Kind() != reflect.Func { + continue + } + // Skip non-hooks, i.e. Copy + if field.Name == "Copy" { + continue + } + + // Get the method + method := wrappedValue.Field(i) + + // Call the method with zero values + params := make([]reflect.Value, method.Type().NumIn()) + for j := 0; j < method.Type().NumIn(); j++ { + params[j] = reflect.Zero(method.Type().In(j)) + } + method.Call(params) + } + + // Check if all hooks were called + if tracer.numCalled() != tracer.hooksCount() { + t.Errorf("Not all hooks were called. Expected %d, got %d", tracer.hooksCount(), tracer.numCalled()) + } + + for hookName, called := range tracer.hooksCalled { + if !called { + t.Errorf("Hook %s was not called", hookName) + } + } +} + +type tracerAllHooks struct { + hooksCalled map[string]bool +} + +func newTracerAllHooks() *tracerAllHooks { + t := &tracerAllHooks{hooksCalled: make(map[string]bool)} + // Initialize all hooks to false. We will use this to + // get total count of hooks. + hooksType := reflect.TypeOf((*Hooks)(nil)).Elem() + for i := 0; i < hooksType.NumField(); i++ { + t.hooksCalled[hooksType.Field(i).Name] = false + } + return t +} + +func (t *tracerAllHooks) hooksCount() int { + return len(t.hooksCalled) +} + +func (t *tracerAllHooks) numCalled() int { + count := 0 + for _, called := range t.hooksCalled { + if called { + count++ + } + } + return count +} + +func (t *tracerAllHooks) hooks() *Hooks { + h := &Hooks{} + // Create a function for each hook that sets the + // corresponding hooksCalled field to true. + hooksValue := reflect.ValueOf(h).Elem() + for i := 0; i < hooksValue.NumField(); i++ { + field := hooksValue.Type().Field(i) + hookMethod := reflect.MakeFunc(field.Type, func(args []reflect.Value) []reflect.Value { + t.hooksCalled[field.Name] = true + return nil + }) + hooksValue.Field(i).Set(hookMethod) + } + return h +}