Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public enum Vehicle {
case bicycle
case car(String)
case motorbike(String, horsePower: Int64)
indirect case transformer(front: Vehicle, back: Vehicle)

public init?(name: String) {
switch name {
Expand All @@ -31,24 +32,28 @@ public enum Vehicle {
case .bicycle: "bicycle"
case .car: "car"
case .motorbike: "motorbike"
case .transformer: "transformer"
}
}

public func isFasterThan(other: Vehicle) -> Bool {
switch (self, other) {
case (.bicycle, .bicycle), (.bicycle, .car), (.bicycle, .motorbike): false
case (.bicycle, .bicycle), (.bicycle, .car), (.bicycle, .motorbike), (.bicycle, .transformer): false
case (.car, .bicycle): true
case (.car, .motorbike), (.car, .car): false
case (.car, .motorbike), (.car, .transformer), (.car, .car): false
case (.motorbike, .bicycle), (.motorbike, .car): true
case (.motorbike, .motorbike): false
case (.motorbike, .motorbike), (.motorbike, .transformer): false
case (.transformer, .bicycle), (.transformer, .car), (.transformer, .motorbike): true
case (.transformer, .transformer): false
}
}

public mutating func upgrade() {
switch self {
case .bicycle: self = .car("Unknown")
case .car: self = .motorbike("Unknown", horsePower: 0)
case .motorbike: break
case .motorbike: self = .transformer(front: .car("BMW"), back: self)
case .transformer: break
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ public Vehicle.Motorbike java_copy(BenchmarkState state, Blackhole bh) {
Vehicle.Motorbike motorbike = state.vehicle.getAsMotorbike().orElseThrow();
bh.consume(motorbike.arg0());
bh.consume(motorbike.horsePower());

return motorbike;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ void getAsMotorbike() {
}
}

@Test
void getAsTransformer() {
try (var arena = new ConfinedSwiftMemorySession()) {
Vehicle vehicle = Vehicle.transformer(Vehicle.bicycle(arena), Vehicle.car("BMW", arena), arena);
Vehicle.Transformer transformer = vehicle.getAsTransformer(arena).orElseThrow();
assertTrue(transformer.front().getAsBicycle().isPresent());
assertEquals("BMW", transformer.back().getAsCar().orElseThrow().arg0());
}
}

@Test
void associatedValuesAreCopied() {
try (var arena = new ConfinedSwiftMemorySession()) {
Expand All @@ -127,4 +137,45 @@ void associatedValuesAreCopied() {
assertEquals("BMW", car.arg0());
}
}

@Test
void getDiscriminator() {
try (var arena = new ConfinedSwiftMemorySession()) {
assertEquals(Vehicle.Discriminator.BICYCLE, Vehicle.bicycle(arena).getDiscriminator());
assertEquals(Vehicle.Discriminator.CAR, Vehicle.car("BMW", arena).getDiscriminator());
assertEquals(Vehicle.Discriminator.MOTORBIKE, Vehicle.motorbike("Yamaha", 750, arena).getDiscriminator());
assertEquals(Vehicle.Discriminator.TRANSFORMER, Vehicle.transformer(Vehicle.bicycle(arena), Vehicle.bicycle(arena), arena).getDiscriminator());
}
}

@Test
void getCase() {
try (var arena = new ConfinedSwiftMemorySession()) {
Vehicle vehicle = Vehicle.bicycle(arena);
Vehicle.Case caseElement = vehicle.getCase(arena);
assertInstanceOf(Vehicle.Bicycle.class, caseElement);
}
}

@Test
void switchGetCase() {
try (var arena = new ConfinedSwiftMemorySession()) {
Vehicle vehicle = Vehicle.car("BMW", arena);
switch (vehicle.getCase(arena)) {
case Vehicle.Bicycle b:
fail("Was bicycle");
break;
case Vehicle.Car car:
assertEquals("BMW", car.arg0());
break;
case Vehicle.Motorbike motorbike:
fail("Was motorbike");
break;
case Vehicle.Transformer transformer:
fail("Was transformer");
break;
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,35 @@ extension JNISwift2JavaGenerator {
decl.cases.map { $0.name.uppercased() }.joined(separator: ",\n")
)
}

// TODO: Consider whether all of these "utility" functions can be printed using our existing printing logic.
printer.printBraceBlock("public Discriminator getDiscriminator()") { printer in
printer.print("return Discriminator.values()[$getDiscriminator(this.$memoryAddress())];")
}
printer.print("private static native int $getDiscriminator(long self);")
}

private func printEnumCaseInterface(_ printer: inout CodePrinter, _ decl: ImportedNominalType) {
printer.print("public sealed interface Case {}")
// TODO: Print `getCase()` method to allow for easy pattern matching.
printer.println()

let requiresSwiftArena = decl.cases.compactMap {
self.translatedEnumCase(for: $0)
}.contains(where: \.requiresSwiftArena)

printer.printBraceBlock("public Case getCase(\(requiresSwiftArena ? "SwiftArena swiftArena$" : ""))") { printer in
printer.print("Discriminator discriminator = this.getDiscriminator();")
printer.printBraceBlock("switch (discriminator)") { printer in
for enumCase in decl.cases {
guard let translatedCase = self.translatedEnumCase(for: enumCase) else {
continue
}
let arenaArgument = translatedCase.requiresSwiftArena ? "swiftArena$" : ""
printer.print("case \(enumCase.name.uppercased()): return this.getAs\(enumCase.name.firstCharacterUppercased)(\(arenaArgument)).orElseThrow();")
}
}
printer.print(#"throw new RuntimeException("Unknown discriminator value " + discriminator);"#)
}
}

private func printEnumStaticInitializers(_ printer: inout CodePrinter, _ decl: ImportedNominalType) {
Expand All @@ -233,40 +257,57 @@ extension JNISwift2JavaGenerator {

let caseName = enumCase.name.firstCharacterUppercased
let hasParameters = !enumCase.parameters.isEmpty
let requiresSwiftArena = translatedCase.requiresSwiftArena

// Print record
printer.printBraceBlock("public record \(caseName)(\(members.joined(separator: ", "))) implements Case") { printer in
if hasParameters {
let nativeResults = zip(translatedCase.translatedValues, translatedCase.conversions).map { value, conversion in
var nativeResults = zip(translatedCase.translatedValues, translatedCase.conversions).map { value, conversion in
"\(conversion.native.javaType) \(value.parameter.name)"
}
printer.print("record $NativeParameters(\(nativeResults.joined(separator: ", "))) {}")
printer.println()

printer.print(#"@SuppressWarnings("unused")"#)
printer.printBraceBlock("static \(caseName) fromJNI(\(nativeResults.joined(separator: ", ")))") { printer in
let swiftArenaParameter = requiresSwiftArena ? ", SwiftArena swiftArena$" : ""
printer.printBraceBlock("\(caseName)($NativeParameters parameters\(swiftArenaParameter))") { printer in
let memberValues = zip(translatedCase.translatedValues, translatedCase.conversions).map { (value, conversion) in
let result = conversion.translated.conversion.render(&printer, value.parameter.name)
let result = conversion.translated.conversion.render(&printer, "parameters.\(value.parameter.name)")
return result
}
printer.print("return new \(caseName)(\(memberValues.joined(separator: ", ")));")
printer.print("this(\(memberValues.joined(separator: ", ")));")
}
}
}

// TODO: Optimize when all values can just be passed directly, instead of going through "middle type"?

// Print method to get enum as case
printer.printBraceBlock("public Optional<\(caseName)> getAs\(caseName)()") { printer in
// TODO: Check that discriminator is OK
printer.printBraceBlock("public Optional<\(caseName)> getAs\(caseName)(\(requiresSwiftArena ? "SwiftArena swiftArena$" : ""))") { printer in
printer.print(
"""
if (this.getDiscriminator() != Discriminator.\(caseName.uppercased())) {
return Optional.empty();
}
"""
)
if hasParameters {
var arguments = ["$getAs\(caseName)(this.$memoryAddress())"]
if requiresSwiftArena {
arguments.append("swiftArena$")
}
printer.print(
"""
return Optional.of($getAs\(caseName)(this.$memoryAddress()));
return Optional.of(new \(caseName)(\(arguments.joined(separator: ", "))));
"""
)
} else {
printer.print("return Optional.of(new \(caseName)());")
}
}
printer.print("private static native \(caseName) $getAs\(caseName)(long self);")
if hasParameters {
printer.print("private static native \(caseName).$NativeParameters $getAs\(caseName)(long self);")
}

printer.println()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,27 +88,13 @@ extension JNISwift2JavaGenerator {
parentName: parentName
)

let conversions = try enumCase.parameters.map {
let result = SwiftResult(convention: .direct, type: $0.type)
let conversions = try enumCase.parameters.enumerated().map { idx, parameter in
let result = SwiftResult(convention: .direct, type: parameter.type)
let translatedResult = try self.translate(swiftResult: result)
let nativeResult = try nativeTranslation.translate(swiftResult: result)
let nativeResult = try nativeTranslation.translate(swiftResult: result, resultName: parameter.name ?? "arg\(idx)")
return (translatedResult, nativeResult)
}

// let nativeParameters = try nativeTranslation.translateParameters(
// enumCase.parameters.map {
// SwiftParameter(
// convention: .byValue,
// argumentLabel: $0.name,
// parameterName: $0.name,
// type: $0.type
// )
// },
// translatedParameters: translatedParameters,
// methodName: methodName,
// parentName: parentName
// )

return TranslatedEnumCase(
name: enumCase.name.firstCharacterUppercased,
enumName: enumCase.enumType.nominalTypeDecl.name,
Expand Down Expand Up @@ -572,6 +558,10 @@ extension JNISwift2JavaGenerator {
let translatedValues: [TranslatedParameter]

let conversions: [(translated: TranslatedResult, native: NativeResult)]

var requiresSwiftArena: Bool {
conversions.contains(where: \.translated.conversion.requiresSwiftArena)
}
}

struct TranslatedFunctionDecl {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,8 @@ extension JNISwift2JavaGenerator {
}

func translate(
swiftResult: SwiftResult
swiftResult: SwiftResult,
resultName: String = "result"
) throws -> NativeResult {
switch swiftResult.type {
case .nominal(let nominalType):
Expand Down Expand Up @@ -416,7 +417,7 @@ extension JNISwift2JavaGenerator {

return NativeResult(
javaType: .long,
conversion: .getJNIValue(.allocateSwiftValue(name: "result", swiftType: swiftResult.type)),
conversion: .getJNIValue(.allocateSwiftValue(name: resultName, swiftType: swiftResult.type)),
outParameters: []
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,14 @@ extension JNISwift2JavaGenerator {
printer.println()
}

for enumCase in type.cases {
printEnumCase(&printer, enumCase)
if type.swiftNominal.kind == .enum {
printEnumDiscriminator(&printer, type)
printer.println()

for enumCase in type.cases {
printEnumCase(&printer, enumCase)
printer.println()
}
}

for method in type.methods {
Expand All @@ -120,6 +125,28 @@ extension JNISwift2JavaGenerator {
printDestroyFunctionThunk(&printer, type)
}

private func printEnumDiscriminator(_ printer: inout CodePrinter, _ type: ImportedNominalType) {
let selfPointerParam = JavaParameter(name: "selfPointer", type: .long)
printCDecl(
&printer,
javaMethodName: "$getDiscriminator",
parentName: type.swiftNominal.name,
parameters: [selfPointerParam],
resultType: .int
) { printer in
let selfPointer = self.printSelfJLongToUnsafeMutablePointer(
&printer,
swiftParentName: type.swiftNominal.name,
selfPointerParam
)
printer.printBraceBlock("switch (\(selfPointer).pointee)") { printer in
for (idx, enumCase) in type.cases.enumerated() {
printer.print("case .\(enumCase.name): return \(idx)")
}
}
}
}

private func printEnumCase(_ printer: inout CodePrinter, _ enumCase: ImportedEnumCase) {
guard let translatedCase = self.translatedEnumCase(for: enumCase) else {
return
Expand All @@ -132,7 +159,8 @@ extension JNISwift2JavaGenerator {
// Print getAsCase method
if !enumCase.parameters.isEmpty {
let selfParameter = JavaParameter(name: "self", type: .long)
let resultType = JavaType.class(package: javaPackage, name: "\(translatedCase.enumName).\(translatedCase.name)")

let resultType = JavaType.class(package: javaPackage, name: "\(translatedCase.enumName).\(translatedCase.name).$NativeParameters")
printCDecl(
&printer,
javaMethodName: "$getAs\(translatedCase.name)",
Expand All @@ -149,15 +177,15 @@ extension JNISwift2JavaGenerator {
parameter.name ?? "_\(idx)"
}
let caseNamesWithLet = caseNames.map { "let \($0)" }
let methodSignature = MethodSignature(resultType: resultType, parameterTypes: translatedCase.conversions.map(\.native.javaType))
let methodSignature = MethodSignature(resultType: .void, parameterTypes: translatedCase.conversions.map(\.native.javaType))
// TODO: Caching of class and static method ID.
printer.print(
"""
guard case .\(enumCase.name)(\(caseNamesWithLet.joined(separator: ", "))) = \(selfPointer).pointee else {
fatalError("Expected enum case '\(enumCase.name)', but was '\\(self$.pointee)'!")
}
let recordClass$ = environment.interface.FindClass(environment, "\(javaPackagePath)/\(translatedCase.enumName)$\(translatedCase.name)")!
let fromJNIID$ = environment.interface.GetStaticMethodID(environment, recordClass$, "fromJNI", "\(methodSignature.mangledName)")!
let class$ = environment.interface.FindClass(environment, "\(javaPackagePath)/\(translatedCase.enumName)$\(translatedCase.name)$$NativeParameters")!
let constructorID$ = environment.interface.GetMethodID(environment, class$, "<init>", "\(methodSignature.mangledName)")!
"""
)

Expand All @@ -170,7 +198,7 @@ extension JNISwift2JavaGenerator {
printer.print(
"""
return withVaList([\(upcallArguments.joined(separator: ", "))]) {
return environment.interface.CallStaticObjectMethodV(environment, recordClass$, fromJNIID$, $0)
return environment.interface.NewObjectV(environment, class$, constructorID$, $0)
}
"""
)
Expand Down
Loading