From bed3b100867daab55a8e3c639ead8da9cda89a3c Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Wed, 12 Oct 2022 10:34:52 +0200 Subject: [PATCH] common/math: optimized modexp (+ fuzzer) (#25525) This adds a * core/vm, tests: optimized modexp + fuzzer * common/math: modexp optimizations * core/vm: special case base 1 in big modexp * core/vm: disable fastexp --- common/math/modexp.go | 82 ++++++++++++++++++++++++++ core/vm/contracts.go | 15 ++++- oss-fuzz.sh | 2 + tests/fuzzers/modexp/modexp-fuzzer.go | 84 +++++++++++++++++++++++++++ 4 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 common/math/modexp.go create mode 100644 tests/fuzzers/modexp/modexp-fuzzer.go diff --git a/common/math/modexp.go b/common/math/modexp.go new file mode 100644 index 0000000000..b0a32e8c27 --- /dev/null +++ b/common/math/modexp.go @@ -0,0 +1,82 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package math + +import ( + "math/big" + "math/bits" + + "github.com/ethereum/go-ethereum/common" +) + +// FastExp is semantically equivalent to x.Exp(x,y, m), but is faster for even +// modulus. +func FastExp(x, y, m *big.Int) *big.Int { + // Split m = m1 × m2 where m1 = 2ⁿ + n := m.TrailingZeroBits() + m1 := new(big.Int).Lsh(common.Big1, n) + mask := new(big.Int).Sub(m1, common.Big1) + m2 := new(big.Int).Rsh(m, n) + + // We want z = x**y mod m. + // z1 = x**y mod m1 = (x**y mod m) mod m1 = z mod m1 + // z2 = x**y mod m2 = (x**y mod m) mod m2 = z mod m2 + z1 := fastExpPow2(x, y, mask) + z2 := new(big.Int).Exp(x, y, m2) + + // Reconstruct z from z1, z2 using CRT, using algorithm from paper, + // which uses only a single modInverse. + // p = (z1 - z2) * m2⁻¹ (mod m1) + // z = z2 + p * m2 + z := new(big.Int).Set(z2) + + // Compute (z1 - z2) mod m1 [m1 == 2**n] into z1. + z1 = z1.And(z1, mask) + z2 = z2.And(z2, mask) + z1 = z1.Sub(z1, z2) + if z1.Sign() < 0 { + z1 = z1.Add(z1, m1) + } + + // Reuse z2 for p = z1 * m2inv. + m2inv := new(big.Int).ModInverse(m2, m1) + z2 = z2.Mul(z1, m2inv) + z2 = z2.And(z2, mask) + + // Reuse z1 for m2 * p. + z = z.Add(z, z1.Mul(z2, m2)) + z = z.Rem(z, m) + + return z +} + +func fastExpPow2(x, y *big.Int, mask *big.Int) *big.Int { + z := big.NewInt(1) + if y.Sign() == 0 { + return z + } + p := new(big.Int).Set(x) + p = p.And(p, mask) + if p.Cmp(z) <= 0 { // p <= 1 + return p + } + if y.Cmp(mask) > 0 { + y = new(big.Int).And(y, mask) + } + t := new(big.Int) + + for _, b := range y.Bits() { + for i := 0; i < bits.UintSize; i++ { + if b&1 != 0 { + z, t = t.Mul(z, p), z + z = z.And(z, mask) + } + p, t = t.Mul(p, p), p + p = p.And(p, mask) + b >>= 1 + } + } + return z +} diff --git a/core/vm/contracts.go b/core/vm/contracts.go index 054c3b66e7..d0e3e69139 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -380,12 +380,23 @@ func (c *bigModExp) Run(input []byte) ([]byte, error) { base = new(big.Int).SetBytes(getData(input, 0, baseLen)) exp = new(big.Int).SetBytes(getData(input, baseLen, expLen)) mod = new(big.Int).SetBytes(getData(input, baseLen+expLen, modLen)) + v []byte ) - if mod.BitLen() == 0 { + switch { + case mod.BitLen() == 0: // Modulo 0 is undefined, return zero return common.LeftPadBytes([]byte{}, int(modLen)), nil + case base.Cmp(common.Big1) == 0: + //If base == 1, then we can just return base % mod (if mod >= 1, which it is) + v = base.Mod(base, mod).Bytes() + //case mod.Bit(0) == 0: + // // Modulo is even + // v = math.FastExp(base, exp, mod).Bytes() + default: + // Modulo is odd + v = base.Exp(base, exp, mod).Bytes() } - return common.LeftPadBytes(base.Exp(base, exp, mod).Bytes(), int(modLen)), nil + return common.LeftPadBytes(v, int(modLen)), nil } // newCurvePoint unmarshals a binary blob into a bn256 elliptic curve point, diff --git a/oss-fuzz.sh b/oss-fuzz.sh index 745a5ba7c7..7f454ff307 100644 --- a/oss-fuzz.sh +++ b/oss-fuzz.sh @@ -125,5 +125,7 @@ compile_fuzzer tests/fuzzers/snap FuzzSRange fuzz_storage_range compile_fuzzer tests/fuzzers/snap FuzzByteCodes fuzz_byte_codes compile_fuzzer tests/fuzzers/snap FuzzTrieNodes fuzz_trie_nodes +compile_fuzzer tests/fuzzers/modexp Fuzz fuzzModexp + #TODO: move this to tests/fuzzers, if possible compile_fuzzer crypto/blake2b Fuzz fuzzBlake2b diff --git a/tests/fuzzers/modexp/modexp-fuzzer.go b/tests/fuzzers/modexp/modexp-fuzzer.go new file mode 100644 index 0000000000..0068c50302 --- /dev/null +++ b/tests/fuzzers/modexp/modexp-fuzzer.go @@ -0,0 +1,84 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package modexp + +import ( + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/math" + "github.com/ethereum/go-ethereum/core/vm" +) + +// The function must return +// 1 if the fuzzer should increase priority of the +// given input during subsequent fuzzing (for example, the input is lexically +// correct and was parsed successfully); +// -1 if the input must not be added to corpus even if gives new coverage; and +// 0 otherwise +// other values are reserved for future use. +func Fuzz(input []byte) int { + if len(input) <= 96 { + return -1 + } + // Abort on too expensive inputs + precomp := vm.PrecompiledContractsBerlin[common.BytesToAddress([]byte{5})] + if gas := precomp.RequiredGas(input); gas > 40_000_000 { + return 0 + } + var ( + baseLen = new(big.Int).SetBytes(getData(input, 0, 32)).Uint64() + expLen = new(big.Int).SetBytes(getData(input, 32, 32)).Uint64() + modLen = new(big.Int).SetBytes(getData(input, 64, 32)).Uint64() + ) + // Handle a special case when both the base and mod length is zero + if baseLen == 0 && modLen == 0 { + return -1 + } + input = input[96:] + // Retrieve the operands and execute the exponentiation + var ( + base = new(big.Int).SetBytes(getData(input, 0, baseLen)) + exp = new(big.Int).SetBytes(getData(input, baseLen, expLen)) + mod = new(big.Int).SetBytes(getData(input, baseLen+expLen, modLen)) + ) + if mod.BitLen() == 0 { + // Modulo 0 is undefined, return zero + return -1 + } + var a = math.FastExp(new(big.Int).Set(base), new(big.Int).Set(exp), new(big.Int).Set(mod)) + var b = base.Exp(base, exp, mod) + if a.Cmp(b) != 0 { + panic(fmt.Sprintf("Inequality %x != %x", a, b)) + } + return 1 +} + +// getData returns a slice from the data based on the start and size and pads +// up to size with zero's. This function is overflow safe. +func getData(data []byte, start uint64, size uint64) []byte { + length := uint64(len(data)) + if start > length { + start = length + } + end := start + size + if end > length { + end = length + } + return common.RightPadBytes(data[start:end], int(size)) +}