diff --git a/src/internal/task/waiter.go b/src/internal/task/waiter.go new file mode 100644 index 0000000000..ac9d0d451d --- /dev/null +++ b/src/internal/task/waiter.go @@ -0,0 +1,99 @@ +package task + +import ( + "runtime/interrupt" + "unsafe" +) + +// A waiter waits for an interrupt to fire. Call Wait() from a goroutine to +// pause it, and call Resume() from an interrupt handler to resume the +// goroutine. +type Waiter struct { + waiting *Task + lock PMutex +} + +var resumeTaskSentinel = (*Task)(unsafe.Pointer(uintptr(2))) +var workingTaskSentinel = (*Task)(unsafe.Pointer(uintptr(1))) + +// Wait from a main goroutine until an interrupt calls Resume(). It will also +// return if the interrupt called Resume() before the Wait() call (to avoid a +// race condition). +func (w *Waiter) Wait() { + if interrupt.In() { + runtimePanic("Waiter: called Wait from interrupt") + } + + mask := interrupt.Disable() + w.lock.Lock() + switch w.waiting { + case nil, workingTaskSentinel: + w.waiting = Current() + w.lock.Unlock() + interrupt.Restore(mask) + Pause() + case resumeTaskSentinel: + // Marked as 'resume now', can return immediately. + w.waiting = workingTaskSentinel + w.lock.Unlock() + interrupt.Restore(mask) + default: + w.lock.Unlock() + interrupt.Restore(mask) + runtimePanic("Waiter: task is waiting already") + } +} + +// Resume a waiting goroutine from an interrupt handler. If there is a goroutine +// waiting, it will resume. If not, the next one that calls Wait() will +// immediately resume. +func (w *Waiter) Resume() { + if !interrupt.In() { + runtimePanic("Waiter: called Resume outside interrupt") + } + + waiting := w.waiting + switch waiting { + case nil, workingTaskSentinel: + w.waiting = resumeTaskSentinel + case resumeTaskSentinel: + // Called Resume() twice in a row, this may indicate a bug at the caller + // site. + default: + // Schedule the given task to resume. + // If it's not yet paused, it will immediately resume on the next call + // to Pause(). + w.waiting = workingTaskSentinel + scheduleTask(waiting) + } +} + +// Return true if Done has not been called after a Resume call. +// This is typically used to indicate the task outside the interrupt is still +// being worked on. +func (w *Waiter) Working() bool { + switch w.waiting { + case workingTaskSentinel, resumeTaskSentinel: + return true + default: // nil, <*task.Task> + return false + } +} + +// Done can be called outside interrupt context to indicate the task this waiter +// was waiting for is done, and Working() can return false. It is entirely +// optional, if not called Working will continue to return true until it is +// blocked in Wait() but the waiter will function normally otherwise. +func (w *Waiter) Done() { + if interrupt.In() { + runtimePanic("Waiter: called Done from interrupt") + } + + mask := interrupt.Disable() + w.lock.Lock() + if w.waiting == workingTaskSentinel { + w.waiting = nil + } + w.lock.Unlock() + interrupt.Restore(mask) +} diff --git a/src/machine/usb/msc/msc.go b/src/machine/usb/msc/msc.go index d3bf8d6e29..d214ad6aca 100644 --- a/src/machine/usb/msc/msc.go +++ b/src/machine/usb/msc/msc.go @@ -1,12 +1,12 @@ package msc import ( + "internal/task" "machine" "machine/usb" "machine/usb/descriptor" "machine/usb/msc/csw" "machine/usb/msc/scsi" - "time" ) type mscState uint8 @@ -26,14 +26,14 @@ const ( var MSC *msc type msc struct { - buf []byte // Buffer for incoming/outgoing data - blockCache []byte // Buffer for block read/write data - taskQueued bool // Flag to indicate if the buffer has a task queued - rxStalled bool // Flag to indicate if the RX endpoint is stalled - txStalled bool // Flag to indicate if the TX endpoint is stalled - maxPacketSize uint32 // Maximum packet size for the IN endpoint - respStatus csw.Status // Response status for the last command - sendZLP bool // Flag to indicate if a zero-length packet should be sent before sending CSW + buf []byte // Buffer for incoming/outgoing data + blockCache []byte // Buffer for block read/write data + taskWaiter task.Waiter // Waiter for events outside interrupt context + rxStalled bool // Flag to indicate if the RX endpoint is stalled + txStalled bool // Flag to indicate if the TX endpoint is stalled + maxPacketSize uint32 // Maximum packet size for the IN endpoint + respStatus csw.Status // Response status for the last command + sendZLP bool // Flag to indicate if a zero-length packet should be sent before sending CSW cbw *CBW // Last received Command Block Wrapper queuedBytes uint32 // Number of bytes queued for sending @@ -120,21 +120,21 @@ func newMSC(dev machine.BlockDevice) *msc { func (m *msc) processTasks() { // Process tasks that cannot be done in an interrupt context for { - if m.taskQueued { - cmd := m.cbw.SCSICmd() - switch cmd.CmdType() { - case scsi.CmdWrite: - m.scsiWrite(cmd, m.buf) - case scsi.CmdUnmap: - m.scsiUnmap(m.buf) - } - - // Acknowledge the received data from the host - m.queuedBytes = 0 - m.taskQueued = false - machine.AckUsbOutTransfer(usb.MSC_ENDPOINT_OUT) + // Wait for the next task to arrive. + m.taskWaiter.Wait() + + cmd := m.cbw.SCSICmd() + switch cmd.CmdType() { + case scsi.CmdWrite: + m.scsiWrite(cmd, m.buf) + case scsi.CmdUnmap: + m.scsiUnmap(m.buf) } - time.Sleep(100 * time.Microsecond) + + // Acknowledge the received data from the host + m.queuedBytes = 0 + m.taskWaiter.Done() + machine.AckUsbOutTransfer(usb.MSC_ENDPOINT_OUT) } } diff --git a/src/machine/usb/msc/scsi.go b/src/machine/usb/msc/scsi.go index d7266ed40f..67d6592949 100644 --- a/src/machine/usb/msc/scsi.go +++ b/src/machine/usb/msc/scsi.go @@ -250,7 +250,7 @@ func (m *msc) scsiQueueTask(cmdType scsi.CmdType, b []byte) bool { } // Save the incoming data in our buffer for processing outside of interrupt context. - if m.taskQueued { + if m.taskWaiter.Working() { // If we already have a full task queue we can't accept this data m.sendScsiError(csw.StatusFailed, scsi.SenseAbortedCommand, scsi.SenseCodeMsgReject) return true @@ -267,14 +267,14 @@ func (m *msc) scsiQueueTask(cmdType scsi.CmdType, b []byte) bool { case scsi.CmdWrite: // If we're writing data wait until we have a full write block of data that can be processed. if m.queuedBytes == uint32(cap(m.blockCache)) { - m.taskQueued = true + m.taskWaiter.Resume() } case scsi.CmdUnmap: - m.taskQueued = true + m.taskWaiter.Resume() } // Don't acknowledge the incoming data until we can process it. - return !m.taskQueued + return !m.taskWaiter.Working() } func (m *msc) sendScsiError(status csw.Status, key scsi.Sense, code scsi.SenseCode) {