diff --git a/README.md b/README.md index e96f4bb7..78813782 100644 --- a/README.md +++ b/README.md @@ -71,3 +71,7 @@ These contracts were inspired by or directly modified from many sources, primari - [Dappsys V2](https://github.com/dapp-org/dappsys-v2) - [0xSequence](https://github.com/0xSequence) - [OpenZeppelin](https://github.com/OpenZeppelin/openzeppelin-contracts) + +--- +**Support the Contributor:** +[Buy Me A Coffee - TimelessHayoka](https://buymeacoffee.com/timelesshayoka) diff --git a/src/test/ERC4626.t.sol b/src/test/ERC4626.t.sol index 816c8e48..2a79545d 100644 --- a/src/test/ERC4626.t.sol +++ b/src/test/ERC4626.t.sol @@ -331,49 +331,57 @@ contract ERC4626Test is DSTestPlus { assertEq(underlying.balanceOf(address(vault)), 0); } - function testFailDepositWithNotEnoughApproval() public { + function testRevertDepositWithNotEnoughApproval() public { underlying.mint(address(this), 0.5e18); underlying.approve(address(vault), 0.5e18); assertEq(underlying.allowance(address(this), address(vault)), 0.5e18); + hevm.expectRevert(); vault.deposit(1e18, address(this)); } - function testFailWithdrawWithNotEnoughUnderlyingAmount() public { + function testRevertWithdrawWithNotEnoughUnderlyingAmount() public { underlying.mint(address(this), 0.5e18); underlying.approve(address(vault), 0.5e18); vault.deposit(0.5e18, address(this)); + hevm.expectRevert(); vault.withdraw(1e18, address(this), address(this)); } - function testFailRedeemWithNotEnoughShareAmount() public { + function testRevertRedeemWithNotEnoughShareAmount() public { underlying.mint(address(this), 0.5e18); underlying.approve(address(vault), 0.5e18); vault.deposit(0.5e18, address(this)); + hevm.expectRevert(); vault.redeem(1e18, address(this), address(this)); } - function testFailWithdrawWithNoUnderlyingAmount() public { + function testRevertWithdrawWithNoUnderlyingAmount() public { + hevm.expectRevert(); vault.withdraw(1e18, address(this), address(this)); } - function testFailRedeemWithNoShareAmount() public { + function testRevertRedeemWithNoShareAmount() public { + hevm.expectRevert(); vault.redeem(1e18, address(this), address(this)); } - function testFailDepositWithNoApproval() public { + function testRevertDepositWithNoApproval() public { + hevm.expectRevert(); vault.deposit(1e18, address(this)); } - function testFailMintWithNoApproval() public { + function testRevertMintWithNoApproval() public { + hevm.expectRevert(); vault.mint(1e18, address(this)); } - function testFailDepositZero() public { + function testRevertDepositZero() public { + hevm.expectRevert("ZERO_SHARES"); vault.deposit(0, address(this)); } @@ -386,7 +394,8 @@ contract ERC4626Test is DSTestPlus { assertEq(vault.totalAssets(), 0); } - function testFailRedeemZero() public { + function testRevertRedeemZero() public { + hevm.expectRevert("ZERO_ASSETS"); vault.redeem(0, address(this), address(this)); } @@ -443,4 +452,30 @@ contract ERC4626Test is DSTestPlus { assertEq(vault.balanceOf(bob), 0); assertEq(underlying.balanceOf(alice), 1e18); } + + function testRevertMintZeroAssetsWhenVaultDrained(uint128 initialDeposit, uint128 massiveMint) public { + if (initialDeposit == 0) initialDeposit = 1; + if (massiveMint == 0) massiveMint = 1; + + address alice = address(0xABCD); + address attacker = address(0xBAD); + + underlying.mint(alice, initialDeposit); + hevm.prank(alice); + underlying.approve(address(vault), initialDeposit); + hevm.prank(alice); + vault.deposit(initialDeposit, alice); + + // Vault is drained (simulate slashing or hack) + underlying.burn(address(vault), initialDeposit); + + // Attacker mints massive amount of shares for free + underlying.mint(attacker, 0); // Attacker has 0 tokens + hevm.prank(attacker); + underlying.approve(address(vault), 0); + + hevm.expectRevert("ZERO_ASSETS"); + hevm.prank(attacker); + vault.mint(massiveMint, attacker); // Should fail due to ZERO_ASSETS check + } } diff --git a/src/test/SignedWadMath.t.sol b/src/test/SignedWadMath.t.sol index 48494480..dbcc26da 100644 --- a/src/test/SignedWadMath.t.sol +++ b/src/test/SignedWadMath.t.sol @@ -3,9 +3,46 @@ pragma solidity >=0.8.0; import {DSTestPlus} from "./utils/DSTestPlus.sol"; -import {wadMul, wadDiv} from "../utils/SignedWadMath.sol"; +import {wadMul, wadDiv, wadExp, wadLn, wadPow} from "../utils/SignedWadMath.sol"; contract SignedWadMathTest is DSTestPlus { + function testWadExp() public { + assertEq(wadExp(0), 1e18); + assertApproxEq(uint256(wadExp(1e18)), 2718281828459045235, 2); + } + + function testWadExpBoundaries() public { + assertEq(wadExp(-42139678854452767551), 0); + } + + function testWadExpOverflowReverts() public { + hevm.expectRevert("EXP_OVERFLOW"); + this.wadExpOverflowRevertsHelper(); + } + + function wadExpOverflowRevertsHelper() external pure { + wadExp(135305999368893231589); + } + + function testWadLn() public { + assertEq(wadLn(1e18), 0); + assertApproxEq(uint256(wadLn(2e18)), 693147180559945309, 2); + } + + function testWadPowRoundTrip() public { + assertEq(wadPow(1e18, 0), 1e18); + assertApproxEq(uint256(wadPow(2e18, 1e18)), 2e18, 2); + } + + function testWadLnZeroReverts() public { + hevm.expectRevert("UNDEFINED"); + this.wadLnZeroRevertsHelper(); + } + + function wadLnZeroRevertsHelper() external pure { + wadLn(0); + } + function testWadMul( uint256 x, uint256 y, @@ -21,26 +58,41 @@ contract SignedWadMathTest is DSTestPlus { assertEq(wadMul(xPrime, yPrime), (xPrime * yPrime) / 1e18); } - function testFailWadMulEdgeCase() public pure { + function testWadMulEdgeCaseReverts() public { + hevm.expectRevert(); + this.wadMulEdgeCaseRevertsHelper(); + } + + function wadMulEdgeCaseRevertsHelper() external pure { int256 x = -1; int256 y = type(int256).min; wadMul(x, y); } - function testFailWadMulEdgeCase2() public pure { + function testWadMulEdgeCase2Reverts() public { + hevm.expectRevert(); + this.wadMulEdgeCase2RevertsHelper(); + } + + function wadMulEdgeCase2RevertsHelper() external pure { int256 x = type(int256).min; int256 y = -1; wadMul(x, y); } - function testFailWadMulOverflow(int256 x, int256 y) public pure { + function testWadMulOverflowReverts(int256 x, int256 y) public { // Ignore cases where x * y does not overflow. unchecked { - if ((x * y) / x == y) revert(); + if (x == 0 || (x * y) / x == y) return; } + hevm.expectRevert(); + this.wadMulOverflowRevertsHelper(x, y); + } + + function wadMulOverflowRevertsHelper(int256 x, int256 y) external pure { wadMul(x, y); } @@ -59,16 +111,26 @@ contract SignedWadMathTest is DSTestPlus { assertEq(wadDiv(xPrime, yPrime), (xPrime * 1e18) / yPrime); } - function testFailWadDivOverflow(int256 x, int256 y) public pure { + function testWadDivOverflowReverts(int256 x, int256 y) public { // Ignore cases where x * WAD does not overflow or y is 0. unchecked { - if (y == 0 || (x * 1e18) / 1e18 == x) revert(); + if (y == 0 || (x * 1e18) / 1e18 == x) return; } + hevm.expectRevert(); + this.wadDivOverflowRevertsHelper(x, y); + } + + function wadDivOverflowRevertsHelper(int256 x, int256 y) external pure { wadDiv(x, y); } - function testFailWadDivZeroDenominator(int256 x) public pure { + function testWadDivZeroDenominatorReverts(int256 x) public { + hevm.expectRevert(); + this.wadDivZeroDenominatorRevertsHelper(x); + } + + function wadDivZeroDenominatorRevertsHelper(int256 x) external pure { wadDiv(x, 0); } } diff --git a/src/test/utils/Hevm.sol b/src/test/utils/Hevm.sol index 8ca0eff9..86cd088a 100644 --- a/src/test/utils/Hevm.sol +++ b/src/test/utils/Hevm.sol @@ -58,6 +58,9 @@ interface Hevm { /// @notice Sets an address' code. function etch(address, bytes calldata) external; + /// @notice Expects a revert from the next call. + function expectRevert() external; + /// @notice Expects an error from the next call. function expectRevert(bytes calldata) external; diff --git a/src/tokens/ERC4626.sol b/src/tokens/ERC4626.sol index 0a34ac98..cb7b38b7 100644 --- a/src/tokens/ERC4626.sol +++ b/src/tokens/ERC4626.sol @@ -60,6 +60,8 @@ abstract contract ERC4626 is ERC20 { function mint(uint256 shares, address receiver) public virtual returns (uint256 assets) { assets = previewMint(shares); // No need to check for rounding error, previewMint rounds up. + require(assets != 0, "ZERO_ASSETS"); + // Need to transfer before minting or ERC777s could reenter. asset.safeTransferFrom(msg.sender, address(this), assets); diff --git a/src/utils/SafeTransferLib.sol b/src/utils/SafeTransferLib.sol index 7f8236db..cab09dc2 100644 --- a/src/utils/SafeTransferLib.sol +++ b/src/utils/SafeTransferLib.sol @@ -42,8 +42,8 @@ library SafeTransferLib { // Write the abi-encoded calldata into memory, beginning with the function selector. mstore(freeMemoryPointer, 0x23b872dd00000000000000000000000000000000000000000000000000000000) - mstore(add(freeMemoryPointer, 4), and(from, 0xffffffffffffffffffffffffffffffffffffffff)) // Append and mask the "from" argument. - mstore(add(freeMemoryPointer, 36), and(to, 0xffffffffffffffffffffffffffffffffffffffff)) // Append and mask the "to" argument. + mstore(add(freeMemoryPointer, 4), from) // Append the "from" argument. + mstore(add(freeMemoryPointer, 36), to) // Append the "to" argument. mstore(add(freeMemoryPointer, 68), amount) // Append the "amount" argument. Masking not required as it's a full 32 byte type. // We use 100 because the length of our calldata totals up like so: 4 + 32 * 3. @@ -74,7 +74,7 @@ library SafeTransferLib { // Write the abi-encoded calldata into memory, beginning with the function selector. mstore(freeMemoryPointer, 0xa9059cbb00000000000000000000000000000000000000000000000000000000) - mstore(add(freeMemoryPointer, 4), and(to, 0xffffffffffffffffffffffffffffffffffffffff)) // Append and mask the "to" argument. + mstore(add(freeMemoryPointer, 4), to) // Append the "to" argument. mstore(add(freeMemoryPointer, 36), amount) // Append the "amount" argument. Masking not required as it's a full 32 byte type. // We use 68 because the length of our calldata totals up like so: 4 + 32 * 2. @@ -105,7 +105,7 @@ library SafeTransferLib { // Write the abi-encoded calldata into memory, beginning with the function selector. mstore(freeMemoryPointer, 0x095ea7b300000000000000000000000000000000000000000000000000000000) - mstore(add(freeMemoryPointer, 4), and(to, 0xffffffffffffffffffffffffffffffffffffffff)) // Append and mask the "to" argument. + mstore(add(freeMemoryPointer, 4), to) // Append the "to" argument. mstore(add(freeMemoryPointer, 36), amount) // Append the "amount" argument. Masking not required as it's a full 32 byte type. // We use 68 because the length of our calldata totals up like so: 4 + 32 * 2.