@@ -9,6 +9,7 @@ use crate::{
99 error:: NttQuoterError ,
1010 state:: { Instance , RegisteredChain , RegisteredNtt , RelayRequest } ,
1111} ;
12+ use std:: result:: Result as StdResult ;
1213
1314//TODO eventually drop the released constraint and instead implement release by relayer
1415fn check_release_constraint_and_fetch_chain_id (
@@ -30,6 +31,10 @@ pub struct RequestRelay<'info> {
3031 #[ account( mut ) ]
3132 pub payer : Signer < ' info > ,
3233
34+ #[ account(
35+ constraint = instance. sol_price != 0 @
36+ NttQuoterError :: PriceCannotBeZero
37+ ) ]
3338 pub instance : Account < ' info , Instance > ,
3439
3540 #[ account(
@@ -76,20 +81,32 @@ pub struct RequestRelayArgs {
7681
7782const GWEI : u64 = u64:: pow ( 10 , 9 ) ;
7883
84+ /// Performs a multiply-divide operation on u64 inputs, using u128s as intermediate terms to prevent overflow.
85+ ///
86+ /// # Errors
87+ ///
88+ /// This function can return either a [`NttQuoterError::DivByZero`] when `denominator` is zero or a
89+ /// [`NttQuoterError::ScalingOverflow`] error if the final result cannot be converted back to a
90+ /// u64.
7991//TODO built-in u128 division likely still wastes a ton of compute units
8092// might be more efficient to use f64 or ruint crate
8193// SECURITY: Integer division is OK here. The calling code is responsible for understanding that
8294// this function returns the quotient of the operation and that the remainder will be lost.
8395#[ allow( clippy:: integer_division) ]
84- fn mul_div ( scalar : u64 , numerator : u64 , denominator : u64 ) -> u64 {
85- if scalar > 0 {
86- //avoid potentially expensive u128 division
87- ( ( scalar as u128 ) * ( numerator as u128 ) / ( denominator as u128 ) )
88- . try_into ( )
89- . unwrap ( )
90- } else {
91- 0
96+ fn mul_div ( scalar : u64 , numerator : u64 , denominator : u64 ) -> StdResult < u64 , NttQuoterError > {
97+ if denominator == 0 {
98+ return Err ( NttQuoterError :: DivByZero ) ;
99+ }
100+
101+ //avoid potentially expensive u128 division
102+ if scalar == 0 || numerator == 0 {
103+ return Ok ( 0 ) ;
92104 }
105+
106+ u64:: try_from ( u128:: from ( scalar) * u128:: from ( numerator) / u128:: from ( denominator) )
107+ . map_or ( Err ( NttQuoterError :: ScalingOverflow ) , |quotient| {
108+ Ok ( quotient)
109+ } )
93110}
94111
95112pub fn request_relay ( ctx : Context < RequestRelay > , args : RequestRelayArgs ) -> Result < ( ) > {
@@ -107,20 +124,20 @@ pub fn request_relay(ctx: Context<RequestRelay>, args: RequestRelayArgs) -> Resu
107124 accs. registered_chain . gas_price ,
108125 accs. registered_ntt . gas_cost as u64 ,
109126 GWEI ,
110- ) ;
127+ ) ? ;
111128
112129 //usd/target_native[usd, 6 decimals] * target_native[gwei, 9 decimals] = usd[usd, 6 decimals]
113130 let target_native_in_usd = mul_div (
114131 accs. registered_chain . native_price ,
115132 target_native_in_gwei,
116133 GWEI ,
117- ) ;
134+ ) ? ;
118135
119136 let total_in_usd = target_native_in_usd + accs. registered_chain . base_price ;
120137
121138 //total_fee[sol, 9 decimals] = total_usd[usd, 6 decimals] / (sol_price[usd, 6 decimals]
122139 mul_div ( total_in_usd, LAMPORTS_PER_SOL , accs. instance . sol_price )
123- } ;
140+ } ? ;
124141
125142 let rent_in_lamports = sysvar:: rent:: Rent :: get ( ) ?. minimum_balance ( 8 + RelayRequest :: INIT_SPACE ) ;
126143 let fee_minus_rent = relay_fee_in_lamports. saturating_sub ( rent_in_lamports) ;
0 commit comments