diff --git a/Samples/JExtractJNISampleApp/Sources/MySwiftLibrary/Alignment.swift b/Samples/JExtractJNISampleApp/Sources/MySwiftLibrary/Alignment.swift new file mode 100644 index 00000000..760c564b --- /dev/null +++ b/Samples/JExtractJNISampleApp/Sources/MySwiftLibrary/Alignment.swift @@ -0,0 +1,18 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2025 Apple Inc. and the Swift.org project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of Swift.org project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +public enum Alignment: String { + case horizontal + case vertical +} diff --git a/Samples/JExtractJNISampleApp/Sources/MySwiftLibrary/Vehicle.swift b/Samples/JExtractJNISampleApp/Sources/MySwiftLibrary/Vehicle.swift new file mode 100644 index 00000000..9d9155ad --- /dev/null +++ b/Samples/JExtractJNISampleApp/Sources/MySwiftLibrary/Vehicle.swift @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2025 Apple Inc. and the Swift.org project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of Swift.org project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +public enum Vehicle { + case bicycle + case car(String, trailer: String?) + case motorbike(String, horsePower: Int64, helmets: Int32?) + indirect case transformer(front: Vehicle, back: Vehicle) + case boat(passengers: Int32?, length: Int16?) + + public init?(name: String) { + switch name { + case "bicycle": self = .bicycle + case "car": self = .car("Unknown", trailer: nil) + case "motorbike": self = .motorbike("Unknown", horsePower: 0, helmets: nil) + case "boat": self = .boat(passengers: nil, length: nil) + default: return nil + } + } + + public var name: String { + switch self { + case .bicycle: "bicycle" + case .car: "car" + case .motorbike: "motorbike" + case .transformer: "transformer" + case .boat: "boat" + } + } + + public func isFasterThan(other: Vehicle) -> Bool { + switch (self, other) { + case (.bicycle, .bicycle), (.bicycle, .car), (.bicycle, .motorbike), (.bicycle, .transformer): false + case (.car, .bicycle): true + case (.car, .motorbike), (.car, .transformer), (.car, .car): false + case (.motorbike, .bicycle), (.motorbike, .car): true + case (.motorbike, .motorbike), (.motorbike, .transformer): false + case (.transformer, .bicycle), (.transformer, .car), (.transformer, .motorbike): true + case (.transformer, .transformer): false + default: false + } + } + + public mutating func upgrade() { + switch self { + case .bicycle: self = .car("Unknown", trailer: nil) + case .car: self = .motorbike("Unknown", horsePower: 0, helmets: nil) + case .motorbike: self = .transformer(front: .car("BMW", trailer: nil), back: self) + case .transformer, .boat: break + } + } +} diff --git a/Samples/JExtractJNISampleApp/Sources/MySwiftLibrary/swift-java.config b/Samples/JExtractJNISampleApp/Sources/MySwiftLibrary/swift-java.config index bb637f34..be44c2fd 100644 --- a/Samples/JExtractJNISampleApp/Sources/MySwiftLibrary/swift-java.config +++ b/Samples/JExtractJNISampleApp/Sources/MySwiftLibrary/swift-java.config @@ -1,4 +1,5 @@ { "javaPackage": "com.example.swift", "mode": "jni", + "logLevel": ["debug"] } diff --git a/Samples/JExtractJNISampleApp/src/jmh/java/com/example/swift/EnumBenchmark.java b/Samples/JExtractJNISampleApp/src/jmh/java/com/example/swift/EnumBenchmark.java new file mode 100644 index 00000000..d3de624b --- /dev/null +++ b/Samples/JExtractJNISampleApp/src/jmh/java/com/example/swift/EnumBenchmark.java @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2025 Apple Inc. and the Swift.org project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of Swift.org project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +package com.example.swift; + +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; +import org.swift.swiftkit.core.ClosableSwiftArena; +import org.swift.swiftkit.core.ConfinedSwiftMemorySession; +import org.swift.swiftkit.core.SwiftArena; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.concurrent.TimeUnit; + +@BenchmarkMode(Mode.AverageTime) +@Warmup(iterations = 5, time = 200, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@Fork(value = 3, jvmArgsAppend = { "--enable-native-access=ALL-UNNAMED" }) +public class EnumBenchmark { + + @State(Scope.Benchmark) + public static class BenchmarkState { + ClosableSwiftArena arena; + Vehicle vehicle; + + @Setup(Level.Trial) + public void beforeAll() { + arena = SwiftArena.ofConfined(); + vehicle = Vehicle.motorbike("Yamaha", 900, OptionalInt.empty(), arena); + } + + @TearDown(Level.Trial) + public void afterAll() { + arena.close(); + } + } + + @Benchmark + public Vehicle.Motorbike getAssociatedValues(BenchmarkState state, Blackhole bh) { + Vehicle.Motorbike motorbike = state.vehicle.getAsMotorbike().orElseThrow(); + bh.consume(motorbike.arg0()); + bh.consume(motorbike.horsePower()); + return motorbike; + } +} \ No newline at end of file diff --git a/Samples/JExtractJNISampleApp/src/test/java/com/example/swift/AlignmentEnumTest.java b/Samples/JExtractJNISampleApp/src/test/java/com/example/swift/AlignmentEnumTest.java new file mode 100644 index 00000000..6be85c75 --- /dev/null +++ b/Samples/JExtractJNISampleApp/src/test/java/com/example/swift/AlignmentEnumTest.java @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2024 Apple Inc. and the Swift.org project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of Swift.org project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +package com.example.swift; + +import org.junit.jupiter.api.Test; +import org.swift.swiftkit.core.ConfinedSwiftMemorySession; +import org.swift.swiftkit.core.SwiftArena; + +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.*; + +public class AlignmentEnumTest { + @Test + void rawValue() { + try (var arena = SwiftArena.ofConfined()) { + Optional invalid = Alignment.init("invalid", arena); + assertFalse(invalid.isPresent()); + + Optional horizontal = Alignment.init("horizontal", arena); + assertTrue(horizontal.isPresent()); + assertEquals("horizontal", horizontal.get().getRawValue()); + + Optional vertical = Alignment.init("vertical", arena); + assertTrue(vertical.isPresent()); + assertEquals("vertical", vertical.get().getRawValue()); + } + } +} \ No newline at end of file diff --git a/Samples/JExtractJNISampleApp/src/test/java/com/example/swift/VehicleEnumTest.java b/Samples/JExtractJNISampleApp/src/test/java/com/example/swift/VehicleEnumTest.java new file mode 100644 index 00000000..c533603a --- /dev/null +++ b/Samples/JExtractJNISampleApp/src/test/java/com/example/swift/VehicleEnumTest.java @@ -0,0 +1,206 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2024 Apple Inc. and the Swift.org project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of Swift.org project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +package com.example.swift; + +import org.junit.jupiter.api.Test; +import org.swift.swiftkit.core.ConfinedSwiftMemorySession; +import org.swift.swiftkit.core.SwiftArena; + +import java.lang.foreign.Arena; +import java.util.Optional; +import java.util.OptionalInt; + +import static org.junit.jupiter.api.Assertions.*; + +public class VehicleEnumTest { + @Test + void bicycle() { + try (var arena = SwiftArena.ofConfined()) { + Vehicle vehicle = Vehicle.bicycle(arena); + assertNotNull(vehicle); + } + } + + @Test + void car() { + try (var arena = SwiftArena.ofConfined()) { + Vehicle vehicle = Vehicle.car("Porsche 911", Optional.empty(), arena); + assertNotNull(vehicle); + } + } + + @Test + void motorbike() { + try (var arena = SwiftArena.ofConfined()) { + Vehicle vehicle = Vehicle.motorbike("Yamaha", 750, OptionalInt.empty(), arena); + assertNotNull(vehicle); + } + } + + @Test + void initName() { + try (var arena = SwiftArena.ofConfined()) { + assertFalse(Vehicle.init("bus", arena).isPresent()); + Optional vehicle = Vehicle.init("car", arena); + assertTrue(vehicle.isPresent()); + assertNotNull(vehicle.get()); + } + } + + @Test + void nameProperty() { + try (var arena = SwiftArena.ofConfined()) { + Vehicle vehicle = Vehicle.bicycle(arena); + assertEquals("bicycle", vehicle.getName()); + } + } + + @Test + void isFasterThan() { + try (var arena = SwiftArena.ofConfined()) { + Vehicle bicycle = Vehicle.bicycle(arena); + Vehicle car = Vehicle.car("Porsche 911", Optional.empty(), arena); + assertFalse(bicycle.isFasterThan(car)); + assertTrue(car.isFasterThan(bicycle)); + } + } + + @Test + void upgrade() { + try (var arena = SwiftArena.ofConfined()) { + Vehicle vehicle = Vehicle.bicycle(arena); + assertEquals("bicycle", vehicle.getName()); + vehicle.upgrade(); + assertEquals("car", vehicle.getName()); + vehicle.upgrade(); + assertEquals("motorbike", vehicle.getName()); + } + } + + @Test + void getAsBicycle() { + try (var arena = SwiftArena.ofConfined()) { + Vehicle vehicle = Vehicle.bicycle(arena); + Vehicle.Bicycle bicycle = vehicle.getAsBicycle().orElseThrow(); + assertNotNull(bicycle); + } + } + + @Test + void getAsCar() { + try (var arena = SwiftArena.ofConfined()) { + Vehicle vehicle = Vehicle.car("BMW", Optional.empty(), arena); + Vehicle.Car car = vehicle.getAsCar().orElseThrow(); + assertEquals("BMW", car.arg0()); + + vehicle = Vehicle.car("BMW", Optional.of("Long trailer"), arena); + car = vehicle.getAsCar().orElseThrow(); + assertEquals("Long trailer", car.trailer().orElseThrow()); + } + } + + @Test + void getAsMotorbike() { + try (var arena = SwiftArena.ofConfined()) { + Vehicle vehicle = Vehicle.motorbike("Yamaha", 750, OptionalInt.empty(), arena); + Vehicle.Motorbike motorbike = vehicle.getAsMotorbike().orElseThrow(); + assertEquals("Yamaha", motorbike.arg0()); + assertEquals(750, motorbike.horsePower()); + assertEquals(OptionalInt.empty(), motorbike.helmets()); + + vehicle = Vehicle.motorbike("Yamaha", 750, OptionalInt.of(2), arena); + motorbike = vehicle.getAsMotorbike().orElseThrow(); + assertEquals(OptionalInt.of(2), motorbike.helmets()); + } + } + + @Test + void getAsTransformer() { + try (var arena = SwiftArena.ofConfined()) { + Vehicle vehicle = Vehicle.transformer(Vehicle.bicycle(arena), Vehicle.car("BMW", Optional.empty(), arena), arena); + Vehicle.Transformer transformer = vehicle.getAsTransformer(arena).orElseThrow(); + assertTrue(transformer.front().getAsBicycle().isPresent()); + assertEquals("BMW", transformer.back().getAsCar().orElseThrow().arg0()); + } + } + + @Test + void getAsBoat() { + try (var arena = SwiftArena.ofConfined()) { + Vehicle vehicle = Vehicle.boat(OptionalInt.of(10), Optional.of((short) 1), arena); + Vehicle.Boat boat = vehicle.getAsBoat().orElseThrow(); + assertEquals(OptionalInt.of(10), boat.passengers()); + assertEquals(Optional.of((short) 1), boat.length()); + } + } + + @Test + void associatedValuesAreCopied() { + try (var arena = SwiftArena.ofConfined()) { + Vehicle vehicle = Vehicle.car("BMW", Optional.empty(), arena); + Vehicle.Car car = vehicle.getAsCar().orElseThrow(); + assertEquals("BMW", car.arg0()); + vehicle.upgrade(); + Vehicle.Motorbike motorbike = vehicle.getAsMotorbike().orElseThrow(); + assertNotNull(motorbike); + // Motorbike should still remain + assertEquals("BMW", car.arg0()); + } + } + + @Test + void getDiscriminator() { + try (var arena = SwiftArena.ofConfined()) { + assertEquals(Vehicle.Discriminator.BICYCLE, Vehicle.bicycle(arena).getDiscriminator()); + assertEquals(Vehicle.Discriminator.CAR, Vehicle.car("BMW", Optional.empty(), arena).getDiscriminator()); + assertEquals(Vehicle.Discriminator.MOTORBIKE, Vehicle.motorbike("Yamaha", 750, OptionalInt.empty(), arena).getDiscriminator()); + assertEquals(Vehicle.Discriminator.TRANSFORMER, Vehicle.transformer(Vehicle.bicycle(arena), Vehicle.bicycle(arena), arena).getDiscriminator()); + } + } + + @Test + void getCase() { + try (var arena = SwiftArena.ofConfined()) { + Vehicle vehicle = Vehicle.bicycle(arena); + Vehicle.Case caseElement = vehicle.getCase(arena); + assertInstanceOf(Vehicle.Bicycle.class, caseElement); + } + } + + @Test + void switchGetCase() { + try (var arena = SwiftArena.ofConfined()) { + Vehicle vehicle = Vehicle.car("BMW", Optional.empty(), arena); + switch (vehicle.getCase(arena)) { + case Vehicle.Bicycle b: + fail("Was bicycle"); + break; + case Vehicle.Car car: + assertEquals("BMW", car.arg0()); + break; + case Vehicle.Motorbike motorbike: + fail("Was motorbike"); + break; + case Vehicle.Transformer transformer: + fail("Was transformer"); + break; + case Vehicle.Boat b: + fail("Was boat"); + break; + } + } + } + +} \ No newline at end of file diff --git a/Sources/JExtractSwiftLib/Convenience/String+Extensions.swift b/Sources/JExtractSwiftLib/Convenience/String+Extensions.swift index 0f5cfeac..25c46366 100644 --- a/Sources/JExtractSwiftLib/Convenience/String+Extensions.swift +++ b/Sources/JExtractSwiftLib/Convenience/String+Extensions.swift @@ -24,6 +24,14 @@ extension String { return "\(f.uppercased())\(String(dropFirst()))" } + var firstCharacterLowercased: String { + guard let f = first else { + return self + } + + return "\(f.lowercased())\(String(dropFirst()))" + } + /// Returns whether the string is of the format `isX` var hasJavaBooleanNamingConvention: Bool { guard self.hasPrefix("is"), self.count > 2 else { diff --git a/Sources/JExtractSwiftLib/Convenience/SwiftSyntax+Extensions.swift b/Sources/JExtractSwiftLib/Convenience/SwiftSyntax+Extensions.swift index e71300af..3719d99d 100644 --- a/Sources/JExtractSwiftLib/Convenience/SwiftSyntax+Extensions.swift +++ b/Sources/JExtractSwiftLib/Convenience/SwiftSyntax+Extensions.swift @@ -253,6 +253,8 @@ extension DeclSyntaxProtocol { } )) .triviaSanitizedDescription + case .enumCaseDecl(let node): + node.triviaSanitizedDescription default: fatalError("unimplemented \(self.kind)") } diff --git a/Sources/JExtractSwiftLib/FFM/CDeclLowering/FFMSwift2JavaGenerator+FunctionLowering.swift b/Sources/JExtractSwiftLib/FFM/CDeclLowering/FFMSwift2JavaGenerator+FunctionLowering.swift index 085a6331..0a61708e 100644 --- a/Sources/JExtractSwiftLib/FFM/CDeclLowering/FFMSwift2JavaGenerator+FunctionLowering.swift +++ b/Sources/JExtractSwiftLib/FFM/CDeclLowering/FFMSwift2JavaGenerator+FunctionLowering.swift @@ -835,6 +835,10 @@ extension LoweredFunctionSignature { case .setter: assert(paramExprs.count == 1) resultExpr = "\(callee) = \(paramExprs[0])" + + case .enumCase: + // This should not be called, but let's fatalError. + fatalError("Enum cases are not supported with FFM.") } // Lower the result. diff --git a/Sources/JExtractSwiftLib/FFM/FFMSwift2JavaGenerator+JavaTranslation.swift b/Sources/JExtractSwiftLib/FFM/FFMSwift2JavaGenerator+JavaTranslation.swift index 06a4d171..57f947a8 100644 --- a/Sources/JExtractSwiftLib/FFM/FFMSwift2JavaGenerator+JavaTranslation.swift +++ b/Sources/JExtractSwiftLib/FFM/FFMSwift2JavaGenerator+JavaTranslation.swift @@ -146,7 +146,7 @@ extension FFMSwift2JavaGenerator { let javaName = switch decl.apiKind { case .getter: decl.javaGetterName case .setter: decl.javaSetterName - case .function, .initializer: decl.name + case .function, .initializer, .enumCase: decl.name } // Signature. diff --git a/Sources/JExtractSwiftLib/ImportedDecls.swift b/Sources/JExtractSwiftLib/ImportedDecls.swift index 4c4d9c0f..5655af00 100644 --- a/Sources/JExtractSwiftLib/ImportedDecls.swift +++ b/Sources/JExtractSwiftLib/ImportedDecls.swift @@ -22,6 +22,7 @@ package enum SwiftAPIKind { case initializer case getter case setter + case enumCase } /// Describes a Swift nominal type (e.g., a class, struct, enum) that has been @@ -32,6 +33,7 @@ package class ImportedNominalType: ImportedDecl { package var initializers: [ImportedFunc] = [] package var methods: [ImportedFunc] = [] package var variables: [ImportedFunc] = [] + package var cases: [ImportedEnumCase] = [] init(swiftNominal: SwiftNominalTypeDeclaration) { self.swiftNominal = swiftNominal @@ -46,6 +48,56 @@ package class ImportedNominalType: ImportedDecl { } } +public final class ImportedEnumCase: ImportedDecl, CustomStringConvertible { + /// The case name + public var name: String + + /// The enum parameters + var parameters: [SwiftEnumCaseParameter] + + var swiftDecl: any DeclSyntaxProtocol + + var enumType: SwiftNominalType + + /// A function that represents the Swift static "initializer" for cases + var caseFunction: ImportedFunc + + init( + name: String, + parameters: [SwiftEnumCaseParameter], + swiftDecl: any DeclSyntaxProtocol, + enumType: SwiftNominalType, + caseFunction: ImportedFunc + ) { + self.name = name + self.parameters = parameters + self.swiftDecl = swiftDecl + self.enumType = enumType + self.caseFunction = caseFunction + } + + public var description: String { + """ + ImportedEnumCase { + name: \(name), + parameters: \(parameters), + swiftDecl: \(swiftDecl), + enumType: \(enumType), + caseFunction: \(caseFunction) + } + """ + } +} + +extension ImportedEnumCase: Hashable { + public func hash(into hasher: inout Hasher) { + hasher.combine(ObjectIdentifier(self)) + } + public static func == (lhs: ImportedEnumCase, rhs: ImportedEnumCase) -> Bool { + return lhs === rhs + } +} + public final class ImportedFunc: ImportedDecl, CustomStringConvertible { /// Swift module name (e.g. the target name where a type or function was declared) public var module: String @@ -113,6 +165,7 @@ public final class ImportedFunc: ImportedDecl, CustomStringConvertible { let prefix = switch self.apiKind { case .getter: "getter:" case .setter: "setter:" + case .enumCase: "case:" case .function, .initializer: "" } diff --git a/Sources/JExtractSwiftLib/JNI/JNICaching.swift b/Sources/JExtractSwiftLib/JNI/JNICaching.swift new file mode 100644 index 00000000..cdb13e3f --- /dev/null +++ b/Sources/JExtractSwiftLib/JNI/JNICaching.swift @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2025 Apple Inc. and the Swift.org project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of Swift.org project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +enum JNICaching { + static func cacheName(for type: ImportedNominalType) -> String { + cacheName(for: type.swiftNominal.name) + } + + static func cacheName(for type: SwiftNominalType) -> String { + cacheName(for: type.nominalTypeDecl.name) + } + + private static func cacheName(for name: String) -> String { + "_JNI_\(name)" + } + + static func cacheMemberName(for enumCase: ImportedEnumCase) -> String { + "\(enumCase.enumType.nominalTypeDecl.name.firstCharacterLowercased)\(enumCase.name.firstCharacterUppercased)Cache" + } +} diff --git a/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+JavaBindingsPrinting.swift b/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+JavaBindingsPrinting.swift index 6d82175f..9351252e 100644 --- a/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+JavaBindingsPrinting.swift +++ b/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+JavaBindingsPrinting.swift @@ -138,6 +138,11 @@ extension JNISwift2JavaGenerator { printer.println() + if decl.swiftNominal.kind == .enum { + printEnumHelpers(&printer, decl) + printer.println() + } + for initializer in decl.initializers { printFunctionDowncallMethods(&printer, initializer) printer.println() @@ -200,6 +205,85 @@ extension JNISwift2JavaGenerator { } } + private func printEnumHelpers(_ printer: inout CodePrinter, _ decl: ImportedNominalType) { + printEnumDiscriminator(&printer, decl) + printer.println() + printEnumCaseInterface(&printer, decl) + printer.println() + printEnumStaticInitializers(&printer, decl) + printer.println() + printEnumCases(&printer, decl) + } + + private func printEnumDiscriminator(_ printer: inout CodePrinter, _ decl: ImportedNominalType) { + printer.printBraceBlock("public enum Discriminator") { printer in + printer.print( + decl.cases.map { $0.name.uppercased() }.joined(separator: ",\n") + ) + } + + // TODO: Consider whether all of these "utility" functions can be printed using our existing printing logic. + printer.printBraceBlock("public Discriminator getDiscriminator()") { printer in + printer.print("return Discriminator.values()[$getDiscriminator(this.$memoryAddress())];") + } + printer.print("private static native int $getDiscriminator(long self);") + } + + private func printEnumCaseInterface(_ printer: inout CodePrinter, _ decl: ImportedNominalType) { + printer.print("public sealed interface Case {}") + printer.println() + + let requiresSwiftArena = decl.cases.compactMap { + self.translatedEnumCase(for: $0) + }.contains(where: \.requiresSwiftArena) + + printer.printBraceBlock("public Case getCase(\(requiresSwiftArena ? "SwiftArena swiftArena$" : ""))") { printer in + printer.print("Discriminator discriminator = this.getDiscriminator();") + printer.printBraceBlock("switch (discriminator)") { printer in + for enumCase in decl.cases { + guard let translatedCase = self.translatedEnumCase(for: enumCase) else { + continue + } + let arenaArgument = translatedCase.requiresSwiftArena ? "swiftArena$" : "" + printer.print("case \(enumCase.name.uppercased()): return this.getAs\(enumCase.name.firstCharacterUppercased)(\(arenaArgument)).orElseThrow();") + } + } + printer.print(#"throw new RuntimeException("Unknown discriminator value " + discriminator);"#) + } + } + + private func printEnumStaticInitializers(_ printer: inout CodePrinter, _ decl: ImportedNominalType) { + for enumCase in decl.cases { + printFunctionDowncallMethods(&printer, enumCase.caseFunction) + } + } + + private func printEnumCases(_ printer: inout CodePrinter, _ decl: ImportedNominalType) { + for enumCase in decl.cases { + guard let translatedCase = self.translatedEnumCase(for: enumCase) else { + return + } + + let members = translatedCase.translatedValues.map { + $0.parameter.renderParameter() + } + + let caseName = enumCase.name.firstCharacterUppercased + + // Print record + printer.printBraceBlock("public record \(caseName)(\(members.joined(separator: ", "))) implements Case") { printer in + let nativeParameters = zip(translatedCase.translatedValues, translatedCase.parameterConversions).flatMap { value, conversion in + ["\(conversion.native.javaType) \(value.parameter.name)"] + } + + printer.print("record $NativeParameters(\(nativeParameters.joined(separator: ", "))) {}") + } + + self.printJavaBindingWrapperMethod(&printer, translatedCase.getAsCaseFunction) + printer.println() + } + } + private func printFunctionDowncallMethods( _ printer: inout CodePrinter, _ decl: ImportedFunc @@ -260,17 +344,23 @@ extension JNISwift2JavaGenerator { guard let translatedDecl = translatedDecl(for: decl) else { fatalError("Decl was not translated, \(decl)") } - let translatedSignature = translatedDecl.translatedFunctionSignature + printJavaBindingWrapperMethod(&printer, translatedDecl, importedFunc: decl) + } + private func printJavaBindingWrapperMethod( + _ printer: inout CodePrinter, + _ translatedDecl: TranslatedFunctionDecl, + importedFunc: ImportedFunc? = nil + ) { var modifiers = ["public"] - - if decl.isStatic || decl.isInitializer || !decl.hasParent { + if translatedDecl.isStatic { modifiers.append("static") } + let translatedSignature = translatedDecl.translatedFunctionSignature let resultType = translatedSignature.resultType.javaType var parameters = translatedDecl.translatedFunctionSignature.parameters.map { $0.parameter.renderParameter() } - let throwsClause = decl.isThrowing ? " throws Exception" : "" + let throwsClause = translatedDecl.isThrowing ? " throws Exception" : "" var annotationsStr = translatedSignature.annotations.map({ $0.render() }).joined(separator: "\n") if !annotationsStr.isEmpty { annotationsStr += "\n" } @@ -279,7 +369,9 @@ extension JNISwift2JavaGenerator { // Print default global arena variation if config.effectiveMemoryManagementMode.requiresGlobalArena && translatedSignature.requiresSwiftArena { - printDeclDocumentation(&printer, decl) + if let importedFunc { + printDeclDocumentation(&printer, importedFunc) + } printer.printBraceBlock( "\(annotationsStr)\(modifiers.joined(separator: " ")) \(resultType) \(translatedDecl.name)(\(parametersStr))\(throwsClause)" ) { printer in @@ -298,18 +390,19 @@ extension JNISwift2JavaGenerator { if translatedSignature.requiresSwiftArena { parameters.append("SwiftArena swiftArena$") } - printDeclDocumentation(&printer, decl) + if let importedFunc { + printDeclDocumentation(&printer, importedFunc) + } printer.printBraceBlock( "\(annotationsStr)\(modifiers.joined(separator: " ")) \(resultType) \(translatedDecl.name)(\(parameters.joined(separator: ", ")))\(throwsClause)" ) { printer in - printDowncall(&printer, decl) + printDowncall(&printer, translatedDecl) } - printNativeFunction(&printer, decl) + printNativeFunction(&printer, translatedDecl) } - private func printNativeFunction(_ printer: inout CodePrinter, _ decl: ImportedFunc) { - let translatedDecl = translatedDecl(for: decl)! // Will always call with valid decl + private func printNativeFunction(_ printer: inout CodePrinter, _ translatedDecl: TranslatedFunctionDecl) { let nativeSignature = translatedDecl.nativeFunctionSignature let resultType = nativeSignature.result.javaType var parameters = nativeSignature.parameters.flatMap(\.parameters) @@ -327,9 +420,8 @@ extension JNISwift2JavaGenerator { private func printDowncall( _ printer: inout CodePrinter, - _ decl: ImportedFunc + _ translatedDecl: TranslatedFunctionDecl ) { - let translatedDecl = translatedDecl(for: decl)! // We will only call this method if we can translate the decl. let translatedFunctionSignature = translatedDecl.translatedFunctionSignature // Regular parameters. diff --git a/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+JavaTranslation.swift b/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+JavaTranslation.swift index 64ae23b7..0e169ec6 100644 --- a/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+JavaTranslation.swift +++ b/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+JavaTranslation.swift @@ -41,12 +41,121 @@ extension JNISwift2JavaGenerator { return translated } + func translatedEnumCase( + for decl: ImportedEnumCase + ) -> TranslatedEnumCase? { + if let cached = translatedEnumCases[decl] { + return cached + } + + let translated: TranslatedEnumCase? + do { + let translation = JavaTranslation( + config: config, + swiftModuleName: swiftModuleName, + javaPackage: self.javaPackage, + javaClassLookupTable: self.javaClassLookupTable + ) + translated = try translation.translate(enumCase: decl) + } catch { + self.logger.debug("Failed to translate: '\(decl.swiftDecl.qualifiedNameForDebug)'; \(error)") + translated = nil + } + + translatedEnumCases[decl] = translated + return translated + } + struct JavaTranslation { let config: Configuration let swiftModuleName: String let javaPackage: String let javaClassLookupTable: JavaClassLookupTable + func translate(enumCase: ImportedEnumCase) throws -> TranslatedEnumCase { + let nativeTranslation = NativeJavaTranslation( + config: self.config, + javaPackage: self.javaPackage, + javaClassLookupTable: self.javaClassLookupTable + ) + + let methodName = "" // TODO: Used for closures, replace with better name? + let parentName = "" // TODO: Used for closures, replace with better name? + + let translatedValues = try self.translateParameters( + enumCase.parameters.map { ($0.name, $0.type) }, + methodName: methodName, + parentName: parentName + ) + + let conversions = try enumCase.parameters.enumerated().map { idx, parameter in + let resultName = parameter.name ?? "arg\(idx)" + let result = SwiftResult(convention: .direct, type: parameter.type) + var translatedResult = try self.translate(swiftResult: result, resultName: resultName) + translatedResult.conversion = .replacingPlaceholder(translatedResult.conversion, placeholder: "$nativeParameters.\(resultName)") + let nativeResult = try nativeTranslation.translate(swiftResult: result, resultName: resultName) + return (translated: translatedResult, native: nativeResult) + } + + let caseName = enumCase.name.firstCharacterUppercased + let enumName = enumCase.enumType.nominalTypeDecl.name + let nativeParametersType = JavaType.class(package: nil, name: "\(caseName).$NativeParameters") + let getAsCaseName = "getAs\(caseName)" + // If the case has no parameters, we can skip the native call. + let constructRecordConversion = JavaNativeConversionStep.method(.constant("Optional"), function: "of", arguments: [ + .constructJavaClass( + .commaSeparated(conversions.map(\.translated.conversion)), + .class(package: nil,name: caseName) + ) + ]) + let getAsCaseFunction = TranslatedFunctionDecl( + name: getAsCaseName, + isStatic: false, + isThrowing: false, + nativeFunctionName: "$\(getAsCaseName)", + parentName: enumName, + functionTypes: [], + translatedFunctionSignature: TranslatedFunctionSignature( + selfParameter: TranslatedParameter( + parameter: JavaParameter(name: "self", type: .long), + conversion: .aggregate( + [ + .ifStatement(.constant("getDiscriminator() != Discriminator.\(caseName.uppercased())"), thenExp: .constant("return Optional.empty();")), + .valueMemoryAddress(.placeholder) + ] + ) + ), + parameters: [], + resultType: TranslatedResult( + javaType: .class(package: nil, name: "Optional<\(caseName)>"), + outParameters: conversions.flatMap(\.translated.outParameters), + conversion: enumCase.parameters.isEmpty ? constructRecordConversion : .aggregate(variable: ("$nativeParameters", nativeParametersType), [constructRecordConversion]) + ) + ), + nativeFunctionSignature: NativeFunctionSignature( + selfParameter: NativeParameter( + parameters: [JavaParameter(name: "self", type: .long)], + conversion: .extractSwiftValue(.placeholder, swiftType: .nominal(enumCase.enumType), allowNil: false) + ), + parameters: [], + result: NativeResult( + javaType: nativeParametersType, + conversion: .placeholder, + outParameters: conversions.flatMap(\.native.outParameters) + ) + ) + ) + + return TranslatedEnumCase( + name: enumCase.name.firstCharacterUppercased, + enumName: enumCase.enumType.nominalTypeDecl.name, + original: enumCase, + translatedValues: translatedValues, + parameterConversions: conversions, + getAsCaseFunction: getAsCaseFunction + ) + } + func translate(_ decl: ImportedFunc) throws -> TranslatedFunctionDecl { let nativeTranslation = NativeJavaTranslation( config: self.config, @@ -61,7 +170,7 @@ extension JNISwift2JavaGenerator { let javaName = switch decl.apiKind { case .getter: decl.javaGetterName case .setter: decl.javaSetterName - case .function, .initializer: decl.name + case .function, .initializer, .enumCase: decl.name } // Swift -> Java @@ -98,6 +207,8 @@ extension JNISwift2JavaGenerator { return TranslatedFunctionDecl( name: javaName, + isStatic: decl.isStatic || !decl.hasParent || decl.isInitializer, + isThrowing: decl.isThrowing, nativeFunctionName: "$\(javaName)", parentName: parentName, functionTypes: funcTypes, @@ -136,31 +247,47 @@ extension JNISwift2JavaGenerator { methodName: String, parentName: String ) throws -> TranslatedFunctionSignature { - let parameters = try functionSignature.parameters.enumerated().map { idx, param in - let parameterName = param.parameterName ?? "arg\(idx))" + let parameters = try translateParameters( + functionSignature.parameters.map { ($0.parameterName, $0.type )}, + methodName: methodName, + parentName: parentName + ) + + // 'self' + let selfParameter = try self.translateSelfParameter(functionSignature.selfParameter, methodName: methodName, parentName: parentName) + + let resultType = try translate(swiftResult: functionSignature.result) + + return TranslatedFunctionSignature( + selfParameter: selfParameter, + parameters: parameters, + resultType: resultType + ) + } + + func translateParameters( + _ parameters: [(name: String?, type: SwiftType)], + methodName: String, + parentName: String + ) throws -> [TranslatedParameter] { + try parameters.enumerated().map { idx, param in + let parameterName = param.name ?? "arg\(idx)" return try translateParameter(swiftType: param.type, parameterName: parameterName, methodName: methodName, parentName: parentName) } + } + func translateSelfParameter(_ selfParameter: SwiftSelfParameter?, methodName: String, parentName: String) throws -> TranslatedParameter? { // 'self' - let selfParameter: TranslatedParameter? - if case .instance(let swiftSelf) = functionSignature.selfParameter { - selfParameter = try self.translateParameter( + if case .instance(let swiftSelf) = selfParameter { + return try self.translateParameter( swiftType: swiftSelf.type, parameterName: swiftSelf.parameterName ?? "self", methodName: methodName, parentName: parentName ) } else { - selfParameter = nil + return nil } - - let resultType = try translate(swiftResult: functionSignature.result) - - return TranslatedFunctionSignature( - selfParameter: selfParameter, - parameters: parameters, - resultType: resultType - ) } func translateParameter( @@ -343,7 +470,7 @@ extension JNISwift2JavaGenerator { } } - func translate(swiftResult: SwiftResult) throws -> TranslatedResult { + func translate(swiftResult: SwiftResult, resultName: String = "result") throws -> TranslatedResult { let swiftType = swiftResult.type // If the result type should cause any annotations on the method, include them here. @@ -357,7 +484,7 @@ extension JNISwift2JavaGenerator { guard let genericArgs = nominalType.genericArguments, genericArgs.count == 1 else { throw JavaTranslationError.unsupportedSwiftType(swiftType) } - return try translateOptionalResult(wrappedType: genericArgs[0]) + return try translateOptionalResult(wrappedType: genericArgs[0], resultName: resultName) default: guard let javaType = JNIJavaTypeTranslator.translate(knownType: knownType, config: self.config) else { @@ -390,7 +517,7 @@ extension JNISwift2JavaGenerator { return TranslatedResult(javaType: .void, outParameters: [], conversion: .placeholder) case .optional(let wrapped): - return try translateOptionalResult(wrappedType: wrapped) + return try translateOptionalResult(wrappedType: wrapped, resultName: resultName) case .metatype, .tuple, .function, .existential, .opaque, .genericParameter: throw JavaTranslationError.unsupportedSwiftType(swiftType) @@ -398,8 +525,11 @@ extension JNISwift2JavaGenerator { } func translateOptionalResult( - wrappedType swiftType: SwiftType + wrappedType swiftType: SwiftType, + resultName: String = "result" ) throws -> TranslatedResult { + let discriminatorName = "\(resultName)$_discriminator$" + let parameterAnnotations: [JavaAnnotation] = getTypeAnnotations(swiftType: swiftType, config: config) switch swiftType { @@ -425,6 +555,7 @@ extension JNISwift2JavaGenerator { conversion: .combinedValueToOptional( .placeholder, nextIntergralTypeWithSpaceForByte.javaType, + resultName: resultName, valueType: javaType, valueSizeInBytes: nextIntergralTypeWithSpaceForByte.valueBytes, optionalType: optionalClass @@ -437,13 +568,14 @@ extension JNISwift2JavaGenerator { javaType: .class(package: nil, name: returnType), annotations: parameterAnnotations, outParameters: [ - OutParameter(name: "result_discriminator$", type: .array(.byte), allocation: .newArray(.byte, size: 1)) + OutParameter(name: discriminatorName, type: .array(.byte), allocation: .newArray(.byte, size: 1)) ], conversion: .toOptionalFromIndirectReturn( - discriminatorName: "result_discriminator$", + discriminatorName: .combinedName(component: "discriminator$"), optionalClass: optionalClass, javaType: javaType, - toValue: .placeholder + toValue: .placeholder, + resultName: resultName ) ) } @@ -459,13 +591,14 @@ extension JNISwift2JavaGenerator { javaType: returnType, annotations: parameterAnnotations, outParameters: [ - OutParameter(name: "result_discriminator$", type: .array(.byte), allocation: .newArray(.byte, size: 1)) + OutParameter(name: discriminatorName, type: .array(.byte), allocation: .newArray(.byte, size: 1)) ], conversion: .toOptionalFromIndirectReturn( - discriminatorName: "result_discriminator$", + discriminatorName: .combinedName(component: "discriminator$"), optionalClass: "Optional", javaType: .long, - toValue: .wrapMemoryAddressUnsafe(.placeholder, .class(package: nil, name: nominalTypeName)) + toValue: .wrapMemoryAddressUnsafe(.placeholder, .class(package: nil, name: nominalTypeName)), + resultName: resultName ) ) @@ -475,10 +608,38 @@ extension JNISwift2JavaGenerator { } } + struct TranslatedEnumCase { + /// The corresponding Java case class (CamelCased) + let name: String + + /// The name of the translated enum + let enumName: String + + /// The oringinal enum case. + let original: ImportedEnumCase + + /// A list of the translated associated values + let translatedValues: [TranslatedParameter] + + /// A list of parameter conversions + let parameterConversions: [(translated: TranslatedResult, native: NativeResult)] + + let getAsCaseFunction: TranslatedFunctionDecl + + /// Returns whether the parameters require an arena + var requiresSwiftArena: Bool { + parameterConversions.contains(where: \.translated.conversion.requiresSwiftArena) + } + } + struct TranslatedFunctionDecl { /// Java function name let name: String + let isStatic: Bool + + let isThrowing: Bool + /// The name of the native function let nativeFunctionName: String @@ -532,7 +693,7 @@ extension JNISwift2JavaGenerator { let outParameters: [OutParameter] /// Represents how to convert the Java native result into a user-facing result. - let conversion: JavaNativeConversionStep + var conversion: JavaNativeConversionStep } struct OutParameter { @@ -573,6 +734,9 @@ extension JNISwift2JavaGenerator { case constant(String) + /// `input_component` + case combinedName(component: String) + // Convert the results of the inner steps to a comma separated list. indirect case commaSeparated([JavaNativeConversionStep]) @@ -582,6 +746,9 @@ extension JNISwift2JavaGenerator { /// Call `new \(Type)(\(placeholder), swiftArena$)` indirect case constructSwiftValue(JavaNativeConversionStep, JavaType) + /// Call `new \(Type)(\(placeholder))` + indirect case constructJavaClass(JavaNativeConversionStep, JavaType) + /// Call the `MyType.wrapMemoryAddressUnsafe` in order to wrap a memory address using the Java binding type indirect case wrapMemoryAddressUnsafe(JavaNativeConversionStep, JavaType) @@ -591,7 +758,7 @@ extension JNISwift2JavaGenerator { case isOptionalPresent - indirect case combinedValueToOptional(JavaNativeConversionStep, JavaType, valueType: JavaType, valueSizeInBytes: Int, optionalType: String) + indirect case combinedValueToOptional(JavaNativeConversionStep, JavaType, resultName: String, valueType: JavaType, valueSizeInBytes: Int, optionalType: String) indirect case ternary(JavaNativeConversionStep, thenExp: JavaNativeConversionStep, elseExp: JavaNativeConversionStep) @@ -600,18 +767,18 @@ extension JNISwift2JavaGenerator { indirect case subscriptOf(JavaNativeConversionStep, arguments: [JavaNativeConversionStep]) static func toOptionalFromIndirectReturn( - discriminatorName: String, + discriminatorName: JavaNativeConversionStep, optionalClass: String, javaType: JavaType, - toValue valueConversion: JavaNativeConversionStep + toValue valueConversion: JavaNativeConversionStep, + resultName: String ) -> JavaNativeConversionStep { .aggregate( - name: "result$", - type: javaType, + variable: (name: "\(resultName)$", type: javaType), [ .ternary( .equals( - .subscriptOf(.constant(discriminatorName), arguments: [.constant("0")]), + .subscriptOf(discriminatorName, arguments: [.constant("0")]), .constant("1") ), thenExp: .method(.constant(optionalClass), function: "of", arguments: [valueConversion]), @@ -622,7 +789,12 @@ extension JNISwift2JavaGenerator { } /// Perform multiple conversions using the same input. - case aggregate(name: String, type: JavaType, [JavaNativeConversionStep]) + case aggregate(variable: (name: String, type: JavaType)? = nil, [JavaNativeConversionStep]) + + indirect case ifStatement(JavaNativeConversionStep, thenExp: JavaNativeConversionStep, elseExp: JavaNativeConversionStep? = nil) + + /// Access a member of the value + indirect case replacingPlaceholder(JavaNativeConversionStep, placeholder: String) /// Returns the conversion string applied to the placeholder. func render(_ printer: inout CodePrinter, _ placeholder: String) -> String { @@ -635,6 +807,9 @@ extension JNISwift2JavaGenerator { case .constant(let value): return value + case .combinedName(let component): + return "\(placeholder)_\(component)" + case .commaSeparated(let list): return list.map({ $0.render(&printer, placeholder)}).joined(separator: ", ") @@ -649,6 +824,10 @@ extension JNISwift2JavaGenerator { let inner = inner.render(&printer, placeholder) return "\(javaType.className!).wrapMemoryAddressUnsafe(\(inner), swiftArena$)" + case .constructJavaClass(let inner, let javaType): + let inner = inner.render(&printer, placeholder) + return "new \(javaType.className!)(\(inner))" + case .call(let inner, let function): let inner = inner.render(&printer, placeholder) return "\(function)(\(inner))" @@ -663,22 +842,22 @@ extension JNISwift2JavaGenerator { let argsStr = args.joined(separator: ", ") return "\(inner).\(methodName)(\(argsStr))" - case .combinedValueToOptional(let combined, let combinedType, let valueType, let valueSizeInBytes, let optionalType): + case .combinedValueToOptional(let combined, let combinedType, let resultName, let valueType, let valueSizeInBytes, let optionalType): let combined = combined.render(&printer, placeholder) printer.print( """ - \(combinedType) combined$ = \(combined); - byte discriminator$ = (byte) (combined$ & 0xFF); + \(combinedType) \(resultName)_combined$ = \(combined); + byte \(resultName)_discriminator$ = (byte) (\(resultName)_combined$ & 0xFF); """ ) if valueType == .boolean { - printer.print("boolean value$ = ((byte) (combined$ >> 8)) != 0;") + printer.print("boolean \(resultName)_value$ = ((byte) (\(resultName)_combined$ >> 8)) != 0;") } else { - printer.print("\(valueType) value$ = (\(valueType)) (combined$ >> \(valueSizeInBytes * 8));") + printer.print("\(valueType) \(resultName)_value$ = (\(valueType)) (\(resultName)_combined$ >> \(valueSizeInBytes * 8));") } - return "discriminator$ == 1 ? \(optionalType).of(value$) : \(optionalType).empty()" + return "\(resultName)_discriminator$ == 1 ? \(optionalType).of(\(resultName)_value$) : \(optionalType).empty()" case .ternary(let cond, let thenExp, let elseExp): let cond = cond.render(&printer, placeholder) @@ -696,25 +875,50 @@ extension JNISwift2JavaGenerator { let arguments = arguments.map { $0.render(&printer, placeholder) } return "\(inner)[\(arguments.joined(separator: ", "))]" - case .aggregate(let name, let type, let steps): + case .aggregate(let variable, let steps): precondition(!steps.isEmpty, "Aggregate must contain steps") - printer.print("\(type) \(name) = \(placeholder);") + let toExplode: String + if let variable { + printer.print("\(variable.type) \(variable.name) = \(placeholder);") + toExplode = variable.name + } else { + toExplode = placeholder + } let steps = steps.map { - $0.render(&printer, name) + $0.render(&printer, toExplode) } return steps.last! + + case .ifStatement(let cond, let thenExp, let elseExp): + let cond = cond.render(&printer, placeholder) + printer.printBraceBlock("if (\(cond))") { printer in + printer.print(thenExp.render(&printer, placeholder)) + } + if let elseExp { + printer.printBraceBlock("else") { printer in + printer.print(elseExp.render(&printer, placeholder)) + } + } + + return "" + + case .replacingPlaceholder(let inner, let placeholder): + return inner.render(&printer, placeholder) } } /// Whether the conversion uses SwiftArena. var requiresSwiftArena: Bool { switch self { - case .placeholder, .constant, .isOptionalPresent: + case .placeholder, .constant, .isOptionalPresent, .combinedName: return false case .constructSwiftValue, .wrapMemoryAddressUnsafe: return true + case .constructJavaClass(let inner, _): + return inner.requiresSwiftArena + case .valueMemoryAddress(let inner): return inner.requiresSwiftArena @@ -724,7 +928,7 @@ extension JNISwift2JavaGenerator { case .method(let inner, _, let args): return inner.requiresSwiftArena || args.contains(where: \.requiresSwiftArena) - case .combinedValueToOptional(let inner, _, _, _, _): + case .combinedValueToOptional(let inner, _, _, _, _, _): return inner.requiresSwiftArena case .ternary(let cond, let thenExp, let elseExp): @@ -736,11 +940,17 @@ extension JNISwift2JavaGenerator { case .subscriptOf(let inner, _): return inner.requiresSwiftArena - case .aggregate(_, _, let steps): + case .aggregate(_, let steps): return steps.contains(where: \.requiresSwiftArena) + case .ifStatement(let cond, let thenExp, let elseExp): + return cond.requiresSwiftArena || thenExp.requiresSwiftArena || (elseExp?.requiresSwiftArena ?? false) + case .call(let inner, _): - return inner.requiresSwiftArena + return inner.requiresSwiftArena + + case .replacingPlaceholder(let inner, _): + return inner.requiresSwiftArena } } } diff --git a/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+NativeTranslation.swift b/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+NativeTranslation.swift index e7f7efbe..f2e9522b 100644 --- a/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+NativeTranslation.swift +++ b/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+NativeTranslation.swift @@ -61,6 +61,23 @@ extension JNISwift2JavaGenerator { ) } + func translateParameters( + _ parameters: [SwiftParameter], + translatedParameters: [TranslatedParameter], + methodName: String, + parentName: String + ) throws -> [NativeParameter] { + try zip(translatedParameters, parameters).map { translatedParameter, swiftParameter in + let parameterName = translatedParameter.parameter.name + return try translate( + swiftParameter: swiftParameter, + parameterName: parameterName, + methodName: methodName, + parentName: parentName + ) + } + } + func translate( swiftParameter: SwiftParameter, parameterName: String, @@ -231,8 +248,11 @@ extension JNISwift2JavaGenerator { } func translateOptionalResult( - wrappedType swiftType: SwiftType + wrappedType swiftType: SwiftType, + resultName: String = "result" ) throws -> NativeResult { + let discriminatorName = "\(resultName)_discriminator$" + switch swiftType { case .nominal(let nominalType): if let knownType = nominalType.nominalTypeDecl.knownTypeKind { @@ -249,6 +269,7 @@ extension JNISwift2JavaGenerator { conversion: .getJNIValue( .optionalRaisingWidenIntegerType( .placeholder, + resultName: resultName, valueType: javaType, combinedSwiftType: nextIntergralTypeWithSpaceForByte.swiftType, valueSizeInBytes: nextIntergralTypeWithSpaceForByte.valueBytes @@ -258,7 +279,6 @@ extension JNISwift2JavaGenerator { ) } else { // Use indirect byte array to store discriminator - let discriminatorName = "result_discriminator$" return NativeResult( javaType: javaType, @@ -284,8 +304,6 @@ extension JNISwift2JavaGenerator { } // Assume JExtract imported class - let discriminatorName = "result_discriminator$" - return NativeResult( javaType: .long, conversion: .optionalRaisingIndirectReturn( @@ -368,7 +386,8 @@ extension JNISwift2JavaGenerator { } func translate( - swiftResult: SwiftResult + swiftResult: SwiftResult, + resultName: String = "result" ) throws -> NativeResult { switch swiftResult.type { case .nominal(let nominalType): @@ -378,7 +397,7 @@ extension JNISwift2JavaGenerator { guard let genericArgs = nominalType.genericArguments, genericArgs.count == 1 else { throw JavaTranslationError.unsupportedSwiftType(swiftResult.type) } - return try translateOptionalResult(wrappedType: genericArgs[0]) + return try translateOptionalResult(wrappedType: genericArgs[0], resultName: resultName) default: guard let javaType = JNIJavaTypeTranslator.translate(knownType: knownType, config: self.config), javaType.implementsJavaValue else { @@ -399,7 +418,7 @@ extension JNISwift2JavaGenerator { return NativeResult( javaType: .long, - conversion: .getJNIValue(.allocateSwiftValue(name: "result", swiftType: swiftResult.type)), + conversion: .getJNIValue(.allocateSwiftValue(name: resultName, swiftType: swiftResult.type)), outParameters: [] ) @@ -411,7 +430,7 @@ extension JNISwift2JavaGenerator { ) case .optional(let wrapped): - return try translateOptionalResult(wrappedType: wrapped) + return try translateOptionalResult(wrappedType: wrapped, resultName: resultName) case .metatype, .tuple, .function, .existential, .opaque, .genericParameter: throw JavaTranslationError.unsupportedSwiftType(swiftResult.type) @@ -449,6 +468,9 @@ extension JNISwift2JavaGenerator { case constant(String) + /// `input_component` + case combinedName(component: String) + /// `value.getJNIValue(in:)` indirect case getJNIValue(NativeSwiftConversionStep) @@ -480,7 +502,7 @@ extension JNISwift2JavaGenerator { indirect case optionalChain(NativeSwiftConversionStep) - indirect case optionalRaisingWidenIntegerType(NativeSwiftConversionStep, valueType: JavaType, combinedSwiftType: SwiftKnownTypeDeclKind, valueSizeInBytes: Int) + indirect case optionalRaisingWidenIntegerType(NativeSwiftConversionStep, resultName: String, valueType: JavaType, combinedSwiftType: SwiftKnownTypeDeclKind, valueSizeInBytes: Int) indirect case optionalRaisingIndirectReturn(NativeSwiftConversionStep, returnType: JavaType, discriminatorParameterName: String, placeholderValue: NativeSwiftConversionStep) @@ -503,6 +525,9 @@ extension JNISwift2JavaGenerator { case .constant(let value): return value + case .combinedName(let component): + return "\(placeholder)_\(component)" + case .getJNIValue(let inner): let inner = inner.render(&printer, placeholder) return "\(inner).getJNIValue(in: environment!)" @@ -606,18 +631,18 @@ extension JNISwift2JavaGenerator { let inner = inner.render(&printer, placeholder) return "\(inner)?" - case .optionalRaisingWidenIntegerType(let inner, let valueType, let combinedSwiftType, let valueSizeInBytes): + case .optionalRaisingWidenIntegerType(let inner, let resultName, let valueType, let combinedSwiftType, let valueSizeInBytes): let inner = inner.render(&printer, placeholder) let value = valueType == .boolean ? "$0 ? 1 : 0" : "$0" let combinedSwiftTypeName = combinedSwiftType.moduleAndName.name printer.print( """ - let value$ = \(inner).map { + let \(resultName)_value$ = \(inner).map { \(combinedSwiftTypeName)(\(value)) << \(valueSizeInBytes * 8) | \(combinedSwiftTypeName)(1) } ?? 0 """ ) - return "value$" + return "\(resultName)_value$" case .optionalRaisingIndirectReturn(let inner, let returnType, let discriminatorParameterName, let placeholderValue): printer.print("let result$: \(returnType.jniTypeName)") diff --git a/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+SwiftThunkPrinting.swift b/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+SwiftThunkPrinting.swift index ca580e60..6b8c435a 100644 --- a/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+SwiftThunkPrinting.swift +++ b/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+SwiftThunkPrinting.swift @@ -80,6 +80,15 @@ extension JNISwift2JavaGenerator { } } + private func printJNICache(_ printer: inout CodePrinter, _ type: ImportedNominalType) { + printer.printBraceBlock("enum \(JNICaching.cacheName(for: type))") { printer in + for enumCase in type.cases { + guard let translatedCase = translatedEnumCase(for: enumCase) else { continue } + printer.print("static let \(JNICaching.cacheMemberName(for: enumCase)) = \(renderEnumCaseCacheInit(translatedCase))") + } + } + } + private func printGlobalSwiftThunkSources(_ printer: inout CodePrinter) throws { printHeader(&printer) @@ -97,11 +106,24 @@ extension JNISwift2JavaGenerator { private func printNominalTypeThunks(_ printer: inout CodePrinter, _ type: ImportedNominalType) throws { printHeader(&printer) + printJNICache(&printer, type) + printer.println() + for initializer in type.initializers { printSwiftFunctionThunk(&printer, initializer) printer.println() } + if type.swiftNominal.kind == .enum { + printEnumDiscriminator(&printer, type) + printer.println() + + for enumCase in type.cases { + printEnumCase(&printer, enumCase) + printer.println() + } + } + for method in type.methods { printSwiftFunctionThunk(&printer, method) printer.println() @@ -115,6 +137,92 @@ extension JNISwift2JavaGenerator { printDestroyFunctionThunk(&printer, type) } + private func printEnumDiscriminator(_ printer: inout CodePrinter, _ type: ImportedNominalType) { + let selfPointerParam = JavaParameter(name: "selfPointer", type: .long) + printCDecl( + &printer, + javaMethodName: "$getDiscriminator", + parentName: type.swiftNominal.name, + parameters: [selfPointerParam], + resultType: .int + ) { printer in + let selfPointer = self.printSelfJLongToUnsafeMutablePointer( + &printer, + swiftParentName: type.swiftNominal.name, + selfPointerParam + ) + printer.printBraceBlock("switch (\(selfPointer).pointee)") { printer in + for (idx, enumCase) in type.cases.enumerated() { + printer.print("case .\(enumCase.name): return \(idx)") + } + } + } + } + + private func printEnumCase(_ printer: inout CodePrinter, _ enumCase: ImportedEnumCase) { + guard let translatedCase = self.translatedEnumCase(for: enumCase) else { + return + } + + // Print static case initializer + printSwiftFunctionThunk(&printer, enumCase.caseFunction) + printer.println() + + // Print getAsCase method + if !translatedCase.translatedValues.isEmpty { + printEnumGetAsCaseThunk(&printer, translatedCase) + } + } + + private func renderEnumCaseCacheInit(_ enumCase: TranslatedEnumCase) -> String { + let nativeParametersClassName = "\(javaPackagePath)/\(enumCase.enumName)$\(enumCase.name)$$NativeParameters" + let methodSignature = MethodSignature(resultType: .void, parameterTypes: enumCase.parameterConversions.map(\.native.javaType)) + let methods = #"[.init(name: "", signature: "\#(methodSignature.mangledName)")]"# + + return #"_JNIMethodIDCache(environment: try! JavaVirtualMachine.shared().environment(), className: "\#(nativeParametersClassName)", methods: \#(methods))"# + } + + private func printEnumGetAsCaseThunk( + _ printer: inout CodePrinter, + _ enumCase: TranslatedEnumCase + ) { + printCDecl( + &printer, + enumCase.getAsCaseFunction + ) { printer in + let selfPointer = enumCase.getAsCaseFunction.nativeFunctionSignature.selfParameter!.conversion.render(&printer, "self") + let caseNames = enumCase.original.parameters.enumerated().map { idx, parameter in + parameter.name ?? "_\(idx)" + } + let caseNamesWithLet = caseNames.map { "let \($0)" } + let methodSignature = MethodSignature(resultType: .void, parameterTypes: enumCase.parameterConversions.map(\.native.javaType)) + printer.print( + """ + guard case .\(enumCase.original.name)(\(caseNamesWithLet.joined(separator: ", "))) = \(selfPointer).pointee else { + fatalError("Expected enum case '\(enumCase.original.name)', but was '\\(\(selfPointer).pointee)'!") + } + let cache$ = \(JNICaching.cacheName(for: enumCase.original.enumType)).\(JNICaching.cacheMemberName(for: enumCase.original)) + let class$ = cache$.javaClass + let method$ = _JNIMethodIDCache.Method(name: "", signature: "\(methodSignature.mangledName)") + let constructorID$ = cache$[method$] + """ + ) + let upcallArguments = zip(enumCase.parameterConversions, caseNames).map { conversion, caseName in + // '0' is treated the same as a null pointer. + let nullConversion = !conversion.native.javaType.isPrimitive ? " ?? 0" : "" + let result = conversion.native.conversion.render(&printer, caseName) + return "\(result)\(nullConversion)" + } + printer.print( + """ + return withVaList([\(upcallArguments.joined(separator: ", "))]) { + return environment.interface.NewObjectV(environment, class$, constructorID$, $0) + } + """ + ) + } + } + private func printSwiftFunctionThunk( _ printer: inout CodePrinter, _ decl: ImportedFunc @@ -124,21 +232,9 @@ extension JNISwift2JavaGenerator { return } - let nativeSignature = translatedDecl.nativeFunctionSignature - var parameters = nativeSignature.parameters.flatMap(\.parameters) - - if let selfParameter = nativeSignature.selfParameter { - parameters += selfParameter.parameters - } - - parameters += nativeSignature.result.outParameters - printCDecl( &printer, - javaMethodName: translatedDecl.nativeFunctionName, - parentName: translatedDecl.parentName, - parameters: parameters, - resultType: nativeSignature.result.javaType + translatedDecl ) { printer in self.printFunctionDowncall(&printer, decl) } @@ -190,6 +286,18 @@ extension JNISwift2JavaGenerator { .joined(separator: ", ") result = "\(tryClause)\(callee).\(decl.name)(\(downcallArguments))" + case .enumCase: + let downcallArguments = zip( + decl.functionSignature.parameters, + arguments + ).map { originalParam, argument in + let label = originalParam.argumentLabel.map { "\($0): " } ?? "" + return "\(label)\(argument)" + } + + let associatedValues = !downcallArguments.isEmpty ? "(\(downcallArguments.joined(separator: ", ")))" : "" + result = "\(callee).\(decl.name)\(associatedValues)" + case .getter: result = "\(tryClause)\(callee).\(decl.name)" @@ -228,6 +336,31 @@ extension JNISwift2JavaGenerator { } } + private func printCDecl( + _ printer: inout CodePrinter, + _ translatedDecl: TranslatedFunctionDecl, + _ body: (inout CodePrinter) -> Void + ) { + let nativeSignature = translatedDecl.nativeFunctionSignature + var parameters = nativeSignature.parameters.flatMap(\.parameters) + + if let selfParameter = nativeSignature.selfParameter { + parameters += selfParameter.parameters + } + + parameters += nativeSignature.result.outParameters + + printCDecl( + &printer, + javaMethodName: translatedDecl.nativeFunctionName, + parentName: translatedDecl.parentName, + parameters: parameters, + resultType: nativeSignature.result.javaType + ) { printer in + body(&printer) + } + } + private func printCDecl( _ printer: inout CodePrinter, javaMethodName: String, diff --git a/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator.swift b/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator.swift index c5f3a5b6..60ec1b8b 100644 --- a/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator.swift +++ b/Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator.swift @@ -39,6 +39,7 @@ package class JNISwift2JavaGenerator: Swift2JavaGenerator { /// Cached Java translation result. 'nil' indicates failed translation. var translatedDecls: [ImportedFunc: TranslatedFunctionDecl] = [:] + var translatedEnumCases: [ImportedEnumCase: TranslatedEnumCase] = [:] /// Because we need to write empty files for SwiftPM, keep track which files we didn't write yet, /// and write an empty file for those. diff --git a/Sources/JExtractSwiftLib/Swift2JavaVisitor.swift b/Sources/JExtractSwiftLib/Swift2JavaVisitor.swift index a933ff3e..bc6ac032 100644 --- a/Sources/JExtractSwiftLib/Swift2JavaVisitor.swift +++ b/Sources/JExtractSwiftLib/Swift2JavaVisitor.swift @@ -46,7 +46,7 @@ final class Swift2JavaVisitor { case .structDecl(let node): self.visit(nominalDecl: node, in: parent) case .enumDecl(let node): - self.visit(nominalDecl: node, in: parent) + self.visit(enumDecl: node, in: parent) case .protocolDecl(let node): self.visit(nominalDecl: node, in: parent) case .extensionDecl(let node): @@ -65,6 +65,8 @@ final class Swift2JavaVisitor { case .subscriptDecl: // TODO: Implement break + case .enumCaseDecl(let node): + self.visit(enumCaseDecl: node, in: parent) default: break @@ -83,6 +85,12 @@ final class Swift2JavaVisitor { } } + func visit(enumDecl node: EnumDeclSyntax, in parent: ImportedNominalType?) { + self.visit(nominalDecl: node, in: parent) + + self.synthesizeRawRepresentableConformance(enumDecl: node, in: parent) + } + func visit(extensionDecl node: ExtensionDeclSyntax, in parent: ImportedNominalType?) { guard parent == nil else { // 'extension' in a nominal type is invalid. Ignore @@ -131,6 +139,49 @@ final class Swift2JavaVisitor { } } + func visit(enumCaseDecl node: EnumCaseDeclSyntax, in typeContext: ImportedNominalType?) { + guard let typeContext else { + self.log.info("Enum case must be within a current type; \(node)") + return + } + + do { + for caseElement in node.elements { + self.log.debug("Import case \(caseElement.name) of enum \(node.qualifiedNameForDebug)") + + let parameters = try caseElement.parameterClause?.parameters.map { + try SwiftEnumCaseParameter($0, lookupContext: translator.lookupContext) + } + + let signature = try SwiftFunctionSignature( + caseElement, + enclosingType: typeContext.swiftType, + lookupContext: translator.lookupContext + ) + + let caseFunction = ImportedFunc( + module: translator.swiftModuleName, + swiftDecl: node, + name: caseElement.name.text, + apiKind: .enumCase, + functionSignature: signature + ) + + let importedCase = ImportedEnumCase( + name: caseElement.name.text, + parameters: parameters ?? [], + swiftDecl: node, + enumType: SwiftNominalType(nominalTypeDecl: typeContext.swiftNominal), + caseFunction: caseFunction + ) + + typeContext.cases.append(importedCase) + } + } catch { + self.log.debug("Failed to import: \(node.qualifiedNameForDebug); \(error)") + } + } + func visit(variableDecl node: VariableDeclSyntax, in typeContext: ImportedNominalType?) { guard node.shouldExtract(config: config, log: log) else { return @@ -213,6 +264,32 @@ final class Swift2JavaVisitor { typeContext.initializers.append(imported) } + + private func synthesizeRawRepresentableConformance(enumDecl node: EnumDeclSyntax, in parent: ImportedNominalType?) { + guard let imported = translator.importedNominalType(node, parent: parent) else { + return + } + + if let firstInheritanceType = imported.swiftNominal.firstInheritanceType, + let inheritanceType = try? SwiftType( + firstInheritanceType, + lookupContext: translator.lookupContext + ), + inheritanceType.isRawTypeCompatible + { + if !imported.variables.contains(where: { $0.name == "rawValue" && $0.functionSignature.result.type != inheritanceType }) { + let decl: DeclSyntax = "public var rawValue: \(raw: inheritanceType.description) { get }" + self.visit(decl: decl, in: imported) + } + + imported.variables.first?.signatureString + + if !imported.initializers.contains(where: { $0.functionSignature.parameters.count == 1 && $0.functionSignature.parameters.first?.parameterName == "rawValue" && $0.functionSignature.parameters.first?.type == inheritanceType }) { + let decl: DeclSyntax = "public init?(rawValue: \(raw: inheritanceType))" + self.visit(decl: decl, in: imported) + } + } + } } extension DeclSyntaxProtocol where Self: WithModifiersSyntax & WithAttributesSyntax { @@ -233,15 +310,6 @@ extension DeclSyntaxProtocol where Self: WithModifiersSyntax & WithAttributesSyn return false } - if let node = self.as(InitializerDeclSyntax.self) { - let isFailable = node.optionalMark != nil - - if isFailable { - log.warning("Skip import '\(self.qualifiedNameForDebug)': failable initializer") - return false - } - } - return true } } diff --git a/Sources/JExtractSwiftLib/SwiftTypes/SwiftEnumCaseParameter.swift b/Sources/JExtractSwiftLib/SwiftTypes/SwiftEnumCaseParameter.swift new file mode 100644 index 00000000..55682152 --- /dev/null +++ b/Sources/JExtractSwiftLib/SwiftTypes/SwiftEnumCaseParameter.swift @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2025 Apple Inc. and the Swift.org project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of Swift.org project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import SwiftSyntax + +struct SwiftEnumCaseParameter: Equatable { + var name: String? + var type: SwiftType +} + +extension SwiftEnumCaseParameter { + init( + _ node: EnumCaseParameterSyntax, + lookupContext: SwiftTypeLookupContext + ) throws { + self.init( + name: node.firstName?.identifier?.name, + type: try SwiftType(node.type, lookupContext: lookupContext) + ) + } +} diff --git a/Sources/JExtractSwiftLib/SwiftTypes/SwiftFunctionSignature.swift b/Sources/JExtractSwiftLib/SwiftTypes/SwiftFunctionSignature.swift index 85b96110..a5b01bee 100644 --- a/Sources/JExtractSwiftLib/SwiftTypes/SwiftFunctionSignature.swift +++ b/Sources/JExtractSwiftLib/SwiftTypes/SwiftFunctionSignature.swift @@ -71,11 +71,6 @@ extension SwiftFunctionSignature { throw SwiftFunctionTranslationError.missingEnclosingType(node) } - // We do not yet support failable initializers. - if node.optionalMark != nil { - throw SwiftFunctionTranslationError.failableInitializer(node) - } - let (genericParams, genericRequirements) = try Self.translateGenericParameters( parameterClause: node.genericParameterClause, whereClause: node.genericWhereClause, @@ -86,16 +81,37 @@ extension SwiftFunctionSignature { lookupContext: lookupContext ) + let type = node.optionalMark != nil ? .optional(enclosingType) : enclosingType + self.init( selfParameter: .initializer(enclosingType), parameters: parameters, - result: SwiftResult(convention: .direct, type: enclosingType), + result: SwiftResult(convention: .direct, type: type), effectSpecifiers: effectSpecifiers, genericParameters: genericParams, genericRequirements: genericRequirements ) } + init( + _ node: EnumCaseElementSyntax, + enclosingType: SwiftType, + lookupContext: SwiftTypeLookupContext + ) throws { + let parameters = try node.parameterClause?.parameters.map { param in + try SwiftParameter(param, lookupContext: lookupContext) + } + + self.init( + selfParameter: .initializer(enclosingType), + parameters: parameters ?? [], + result: SwiftResult(convention: .direct, type: enclosingType), + effectSpecifiers: [], + genericParameters: [], + genericRequirements: [] + ) + } + init( _ node: FunctionDeclSyntax, enclosingType: SwiftType?, diff --git a/Sources/JExtractSwiftLib/SwiftTypes/SwiftNominalTypeDeclaration.swift b/Sources/JExtractSwiftLib/SwiftTypes/SwiftNominalTypeDeclaration.swift index cef4e731..335979a4 100644 --- a/Sources/JExtractSwiftLib/SwiftTypes/SwiftNominalTypeDeclaration.swift +++ b/Sources/JExtractSwiftLib/SwiftTypes/SwiftNominalTypeDeclaration.swift @@ -85,6 +85,14 @@ package class SwiftNominalTypeDeclaration: SwiftTypeDeclaration { super.init(moduleName: moduleName, name: node.name.text) } + lazy var firstInheritanceType: TypeSyntax? = { + guard let firstInheritanceType = self.syntax?.inheritanceClause?.inheritedTypes.first else { + return nil + } + + return firstInheritanceType.type + }() + /// Returns true if this type conforms to `Sendable` and therefore is "threadsafe". lazy var isSendable: Bool = { // Check if Sendable is in the inheritance list diff --git a/Sources/JExtractSwiftLib/SwiftTypes/SwiftParameter.swift b/Sources/JExtractSwiftLib/SwiftTypes/SwiftParameter.swift index 75d165e9..63f7d75b 100644 --- a/Sources/JExtractSwiftLib/SwiftTypes/SwiftParameter.swift +++ b/Sources/JExtractSwiftLib/SwiftTypes/SwiftParameter.swift @@ -57,6 +57,16 @@ enum SwiftParameterConvention: Equatable { case `inout` } +extension SwiftParameter { + init(_ node: EnumCaseParameterSyntax, lookupContext: SwiftTypeLookupContext) throws { + self.convention = .byValue + self.type = try SwiftType(node.type, lookupContext: lookupContext) + self.argumentLabel = nil + self.parameterName = node.firstName?.identifier?.name + self.argumentLabel = node.firstName?.identifier?.name + } +} + extension SwiftParameter { init(_ node: FunctionParameterSyntax, lookupContext: SwiftTypeLookupContext) throws { // Determine the convention. The default is by-value, but there are diff --git a/Sources/JExtractSwiftLib/SwiftTypes/SwiftType.swift b/Sources/JExtractSwiftLib/SwiftTypes/SwiftType.swift index 3cc14406..58bb65c3 100644 --- a/Sources/JExtractSwiftLib/SwiftTypes/SwiftType.swift +++ b/Sources/JExtractSwiftLib/SwiftTypes/SwiftType.swift @@ -103,6 +103,19 @@ enum SwiftType: Equatable { default: false } } + + var isRawTypeCompatible: Bool { + switch self { + case .nominal(let nominal): + switch nominal.nominalTypeDecl.knownTypeKind { + case .int, .uint, .int8, .uint8, .int16, .uint16, .int32, .uint32, .int64, .uint64, .float, .double, .string: + true + default: + false + } + default: false + } + } } extension SwiftType: CustomStringConvertible { diff --git a/Sources/JavaKit/Helpers/_JNIMethodIDCache.swift b/Sources/JavaKit/Helpers/_JNIMethodIDCache.swift new file mode 100644 index 00000000..a67d225f --- /dev/null +++ b/Sources/JavaKit/Helpers/_JNIMethodIDCache.swift @@ -0,0 +1,59 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2024 Apple Inc. and the Swift.org project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of Swift.org project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +/// A cache used to hold references for JNI method and classes. +/// +/// This type is used internally in by the outputted JExtract wrappers +/// to improve performance of any JNI lookups. +public final class _JNIMethodIDCache: Sendable { + public struct Method: Hashable { + public let name: String + public let signature: String + + public init(name: String, signature: String) { + self.name = name + self.signature = signature + } + } + + nonisolated(unsafe) let _class: jclass? + nonisolated(unsafe) let methods: [Method: jmethodID] + + public var javaClass: jclass { + self._class! + } + + public init(environment: UnsafeMutablePointer!, className: String, methods: [Method]) { + guard let clazz = environment.interface.FindClass(environment, className) else { + fatalError("Class \(className) could not be found!") + } + self._class = environment.interface.NewGlobalRef(environment, clazz)! + self.methods = methods.reduce(into: [:]) { (result, method) in + if let methodID = environment.interface.GetMethodID(environment, clazz, method.name, method.signature) { + result[method] = methodID + } else { + fatalError("Method \(method.signature) with signature \(method.signature) not found in class \(className)") + } + } + } + + + public subscript(_ method: Method) -> jmethodID? { + methods[method] + } + + public func cleanup(environment: UnsafeMutablePointer!) { + environment.interface.DeleteGlobalRef(environment, self._class) + } +} diff --git a/Sources/JavaTypes/Mangling.swift b/Sources/JavaTypes/Mangling.swift index 74e2e0b8..f0dbd484 100644 --- a/Sources/JavaTypes/Mangling.swift +++ b/Sources/JavaTypes/Mangling.swift @@ -36,7 +36,7 @@ extension JavaType { case .void: "V" case .array(let elementType): "[" + elementType.mangledName case .class(package: let package, name: let name): - "L\(package!).\(name);".replacingPeriodsWithSlashes() + "L\(package!).\(name.replacingPeriodsWithDollars());".replacingPeriodsWithSlashes() } } } @@ -145,4 +145,9 @@ extension StringProtocol { fileprivate func replacingSlashesWithPeriods() -> String { return String(self.map { $0 == "/" ? "." as Character : $0 }) } + + /// Return the string after replacing all of the periods (".") with slashes ("$"). + fileprivate func replacingPeriodsWithDollars() -> String { + return String(self.map { $0 == "." ? "$" as Character : $0 }) + } } diff --git a/Sources/SwiftJavaDocumentation/Documentation.docc/SupportedFeatures.md b/Sources/SwiftJavaDocumentation/Documentation.docc/SupportedFeatures.md index 130c333b..c5de6c45 100644 --- a/Sources/SwiftJavaDocumentation/Documentation.docc/SupportedFeatures.md +++ b/Sources/SwiftJavaDocumentation/Documentation.docc/SupportedFeatures.md @@ -49,7 +49,8 @@ SwiftJava's `swift-java jextract` tool automates generating Java bindings from S | Initializers: `class`, `struct` | ✅ | ✅ | | Optional Initializers / Throwing Initializers | ❌ | ❌ | | Deinitializers: `class`, `struct` | ✅ | ✅ | -| `enum`, `actor` | ❌ | ❌ | +| `enum` | ❌ | ✅ | +| `actor` | ❌ | ❌ | | Global Swift `func` | ✅ | ✅ | | Class/struct member `func` | ✅ | ✅ | | Throwing functions: `func x() throws` | ❌ | ✅ | @@ -157,3 +158,113 @@ you are expected to add a Guava dependency to your Java project. | `Double` | `double` | > Note: The `wrap-guava` mode is currently only available in FFM mode of jextract. + +### Enums + +> Note: Enums are currently only supported in JNI mode. + +Swift enums are extracted into a corresponding Java `class`. To support associated values +all cases are also extracted as Java `record`s. + +Consider the following Swift enum: +```swift +public enum Vehicle { + case car(String) + case bicycle(maker: String) +} +``` +You can then instantiate a case of `Vehicle` by using one of the static methods: +```java +try (var arena = SwiftArena.ofConfined()) { + Vehicle vehicle = Vehicle.car("BMW", arena); + Optional car = vehicle.getAsCar(); + assertEquals("BMW", car.orElseThrow().arg0()); +} +``` +As you can see above, to access the associated values of a case you can call one of the +`getAsX` methods that will return an Optional record with the associated values. +```java +try (var arena = SwiftArena.ofConfined()) { + Vehicle vehicle = Vehicle.bycicle("My Brand", arena); + Optional car = vehicle.getAsCar(); + assertFalse(car.isPresent()); + + Optional bicycle = vehicle.getAsBicycle(); + assertEquals("My Brand", bicycle.orElseThrow().maker()); +} +``` + +#### Switching and pattern matching + +If you only need to switch on the case and not access any associated values, +you can use the `getDiscriminator()` method: +```java +Vehicle vehicle = ...; +switch (vehicle.getDiscriminator()) { + case BICYCLE: + System.out.println("I am a bicycle!"); + break + case CAR: + System.out.println("I am a car!"); + break +} +``` +If you also want access to the associated values, you have various options +depending on the Java version you are using. +If you are running Java 21+ you can use [pattern matching for switch](https://openjdk.org/jeps/441): +```java +Vehicle vehicle = ...; +switch (vehicle.getCase()) { + case Vehicle.Bicycle b: + System.out.println("Bicycle maker: " + b.maker()); + break + case Vehicle.Car c: + System.out.println("Car: " + c.arg0()); + break +} +``` +For Java 16+ you can use [pattern matching for instanceof](https://openjdk.org/jeps/394) +```java +Vehicle vehicle = ...; +Vehicle.Case case = vehicle.getCase(); +if (case instanceof Vehicle.Bicycle b) { + System.out.println("Bicycle maker: " + b.maker()); +} else if(case instanceof Vehicle.Car c) { + System.out.println("Car: " + c.arg0()); +} +``` +For any previous Java versions you can resort to casting the `Case` to the expected type: +```java +Vehicle vehicle = ...; +Vehicle.Case case = vehicle.getCase(); +if (case instanceof Vehicle.Bicycle) { + Vehicle.Bicycle b = (Vehicle.Bicycle) case; + System.out.println("Bicycle maker: " + b.maker()); +} else if(case instanceof Vehicle.Car) { + Vehicle.Car c = (Vehicle.Car) case; + System.out.println("Car: " + c.arg0()); +} +``` + +#### RawRepresentable enums + +JExtract also supports extracting enums that conform to `RawRepresentable` +by giving access to an optional initializer and the `rawValue` variable. +Consider the following example: +```swift +public enum Alignment: String { + case horizontal + case vertical +} +``` +you can then initialize `Alignment` from a `String` and also retrieve back its `rawValue`: +```java +try (var arena = SwiftArena.ofConfined()) { + Optional alignment = Alignment.init("horizontal", arena); + assertEqual(HORIZONTAL, alignment.orElseThrow().getDiscriminator()); + assertEqual("horizontal", alignment.orElseThrow().getRawValue()); +} +``` + + + diff --git a/Tests/JExtractSwiftTests/JNI/JNIEnumTests.swift b/Tests/JExtractSwiftTests/JNI/JNIEnumTests.swift new file mode 100644 index 00000000..47765043 --- /dev/null +++ b/Tests/JExtractSwiftTests/JNI/JNIEnumTests.swift @@ -0,0 +1,313 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2025 Apple Inc. and the Swift.org project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of Swift.org project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import JExtractSwiftLib +import Testing + +@Suite +struct JNIEnumTests { + let source = """ + public enum MyEnum { + case first + case second(String) + case third(x: Int64, y: Int32) + } + """ + + @Test + func generatesJavaClass() throws { + try assertOutput( + input: source, + .jni, .java, + expectedChunks: [ + """ + // Generated by jextract-swift + // Swift module: SwiftModule + + package com.example.swift; + + import org.swift.swiftkit.core.*; + import org.swift.swiftkit.core.util.*; + """, + """ + public final class MyEnum extends JNISwiftInstance { + static final String LIB_NAME = "SwiftModule"; + + @SuppressWarnings("unused") + private static final boolean INITIALIZED_LIBS = initializeLibs(); + static boolean initializeLibs() { + System.loadLibrary(LIB_NAME); + return true; + } + """, + """ + private MyEnum(long selfPointer, SwiftArena swiftArena) { + super(selfPointer, swiftArena); + } + """, + """ + public static MyEnum wrapMemoryAddressUnsafe(long selfPointer, SwiftArena swiftArena) { + return new MyEnum(selfPointer, swiftArena); + } + """, + """ + private static native void $destroy(long selfPointer); + """, + """ + @Override + protected Runnable $createDestroyFunction() { + long self$ = this.$memoryAddress(); + if (CallTraces.TRACE_DOWNCALLS) { + CallTraces.traceDowncall("MyEnum.$createDestroyFunction", + "this", this, + "self", self$); + } + return new Runnable() { + @Override + public void run() { + if (CallTraces.TRACE_DOWNCALLS) { + CallTraces.traceDowncall("MyEnum.$destroy", "self", self$); + } + MyEnum.$destroy(self$); + } + }; + """ + ]) + } + + @Test + func generatesDiscriminator_java() throws { + try assertOutput( + input: source, + .jni, .java, + detectChunkByInitialLines: 1, + expectedChunks: [ + """ + public enum Discriminator { + FIRST, + SECOND, + THIRD + } + """, + """ + public Discriminator getDiscriminator() { + return Discriminator.values()[$getDiscriminator(this.$memoryAddress())]; + } + """, + """ + private static native int $getDiscriminator(long self); + """ + ]) + } + + @Test + func generatesDiscriminator_swift() throws { + try assertOutput( + input: source, + .jni, .swift, + detectChunkByInitialLines: 1, + expectedChunks: [ + """ + @_cdecl("Java_com_example_swift_MyEnum__00024getDiscriminator__J") + func Java_com_example_swift_MyEnum__00024getDiscriminator__J(environment: UnsafeMutablePointer!, thisClass: jclass, selfPointer: jlong) -> jint { + ... + switch (self$.pointee) { + case .first: return 0 + case .second: return 1 + case .third: return 2 + } + } + """ + ]) + } + + @Test + func generatesCases_java() throws { + try assertOutput( + input: source, + .jni, .java, + detectChunkByInitialLines: 1, + expectedChunks: [ + """ + public sealed interface Case {} + """, + """ + public Case getCase() { + Discriminator discriminator = this.getDiscriminator(); + switch (discriminator) { + case FIRST: return this.getAsFirst().orElseThrow(); + case SECOND: return this.getAsSecond().orElseThrow(); + case THIRD: return this.getAsThird().orElseThrow(); + } + throw new RuntimeException("Unknown discriminator value " + discriminator); + } + """, + """ + public record First() implements Case { + record $NativeParameters() {} + } + """, + """ + public record Second(java.lang.String arg0) implements Case { + record $NativeParameters(java.lang.String arg0) {} + } + """, + """ + public record Third(long x, int y) implements Case { + record $NativeParameters(long x, int y) {} + } + """ + ]) + } + + @Test + func generatesCaseInitializers_java() throws { + try assertOutput( + input: source, + .jni, .java, + detectChunkByInitialLines: 1, + expectedChunks: [ + """ + public static MyEnum first(SwiftArena swiftArena$) { + return MyEnum.wrapMemoryAddressUnsafe(MyEnum.$first(), swiftArena$); + } + """, + """ + public static MyEnum second(java.lang.String arg0, SwiftArena swiftArena$) { + return MyEnum.wrapMemoryAddressUnsafe(MyEnum.$second(arg0), swiftArena$); + } + """, + """ + public static MyEnum third(long x, int y, SwiftArena swiftArena$) { + return MyEnum.wrapMemoryAddressUnsafe(MyEnum.$third(x, y), swiftArena$); + } + """ + ]) + } + + @Test + func generatesCaseInitializers_swift() throws { + try assertOutput( + input: source, + .jni, .swift, + detectChunkByInitialLines: 1, + expectedChunks: [ + """ + @_cdecl("Java_com_example_swift_MyEnum__00024first__") + func Java_com_example_swift_MyEnum__00024first__(environment: UnsafeMutablePointer!, thisClass: jclass) -> jlong { + let result$ = UnsafeMutablePointer.allocate(capacity: 1) + result$.initialize(to: MyEnum.first) + let resultBits$ = Int64(Int(bitPattern: result$)) + return resultBits$.getJNIValue(in: environment!) + } + """, + """ + @_cdecl("Java_com_example_swift_MyEnum__00024second__Ljava_lang_String_2") + func Java_com_example_swift_MyEnum__00024second__Ljava_lang_String_2(environment: UnsafeMutablePointer!, thisClass: jclass, arg0: jstring?) -> jlong { + let result$ = UnsafeMutablePointer.allocate(capacity: 1) + result$.initialize(to: MyEnum.second(String(fromJNI: arg0, in: environment!))) + let resultBits$ = Int64(Int(bitPattern: result$)) + return resultBits$.getJNIValue(in: environment!) + } + """, + """ + @_cdecl("Java_com_example_swift_MyEnum__00024third__JI") + func Java_com_example_swift_MyEnum__00024third__JI(environment: UnsafeMutablePointer!, thisClass: jclass, x: jlong, y: jint) -> jlong { + let result$ = UnsafeMutablePointer.allocate(capacity: 1) + result$.initialize(to: MyEnum.third(x: Int64(fromJNI: x, in: environment!), y: Int32(fromJNI: y, in: environment!))) + let resultBits$ = Int64(Int(bitPattern: result$)) + return resultBits$.getJNIValue(in: environment!) + } + """ + ]) + } + + @Test + func generatesGetAsCase_java() throws { + try assertOutput( + input: source, + .jni, .java, + detectChunkByInitialLines: 1, + expectedChunks: [ + """ + public Optional getAsFirst() { + if (getDiscriminator() != Discriminator.FIRST) { + return Optional.empty(); + } + return Optional.of(new First()); + } + """, + """ + public Optional getAsSecond() { + if (getDiscriminator() != Discriminator.SECOND) { + return Optional.empty(); + } + Second.$NativeParameters $nativeParameters = MyEnum.$getAsSecond(this.$memoryAddress()); + return Optional.of(new Second($nativeParameters.arg0)); + } + """, + """ + public Optional getAsThird() { + if (getDiscriminator() != Discriminator.THIRD) { + return Optional.empty(); + } + Third.$NativeParameters $nativeParameters = MyEnum.$getAsThird(this.$memoryAddress()); + return Optional.of(new Third($nativeParameters.x, $nativeParameters.y)); + } + """ + ]) + } + + @Test + func generatesGetAsCase_swift() throws { + try assertOutput( + input: source, + .jni, .swift, + detectChunkByInitialLines: 1, + expectedChunks: [ + """ + @_cdecl("Java_com_example_swift_MyEnum__00024getAsSecond__J") + func Java_com_example_swift_MyEnum__00024getAsSecond__J(environment: UnsafeMutablePointer!, thisClass: jclass, self: jlong) -> jobject? { + ... + guard case .second(let _0) = self$.pointee else { + fatalError("Expected enum case 'second', but was '\\(self$.pointee)'!") + } + let cache$ = _JNI_MyEnum.myEnumSecondCache + let class$ = cache$.javaClass + let method$ = _JNIMethodIDCache.Method(name: "", signature: "(Ljava/lang/String;)V") + let constructorID$ = cache$[method$] + return withVaList([_0.getJNIValue(in: environment!) ?? 0]) { + return environment.interface.NewObjectV(environment, class$, constructorID$, $0) + } + } + """, + """ + @_cdecl("Java_com_example_swift_MyEnum__00024getAsThird__J") + func Java_com_example_swift_MyEnum__00024getAsThird__J(environment: UnsafeMutablePointer!, thisClass: jclass, self: jlong) -> jobject? { + ... + guard case .third(let x, let y) = self$.pointee else { + fatalError("Expected enum case 'third', but was '\\(self$.pointee)'!") + } + let cache$ = _JNI_MyEnum.myEnumThirdCache + let class$ = cache$.javaClass + let method$ = _JNIMethodIDCache.Method(name: "", signature: "(JI)V") + let constructorID$ = cache$[method$] + return withVaList([x.getJNIValue(in: environment!), y.getJNIValue(in: environment!)]) { + return environment.interface.NewObjectV(environment, class$, constructorID$, $0) + } + } + """ + ]) + } +} diff --git a/Tests/JExtractSwiftTests/JNI/JNIOptionalTests.swift b/Tests/JExtractSwiftTests/JNI/JNIOptionalTests.swift index 2186e227..548a2eac 100644 --- a/Tests/JExtractSwiftTests/JNI/JNIOptionalTests.swift +++ b/Tests/JExtractSwiftTests/JNI/JNIOptionalTests.swift @@ -47,10 +47,10 @@ struct JNIOptionalTests { * } */ public static OptionalInt optionalSugar(OptionalLong arg) { - long combined$ = SwiftModule.$optionalSugar((byte) (arg.isPresent() ? 1 : 0), arg.orElse(0L)); - byte discriminator$ = (byte) (combined$ & 0xFF); - int value$ = (int) (combined$ >> 32); - return discriminator$ == 1 ? OptionalInt.of(value$) : OptionalInt.empty(); + long result_combined$ = SwiftModule.$optionalSugar((byte) (arg.isPresent() ? 1 : 0), arg.orElse(0L)); + byte result_discriminator$ = (byte) (result_combined$ & 0xFF); + int result_value$ = (int) (result_combined$ >> 32); + return result_discriminator$ == 1 ? OptionalInt.of(result_value$) : OptionalInt.empty(); } """, """ @@ -72,10 +72,10 @@ struct JNIOptionalTests { """ @_cdecl("Java_com_example_swift_SwiftModule__00024optionalSugar__BJ") func Java_com_example_swift_SwiftModule__00024optionalSugar__BJ(environment: UnsafeMutablePointer!, thisClass: jclass, arg_discriminator: jbyte, arg_value: jlong) -> jlong { - let value$ = SwiftModule.optionalSugar(arg_discriminator == 1 ? Int64(fromJNI: arg_value, in: environment!) : nil).map { + let result_value$ = SwiftModule.optionalSugar(arg_discriminator == 1 ? Int64(fromJNI: arg_value, in: environment!) : nil).map { Int64($0) << 32 | Int64(1) } ?? 0 - return value$.getJNIValue(in: environment!) + return result_value$.getJNIValue(in: environment!) } """ ] @@ -98,9 +98,9 @@ struct JNIOptionalTests { * } */ public static Optional optionalExplicit(Optional arg) { - byte[] result_discriminator$ = new byte[1]; - java.lang.String result$ = SwiftModule.$optionalExplicit((byte) (arg.isPresent() ? 1 : 0), arg.orElse(null), result_discriminator$); - return (result_discriminator$[0] == 1) ? Optional.of(result$) : Optional.empty(); + byte[] result$_discriminator$ = new byte[1]; + java.lang.String result$ = SwiftModule.$optionalExplicit((byte) (arg.isPresent() ? 1 : 0), arg.orElse(null), result$_discriminator$); + return (result$_discriminator$[0] == 1) ? Optional.of(result$) : Optional.empty(); } """, """ @@ -127,12 +127,12 @@ struct JNIOptionalTests { result$ = innerResult$.getJNIValue(in: environment!) var flag$ = Int8(1) environment.interface.SetByteArrayRegion(environment, result_discriminator$, 0, 1, &flag$) - } // render(_:_:) @ JExtractSwiftLib/JNISwift2JavaGenerator+NativeTranslation.swift:624 + } // render(_:_:) @ JExtractSwiftLib/JNISwift2JavaGenerator+NativeTranslation.swift:649 else { result$ = String.jniPlaceholderValue var flag$ = Int8(0) environment.interface.SetByteArrayRegion(environment, result_discriminator$, 0, 1, &flag$) - } // render(_:_:) @ JExtractSwiftLib/JNISwift2JavaGenerator+NativeTranslation.swift:634 + } // render(_:_:) @ JExtractSwiftLib/JNISwift2JavaGenerator+NativeTranslation.swift:659 return result$ } """ @@ -156,9 +156,9 @@ struct JNIOptionalTests { * } */ public static Optional optionalClass(Optional arg, SwiftArena swiftArena$) { - byte[] result_discriminator$ = new byte[1]; - long result$ = SwiftModule.$optionalClass(arg.map(MyClass::$memoryAddress).orElse(0L), result_discriminator$); - return (result_discriminator$[0] == 1) ? Optional.of(MyClass.wrapMemoryAddressUnsafe(result$, swiftArena$)) : Optional.empty(); + byte[] result$_discriminator$ = new byte[1]; + long result$ = SwiftModule.$optionalClass(arg.map(MyClass::$memoryAddress).orElse(0L), result$_discriminator$); + return (result$_discriminator$[0] == 1) ? Optional.of(MyClass.wrapMemoryAddressUnsafe(result$, swiftArena$)) : Optional.empty(); } """, """ @@ -190,12 +190,12 @@ struct JNIOptionalTests { result$ = _resultBits$.getJNIValue(in: environment!) var flag$ = Int8(1) environment.interface.SetByteArrayRegion(environment, result_discriminator$, 0, 1, &flag$) - } // render(_:_:) @ JExtractSwiftLib/JNISwift2JavaGenerator+NativeTranslation.swift:624 + } // render(_:_:) @ JExtractSwiftLib/JNISwift2JavaGenerator+NativeTranslation.swift:649 else { result$ = 0 var flag$ = Int8(0) environment.interface.SetByteArrayRegion(environment, result_discriminator$, 0, 1, &flag$) - } // render(_:_:) @ JExtractSwiftLib/JNISwift2JavaGenerator+NativeTranslation.swift:634 + } // render(_:_:) @ JExtractSwiftLib/JNISwift2JavaGenerator+NativeTranslation.swift:659 return result$ } """ @@ -243,7 +243,7 @@ struct JNIOptionalTests { func Java_com_example_swift_SwiftModule__00024optionalJavaKitClass__Ljava_lang_Long_2(environment: UnsafeMutablePointer!, thisClass: jclass, arg: jobject?) { SwiftModule.optionalJavaKitClass(arg.map { return JavaLong(javaThis: $0, environment: environment!) - } // render(_:_:) @ JExtractSwiftLib/JNISwift2JavaGenerator+NativeTranslation.swift:666 + } // render(_:_:) @ JExtractSwiftLib/JNISwift2JavaGenerator+NativeTranslation.swift:691 ) } """