Skip to content

Commit 4a252d5

Browse files
committed
Prefer Mutex instead of Semaphore(1)
1 parent 62711ab commit 4a252d5

File tree

5 files changed

+36
-36
lines changed

5 files changed

+36
-36
lines changed

io/jvm-native/src/main/scala/fs2/io/net/SocketPlatform.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ package net
2525

2626
import com.comcast.ip4s.{IpAddress, SocketAddress}
2727
import cats.effect.Async
28-
import cats.effect.std.Semaphore
28+
import cats.effect.std.Mutex
2929
import cats.syntax.all._
3030

3131
import java.net.InetSocketAddress
@@ -36,19 +36,19 @@ private[net] trait SocketCompanionPlatform {
3636
private[net] def forAsync[F[_]: Async](
3737
ch: AsynchronousSocketChannel
3838
): F[Socket[F]] =
39-
(Semaphore[F](1), Semaphore[F](1)).mapN { (readSemaphore, writeSemaphore) =>
40-
new AsyncSocket[F](ch, readSemaphore, writeSemaphore)
39+
(Mutex[F], Mutex[F]).mapN { (readMutex, writeMutex) =>
40+
new AsyncSocket[F](ch, readMutex, writeMutex)
4141
}
4242

4343
private[net] abstract class BufferedReads[F[_]](
44-
readSemaphore: Semaphore[F]
44+
readMutex: Mutex[F]
4545
)(implicit F: Async[F])
4646
extends Socket[F] {
4747
private[this] final val defaultReadSize = 8192
4848
private[this] var readBuffer: ByteBuffer = ByteBuffer.allocateDirect(defaultReadSize)
4949

5050
private def withReadBuffer[A](size: Int)(f: ByteBuffer => F[A]): F[A] =
51-
readSemaphore.permit.use { _ =>
51+
readMutex.lock.surround {
5252
F.delay {
5353
if (readBuffer.capacity() < size)
5454
readBuffer = ByteBuffer.allocateDirect(size)
@@ -105,10 +105,10 @@ private[net] trait SocketCompanionPlatform {
105105

106106
private final class AsyncSocket[F[_]](
107107
ch: AsynchronousSocketChannel,
108-
readSemaphore: Semaphore[F],
109-
writeSemaphore: Semaphore[F]
108+
readMutex: Mutex[F],
109+
writeMutex: Mutex[F]
110110
)(implicit F: Async[F])
111-
extends BufferedReads[F](readSemaphore) {
111+
extends BufferedReads[F](readMutex) {
112112

113113
protected def readChunk(buffer: ByteBuffer): F[Int] =
114114
F.async[Int] { cb =>
@@ -140,7 +140,7 @@ private[net] trait SocketCompanionPlatform {
140140
go(buff)
141141
else F.unit
142142
}
143-
writeSemaphore.permit.use { _ =>
143+
writeMutex.lock.surround {
144144
go(bytes.toByteBuffer)
145145
}
146146
}

io/jvm/src/main/scala/fs2/io/net/tls/TLSEngine.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import javax.net.ssl.{SSLEngine, SSLEngineResult}
2828

2929
import cats.Applicative
3030
import cats.effect.kernel.{Async, Sync}
31-
import cats.effect.std.Semaphore
31+
import cats.effect.std.Mutex
3232
import cats.syntax.all._
3333

3434
/** Provides the ability to establish and communicate over a TLS session.
@@ -65,9 +65,9 @@ private[tls] object TLSEngine {
6565
engine.getSession.getPacketBufferSize,
6666
engine.getSession.getApplicationBufferSize
6767
)
68-
readSemaphore <- Semaphore[F](1)
69-
writeSemaphore <- Semaphore[F](1)
70-
handshakeSemaphore <- Semaphore[F](1)
68+
readMutex <- Mutex[F]
69+
writeMutex <- Mutex[F]
70+
handshakeMutex <- Mutex[F]
7171
sslEngineTaskRunner = SSLEngineTaskRunner[F](engine)
7272
} yield new TLSEngine[F] {
7373
private val doLog: (() => String) => F[Unit] =
@@ -85,7 +85,7 @@ private[tls] object TLSEngine {
8585
def stopUnwrap = Sync[F].delay(engine.closeInbound()).attempt.void
8686

8787
def write(data: Chunk[Byte]): F[Unit] =
88-
writeSemaphore.permit.use(_ => write0(data))
88+
writeMutex.lock.surround(write0(data))
8989

9090
private def write0(data: Chunk[Byte]): F[Unit] =
9191
wrapBuffer.input(data) >> wrap
@@ -104,8 +104,8 @@ private[tls] object TLSEngine {
104104
wrapBuffer.inputRemains
105105
.flatMap(x => wrap.whenA(x > 0 && result.bytesConsumed > 0))
106106
case _ =>
107-
handshakeSemaphore.permit
108-
.use(_ => stepHandshake(result, true)) >> wrap
107+
handshakeMutex.lock
108+
.surround(stepHandshake(result, true)) >> wrap
109109
}
110110
}
111111
case SSLEngineResult.Status.BUFFER_UNDERFLOW =>
@@ -124,7 +124,7 @@ private[tls] object TLSEngine {
124124
}
125125

126126
def read(maxBytes: Int): F[Option[Chunk[Byte]]] =
127-
readSemaphore.permit.use(_ => read0(maxBytes))
127+
readMutex.lock.surround(read0(maxBytes))
128128

129129
private def initialHandshakeDone: F[Boolean] =
130130
Sync[F].delay(engine.getSession.getCipherSuite != "SSL_NULL_WITH_NULL_NULL")
@@ -168,8 +168,8 @@ private[tls] object TLSEngine {
168168
case SSLEngineResult.HandshakeStatus.FINISHED =>
169169
unwrap(maxBytes)
170170
case _ =>
171-
handshakeSemaphore.permit
172-
.use(_ => stepHandshake(result, false)) >> unwrap(
171+
handshakeMutex.lock
172+
.surround(stepHandshake(result, false)) >> unwrap(
173173
maxBytes
174174
)
175175
}

io/jvm/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ package net
2525
package tls
2626

2727
import cats.Applicative
28-
import cats.effect.std.Semaphore
28+
import cats.effect.std.Mutex
2929
import cats.effect.kernel._
3030
import cats.syntax.all._
3131

@@ -53,7 +53,7 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type =>
5353
engine: TLSEngine[F]
5454
): F[TLSSocket[F]] =
5555
for {
56-
readSem <- Semaphore(1)
56+
readMutex <- Mutex[F]
5757
} yield new UnsealedTLSSocket[F] {
5858
def write(bytes: Chunk[Byte]): F[Unit] =
5959
engine.write(bytes)
@@ -62,7 +62,7 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type =>
6262
engine.read(maxBytes)
6363

6464
def readN(numBytes: Int): F[Chunk[Byte]] =
65-
readSem.permit.use { _ =>
65+
readMutex.lock.surround {
6666
def go(acc: Chunk[Byte]): F[Chunk[Byte]] = {
6767
val toRead = numBytes - acc.size
6868
if (toRead <= 0) Applicative[F].pure(acc)
@@ -76,7 +76,7 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type =>
7676
}
7777

7878
def read(maxBytes: Int): F[Option[Chunk[Byte]]] =
79-
readSem.permit.use(_ => read0(maxBytes))
79+
readMutex.lock.surround(read0(maxBytes))
8080

8181
def reads: Stream[F, Byte] =
8282
Stream.repeatEval(read(8192)).unNoneTerminate.unchunks

io/jvm/src/main/scala/fs2/io/net/unixsocket/UnixSocketsPlatform.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
package fs2.io.net.unixsocket
2323

2424
import cats.effect.kernel.{Async, Resource}
25-
import cats.effect.std.Semaphore
25+
import cats.effect.std.Mutex
2626
import cats.syntax.all._
2727
import com.comcast.ip4s.{IpAddress, SocketAddress}
2828
import fs2.{Chunk, Stream}
@@ -89,17 +89,17 @@ private[unixsocket] trait UnixSocketsCompanionPlatform {
8989
ch: SocketChannel
9090
): Resource[F, Socket[F]] =
9191
Resource.make {
92-
(Semaphore[F](1), Semaphore[F](1)).mapN { (readSemaphore, writeSemaphore) =>
93-
new AsyncSocket[F](ch, readSemaphore, writeSemaphore)
92+
(Mutex[F], Mutex[F]).mapN { (readMutex, writeMutex) =>
93+
new AsyncSocket[F](ch, readMutex, writeMutex)
9494
}
9595
}(_ => Async[F].delay(if (ch.isOpen) ch.close else ()))
9696

9797
private final class AsyncSocket[F[_]](
9898
ch: SocketChannel,
99-
readSemaphore: Semaphore[F],
100-
writeSemaphore: Semaphore[F]
99+
readMutex: Mutex[F],
100+
writeMutex: Mutex[F]
101101
)(implicit F: Async[F])
102-
extends Socket.BufferedReads[F](readSemaphore) {
102+
extends Socket.BufferedReads[F](readMutex) {
103103

104104
def readChunk(buff: ByteBuffer): F[Int] =
105105
F.blocking(ch.read(buff))
@@ -110,7 +110,7 @@ private[unixsocket] trait UnixSocketsCompanionPlatform {
110110
if (buff.remaining <= 0) F.unit
111111
else go(buff)
112112
}
113-
writeSemaphore.permit.use { _ =>
113+
writeMutex.lock.surround {
114114
go(bytes.toByteBuffer)
115115
}
116116
}

io/native/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ package tls
2626

2727
import cats.effect.kernel.Async
2828
import cats.effect.kernel.Resource
29-
import cats.effect.std.Semaphore
29+
import cats.effect.std.Mutex
3030
import cats.syntax.all._
3131
import com.comcast.ip4s.IpAddress
3232
import com.comcast.ip4s.SocketAddress
@@ -49,17 +49,17 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type =>
4949
connection: S2nConnection[F]
5050
)(implicit F: Async[F]): F[TLSSocket[F]] =
5151
for {
52-
readSem <- Semaphore(1)
53-
writeSem <- Semaphore(1)
52+
readMutex <- Mutex[F]
53+
writeMutex <- Mutex[F]
5454
} yield new UnsealedTLSSocket[F] {
5555
def write(bytes: Chunk[Byte]): F[Unit] =
56-
writeSem.permit.surround(connection.write(bytes))
56+
writeMutex.lock.surround(connection.write(bytes))
5757

5858
private def read0(maxBytes: Int): F[Option[Chunk[Byte]]] =
5959
connection.read(maxBytes)
6060

6161
def readN(numBytes: Int): F[Chunk[Byte]] =
62-
readSem.permit.use { _ =>
62+
readMutex.lock.surround {
6363
def go(acc: Chunk[Byte]): F[Chunk[Byte]] = {
6464
val toRead = numBytes - acc.size
6565
if (toRead <= 0) F.pure(acc)
@@ -73,7 +73,7 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type =>
7373
}
7474

7575
def read(maxBytes: Int): F[Option[Chunk[Byte]]] =
76-
readSem.permit.surround(read0(maxBytes))
76+
readMutex.lock.surround(read0(maxBytes))
7777

7878
def reads: Stream[F, Byte] =
7979
Stream.repeatEval(read(8192)).unNoneTerminate.unchunks

0 commit comments

Comments
 (0)