diff --git a/Sources/Subprocess/AsyncBufferSequence.swift b/Sources/Subprocess/AsyncBufferSequence.swift index b184c3c7..f1e73175 100644 --- a/Sources/Subprocess/AsyncBufferSequence.swift +++ b/Sources/Subprocess/AsyncBufferSequence.swift @@ -143,6 +143,7 @@ extension AsyncBufferSequence { private var source: AsyncBufferSequence.AsyncIterator private var buffer: [Encoding.CodeUnit] private var underlyingBuffer: [Encoding.CodeUnit] + private var underlyingBufferIndex: Array.Index private var leftover: Encoding.CodeUnit? private var eofReached: Bool private let bufferingPolicy: BufferingPolicy @@ -154,6 +155,7 @@ extension AsyncBufferSequence { self.source = underlyingIterator self.buffer = [] self.underlyingBuffer = [] + self.underlyingBufferIndex = self.underlyingBuffer.startIndex self.leftover = nil self.eofReached = false self.bufferingPolicy = bufferingPolicy @@ -208,13 +210,16 @@ extension AsyncBufferSequence { } func nextFromSource() async throws -> Encoding.CodeUnit? { - if underlyingBuffer.isEmpty { + if underlyingBufferIndex >= underlyingBuffer.count { guard let buf = try await loadBuffer() else { return nil } underlyingBuffer = buf + underlyingBufferIndex = buf.startIndex } - return underlyingBuffer.removeFirst() + let result = underlyingBuffer[underlyingBufferIndex] + underlyingBufferIndex = underlyingBufferIndex.advanced(by: 1) + return result } func nextCodeUnit() async throws -> Encoding.CodeUnit? { diff --git a/Sources/Subprocess/Configuration.swift b/Sources/Subprocess/Configuration.swift index 246fbc0d..34d7f9f3 100644 --- a/Sources/Subprocess/Configuration.swift +++ b/Sources/Subprocess/Configuration.swift @@ -187,33 +187,73 @@ extension Configuration { ) throws { var possibleError: (any Swift.Error)? = nil + // To avoid closing the same descriptor multiple times, + // keep track of the list of descriptors that we have + // already closed. If a `IODescriptor.Descriptor` is + // already closed, mark that `IODescriptor` as closed + // as opposed to actually try to close it. + var remainingSet: Set = Set( + optionalSequence: [ + inputRead?.descriptor, + inputWrite?.descriptor, + outputRead?.descriptor, + outputWrite?.descriptor, + errorRead?.descriptor, + errorWrite?.descriptor, + ] + ) + do { - try inputRead?.safelyClose() + if remainingSet.tryRemove(inputRead?.descriptor) { + try inputRead?.safelyClose() + } else { + try inputRead?.markAsClosed() + } } catch { possibleError = error } do { - try inputWrite?.safelyClose() + if remainingSet.tryRemove(inputWrite?.descriptor) { + try inputWrite?.safelyClose() + } else { + try inputWrite?.markAsClosed() + } } catch { possibleError = error } do { - try outputRead?.safelyClose() + if remainingSet.tryRemove(outputRead?.descriptor) { + try outputRead?.safelyClose() + } else { + try outputRead?.markAsClosed() + } } catch { possibleError = error } do { - try outputWrite?.safelyClose() + if remainingSet.tryRemove(outputWrite?.descriptor) { + try outputWrite?.safelyClose() + } else { + try outputWrite?.markAsClosed() + } } catch { possibleError = error } do { - try errorRead?.safelyClose() + if remainingSet.tryRemove(errorRead?.descriptor) { + try errorRead?.safelyClose() + } else { + try errorRead?.markAsClosed() + } } catch { possibleError = error } do { - try errorWrite?.safelyClose() + if remainingSet.tryRemove(errorWrite?.descriptor) { + try errorWrite?.safelyClose() + } else { + try errorWrite?.markAsClosed() + } } catch { possibleError = error } @@ -733,6 +773,10 @@ internal struct IODescriptor: ~Copyable { #endif } + internal mutating func markAsClosed() throws { + self.closeWhenDone = false + } + deinit { guard self.closeWhenDone else { return @@ -1081,3 +1125,17 @@ extension _OrderedSet: Sequence { return self.elements.makeIterator() } } + +extension Set { + init(optionalSequence sequence: S) where S: Sequence, S.Element == Optional { + let sequence: [Self.Element] = sequence.compactMap(\.self) + self.init(sequence) + } + + mutating func tryRemove(_ element: Self.Element?) -> Bool { + guard let element else { + return false + } + return self.remove(element) != nil + } +} diff --git a/Sources/_SubprocessCShims/process_shims.c b/Sources/_SubprocessCShims/process_shims.c index d77d3723..9fc00e53 100644 --- a/Sources/_SubprocessCShims/process_shims.c +++ b/Sources/_SubprocessCShims/process_shims.c @@ -431,14 +431,13 @@ static int _positive_int_parse(const char *str) { } #if defined(__linux__) -// Linux-specific version that uses syscalls directly and doesn't allocate heap memory. -// Safe to use after vfork() and before execve() -static int _highest_possibly_open_fd_dir_linux(const char *fd_dir) { - int highest_fd_so_far = 0; +/// Set `FD_CLOEXEC` on all open file descriptors listed under `fd_dir` so +/// they are automatically closed upon `execve()`. +/// Safe to use after `vfork()` and before `execve()` +static void _set_cloexec_to_open_fds(const char *fd_dir) { int dir_fd = open(fd_dir, O_RDONLY); if (dir_fd < 0) { - // errno set by `open`. - return -1; + return; } // Buffer for directory entries - allocated on stack, no heap allocation @@ -450,49 +449,37 @@ static int _highest_possibly_open_fd_dir_linux(const char *fd_dir) { if (errno == EINTR) { continue; } else { - // `errno` set by _getdents64. - highest_fd_so_far = -1; close(dir_fd); - return highest_fd_so_far; + return; } } if (bytes_read == 0) { close(dir_fd); - return highest_fd_so_far; + return; } long offset = 0; while (offset < bytes_read) { struct linux_dirent64 *entry = (struct linux_dirent64 *)(buffer + offset); - // Skip "." and ".." entries if (entry->d_name[0] != '.') { - int number = _positive_int_parse(entry->d_name); - if (number > highest_fd_so_far) { - highest_fd_so_far = number; + int fd = _positive_int_parse(entry->d_name); + if (fd > STDERR_FILENO && fd != dir_fd) { + int flags = fcntl(fd, F_GETFD); + if (flags >= 0) { + // Set FD_CLOEXEC on every open fd so they are closed after exec() + fcntl(fd, F_SETFD, flags | FD_CLOEXEC); + } } } - offset += entry->d_reclen; } } - - close(dir_fd); - return highest_fd_so_far; } #endif -// This function is only used on systems with Linux kernel 5.9 or lower. -// On newer systems, `close_range` is used instead. +// This function is only used on non-Linux systems. static int _highest_possibly_open_fd(void) { -#if defined(__linux__) - int hi = _highest_possibly_open_fd_dir_linux("/dev/fd"); - if (hi < 0) { - hi = sysconf(_SC_OPEN_MAX); - } -#else - int hi = sysconf(_SC_OPEN_MAX); -#endif - return hi; + return sysconf(_SC_OPEN_MAX); } int _subprocess_fork_exec( @@ -681,8 +668,8 @@ int _subprocess_fork_exec( errno = ENOSYS; #if (__has_include() && (!defined(__ANDROID__) || __ANDROID_API__ >= 34)) || defined(__FreeBSD__) // We must NOT close pipefd[1] for writing errors - rc = close_range(STDERR_FILENO + 1, pipefd[1] - 1, 0); - rc |= close_range(pipefd[1] + 1, ~0U, 0); + rc = close_range(STDERR_FILENO + 1, pipefd[1] - 1, CLOSE_RANGE_CLOEXEC); + rc |= close_range(pipefd[1] + 1, ~0U, CLOSE_RANGE_CLOEXEC); #elif defined(__OpenBSD__) // OpenBSD Supports closefrom, but not close_range // See https://man.openbsd.org/closefrom @@ -692,13 +679,22 @@ int _subprocess_fork_exec( rc = closefrom(pipefd[1] + 1); #endif if (rc != 0) { - // close_range failed (or doesn't exist), fall back to close() - for (int fd = STDERR_FILENO + 1; fd <= _highest_possibly_open_fd(); fd++) { + #if defined(__linux__) + _set_cloexec_to_open_fds("/dev/fd"); + #else + // close_range failed (or doesn't exist), fall back to setting FD_CLOEXEC + int highest_open_fd = _highest_possibly_open_fd(); + for (int fd = STDERR_FILENO + 1; fd <= highest_open_fd; fd++) { // We must NOT close pipefd[1] for writing errors if (fd != pipefd[1]) { - close(fd); + int flags = fcntl(fd, F_GETFD); + if (flags >= 0) { + // Set FD_CLOEXEC on every open fd so they are closed after exec() + fcntl(fd, F_SETFD, flags | FD_CLOEXEC); + } } } + #endif } // Finally, exec diff --git a/Tests/SubprocessTests/IntegrationTests.swift b/Tests/SubprocessTests/IntegrationTests.swift index d1f5479a..aa2f3ac3 100644 --- a/Tests/SubprocessTests/IntegrationTests.swift +++ b/Tests/SubprocessTests/IntegrationTests.swift @@ -1337,22 +1337,19 @@ extension SubprocessIntegrationTests { options: .create, permissions: [.ownerReadWrite, .groupReadWrite] ) - let echoResult = try await outputFile.closeAfter { - let echoResult = try await _run( - setup, - input: .none, - output: .fileDescriptor( - outputFile, - closeAfterSpawningProcess: false - ), - error: .fileDescriptor( - outputFile, - closeAfterSpawningProcess: false - ) + let echoResult = try await _run( + setup, + input: .none, + output: .fileDescriptor( + outputFile, + closeAfterSpawningProcess: true + ), + error: .fileDescriptor( + outputFile, + closeAfterSpawningProcess: true ) - #expect(echoResult.terminationStatus.isSuccess) - return echoResult - } + ) + #expect(echoResult.terminationStatus.isSuccess) let outputData: Data = try Data( contentsOf: URL(filePath: outputFilePath.string) )