Skip to content

Commit 1161819

Browse files
committed
value: check if pointer is null in cass_value_get_*
cpp-driver does that check for some reason.
1 parent 376b28a commit 1161819

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

scylla-rust-wrapper/src/query_result.rs

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -915,12 +915,21 @@ pub unsafe extern "C" fn cass_value_data_type(value: *const CassValue) -> *const
915915
Arc::as_ptr(&value_from_raw.value_type)
916916
}
917917

918+
macro_rules! val_ptr_to_ref_ensure_non_null {
919+
($ptr:ident) => {{
920+
if $ptr.is_null() {
921+
return CassError::CASS_ERROR_LIB_NULL_VALUE;
922+
}
923+
ptr_to_ref($ptr)
924+
}};
925+
}
926+
918927
#[no_mangle]
919928
pub unsafe extern "C" fn cass_value_get_float(
920929
value: *const CassValue,
921930
output: *mut cass_float_t,
922931
) -> CassError {
923-
let val: &CassValue = ptr_to_ref(value);
932+
let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value);
924933
match val.value {
925934
Some(Value::RegularValue(CqlValue::Float(f))) => std::ptr::write(output, f),
926935
Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE,
@@ -935,7 +944,7 @@ pub unsafe extern "C" fn cass_value_get_double(
935944
value: *const CassValue,
936945
output: *mut cass_double_t,
937946
) -> CassError {
938-
let val: &CassValue = ptr_to_ref(value);
947+
let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value);
939948
match val.value {
940949
Some(Value::RegularValue(CqlValue::Double(d))) => std::ptr::write(output, d),
941950
Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE,
@@ -950,7 +959,7 @@ pub unsafe extern "C" fn cass_value_get_bool(
950959
value: *const CassValue,
951960
output: *mut cass_bool_t,
952961
) -> CassError {
953-
let val: &CassValue = ptr_to_ref(value);
962+
let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value);
954963
match val.value {
955964
Some(Value::RegularValue(CqlValue::Boolean(b))) => {
956965
std::ptr::write(output, b as cass_bool_t)
@@ -967,7 +976,7 @@ pub unsafe extern "C" fn cass_value_get_int8(
967976
value: *const CassValue,
968977
output: *mut cass_int8_t,
969978
) -> CassError {
970-
let val: &CassValue = ptr_to_ref(value);
979+
let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value);
971980
match val.value {
972981
Some(Value::RegularValue(CqlValue::TinyInt(i))) => std::ptr::write(output, i),
973982
Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE,
@@ -982,7 +991,7 @@ pub unsafe extern "C" fn cass_value_get_int16(
982991
value: *const CassValue,
983992
output: *mut cass_int16_t,
984993
) -> CassError {
985-
let val: &CassValue = ptr_to_ref(value);
994+
let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value);
986995
match val.value {
987996
Some(Value::RegularValue(CqlValue::SmallInt(i))) => std::ptr::write(output, i),
988997
Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE,
@@ -997,7 +1006,7 @@ pub unsafe extern "C" fn cass_value_get_uint32(
9971006
value: *const CassValue,
9981007
output: *mut cass_uint32_t,
9991008
) -> CassError {
1000-
let val: &CassValue = ptr_to_ref(value);
1009+
let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value);
10011010
match val.value {
10021011
Some(Value::RegularValue(CqlValue::Date(u))) => std::ptr::write(output, u.0), // FIXME: hack
10031012
Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE,
@@ -1012,7 +1021,7 @@ pub unsafe extern "C" fn cass_value_get_int32(
10121021
value: *const CassValue,
10131022
output: *mut cass_int32_t,
10141023
) -> CassError {
1015-
let val: &CassValue = ptr_to_ref(value);
1024+
let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value);
10161025
match val.value {
10171026
Some(Value::RegularValue(CqlValue::Int(i))) => std::ptr::write(output, i),
10181027
Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE,
@@ -1027,7 +1036,7 @@ pub unsafe extern "C" fn cass_value_get_int64(
10271036
value: *const CassValue,
10281037
output: *mut cass_int64_t,
10291038
) -> CassError {
1030-
let val: &CassValue = ptr_to_ref(value);
1039+
let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value);
10311040
match val.value {
10321041
Some(Value::RegularValue(CqlValue::BigInt(i))) => std::ptr::write(output, i),
10331042
Some(Value::RegularValue(CqlValue::Counter(i))) => {
@@ -1049,7 +1058,7 @@ pub unsafe extern "C" fn cass_value_get_uuid(
10491058
value: *const CassValue,
10501059
output: *mut CassUuid,
10511060
) -> CassError {
1052-
let val: &CassValue = ptr_to_ref(value);
1061+
let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value);
10531062
match val.value {
10541063
Some(Value::RegularValue(CqlValue::Uuid(uuid))) => std::ptr::write(output, uuid.into()),
10551064
Some(Value::RegularValue(CqlValue::Timeuuid(uuid))) => {
@@ -1067,7 +1076,7 @@ pub unsafe extern "C" fn cass_value_get_inet(
10671076
value: *const CassValue,
10681077
output: *mut CassInet,
10691078
) -> CassError {
1070-
let val: &CassValue = ptr_to_ref(value);
1079+
let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value);
10711080
match val.value {
10721081
Some(Value::RegularValue(CqlValue::Inet(inet))) => std::ptr::write(output, inet.into()),
10731082
Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE,
@@ -1083,7 +1092,7 @@ pub unsafe extern "C" fn cass_value_get_string(
10831092
output: *mut *const c_char,
10841093
output_size: *mut size_t,
10851094
) -> CassError {
1086-
let val: &CassValue = ptr_to_ref(value);
1095+
let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value);
10871096
match &val.value {
10881097
// It seems that cpp driver doesn't check the type - you can call _get_string
10891098
// on any type and get internal represenation. I don't see how to do it easily in
@@ -1109,7 +1118,7 @@ pub unsafe extern "C" fn cass_value_get_duration(
11091118
days: *mut cass_int32_t,
11101119
nanos: *mut cass_int64_t,
11111120
) -> CassError {
1112-
let val: &CassValue = ptr_to_ref(value);
1121+
let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value);
11131122

11141123
match &val.value {
11151124
Some(Value::RegularValue(CqlValue::Duration(duration))) => {
@@ -1130,11 +1139,7 @@ pub unsafe extern "C" fn cass_value_get_bytes(
11301139
output: *mut *const cass_byte_t,
11311140
output_size: *mut size_t,
11321141
) -> CassError {
1133-
if value.is_null() {
1134-
return CassError::CASS_ERROR_LIB_NULL_VALUE;
1135-
}
1136-
1137-
let value_from_raw: &CassValue = ptr_to_ref(value);
1142+
let value_from_raw: &CassValue = val_ptr_to_ref_ensure_non_null!(value);
11381143

11391144
// FIXME: This should be implemented for all CQL types
11401145
// Note: currently rust driver does not allow to get raw bytes of the CQL value.

0 commit comments

Comments
 (0)