Generate already lint code from procedural generation (#5060)

pull/5065/head
Hadrien Croubois 8 months ago committed by GitHub
parent a241f09905
commit dd1e8988ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 6
      scripts/generate/run.js
  2. 149
      scripts/generate/templates/Arrays.js
  3. 49
      scripts/generate/templates/Checkpoints.js
  4. 48
      scripts/generate/templates/Checkpoints.t.js
  5. 30
      scripts/generate/templates/EnumerableMap.js
  6. 9
      scripts/generate/templates/EnumerableSet.js
  7. 71
      scripts/generate/templates/SafeCast.js
  8. 74
      scripts/generate/templates/SlotDerivation.js
  9. 118
      scripts/generate/templates/SlotDerivation.t.js
  10. 50
      scripts/generate/templates/StorageSlot.js
  11. 21
      scripts/generate/templates/StorageSlotMock.js

@ -1,6 +1,6 @@
#!/usr/bin/env node #!/usr/bin/env node
const cp = require('child_process'); // const cp = require('child_process');
const fs = require('fs'); const fs = require('fs');
const path = require('path'); const path = require('path');
const format = require('./format-lines'); const format = require('./format-lines');
@ -23,11 +23,11 @@ function generateFromTemplate(file, template, outputPrefix = '') {
...(version ? [version + ` (${file})`] : []), ...(version ? [version + ` (${file})`] : []),
`// This file was procedurally generated from ${input}.`, `// This file was procedurally generated from ${input}.`,
'', '',
require(template), require(template).trimEnd(),
); );
fs.writeFileSync(output, content); fs.writeFileSync(output, content);
cp.execFileSync('prettier', ['--write', output]); // cp.execFileSync('prettier', ['--write', output]);
} }
// Contracts // Contracts

@ -15,39 +15,39 @@ import {Math} from "./math/Math.sol";
`; `;
const sort = type => `\ const sort = type => `\
/** /**
* @dev Sort an array of ${type} (in memory) following the provided comparator function. * @dev Sort an array of ${type} (in memory) following the provided comparator function.
* *
* This function does the sorting "in place", meaning that it overrides the input. The object is returned for * This function does the sorting "in place", meaning that it overrides the input. The object is returned for
* convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array. * convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array.
* *
* NOTE: this function's cost is \`O(n · log(n))\` in average and \`O(n²)\` in the worst case, with n the length of the * NOTE: this function's cost is \`O(n · log(n))\` in average and \`O(n²)\` in the worst case, with n the length of the
* array. Using it in view functions that are executed through \`eth_call\` is safe, but one should be very careful * array. Using it in view functions that are executed through \`eth_call\` is safe, but one should be very careful
* when executing this as part of a transaction. If the array being sorted is too large, the sort operation may * when executing this as part of a transaction. If the array being sorted is too large, the sort operation may
* consume more gas than is available in a block, leading to potential DoS. * consume more gas than is available in a block, leading to potential DoS.
*/ */
function sort( function sort(
${type}[] memory array, ${type}[] memory array,
function(${type}, ${type}) pure returns (bool) comp function(${type}, ${type}) pure returns (bool) comp
) internal pure returns (${type}[] memory) { ) internal pure returns (${type}[] memory) {
${ ${
type === 'bytes32' type === 'bytes32'
? '_quickSort(_begin(array), _end(array), comp);' ? '_quickSort(_begin(array), _end(array), comp);'
: 'sort(_castToBytes32Array(array), _castToBytes32Comp(comp));' : 'sort(_castToBytes32Array(array), _castToBytes32Comp(comp));'
}
return array;
} }
return array;
}
/** /**
* @dev Variant of {sort} that sorts an array of ${type} in increasing order. * @dev Variant of {sort} that sorts an array of ${type} in increasing order.
*/ */
function sort(${type}[] memory array) internal pure returns (${type}[] memory) { function sort(${type}[] memory array) internal pure returns (${type}[] memory) {
${type === 'bytes32' ? 'sort(array, _defaultComp);' : 'sort(_castToBytes32Array(array), _defaultComp);'} ${type === 'bytes32' ? 'sort(array, _defaultComp);' : 'sort(_castToBytes32Array(array), _defaultComp);'}
return array; return array;
} }
`; `;
const quickSort = ` const quickSort = `\
/** /**
* @dev Performs a quick sort of a segment of memory. The segment sorted starts at \`begin\` (inclusive), and stops * @dev Performs a quick sort of a segment of memory. The segment sorted starts at \`begin\` (inclusive), and stops
* at end (exclusive). Sorting follows the \`comp\` comparator. * at end (exclusive). Sorting follows the \`comp\` comparator.
@ -123,34 +123,34 @@ function _swap(uint256 ptr1, uint256 ptr2) private pure {
} }
`; `;
const defaultComparator = ` const defaultComparator = `\
/// @dev Comparator for sorting arrays in increasing order. /// @dev Comparator for sorting arrays in increasing order.
function _defaultComp(bytes32 a, bytes32 b) private pure returns (bool) { function _defaultComp(bytes32 a, bytes32 b) private pure returns (bool) {
return a < b; return a < b;
} }
`; `;
const castArray = type => `\ const castArray = type => `\
/// @dev Helper: low level cast ${type} memory array to uint256 memory array /// @dev Helper: low level cast ${type} memory array to uint256 memory array
function _castToBytes32Array(${type}[] memory input) private pure returns (bytes32[] memory output) { function _castToBytes32Array(${type}[] memory input) private pure returns (bytes32[] memory output) {
assembly { assembly {
output := input output := input
}
} }
}
`; `;
const castComparator = type => `\ const castComparator = type => `\
/// @dev Helper: low level cast ${type} comp function to bytes32 comp function /// @dev Helper: low level cast ${type} comp function to bytes32 comp function
function _castToBytes32Comp( function _castToBytes32Comp(
function(${type}, ${type}) pure returns (bool) input function(${type}, ${type}) pure returns (bool) input
) private pure returns (function(bytes32, bytes32) pure returns (bool) output) { ) private pure returns (function(bytes32, bytes32) pure returns (bool) output) {
assembly { assembly {
output := input output := input
}
} }
}
`; `;
const search = ` const search = `\
/** /**
* @dev Searches a sorted \`array\` and returns the first index that contains * @dev Searches a sorted \`array\` and returns the first index that contains
* a value greater or equal to \`element\`. If no such index exists (i.e. all * a value greater or equal to \`element\`. If no such index exists (i.e. all
@ -319,12 +319,12 @@ function upperBoundMemory(uint256[] memory array, uint256 element) internal pure
} }
`; `;
const unsafeAccessStorage = type => ` const unsafeAccessStorage = type => `\
/** /**
* @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check. * @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
* *
* WARNING: Only use if you are certain \`pos\` is lower than the array length. * WARNING: Only use if you are certain \`pos\` is lower than the array length.
*/ */
function unsafeAccess(${type}[] storage arr, uint256 pos) internal pure returns (StorageSlot.${capitalize( function unsafeAccess(${type}[] storage arr, uint256 pos) internal pure returns (StorageSlot.${capitalize(
type, type,
)}Slot storage) { )}Slot storage) {
@ -334,9 +334,10 @@ function unsafeAccess(${type}[] storage arr, uint256 pos) internal pure returns
slot := arr.slot slot := arr.slot
} }
return slot.deriveArray().offset(pos).get${capitalize(type)}Slot(); return slot.deriveArray().offset(pos).get${capitalize(type)}Slot();
}`; }
`;
const unsafeAccessMemory = type => ` const unsafeAccessMemory = type => `\
/** /**
* @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check. * @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
* *
@ -349,7 +350,7 @@ function unsafeMemoryAccess(${type}[] memory arr, uint256 pos) internal pure ret
} }
`; `;
const unsafeSetLength = type => ` const unsafeSetLength = type => `\
/** /**
* @dev Helper to set the length of an dynamic array. Directly writing to \`.length\` is forbidden. * @dev Helper to set the length of an dynamic array. Directly writing to \`.length\` is forbidden.
* *
@ -360,26 +361,32 @@ function unsafeSetLength(${type}[] storage array, uint256 len) internal {
assembly { assembly {
sstore(array.slot, len) sstore(array.slot, len)
} }
}`; }
`;
// GENERATE // GENERATE
module.exports = format( module.exports = format(
header.trimEnd(), header.trimEnd(),
'library Arrays {', 'library Arrays {',
'using SlotDerivation for bytes32;', format(
'using StorageSlot for bytes32;', [].concat(
// sorting, comparator, helpers and internal 'using SlotDerivation for bytes32;',
sort('bytes32'), 'using StorageSlot for bytes32;',
TYPES.filter(type => type !== 'bytes32').map(sort), '',
quickSort, // sorting, comparator, helpers and internal
defaultComparator, sort('bytes32'),
TYPES.filter(type => type !== 'bytes32').map(castArray), TYPES.filter(type => type !== 'bytes32').map(sort),
TYPES.filter(type => type !== 'bytes32').map(castComparator), quickSort,
// lookup defaultComparator,
search, TYPES.filter(type => type !== 'bytes32').map(castArray),
// unsafe (direct) storage and memory access TYPES.filter(type => type !== 'bytes32').map(castComparator),
TYPES.map(unsafeAccessStorage), // lookup
TYPES.map(unsafeAccessMemory), search,
TYPES.map(unsafeSetLength), // unsafe (direct) storage and memory access
TYPES.map(unsafeAccessStorage),
TYPES.map(unsafeAccessMemory),
TYPES.map(unsafeSetLength),
),
).trimEnd(),
'}', '}',
); );

@ -17,10 +17,10 @@ import {Math} from "../math/Math.sol";
`; `;
const errors = `\ const errors = `\
/** /**
* @dev A value was attempted to be inserted on a past checkpoint. * @dev A value was attempted to be inserted on a past checkpoint.
*/ */
error CheckpointUnorderedInsertion(); error CheckpointUnorderedInsertion();
`; `;
const template = opts => `\ const template = opts => `\
@ -37,15 +37,11 @@ struct ${opts.checkpointTypeName} {
* @dev Pushes a (\`key\`, \`value\`) pair into a ${opts.historyTypeName} so that it is stored as the checkpoint. * @dev Pushes a (\`key\`, \`value\`) pair into a ${opts.historyTypeName} so that it is stored as the checkpoint.
* *
* Returns previous value and new value. * Returns previous value and new value.
* *
* IMPORTANT: Never accept \`key\` as a user input, since an arbitrary \`type(${opts.keyTypeName}).max\` key set will disable the * IMPORTANT: Never accept \`key\` as a user input, since an arbitrary \`type(${opts.keyTypeName}).max\` key set will disable the
* library. * library.
*/ */
function push( function push(${opts.historyTypeName} storage self, ${opts.keyTypeName} key, ${opts.valueTypeName} value) internal returns (${opts.valueTypeName}, ${opts.valueTypeName}) {
${opts.historyTypeName} storage self,
${opts.keyTypeName} key,
${opts.valueTypeName} value
) internal returns (${opts.valueTypeName}, ${opts.valueTypeName}) {
return _insert(self.${opts.checkpointFieldName}, key, value); return _insert(self.${opts.checkpointFieldName}, key, value);
} }
@ -108,15 +104,7 @@ function latest(${opts.historyTypeName} storage self) internal view returns (${o
* @dev Returns whether there is a checkpoint in the structure (i.e. it is not empty), and if so the key and value * @dev Returns whether there is a checkpoint in the structure (i.e. it is not empty), and if so the key and value
* in the most recent checkpoint. * in the most recent checkpoint.
*/ */
function latestCheckpoint(${opts.historyTypeName} storage self) function latestCheckpoint(${opts.historyTypeName} storage self) internal view returns (bool exists, ${opts.keyTypeName} ${opts.keyFieldName}, ${opts.valueTypeName} ${opts.valueFieldName}) {
internal
view
returns (
bool exists,
${opts.keyTypeName} ${opts.keyFieldName},
${opts.valueTypeName} ${opts.valueFieldName}
)
{
uint256 pos = self.${opts.checkpointFieldName}.length; uint256 pos = self.${opts.checkpointFieldName}.length;
if (pos == 0) { if (pos == 0) {
return (false, 0, 0); return (false, 0, 0);
@ -144,11 +132,7 @@ function at(${opts.historyTypeName} storage self, uint32 pos) internal view retu
* @dev Pushes a (\`key\`, \`value\`) pair into an ordered list of checkpoints, either by inserting a new checkpoint, * @dev Pushes a (\`key\`, \`value\`) pair into an ordered list of checkpoints, either by inserting a new checkpoint,
* or by updating the last one. * or by updating the last one.
*/ */
function _insert( function _insert(${opts.checkpointTypeName}[] storage self, ${opts.keyTypeName} key, ${opts.valueTypeName} value) private returns (${opts.valueTypeName}, ${opts.valueTypeName}) {
${opts.checkpointTypeName}[] storage self,
${opts.keyTypeName} key,
${opts.valueTypeName} value
) private returns (${opts.valueTypeName}, ${opts.valueTypeName}) {
uint256 pos = self.length; uint256 pos = self.length;
if (pos > 0) { if (pos > 0) {
@ -225,11 +209,10 @@ function _lowerBinaryLookup(
/** /**
* @dev Access an element of the array without performing bounds check. The position is assumed to be within bounds. * @dev Access an element of the array without performing bounds check. The position is assumed to be within bounds.
*/ */
function _unsafeAccess(${opts.checkpointTypeName}[] storage self, uint256 pos) function _unsafeAccess(
private ${opts.checkpointTypeName}[] storage self,
pure uint256 pos
returns (${opts.checkpointTypeName} storage result) ) private pure returns (${opts.checkpointTypeName} storage result) {
{
assembly { assembly {
mstore(0, self.slot) mstore(0, self.slot)
result.slot := add(keccak256(0, 0x20), pos) result.slot := add(keccak256(0, 0x20), pos)
@ -242,7 +225,11 @@ function _unsafeAccess(${opts.checkpointTypeName}[] storage self, uint256 pos)
module.exports = format( module.exports = format(
header.trimEnd(), header.trimEnd(),
'library Checkpoints {', 'library Checkpoints {',
errors, format(
OPTS.flatMap(opts => template(opts)), [].concat(
errors,
OPTS.map(opts => template(opts)),
),
).trimEnd(),
'}', '}',
); );

@ -22,18 +22,13 @@ uint8 internal constant _KEY_MAX_GAP = 64;
Checkpoints.${opts.historyTypeName} internal _ckpts; Checkpoints.${opts.historyTypeName} internal _ckpts;
// helpers // helpers
function _bound${capitalize(opts.keyTypeName)}( function _bound${capitalize(opts.keyTypeName)}(${opts.keyTypeName} x, ${opts.keyTypeName} min, ${
${opts.keyTypeName} x, opts.keyTypeName
${opts.keyTypeName} min, } max) internal pure returns (${opts.keyTypeName}) {
${opts.keyTypeName} max
) internal pure returns (${opts.keyTypeName}) {
return SafeCast.to${capitalize(opts.keyTypeName)}(bound(uint256(x), uint256(min), uint256(max))); return SafeCast.to${capitalize(opts.keyTypeName)}(bound(uint256(x), uint256(min), uint256(max)));
} }
function _prepareKeys( function _prepareKeys(${opts.keyTypeName}[] memory keys, ${opts.keyTypeName} maxSpread) internal pure {
${opts.keyTypeName}[] memory keys,
${opts.keyTypeName} maxSpread
) internal pure {
${opts.keyTypeName} lastKey = 0; ${opts.keyTypeName} lastKey = 0;
for (uint256 i = 0; i < keys.length; ++i) { for (uint256 i = 0; i < keys.length; ++i) {
${opts.keyTypeName} key = _bound${capitalize(opts.keyTypeName)}(keys[i], lastKey, lastKey + maxSpread); ${opts.keyTypeName} key = _bound${capitalize(opts.keyTypeName)}(keys[i], lastKey, lastKey + maxSpread);
@ -42,11 +37,7 @@ function _prepareKeys(
} }
} }
function _assertLatestCheckpoint( function _assertLatestCheckpoint(bool exist, ${opts.keyTypeName} key, ${opts.valueTypeName} value) internal {
bool exist,
${opts.keyTypeName} key,
${opts.valueTypeName} value
) internal {
(bool _exist, ${opts.keyTypeName} _key, ${opts.valueTypeName} _value) = _ckpts.latestCheckpoint(); (bool _exist, ${opts.keyTypeName} _key, ${opts.valueTypeName} _value) = _ckpts.latestCheckpoint();
assertEq(_exist, exist); assertEq(_exist, exist);
assertEq(_key, key); assertEq(_key, key);
@ -54,11 +45,9 @@ function _assertLatestCheckpoint(
} }
// tests // tests
function testPush( function testPush(${opts.keyTypeName}[] memory keys, ${opts.valueTypeName}[] memory values, ${
${opts.keyTypeName}[] memory keys, opts.keyTypeName
${opts.valueTypeName}[] memory values, } pastKey) public {
${opts.keyTypeName} pastKey
) public {
vm.assume(values.length > 0 && values.length <= keys.length); vm.assume(values.length > 0 && values.length <= keys.length);
_prepareKeys(keys, _KEY_MAX_GAP); _prepareKeys(keys, _KEY_MAX_GAP);
@ -71,7 +60,7 @@ function testPush(
for (uint256 i = 0; i < keys.length; ++i) { for (uint256 i = 0; i < keys.length; ++i) {
${opts.keyTypeName} key = keys[i]; ${opts.keyTypeName} key = keys[i];
${opts.valueTypeName} value = values[i % values.length]; ${opts.valueTypeName} value = values[i % values.length];
if (i > 0 && key == keys[i-1]) ++duplicates; if (i > 0 && key == keys[i - 1]) ++duplicates;
// push // push
_ckpts.push(key, value); _ckpts.push(key, value);
@ -95,14 +84,12 @@ function testPush(
// used to test reverts // used to test reverts
function push(${opts.keyTypeName} key, ${opts.valueTypeName} value) external { function push(${opts.keyTypeName} key, ${opts.valueTypeName} value) external {
_ckpts.push(key, value); _ckpts.push(key, value);
} }
function testLookup( function testLookup(${opts.keyTypeName}[] memory keys, ${opts.valueTypeName}[] memory values, ${
${opts.keyTypeName}[] memory keys, opts.keyTypeName
${opts.valueTypeName}[] memory values, } lookup) public {
${opts.keyTypeName} lookup
) public {
vm.assume(values.length > 0 && values.length <= keys.length); vm.assume(values.length > 0 && values.length <= keys.length);
_prepareKeys(keys, _KEY_MAX_GAP); _prepareKeys(keys, _KEY_MAX_GAP);
@ -124,7 +111,7 @@ function testLookup(
upper = value; upper = value;
} }
// find the first key that is not smaller than the lookup key // find the first key that is not smaller than the lookup key
if (key >= lookup && (i == 0 || keys[i-1] < lookup)) { if (key >= lookup && (i == 0 || keys[i - 1] < lookup)) {
lowerKey = key; lowerKey = key;
} }
if (key == lowerKey) { if (key == lowerKey) {
@ -142,5 +129,10 @@ function testLookup(
// GENERATE // GENERATE
module.exports = format( module.exports = format(
header, header,
...OPTS.flatMap(opts => [`contract Checkpoints${opts.historyTypeName}Test is Test {`, [template(opts)], '}']), ...OPTS.flatMap(opts => [
`contract Checkpoints${opts.historyTypeName}Test is Test {`,
[template(opts).trimEnd()],
'}',
'',
]),
); );

@ -54,7 +54,7 @@ import {EnumerableSet} from "./EnumerableSet.sol";
`; `;
/* eslint-enable max-len */ /* eslint-enable max-len */
const defaultMap = () => `\ const defaultMap = `\
// To implement this library for multiple types with as little code repetition as possible, we write it in // To implement this library for multiple types with as little code repetition as possible, we write it in
// terms of a generic Map type with bytes32 keys and values. The Map implementation uses private functions, // terms of a generic Map type with bytes32 keys and values. The Map implementation uses private functions,
// and user-facing implementations such as \`UintToAddressMap\` are just wrappers around the underlying Map. // and user-facing implementations such as \`UintToAddressMap\` are just wrappers around the underlying Map.
@ -78,11 +78,7 @@ struct Bytes32ToBytes32Map {
* Returns true if the key was added to the map, that is if it was not * Returns true if the key was added to the map, that is if it was not
* already present. * already present.
*/ */
function set( function set(Bytes32ToBytes32Map storage map, bytes32 key, bytes32 value) internal returns (bool) {
Bytes32ToBytes32Map storage map,
bytes32 key,
bytes32 value
) internal returns (bool) {
map._values[key] = value; map._values[key] = value;
return map._keys.add(key); return map._keys.add(key);
} }
@ -148,7 +144,7 @@ function tryGet(Bytes32ToBytes32Map storage map, bytes32 key) internal view retu
*/ */
function get(Bytes32ToBytes32Map storage map, bytes32 key) internal view returns (bytes32) { function get(Bytes32ToBytes32Map storage map, bytes32 key) internal view returns (bytes32) {
bytes32 value = map._values[key]; bytes32 value = map._values[key];
if(value == 0 && !contains(map, key)) { if (value == 0 && !contains(map, key)) {
revert EnumerableMapNonexistentKey(key); revert EnumerableMapNonexistentKey(key);
} }
return value; return value;
@ -181,11 +177,7 @@ struct ${name} {
* Returns true if the key was added to the map, that is if it was not * Returns true if the key was added to the map, that is if it was not
* already present. * already present.
*/ */
function set( function set(${name} storage map, ${keyType} key, ${valueType} value) internal returns (bool) {
${name} storage map,
${keyType} key,
${valueType} value
) internal returns (bool) {
return set(map._inner, ${toBytes32(keyType, 'key')}, ${toBytes32(valueType, 'value')}); return set(map._inner, ${toBytes32(keyType, 'key')}, ${toBytes32(valueType, 'value')});
} }
@ -271,11 +263,13 @@ function keys(${name} storage map) internal view returns (${keyType}[] memory) {
module.exports = format( module.exports = format(
header.trimEnd(), header.trimEnd(),
'library EnumerableMap {', 'library EnumerableMap {',
[ format(
'using EnumerableSet for EnumerableSet.Bytes32Set;', [].concat(
'', 'using EnumerableSet for EnumerableSet.Bytes32Set;',
defaultMap(), '',
TYPES.map(details => customMap(details).trimEnd()).join('\n\n'), defaultMap,
], TYPES.map(details => customMap(details)),
),
).trimEnd(),
'}', '}',
); );

@ -43,7 +43,7 @@ pragma solidity ^0.8.20;
`; `;
/* eslint-enable max-len */ /* eslint-enable max-len */
const defaultSet = () => `\ const defaultSet = `\
// To implement this library for multiple types with as little code // To implement this library for multiple types with as little code
// repetition as possible, we write it in terms of a generic Set type with // repetition as possible, we write it in terms of a generic Set type with
// bytes32 values. // bytes32 values.
@ -240,6 +240,11 @@ function values(${name} storage set) internal view returns (${type}[] memory) {
module.exports = format( module.exports = format(
header.trimEnd(), header.trimEnd(),
'library EnumerableSet {', 'library EnumerableSet {',
[defaultSet(), TYPES.map(details => customSet(details).trimEnd()).join('\n\n')], format(
[].concat(
defaultSet,
TYPES.map(details => customSet(details)),
),
).trimEnd(),
'}', '}',
); );

@ -21,25 +21,25 @@ pragma solidity ^0.8.20;
`; `;
const errors = `\ const errors = `\
/** /**
* @dev Value doesn't fit in an uint of \`bits\` size. * @dev Value doesn't fit in an uint of \`bits\` size.
*/ */
error SafeCastOverflowedUintDowncast(uint8 bits, uint256 value); error SafeCastOverflowedUintDowncast(uint8 bits, uint256 value);
/** /**
* @dev An int value doesn't fit in an uint of \`bits\` size. * @dev An int value doesn't fit in an uint of \`bits\` size.
*/ */
error SafeCastOverflowedIntToUint(int256 value); error SafeCastOverflowedIntToUint(int256 value);
/** /**
* @dev Value doesn't fit in an int of \`bits\` size. * @dev Value doesn't fit in an int of \`bits\` size.
*/ */
error SafeCastOverflowedIntDowncast(uint8 bits, int256 value); error SafeCastOverflowedIntDowncast(uint8 bits, int256 value);
/** /**
* @dev An uint value doesn't fit in an int of \`bits\` size. * @dev An uint value doesn't fit in an int of \`bits\` size.
*/ */
error SafeCastOverflowedUintToInt(uint256 value); error SafeCastOverflowedUintToInt(uint256 value);
`; `;
const toUintDownCast = length => `\ const toUintDownCast = length => `\
@ -55,7 +55,7 @@ const toUintDownCast = length => `\
*/ */
function toUint${length}(uint256 value) internal pure returns (uint${length}) { function toUint${length}(uint256 value) internal pure returns (uint${length}) {
if (value > type(uint${length}).max) { if (value > type(uint${length}).max) {
revert SafeCastOverflowedUintDowncast(${length}, value); revert SafeCastOverflowedUintDowncast(${length}, value);
} }
return uint${length}(value); return uint${length}(value);
} }
@ -77,7 +77,7 @@ const toIntDownCast = length => `\
function toInt${length}(int256 value) internal pure returns (int${length} downcasted) { function toInt${length}(int256 value) internal pure returns (int${length} downcasted) {
downcasted = int${length}(value); downcasted = int${length}(value);
if (downcasted != value) { if (downcasted != value) {
revert SafeCastOverflowedIntDowncast(${length}, value); revert SafeCastOverflowedIntDowncast(${length}, value);
} }
} }
`; `;
@ -94,7 +94,7 @@ const toInt = length => `\
function toInt${length}(uint${length} value) internal pure returns (int${length}) { function toInt${length}(uint${length} value) internal pure returns (int${length}) {
// Note: Unsafe cast below is okay because \`type(int${length}).max\` is guaranteed to be positive // Note: Unsafe cast below is okay because \`type(int${length}).max\` is guaranteed to be positive
if (value > uint${length}(type(int${length}).max)) { if (value > uint${length}(type(int${length}).max)) {
revert SafeCastOverflowedUintToInt(value); revert SafeCastOverflowedUintToInt(value);
} }
return int${length}(value); return int${length}(value);
} }
@ -110,29 +110,30 @@ const toUint = length => `\
*/ */
function toUint${length}(int${length} value) internal pure returns (uint${length}) { function toUint${length}(int${length} value) internal pure returns (uint${length}) {
if (value < 0) { if (value < 0) {
revert SafeCastOverflowedIntToUint(value); revert SafeCastOverflowedIntToUint(value);
} }
return uint${length}(value); return uint${length}(value);
} }
`; `;
const boolToUint = ` const boolToUint = `\
/** /**
* @dev Cast a boolean (false or true) to a uint256 (0 or 1) with no jump. * @dev Cast a boolean (false or true) to a uint256 (0 or 1) with no jump.
*/ */
function toUint(bool b) internal pure returns (uint256 u) { function toUint(bool b) internal pure returns (uint256 u) {
/// @solidity memory-safe-assembly /// @solidity memory-safe-assembly
assembly { assembly {
u := iszero(iszero(b)) u := iszero(iszero(b))
} }
} }
`; `;
// GENERATE // GENERATE
module.exports = format( module.exports = format(
header.trimEnd(), header.trimEnd(),
'library SafeCast {', 'library SafeCast {',
errors, format(
[...LENGTHS.map(toUintDownCast), toUint(256), ...LENGTHS.map(toIntDownCast), toInt(256), boolToUint], [].concat(errors, LENGTHS.map(toUintDownCast), toUint(256), LENGTHS.map(toIntDownCast), toInt(256), boolToUint),
).trimEnd(),
'}', '}',
); );

@ -10,7 +10,7 @@ pragma solidity ^0.8.20;
* the solidity language / compiler. * the solidity language / compiler.
* *
* See https://docs.soliditylang.org/en/v0.8.20/internals/layout_in_storage.html#mappings-and-dynamic-arrays[Solidity docs for mappings and dynamic arrays.]. * See https://docs.soliditylang.org/en/v0.8.20/internals/layout_in_storage.html#mappings-and-dynamic-arrays[Solidity docs for mappings and dynamic arrays.].
* *
* Example usage: * Example usage:
* \`\`\`solidity * \`\`\`solidity
* contract Example { * contract Example {
@ -30,9 +30,9 @@ pragma solidity ^0.8.20;
* } * }
* } * }
* \`\`\` * \`\`\`
* *
* TIP: Consider using this library along with {StorageSlot}. * TIP: Consider using this library along with {StorageSlot}.
* *
* NOTE: This library provides a way to manipulate storage locations in a non-standard way. Tooling for checking * NOTE: This library provides a way to manipulate storage locations in a non-standard way. Tooling for checking
* upgrade safety will ignore the slots accessed through this library. * upgrade safety will ignore the slots accessed through this library.
*/ */
@ -43,11 +43,11 @@ const namespace = `\
* @dev Derive an ERC-7201 slot from a string (namespace). * @dev Derive an ERC-7201 slot from a string (namespace).
*/ */
function erc7201Slot(string memory namespace) internal pure returns (bytes32 slot) { function erc7201Slot(string memory namespace) internal pure returns (bytes32 slot) {
/// @solidity memory-safe-assembly /// @solidity memory-safe-assembly
assembly { assembly {
mstore(0x00, sub(keccak256(add(namespace, 0x20), mload(namespace)), 1)) mstore(0x00, sub(keccak256(add(namespace, 0x20), mload(namespace)), 1))
slot := and(keccak256(0x00, 0x20), not(0xff)) slot := and(keccak256(0x00, 0x20), not(0xff))
} }
} }
`; `;
@ -56,20 +56,20 @@ const array = `\
* @dev Add an offset to a slot to get the n-th element of a structure or an array. * @dev Add an offset to a slot to get the n-th element of a structure or an array.
*/ */
function offset(bytes32 slot, uint256 pos) internal pure returns (bytes32 result) { function offset(bytes32 slot, uint256 pos) internal pure returns (bytes32 result) {
unchecked { unchecked {
return bytes32(uint256(slot) + pos); return bytes32(uint256(slot) + pos);
} }
} }
/** /**
* @dev Derive the location of the first element in an array from the slot where the length is stored. * @dev Derive the location of the first element in an array from the slot where the length is stored.
*/ */
function deriveArray(bytes32 slot) internal pure returns (bytes32 result) { function deriveArray(bytes32 slot) internal pure returns (bytes32 result) {
/// @solidity memory-safe-assembly /// @solidity memory-safe-assembly
assembly { assembly {
mstore(0x00, slot) mstore(0x00, slot)
result := keccak256(0x00, 0x20) result := keccak256(0x00, 0x20)
} }
} }
`; `;
@ -78,12 +78,12 @@ const mapping = ({ type }) => `\
* @dev Derive the location of a mapping element from the key. * @dev Derive the location of a mapping element from the key.
*/ */
function deriveMapping(bytes32 slot, ${type} key) internal pure returns (bytes32 result) { function deriveMapping(bytes32 slot, ${type} key) internal pure returns (bytes32 result) {
/// @solidity memory-safe-assembly /// @solidity memory-safe-assembly
assembly { assembly {
mstore(0x00, key) mstore(0x00, key)
mstore(0x20, slot) mstore(0x20, slot)
result := keccak256(0x00, 0x40) result := keccak256(0x00, 0x40)
} }
} }
`; `;
@ -92,16 +92,16 @@ const mapping2 = ({ type }) => `\
* @dev Derive the location of a mapping element from the key. * @dev Derive the location of a mapping element from the key.
*/ */
function deriveMapping(bytes32 slot, ${type} memory key) internal pure returns (bytes32 result) { function deriveMapping(bytes32 slot, ${type} memory key) internal pure returns (bytes32 result) {
/// @solidity memory-safe-assembly /// @solidity memory-safe-assembly
assembly { assembly {
let length := mload(key) let length := mload(key)
let begin := add(key, 0x20) let begin := add(key, 0x20)
let end := add(begin, length) let end := add(begin, length)
let cache := mload(end) let cache := mload(end)
mstore(end, slot) mstore(end, slot)
result := keccak256(begin, add(length, 0x20)) result := keccak256(begin, add(length, 0x20))
mstore(end, cache) mstore(end, cache)
} }
} }
`; `;
@ -109,8 +109,12 @@ function deriveMapping(bytes32 slot, ${type} memory key) internal pure returns (
module.exports = format( module.exports = format(
header.trimEnd(), header.trimEnd(),
'library SlotDerivation {', 'library SlotDerivation {',
namespace, format(
array, [].concat(
TYPES.map(type => (type.isValueType ? mapping(type) : mapping2(type))), namespace,
array,
TYPES.map(type => (type.isValueType ? mapping(type) : mapping2(type))),
),
).trimEnd(),
'}', '}',
); );

@ -14,31 +14,31 @@ const array = `\
bytes[] private _array; bytes[] private _array;
function symbolicDeriveArray(uint256 length, uint256 offset) public { function symbolicDeriveArray(uint256 length, uint256 offset) public {
vm.assume(length > 0); vm.assume(length > 0);
vm.assume(offset < length); vm.assume(offset < length);
_assertDeriveArray(length, offset); _assertDeriveArray(length, offset);
} }
function testDeriveArray(uint256 length, uint256 offset) public { function testDeriveArray(uint256 length, uint256 offset) public {
length = bound(length, 1, type(uint256).max); length = bound(length, 1, type(uint256).max);
offset = bound(offset, 0, length - 1); offset = bound(offset, 0, length - 1);
_assertDeriveArray(length, offset); _assertDeriveArray(length, offset);
} }
function _assertDeriveArray(uint256 length, uint256 offset) public { function _assertDeriveArray(uint256 length, uint256 offset) public {
bytes32 baseSlot; bytes32 baseSlot;
assembly { assembly {
baseSlot := _array.slot baseSlot := _array.slot
sstore(baseSlot, length) // store length so solidity access does not revert sstore(baseSlot, length) // store length so solidity access does not revert
} }
bytes storage derived = _array[offset]; bytes storage derived = _array[offset];
bytes32 derivedSlot; bytes32 derivedSlot;
assembly { assembly {
derivedSlot := derived.slot derivedSlot := derived.slot
} }
assertEq(baseSlot.deriveArray().offset(offset), derivedSlot); assertEq(baseSlot.deriveArray().offset(offset), derivedSlot);
} }
`; `;
@ -46,18 +46,18 @@ const mapping = ({ type, name }) => `\
mapping(${type} => bytes) private _${type}Mapping; mapping(${type} => bytes) private _${type}Mapping;
function testSymbolicDeriveMapping${name}(${type} key) public { function testSymbolicDeriveMapping${name}(${type} key) public {
bytes32 baseSlot; bytes32 baseSlot;
assembly { assembly {
baseSlot := _${type}Mapping.slot baseSlot := _${type}Mapping.slot
} }
bytes storage derived = _${type}Mapping[key]; bytes storage derived = _${type}Mapping[key];
bytes32 derivedSlot; bytes32 derivedSlot;
assembly { assembly {
derivedSlot := derived.slot derivedSlot := derived.slot
} }
assertEq(baseSlot.deriveMapping(key), derivedSlot); assertEq(baseSlot.deriveMapping(key), derivedSlot);
} }
`; `;
@ -65,45 +65,49 @@ const boundedMapping = ({ type, name }) => `\
mapping(${type} => bytes) private _${type}Mapping; mapping(${type} => bytes) private _${type}Mapping;
function testDeriveMapping${name}(${type} memory key) public { function testDeriveMapping${name}(${type} memory key) public {
_assertDeriveMapping${name}(key); _assertDeriveMapping${name}(key);
} }
function symbolicDeriveMapping${name}() public { function symbolicDeriveMapping${name}() public {
_assertDeriveMapping${name}(svm.create${name}(256, "DeriveMapping${name}Input")); _assertDeriveMapping${name}(svm.create${name}(256, "DeriveMapping${name}Input"));
} }
function _assertDeriveMapping${name}(${type} memory key) internal { function _assertDeriveMapping${name}(${type} memory key) internal {
bytes32 baseSlot; bytes32 baseSlot;
assembly { assembly {
baseSlot := _${type}Mapping.slot baseSlot := _${type}Mapping.slot
} }
bytes storage derived = _${type}Mapping[key]; bytes storage derived = _${type}Mapping[key];
bytes32 derivedSlot; bytes32 derivedSlot;
assembly { assembly {
derivedSlot := derived.slot derivedSlot := derived.slot
} }
assertEq(baseSlot.deriveMapping(key), derivedSlot); assertEq(baseSlot.deriveMapping(key), derivedSlot);
} }
`; `;
// GENERATE // GENERATE
module.exports = format( module.exports = format(
header.trimEnd(), header,
'contract SlotDerivationTest is Test, SymTest {', 'contract SlotDerivationTest is Test, SymTest {',
'using SlotDerivation for bytes32;', format(
'',
array,
TYPES.flatMap(type =>
[].concat( [].concat(
type, 'using SlotDerivation for bytes32;',
(type.variants ?? []).map(variant => ({ '',
type: variant, array,
name: capitalize(variant), TYPES.flatMap(type =>
isValueType: type.isValueType, [].concat(
})), type,
(type.variants ?? []).map(variant => ({
type: variant,
name: capitalize(variant),
isValueType: type.isValueType,
})),
),
).map(type => (type.isValueType ? mapping(type) : boundedMapping(type))),
), ),
).map(type => (type.isValueType ? mapping(type) : boundedMapping(type))), ).trimEnd(),
'}', '}',
); );

@ -53,7 +53,7 @@ pragma solidity ^0.8.24;
const struct = ({ type, name }) => `\ const struct = ({ type, name }) => `\
struct ${name}Slot { struct ${name}Slot {
${type} value; ${type} value;
} }
`; `;
@ -62,10 +62,10 @@ const get = ({ name }) => `\
* @dev Returns an \`${name}Slot\` with member \`value\` located at \`slot\`. * @dev Returns an \`${name}Slot\` with member \`value\` located at \`slot\`.
*/ */
function get${name}Slot(bytes32 slot) internal pure returns (${name}Slot storage r) { function get${name}Slot(bytes32 slot) internal pure returns (${name}Slot storage r) {
/// @solidity memory-safe-assembly /// @solidity memory-safe-assembly
assembly { assembly {
r.slot := slot r.slot := slot
} }
} }
`; `;
@ -74,10 +74,10 @@ const getStorage = ({ type, name }) => `\
* @dev Returns an \`${name}Slot\` representation of the ${type} storage pointer \`store\`. * @dev Returns an \`${name}Slot\` representation of the ${type} storage pointer \`store\`.
*/ */
function get${name}Slot(${type} storage store) internal pure returns (${name}Slot storage r) { function get${name}Slot(${type} storage store) internal pure returns (${name}Slot storage r) {
/// @solidity memory-safe-assembly /// @solidity memory-safe-assembly
assembly { assembly {
r.slot := store.slot r.slot := store.slot
} }
} }
`; `;
@ -86,11 +86,12 @@ const udvt = ({ type, name }) => `\
* @dev UDVT that represent a slot holding a ${type}. * @dev UDVT that represent a slot holding a ${type}.
*/ */
type ${name}SlotType is bytes32; type ${name}SlotType is bytes32;
/** /**
* @dev Cast an arbitrary slot to a ${name}SlotType. * @dev Cast an arbitrary slot to a ${name}SlotType.
*/ */
function as${name}(bytes32 slot) internal pure returns (${name}SlotType) { function as${name}(bytes32 slot) internal pure returns (${name}SlotType) {
return ${name}SlotType.wrap(slot); return ${name}SlotType.wrap(slot);
} }
`; `;
@ -99,19 +100,20 @@ const transient = ({ type, name }) => `\
* @dev Load the value held at location \`slot\` in transient storage. * @dev Load the value held at location \`slot\` in transient storage.
*/ */
function tload(${name}SlotType slot) internal view returns (${type} value) { function tload(${name}SlotType slot) internal view returns (${type} value) {
/// @solidity memory-safe-assembly /// @solidity memory-safe-assembly
assembly { assembly {
value := tload(slot) value := tload(slot)
} }
} }
/** /**
* @dev Store \`value\` at location \`slot\` in transient storage. * @dev Store \`value\` at location \`slot\` in transient storage.
*/ */
function tstore(${name}SlotType slot, ${type} value) internal { function tstore(${name}SlotType slot, ${type} value) internal {
/// @solidity memory-safe-assembly /// @solidity memory-safe-assembly
assembly { assembly {
tstore(slot, value) tstore(slot, value)
} }
} }
`; `;
@ -119,9 +121,13 @@ function tstore(${name}SlotType slot, ${type} value) internal {
module.exports = format( module.exports = format(
header.trimEnd(), header.trimEnd(),
'library StorageSlot {', 'library StorageSlot {',
TYPES.map(type => struct(type)), format(
TYPES.flatMap(type => [get(type), type.isValueType ? '' : getStorage(type)]), [].concat(
TYPES.filter(type => type.isValueType).map(type => udvt(type)), TYPES.map(type => struct(type)),
TYPES.filter(type => type.isValueType).map(type => transient(type)), TYPES.flatMap(type => [get(type), !type.isValueType && getStorage(type)].filter(Boolean)),
TYPES.filter(type => type.isValueType).map(type => udvt(type)),
TYPES.filter(type => type.isValueType).map(type => transient(type)),
),
).trimEnd(),
'}', '}',
); );

@ -44,22 +44,27 @@ const transient = ({ type, name }) => `\
event ${name}Value(bytes32 slot, ${type} value); event ${name}Value(bytes32 slot, ${type} value);
function tload${name}(bytes32 slot) public { function tload${name}(bytes32 slot) public {
emit ${name}Value(slot, slot.as${name}().tload()); emit ${name}Value(slot, slot.as${name}().tload());
} }
function tstore(bytes32 slot, ${type} value) public { function tstore(bytes32 slot, ${type} value) public {
slot.as${name}().tstore(value); slot.as${name}().tstore(value);
} }
`; `;
// GENERATE // GENERATE
module.exports = format( module.exports = format(
header.trimEnd(), header,
'contract StorageSlotMock is Multicall {', 'contract StorageSlotMock is Multicall {',
'using StorageSlot for *;', format(
TYPES.filter(type => type.isValueType).map(type => storageSetValueType(type)), [].concat(
TYPES.filter(type => type.isValueType).map(type => storageGetValueType(type)), 'using StorageSlot for *;',
TYPES.filter(type => !type.isValueType).map(type => storageSetNonValueType(type)), '',
TYPES.filter(type => type.isValueType).map(type => transient(type)), TYPES.filter(type => type.isValueType).map(type => storageSetValueType(type)),
TYPES.filter(type => type.isValueType).map(type => storageGetValueType(type)),
TYPES.filter(type => !type.isValueType).map(type => storageSetNonValueType(type)),
TYPES.filter(type => type.isValueType).map(type => transient(type)),
),
).trimEnd(),
'}', '}',
); );

Loading…
Cancel
Save