Skip to content

Commit b27e22c

Browse files
committed
Fix union helper methods for fields of value type
1 parent 3e13e92 commit b27e22c

File tree

7 files changed

+70
-11
lines changed

7 files changed

+70
-11
lines changed

bindgen/ir/Struct.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -260,15 +260,8 @@ std::string Union::generateHelperClass() const {
260260
<< "(val p: native.Ptr[" << type << "]) extends AnyVal {\n";
261261
for (const auto &field : fields) {
262262
if (!field->getName().empty()) {
263-
std::string getter = handleReservedWords(field->getName());
264-
std::string setter = handleReservedWords(field->getName(), "_=");
265-
std::shared_ptr<const Type> ftype = field->getType();
266-
s << " def " << getter << ": native.Ptr[" << ftype->str()
267-
<< "] = p.cast[native.Ptr[" << ftype->str() << "]]\n";
268-
269-
s << " def " << setter << "(value: " << ftype->str()
270-
<< "): Unit = !p.cast[native.Ptr[" << ftype->str()
271-
<< "]] = value\n";
263+
s << generateGetter(field);
264+
s << generateSetter(field);
272265
}
273266
}
274267
s << " }\n";
@@ -288,3 +281,22 @@ bool Union::operator==(const Type &other) const {
288281
}
289282
return false;
290283
}
284+
285+
std::string Union::generateGetter(const std::shared_ptr<Field> &field) const {
286+
std::string getter = handleReservedWords(field->getName());
287+
std::string ftype = field->getType()->str();
288+
return " def " + getter + ": native.Ptr[" + ftype +
289+
"] = p.cast[native.Ptr[" + ftype + "]]\n";
290+
}
291+
292+
std::string Union::generateSetter(const std::shared_ptr<Field> &field) const {
293+
std::string setter = handleReservedWords(field->getName(), "_=");
294+
std::string ftype = field->getType()->str();
295+
if (isAliasForType<ArrayType>(field->getType().get()) ||
296+
isAliasForType<Struct>(field->getType().get())) {
297+
return " def " + setter + "(value: native.Ptr[" + ftype +
298+
"]): Unit = !p.cast[native.Ptr[" + ftype + "]] = !value\n";
299+
}
300+
return " def " + setter + "(value: " + ftype +
301+
"): Unit = !p.cast[native.Ptr[" + ftype + "]] = value\n";
302+
}

bindgen/ir/Struct.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ class Union : public StructOrUnion, public ArrayType {
114114
bool operator==(const Type &other) const override;
115115

116116
std::string getTypeAlias() const override;
117+
118+
private:
119+
std::string generateGetter(const std::shared_ptr<Field> &field) const;
120+
121+
std::string generateSetter(const std::shared_ptr<Field> &field) const;
117122
};
118123

119124
#endif // SCALA_NATIVE_BINDGEN_STRUCT_H

tests/samples/ReservedWords.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,6 @@ object ReservedWordsHelpers {
4949
def `forSome`: native.Ptr[`match`] = p.cast[native.Ptr[`match`]]
5050
def `forSome_=`(value: `match`): Unit = !p.cast[native.Ptr[`match`]] = value
5151
def `implicit`: native.Ptr[native.CArray[Byte, native.Nat.Digit[native.Nat._1, native.Nat._6]]] = p.cast[native.Ptr[native.CArray[Byte, native.Nat.Digit[native.Nat._1, native.Nat._6]]]]
52-
def `implicit_=`(value: native.CArray[Byte, native.Nat.Digit[native.Nat._1, native.Nat._6]]): Unit = !p.cast[native.Ptr[native.CArray[Byte, native.Nat.Digit[native.Nat._1, native.Nat._6]]]] = value
52+
def `implicit_=`(value: native.Ptr[native.CArray[Byte, native.Nat.Digit[native.Nat._1, native.Nat._6]]]): Unit = !p.cast[native.Ptr[native.CArray[Byte, native.Nat.Digit[native.Nat._1, native.Nat._6]]]] = !value
5353
}
5454
}

tests/samples/Union.c

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,14 @@ int union_test_string(union values *v, enum union_op op, const char *value) {
5252
case UNION_TEST:
5353
return v->s == value || !strcmp(v->s, value);
5454
}
55-
}
55+
}
56+
57+
int union_test_struct(union values *v, enum union_op op, struct s *value) {
58+
switch (op) {
59+
case UNION_SET:
60+
v->structInUnion.a = value->a;
61+
return 1;
62+
case UNION_TEST:
63+
return v->structInUnion.a == value->a;
64+
}
65+
}

tests/samples/Union.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
struct s {
2+
int a;
3+
};
4+
15
union values {
26
long l;
37
int i;
48
long long ll;
59
double d;
610
const char *s;
11+
struct s structInUnion;
712
};
813

914
enum union_op { UNION_SET, UNION_TEST };
@@ -15,3 +20,4 @@ int union_test_long(union values *v, enum union_op op, long value);
1520
int union_test_long_long(union values *v, enum union_op op, long long value);
1621
int union_test_double(union values *v, enum union_op op, double value);
1722
int union_test_string(union values *v, enum union_op op, const char *value);
23+
int union_test_struct(union values *v, enum union_op op, struct s *value);

tests/samples/Union.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import scala.scalanative.native._
66
@native.link("bindgentests")
77
@native.extern
88
object Union {
9+
type struct_s = native.CStruct1[native.CInt]
910
type union_values = native.CArray[Byte, native.Nat._8]
1011
type enum_union_op = native.CUnsignedInt
1112
def union_get_sizeof(): native.CInt = native.extern
@@ -14,6 +15,7 @@ object Union {
1415
def union_test_long_long(v: native.Ptr[union_values], op: enum_union_op, value: native.CLongLong): native.CInt = native.extern
1516
def union_test_double(v: native.Ptr[union_values], op: enum_union_op, value: native.CDouble): native.CInt = native.extern
1617
def union_test_string(v: native.Ptr[union_values], op: enum_union_op, value: native.CString): native.CInt = native.extern
18+
def union_test_struct(v: native.Ptr[union_values], op: enum_union_op, value: native.Ptr[struct_s]): native.CInt = native.extern
1719
}
1820

1921
import Union._
@@ -25,6 +27,13 @@ object UnionEnums {
2527

2628
object UnionHelpers {
2729

30+
implicit class struct_s_ops(val p: native.Ptr[struct_s]) extends AnyVal {
31+
def a: native.CInt = !p._1
32+
def a_=(value: native.CInt): Unit = !p._1 = value
33+
}
34+
35+
def struct_s()(implicit z: native.Zone): native.Ptr[struct_s] = native.alloc[struct_s]
36+
2837
implicit class union_values_pos(val p: native.Ptr[union_values]) extends AnyVal {
2938
def l: native.Ptr[native.CLong] = p.cast[native.Ptr[native.CLong]]
3039
def l_=(value: native.CLong): Unit = !p.cast[native.Ptr[native.CLong]] = value
@@ -36,5 +45,7 @@ object UnionHelpers {
3645
def d_=(value: native.CDouble): Unit = !p.cast[native.Ptr[native.CDouble]] = value
3746
def s: native.Ptr[native.CString] = p.cast[native.Ptr[native.CString]]
3847
def s_=(value: native.CString): Unit = !p.cast[native.Ptr[native.CString]] = value
48+
def structInUnion: native.Ptr[struct_s] = p.cast[native.Ptr[struct_s]]
49+
def structInUnion_=(value: native.Ptr[struct_s]): Unit = !p.cast[native.Ptr[struct_s]] = !value
3950
}
4051
}

tests/samples/src/test/scala/org/scalanative/bindgen/samples/UnionTests.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ object UnionTests extends TestSuite {
5656
UnionEnums.enum_union_op_UNION_SET,
5757
null)
5858
assert(!unionPtr.s == null)
59+
60+
val struct = alloc[Union.struct_s]
61+
struct.a = 10
62+
Union.union_test_struct(unionPtr,
63+
UnionEnums.enum_union_op_UNION_SET,
64+
struct)
65+
assert(unionPtr.structInUnion.a == 10)
5966
}
6067
}
6168

@@ -112,6 +119,14 @@ object UnionTests extends TestSuite {
112119
Union.union_test_string(unionPtr,
113120
UnionEnums.enum_union_op_UNION_TEST,
114121
null) == 1)
122+
123+
val struct = alloc[Union.struct_s]
124+
struct.a = 10
125+
unionPtr.structInUnion = struct
126+
assert(
127+
Union.union_test_struct(unionPtr,
128+
UnionEnums.enum_union_op_UNION_TEST,
129+
struct) == 1)
115130
}
116131
}
117132
}

0 commit comments

Comments
 (0)