@@ -37,6 +37,7 @@ use crate::vm::costs::{
37
37
analysis_typecheck_cost, cost_functions, runtime_cost, ClarityCostFunctionReference ,
38
38
CostErrors , CostOverflowingMath , CostTracker , ExecutionCost , LimitedCostTracker ,
39
39
} ;
40
+ use crate :: vm:: diagnostic:: Diagnostic ;
40
41
use crate :: vm:: functions:: define:: DefineFunctionsParsed ;
41
42
use crate :: vm:: functions:: NativeFunctions ;
42
43
use crate :: vm:: representations:: SymbolicExpressionType :: {
@@ -151,7 +152,130 @@ impl TypeChecker<'_, '_> {
151
152
152
153
pub type TypeResult = CheckResult < TypeSignature > ;
153
154
155
+ pub fn compute_typecheck_cost < T : CostTracker > (
156
+ track : & mut T ,
157
+ t1 : & TypeSignature ,
158
+ t2 : & TypeSignature ,
159
+ ) -> Result < ExecutionCost , CostErrors > {
160
+ let t1_size = t1. type_size ( ) . map_err ( |_| CostErrors :: CostOverflow ) ?;
161
+ let t2_size = t2. type_size ( ) . map_err ( |_| CostErrors :: CostOverflow ) ?;
162
+ track. compute_cost (
163
+ ClarityCostFunction :: AnalysisTypeCheck ,
164
+ & [ std:: cmp:: max ( t1_size, t2_size) . into ( ) ] ,
165
+ )
166
+ }
167
+
168
+ pub fn check_argument_len ( expected : usize , args_len : usize ) -> Result < ( ) , CheckErrors > {
169
+ if args_len != expected {
170
+ Err ( CheckErrors :: IncorrectArgumentCount ( expected, args_len) )
171
+ } else {
172
+ Ok ( ( ) )
173
+ }
174
+ }
175
+
154
176
impl FunctionType {
177
+ pub fn check_args_visitor_2_1 < T : CostTracker > (
178
+ & self ,
179
+ accounting : & mut T ,
180
+ arg_type : & TypeSignature ,
181
+ arg_index : usize ,
182
+ accumulated_type : Option < & TypeSignature > ,
183
+ ) -> (
184
+ Option < Result < ExecutionCost , CostErrors > > ,
185
+ CheckResult < Option < TypeSignature > > ,
186
+ ) {
187
+ match self {
188
+ // variadic stops checking cost at the first error...
189
+ FunctionType :: Variadic ( expected_type, _) => {
190
+ let cost = Some ( compute_typecheck_cost ( accounting, expected_type, arg_type) ) ;
191
+ let admitted = match expected_type. admits_type ( & StacksEpochId :: Epoch21 , arg_type) {
192
+ Ok ( admitted) => admitted,
193
+ Err ( e) => return ( cost, Err ( e. into ( ) ) ) ,
194
+ } ;
195
+ if !admitted {
196
+ return (
197
+ cost,
198
+ Err ( CheckErrors :: TypeError ( expected_type. clone ( ) , arg_type. clone ( ) ) . into ( ) ) ,
199
+ ) ;
200
+ }
201
+ ( cost, Ok ( None ) )
202
+ }
203
+ FunctionType :: ArithmeticVariadic => {
204
+ let cost = Some ( compute_typecheck_cost (
205
+ accounting,
206
+ & TypeSignature :: IntType ,
207
+ arg_type,
208
+ ) ) ;
209
+ if arg_index == 0 {
210
+ let return_type = match arg_type {
211
+ TypeSignature :: IntType => Ok ( Some ( TypeSignature :: IntType ) ) ,
212
+ TypeSignature :: UIntType => Ok ( Some ( TypeSignature :: UIntType ) ) ,
213
+ _ => Err ( CheckErrors :: UnionTypeError (
214
+ vec ! [ TypeSignature :: IntType , TypeSignature :: UIntType ] ,
215
+ arg_type. clone ( ) ,
216
+ )
217
+ . into ( ) ) ,
218
+ } ;
219
+ ( cost, return_type)
220
+ } else {
221
+ let return_type = accumulated_type
222
+ . ok_or_else ( || CheckErrors :: Expects ( "Failed to set accumulated type for arg indices >= 1 in variadic arithmetic" . into ( ) ) . into ( ) ) ;
223
+ let check_result = return_type. and_then ( |return_type| {
224
+ if arg_type != return_type {
225
+ Err (
226
+ CheckErrors :: TypeError ( return_type. clone ( ) , arg_type. clone ( ) )
227
+ . into ( ) ,
228
+ )
229
+ } else {
230
+ Ok ( None )
231
+ }
232
+ } ) ;
233
+ ( cost, check_result)
234
+ }
235
+ }
236
+ // For the fixed function types, the visitor will just
237
+ // tell the processor that any results greater than the args len
238
+ // do not need to be stored, because an error will occur before
239
+ // further checking anyways
240
+ FunctionType :: Fixed ( FixedFunction {
241
+ args : arg_types, ..
242
+ } ) => {
243
+ if arg_index >= arg_types. len ( ) {
244
+ // note: argument count will be wrong?
245
+ return (
246
+ None ,
247
+ Err ( CheckErrors :: IncorrectArgumentCount ( arg_types. len ( ) , arg_index) . into ( ) ) ,
248
+ ) ;
249
+ }
250
+ return ( None , Ok ( None ) ) ;
251
+ }
252
+ // For the following function types, the visitor will just
253
+ // tell the processor that any results greater than len 1 or 2
254
+ // do not need to be stored, because an error will occur before
255
+ // further checking anyways
256
+ FunctionType :: ArithmeticUnary | FunctionType :: UnionArgs ( ..) => {
257
+ if arg_index >= 1 {
258
+ return (
259
+ None ,
260
+ Err ( CheckErrors :: IncorrectArgumentCount ( 1 , arg_index) . into ( ) ) ,
261
+ ) ;
262
+ }
263
+ return ( None , Ok ( None ) ) ;
264
+ }
265
+ FunctionType :: ArithmeticBinary
266
+ | FunctionType :: ArithmeticComparison
267
+ | FunctionType :: Binary ( ..) => {
268
+ if arg_index >= 2 {
269
+ return (
270
+ None ,
271
+ Err ( CheckErrors :: IncorrectArgumentCount ( 2 , arg_index) . into ( ) ) ,
272
+ ) ;
273
+ }
274
+ return ( None , Ok ( None ) ) ;
275
+ }
276
+ }
277
+ }
278
+
155
279
pub fn check_args_2_1 < T : CostTracker > (
156
280
& self ,
157
281
accounting : & mut T ,
@@ -1017,17 +1141,23 @@ impl<'a, 'b> TypeChecker<'a, 'b> {
1017
1141
args : & [ SymbolicExpression ] ,
1018
1142
context : & TypingContext ,
1019
1143
) -> TypeResult {
1020
- let mut types_returned = self . type_check_all ( args, context) ?;
1021
-
1022
- let last_return = types_returned
1023
- . pop ( )
1024
- . ok_or ( CheckError :: new ( CheckErrors :: CheckerImplementationFailure ) ) ?;
1025
-
1026
- for type_return in types_returned. iter ( ) {
1027
- if type_return. is_response_type ( ) {
1028
- return Err ( CheckErrors :: UncheckedIntermediaryResponses . into ( ) ) ;
1144
+ let mut last_return = None ;
1145
+ let mut return_failure = Ok ( ( ) ) ;
1146
+ for ix in 0 ..args. len ( ) {
1147
+ let type_return = self . type_check ( & args[ ix] , context) ?;
1148
+ if ix + 1 < args. len ( ) {
1149
+ if type_return. is_response_type ( ) {
1150
+ return_failure = Err ( CheckErrors :: UncheckedIntermediaryResponses ) ;
1151
+ }
1152
+ } else {
1153
+ last_return = Some ( type_return) ;
1029
1154
}
1030
1155
}
1156
+
1157
+ let last_return = last_return
1158
+ . ok_or_else ( || CheckError :: new ( CheckErrors :: CheckerImplementationFailure ) ) ?;
1159
+ return_failure?;
1160
+
1031
1161
Ok ( last_return)
1032
1162
}
1033
1163
@@ -1052,8 +1182,56 @@ impl<'a, 'b> TypeChecker<'a, 'b> {
1052
1182
epoch : StacksEpochId ,
1053
1183
clarity_version : ClarityVersion ,
1054
1184
) -> TypeResult {
1055
- let typed_args = self . type_check_all ( args, context) ?;
1056
- func_type. check_args ( self , & typed_args, epoch, clarity_version)
1185
+ if epoch <= StacksEpochId :: Epoch2_05 {
1186
+ let typed_args = self . type_check_all ( args, context) ?;
1187
+ return func_type. check_args ( self , & typed_args, epoch, clarity_version) ;
1188
+ }
1189
+ // use func_type visitor pattern
1190
+ let mut accumulated_type = None ;
1191
+ let mut total_costs = vec ! [ ] ;
1192
+ let mut check_result = Ok ( ( ) ) ;
1193
+ let mut accumulated_types = Vec :: new ( ) ;
1194
+ for ( arg_ix, arg_expr) in args. iter ( ) . enumerate ( ) {
1195
+ let arg_type = self . type_check ( arg_expr, context) ?;
1196
+ if check_result. is_ok ( ) {
1197
+ let ( costs, result) = func_type. check_args_visitor_2_1 (
1198
+ self ,
1199
+ & arg_type,
1200
+ arg_ix,
1201
+ accumulated_type. as_ref ( ) ,
1202
+ ) ;
1203
+ // add the accumulated type and total cost *before*
1204
+ // checking for an error: we want the subsequent error handling
1205
+ // to account for this cost
1206
+ accumulated_types. push ( arg_type) ;
1207
+ total_costs. extend ( costs) ;
1208
+
1209
+ match result {
1210
+ Ok ( Some ( returned_type) ) => {
1211
+ accumulated_type = Some ( returned_type) ;
1212
+ }
1213
+ Ok ( None ) => { }
1214
+ Err ( e) => {
1215
+ check_result = Err ( e) ;
1216
+ }
1217
+ } ;
1218
+ }
1219
+ }
1220
+ if let Err ( mut check_error) = check_result {
1221
+ if let CheckErrors :: IncorrectArgumentCount ( expected, _actual) = check_error. err {
1222
+ check_error. err = CheckErrors :: IncorrectArgumentCount ( expected, args. len ( ) ) ;
1223
+ check_error. diagnostic = Diagnostic :: err ( & check_error. err )
1224
+ }
1225
+ // accumulate the checking costs
1226
+ // the reason we do this now (instead of within the loop) is for backwards compatibility
1227
+ for cost in total_costs. into_iter ( ) {
1228
+ self . add_cost ( cost?) ?;
1229
+ }
1230
+
1231
+ return Err ( check_error) ;
1232
+ }
1233
+ // otherwise, just invoke the normal checking routine
1234
+ func_type. check_args ( self , & accumulated_types, epoch, clarity_version)
1057
1235
}
1058
1236
1059
1237
fn get_function_type ( & self , function_name : & str ) -> Option < FunctionType > {
0 commit comments